##// END OF EJS Templates
Merge pull request #8142 from SylvainCorlay/slider_validation...
Min RK -
r20989:e99202ca merge
parent child Browse files
Show More
@@ -1,396 +1,396 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """
2 """
3 An application for IPython.
3 An application for IPython.
4
4
5 All top-level applications should use the classes in this module for
5 All top-level applications should use the classes in this module for
6 handling configuration and creating configurables.
6 handling configuration and creating configurables.
7
7
8 The job of an :class:`Application` is to create the master configuration
8 The job of an :class:`Application` is to create the master configuration
9 object and then create the configurable objects, passing the config to them.
9 object and then create the configurable objects, passing the config to them.
10 """
10 """
11
11
12 # Copyright (c) IPython Development Team.
12 # Copyright (c) IPython Development Team.
13 # Distributed under the terms of the Modified BSD License.
13 # Distributed under the terms of the Modified BSD License.
14
14
15 import atexit
15 import atexit
16 import glob
16 import glob
17 import logging
17 import logging
18 import os
18 import os
19 import shutil
19 import shutil
20 import sys
20 import sys
21
21
22 from IPython.config.application import Application, catch_config_error
22 from IPython.config.application import Application, catch_config_error
23 from IPython.config.loader import ConfigFileNotFound, PyFileConfigLoader
23 from IPython.config.loader import ConfigFileNotFound, PyFileConfigLoader
24 from IPython.core import release, crashhandler
24 from IPython.core import release, crashhandler
25 from IPython.core.profiledir import ProfileDir, ProfileDirError
25 from IPython.core.profiledir import ProfileDir, ProfileDirError
26 from IPython.utils.path import get_ipython_dir, get_ipython_package_dir, ensure_dir_exists
26 from IPython.utils.path import get_ipython_dir, get_ipython_package_dir, ensure_dir_exists
27 from IPython.utils import py3compat
27 from IPython.utils import py3compat
28 from IPython.utils.traitlets import List, Unicode, Type, Bool, Dict, Set, Instance
28 from IPython.utils.traitlets import List, Unicode, Type, Bool, Dict, Set, Instance, Undefined
29
29
30 if os.name == 'nt':
30 if os.name == 'nt':
31 programdata = os.environ.get('PROGRAMDATA', None)
31 programdata = os.environ.get('PROGRAMDATA', None)
32 if programdata:
32 if programdata:
33 SYSTEM_CONFIG_DIRS = [os.path.join(programdata, 'ipython')]
33 SYSTEM_CONFIG_DIRS = [os.path.join(programdata, 'ipython')]
34 else: # PROGRAMDATA is not defined by default on XP.
34 else: # PROGRAMDATA is not defined by default on XP.
35 SYSTEM_CONFIG_DIRS = []
35 SYSTEM_CONFIG_DIRS = []
36 else:
36 else:
37 SYSTEM_CONFIG_DIRS = [
37 SYSTEM_CONFIG_DIRS = [
38 "/usr/local/etc/ipython",
38 "/usr/local/etc/ipython",
39 "/etc/ipython",
39 "/etc/ipython",
40 ]
40 ]
41
41
42
42
43 # aliases and flags
43 # aliases and flags
44
44
45 base_aliases = {
45 base_aliases = {
46 'profile-dir' : 'ProfileDir.location',
46 'profile-dir' : 'ProfileDir.location',
47 'profile' : 'BaseIPythonApplication.profile',
47 'profile' : 'BaseIPythonApplication.profile',
48 'ipython-dir' : 'BaseIPythonApplication.ipython_dir',
48 'ipython-dir' : 'BaseIPythonApplication.ipython_dir',
49 'log-level' : 'Application.log_level',
49 'log-level' : 'Application.log_level',
50 'config' : 'BaseIPythonApplication.extra_config_file',
50 'config' : 'BaseIPythonApplication.extra_config_file',
51 }
51 }
52
52
53 base_flags = dict(
53 base_flags = dict(
54 debug = ({'Application' : {'log_level' : logging.DEBUG}},
54 debug = ({'Application' : {'log_level' : logging.DEBUG}},
55 "set log level to logging.DEBUG (maximize logging output)"),
55 "set log level to logging.DEBUG (maximize logging output)"),
56 quiet = ({'Application' : {'log_level' : logging.CRITICAL}},
56 quiet = ({'Application' : {'log_level' : logging.CRITICAL}},
57 "set log level to logging.CRITICAL (minimize logging output)"),
57 "set log level to logging.CRITICAL (minimize logging output)"),
58 init = ({'BaseIPythonApplication' : {
58 init = ({'BaseIPythonApplication' : {
59 'copy_config_files' : True,
59 'copy_config_files' : True,
60 'auto_create' : True}
60 'auto_create' : True}
61 }, """Initialize profile with default config files. This is equivalent
61 }, """Initialize profile with default config files. This is equivalent
62 to running `ipython profile create <profile>` prior to startup.
62 to running `ipython profile create <profile>` prior to startup.
63 """)
63 """)
64 )
64 )
65
65
66 class ProfileAwareConfigLoader(PyFileConfigLoader):
66 class ProfileAwareConfigLoader(PyFileConfigLoader):
67 """A Python file config loader that is aware of IPython profiles."""
67 """A Python file config loader that is aware of IPython profiles."""
68 def load_subconfig(self, fname, path=None, profile=None):
68 def load_subconfig(self, fname, path=None, profile=None):
69 if profile is not None:
69 if profile is not None:
70 try:
70 try:
71 profile_dir = ProfileDir.find_profile_dir_by_name(
71 profile_dir = ProfileDir.find_profile_dir_by_name(
72 get_ipython_dir(),
72 get_ipython_dir(),
73 profile,
73 profile,
74 )
74 )
75 except ProfileDirError:
75 except ProfileDirError:
76 return
76 return
77 path = profile_dir.location
77 path = profile_dir.location
78 return super(ProfileAwareConfigLoader, self).load_subconfig(fname, path=path)
78 return super(ProfileAwareConfigLoader, self).load_subconfig(fname, path=path)
79
79
80 class BaseIPythonApplication(Application):
80 class BaseIPythonApplication(Application):
81
81
82 name = Unicode(u'ipython')
82 name = Unicode(u'ipython')
83 description = Unicode(u'IPython: an enhanced interactive Python shell.')
83 description = Unicode(u'IPython: an enhanced interactive Python shell.')
84 version = Unicode(release.version)
84 version = Unicode(release.version)
85
85
86 aliases = Dict(base_aliases)
86 aliases = Dict(base_aliases)
87 flags = Dict(base_flags)
87 flags = Dict(base_flags)
88 classes = List([ProfileDir])
88 classes = List([ProfileDir])
89
89
90 # enable `load_subconfig('cfg.py', profile='name')`
90 # enable `load_subconfig('cfg.py', profile='name')`
91 python_config_loader_class = ProfileAwareConfigLoader
91 python_config_loader_class = ProfileAwareConfigLoader
92
92
93 # Track whether the config_file has changed,
93 # Track whether the config_file has changed,
94 # because some logic happens only if we aren't using the default.
94 # because some logic happens only if we aren't using the default.
95 config_file_specified = Set()
95 config_file_specified = Set()
96
96
97 config_file_name = Unicode()
97 config_file_name = Unicode()
98 def _config_file_name_default(self):
98 def _config_file_name_default(self):
99 return self.name.replace('-','_') + u'_config.py'
99 return self.name.replace('-','_') + u'_config.py'
100 def _config_file_name_changed(self, name, old, new):
100 def _config_file_name_changed(self, name, old, new):
101 if new != old:
101 if new != old:
102 self.config_file_specified.add(new)
102 self.config_file_specified.add(new)
103
103
104 # The directory that contains IPython's builtin profiles.
104 # The directory that contains IPython's builtin profiles.
105 builtin_profile_dir = Unicode(
105 builtin_profile_dir = Unicode(
106 os.path.join(get_ipython_package_dir(), u'config', u'profile', u'default')
106 os.path.join(get_ipython_package_dir(), u'config', u'profile', u'default')
107 )
107 )
108
108
109 config_file_paths = List(Unicode)
109 config_file_paths = List(Unicode)
110 def _config_file_paths_default(self):
110 def _config_file_paths_default(self):
111 return [py3compat.getcwd()]
111 return [py3compat.getcwd()]
112
112
113 extra_config_file = Unicode(config=True,
113 extra_config_file = Unicode(config=True,
114 help="""Path to an extra config file to load.
114 help="""Path to an extra config file to load.
115
115
116 If specified, load this config file in addition to any other IPython config.
116 If specified, load this config file in addition to any other IPython config.
117 """)
117 """)
118 def _extra_config_file_changed(self, name, old, new):
118 def _extra_config_file_changed(self, name, old, new):
119 try:
119 try:
120 self.config_files.remove(old)
120 self.config_files.remove(old)
121 except ValueError:
121 except ValueError:
122 pass
122 pass
123 self.config_file_specified.add(new)
123 self.config_file_specified.add(new)
124 self.config_files.append(new)
124 self.config_files.append(new)
125
125
126 profile = Unicode(u'default', config=True,
126 profile = Unicode(u'default', config=True,
127 help="""The IPython profile to use."""
127 help="""The IPython profile to use."""
128 )
128 )
129
129
130 def _profile_changed(self, name, old, new):
130 def _profile_changed(self, name, old, new):
131 self.builtin_profile_dir = os.path.join(
131 self.builtin_profile_dir = os.path.join(
132 get_ipython_package_dir(), u'config', u'profile', new
132 get_ipython_package_dir(), u'config', u'profile', new
133 )
133 )
134
134
135 ipython_dir = Unicode(config=True,
135 ipython_dir = Unicode(config=True,
136 help="""
136 help="""
137 The name of the IPython directory. This directory is used for logging
137 The name of the IPython directory. This directory is used for logging
138 configuration (through profiles), history storage, etc. The default
138 configuration (through profiles), history storage, etc. The default
139 is usually $HOME/.ipython. This option can also be specified through
139 is usually $HOME/.ipython. This option can also be specified through
140 the environment variable IPYTHONDIR.
140 the environment variable IPYTHONDIR.
141 """
141 """
142 )
142 )
143 def _ipython_dir_default(self):
143 def _ipython_dir_default(self):
144 d = get_ipython_dir()
144 d = get_ipython_dir()
145 self._ipython_dir_changed('ipython_dir', d, d)
145 self._ipython_dir_changed('ipython_dir', d, d)
146 return d
146 return d
147
147
148 _in_init_profile_dir = False
148 _in_init_profile_dir = False
149 profile_dir = Instance(ProfileDir, allow_none=True)
149 profile_dir = Instance(ProfileDir, allow_none=True)
150 def _profile_dir_default(self):
150 def _profile_dir_default(self):
151 # avoid recursion
151 # avoid recursion
152 if self._in_init_profile_dir:
152 if self._in_init_profile_dir:
153 return
153 return
154 # profile_dir requested early, force initialization
154 # profile_dir requested early, force initialization
155 self.init_profile_dir()
155 self.init_profile_dir()
156 return self.profile_dir
156 return self.profile_dir
157
157
158 overwrite = Bool(False, config=True,
158 overwrite = Bool(False, config=True,
159 help="""Whether to overwrite existing config files when copying""")
159 help="""Whether to overwrite existing config files when copying""")
160 auto_create = Bool(False, config=True,
160 auto_create = Bool(False, config=True,
161 help="""Whether to create profile dir if it doesn't exist""")
161 help="""Whether to create profile dir if it doesn't exist""")
162
162
163 config_files = List(Unicode)
163 config_files = List(Unicode)
164 def _config_files_default(self):
164 def _config_files_default(self):
165 return [self.config_file_name]
165 return [self.config_file_name]
166
166
167 copy_config_files = Bool(False, config=True,
167 copy_config_files = Bool(False, config=True,
168 help="""Whether to install the default config files into the profile dir.
168 help="""Whether to install the default config files into the profile dir.
169 If a new profile is being created, and IPython contains config files for that
169 If a new profile is being created, and IPython contains config files for that
170 profile, then they will be staged into the new directory. Otherwise,
170 profile, then they will be staged into the new directory. Otherwise,
171 default config files will be automatically generated.
171 default config files will be automatically generated.
172 """)
172 """)
173
173
174 verbose_crash = Bool(False, config=True,
174 verbose_crash = Bool(False, config=True,
175 help="""Create a massive crash report when IPython encounters what may be an
175 help="""Create a massive crash report when IPython encounters what may be an
176 internal error. The default is to append a short message to the
176 internal error. The default is to append a short message to the
177 usual traceback""")
177 usual traceback""")
178
178
179 # The class to use as the crash handler.
179 # The class to use as the crash handler.
180 crash_handler_class = Type(crashhandler.CrashHandler)
180 crash_handler_class = Type(crashhandler.CrashHandler)
181
181
182 @catch_config_error
182 @catch_config_error
183 def __init__(self, **kwargs):
183 def __init__(self, **kwargs):
184 super(BaseIPythonApplication, self).__init__(**kwargs)
184 super(BaseIPythonApplication, self).__init__(**kwargs)
185 # ensure current working directory exists
185 # ensure current working directory exists
186 try:
186 try:
187 directory = py3compat.getcwd()
187 directory = py3compat.getcwd()
188 except:
188 except:
189 # exit if cwd doesn't exist
189 # exit if cwd doesn't exist
190 self.log.error("Current working directory doesn't exist.")
190 self.log.error("Current working directory doesn't exist.")
191 self.exit(1)
191 self.exit(1)
192
192
193 #-------------------------------------------------------------------------
193 #-------------------------------------------------------------------------
194 # Various stages of Application creation
194 # Various stages of Application creation
195 #-------------------------------------------------------------------------
195 #-------------------------------------------------------------------------
196
196
197 def init_crash_handler(self):
197 def init_crash_handler(self):
198 """Create a crash handler, typically setting sys.excepthook to it."""
198 """Create a crash handler, typically setting sys.excepthook to it."""
199 self.crash_handler = self.crash_handler_class(self)
199 self.crash_handler = self.crash_handler_class(self)
200 sys.excepthook = self.excepthook
200 sys.excepthook = self.excepthook
201 def unset_crashhandler():
201 def unset_crashhandler():
202 sys.excepthook = sys.__excepthook__
202 sys.excepthook = sys.__excepthook__
203 atexit.register(unset_crashhandler)
203 atexit.register(unset_crashhandler)
204
204
205 def excepthook(self, etype, evalue, tb):
205 def excepthook(self, etype, evalue, tb):
206 """this is sys.excepthook after init_crashhandler
206 """this is sys.excepthook after init_crashhandler
207
207
208 set self.verbose_crash=True to use our full crashhandler, instead of
208 set self.verbose_crash=True to use our full crashhandler, instead of
209 a regular traceback with a short message (crash_handler_lite)
209 a regular traceback with a short message (crash_handler_lite)
210 """
210 """
211
211
212 if self.verbose_crash:
212 if self.verbose_crash:
213 return self.crash_handler(etype, evalue, tb)
213 return self.crash_handler(etype, evalue, tb)
214 else:
214 else:
215 return crashhandler.crash_handler_lite(etype, evalue, tb)
215 return crashhandler.crash_handler_lite(etype, evalue, tb)
216
216
217 def _ipython_dir_changed(self, name, old, new):
217 def _ipython_dir_changed(self, name, old, new):
218 if old is not None:
218 if old is not None and old is not Undefined:
219 str_old = py3compat.cast_bytes_py2(os.path.abspath(old),
219 str_old = py3compat.cast_bytes_py2(os.path.abspath(old),
220 sys.getfilesystemencoding()
220 sys.getfilesystemencoding()
221 )
221 )
222 if str_old in sys.path:
222 if str_old in sys.path:
223 sys.path.remove(str_old)
223 sys.path.remove(str_old)
224 str_path = py3compat.cast_bytes_py2(os.path.abspath(new),
224 str_path = py3compat.cast_bytes_py2(os.path.abspath(new),
225 sys.getfilesystemencoding()
225 sys.getfilesystemencoding()
226 )
226 )
227 sys.path.append(str_path)
227 sys.path.append(str_path)
228 ensure_dir_exists(new)
228 ensure_dir_exists(new)
229 readme = os.path.join(new, 'README')
229 readme = os.path.join(new, 'README')
230 readme_src = os.path.join(get_ipython_package_dir(), u'config', u'profile', 'README')
230 readme_src = os.path.join(get_ipython_package_dir(), u'config', u'profile', 'README')
231 if not os.path.exists(readme) and os.path.exists(readme_src):
231 if not os.path.exists(readme) and os.path.exists(readme_src):
232 shutil.copy(readme_src, readme)
232 shutil.copy(readme_src, readme)
233 for d in ('extensions', 'nbextensions'):
233 for d in ('extensions', 'nbextensions'):
234 path = os.path.join(new, d)
234 path = os.path.join(new, d)
235 try:
235 try:
236 ensure_dir_exists(path)
236 ensure_dir_exists(path)
237 except OSError:
237 except OSError:
238 # this will not be EEXIST
238 # this will not be EEXIST
239 self.log.error("couldn't create path %s: %s", path, e)
239 self.log.error("couldn't create path %s: %s", path, e)
240 self.log.debug("IPYTHONDIR set to: %s" % new)
240 self.log.debug("IPYTHONDIR set to: %s" % new)
241
241
242 def load_config_file(self, suppress_errors=True):
242 def load_config_file(self, suppress_errors=True):
243 """Load the config file.
243 """Load the config file.
244
244
245 By default, errors in loading config are handled, and a warning
245 By default, errors in loading config are handled, and a warning
246 printed on screen. For testing, the suppress_errors option is set
246 printed on screen. For testing, the suppress_errors option is set
247 to False, so errors will make tests fail.
247 to False, so errors will make tests fail.
248 """
248 """
249 self.log.debug("Searching path %s for config files", self.config_file_paths)
249 self.log.debug("Searching path %s for config files", self.config_file_paths)
250 base_config = 'ipython_config.py'
250 base_config = 'ipython_config.py'
251 self.log.debug("Attempting to load config file: %s" %
251 self.log.debug("Attempting to load config file: %s" %
252 base_config)
252 base_config)
253 try:
253 try:
254 Application.load_config_file(
254 Application.load_config_file(
255 self,
255 self,
256 base_config,
256 base_config,
257 path=self.config_file_paths
257 path=self.config_file_paths
258 )
258 )
259 except ConfigFileNotFound:
259 except ConfigFileNotFound:
260 # ignore errors loading parent
260 # ignore errors loading parent
261 self.log.debug("Config file %s not found", base_config)
261 self.log.debug("Config file %s not found", base_config)
262 pass
262 pass
263
263
264 for config_file_name in self.config_files:
264 for config_file_name in self.config_files:
265 if not config_file_name or config_file_name == base_config:
265 if not config_file_name or config_file_name == base_config:
266 continue
266 continue
267 self.log.debug("Attempting to load config file: %s" %
267 self.log.debug("Attempting to load config file: %s" %
268 self.config_file_name)
268 self.config_file_name)
269 try:
269 try:
270 Application.load_config_file(
270 Application.load_config_file(
271 self,
271 self,
272 config_file_name,
272 config_file_name,
273 path=self.config_file_paths
273 path=self.config_file_paths
274 )
274 )
275 except ConfigFileNotFound:
275 except ConfigFileNotFound:
276 # Only warn if the default config file was NOT being used.
276 # Only warn if the default config file was NOT being used.
277 if config_file_name in self.config_file_specified:
277 if config_file_name in self.config_file_specified:
278 msg = self.log.warn
278 msg = self.log.warn
279 else:
279 else:
280 msg = self.log.debug
280 msg = self.log.debug
281 msg("Config file not found, skipping: %s", config_file_name)
281 msg("Config file not found, skipping: %s", config_file_name)
282 except:
282 except:
283 # For testing purposes.
283 # For testing purposes.
284 if not suppress_errors:
284 if not suppress_errors:
285 raise
285 raise
286 self.log.warn("Error loading config file: %s" %
286 self.log.warn("Error loading config file: %s" %
287 self.config_file_name, exc_info=True)
287 self.config_file_name, exc_info=True)
288
288
289 def init_profile_dir(self):
289 def init_profile_dir(self):
290 """initialize the profile dir"""
290 """initialize the profile dir"""
291 self._in_init_profile_dir = True
291 self._in_init_profile_dir = True
292 if self.profile_dir is not None:
292 if self.profile_dir is not None:
293 # already ran
293 # already ran
294 return
294 return
295 if 'ProfileDir.location' not in self.config:
295 if 'ProfileDir.location' not in self.config:
296 # location not specified, find by profile name
296 # location not specified, find by profile name
297 try:
297 try:
298 p = ProfileDir.find_profile_dir_by_name(self.ipython_dir, self.profile, self.config)
298 p = ProfileDir.find_profile_dir_by_name(self.ipython_dir, self.profile, self.config)
299 except ProfileDirError:
299 except ProfileDirError:
300 # not found, maybe create it (always create default profile)
300 # not found, maybe create it (always create default profile)
301 if self.auto_create or self.profile == 'default':
301 if self.auto_create or self.profile == 'default':
302 try:
302 try:
303 p = ProfileDir.create_profile_dir_by_name(self.ipython_dir, self.profile, self.config)
303 p = ProfileDir.create_profile_dir_by_name(self.ipython_dir, self.profile, self.config)
304 except ProfileDirError:
304 except ProfileDirError:
305 self.log.fatal("Could not create profile: %r"%self.profile)
305 self.log.fatal("Could not create profile: %r"%self.profile)
306 self.exit(1)
306 self.exit(1)
307 else:
307 else:
308 self.log.info("Created profile dir: %r"%p.location)
308 self.log.info("Created profile dir: %r"%p.location)
309 else:
309 else:
310 self.log.fatal("Profile %r not found."%self.profile)
310 self.log.fatal("Profile %r not found."%self.profile)
311 self.exit(1)
311 self.exit(1)
312 else:
312 else:
313 self.log.debug("Using existing profile dir: %r"%p.location)
313 self.log.debug("Using existing profile dir: %r"%p.location)
314 else:
314 else:
315 location = self.config.ProfileDir.location
315 location = self.config.ProfileDir.location
316 # location is fully specified
316 # location is fully specified
317 try:
317 try:
318 p = ProfileDir.find_profile_dir(location, self.config)
318 p = ProfileDir.find_profile_dir(location, self.config)
319 except ProfileDirError:
319 except ProfileDirError:
320 # not found, maybe create it
320 # not found, maybe create it
321 if self.auto_create:
321 if self.auto_create:
322 try:
322 try:
323 p = ProfileDir.create_profile_dir(location, self.config)
323 p = ProfileDir.create_profile_dir(location, self.config)
324 except ProfileDirError:
324 except ProfileDirError:
325 self.log.fatal("Could not create profile directory: %r"%location)
325 self.log.fatal("Could not create profile directory: %r"%location)
326 self.exit(1)
326 self.exit(1)
327 else:
327 else:
328 self.log.debug("Creating new profile dir: %r"%location)
328 self.log.debug("Creating new profile dir: %r"%location)
329 else:
329 else:
330 self.log.fatal("Profile directory %r not found."%location)
330 self.log.fatal("Profile directory %r not found."%location)
331 self.exit(1)
331 self.exit(1)
332 else:
332 else:
333 self.log.info("Using existing profile dir: %r"%location)
333 self.log.info("Using existing profile dir: %r"%location)
334 # if profile_dir is specified explicitly, set profile name
334 # if profile_dir is specified explicitly, set profile name
335 dir_name = os.path.basename(p.location)
335 dir_name = os.path.basename(p.location)
336 if dir_name.startswith('profile_'):
336 if dir_name.startswith('profile_'):
337 self.profile = dir_name[8:]
337 self.profile = dir_name[8:]
338
338
339 self.profile_dir = p
339 self.profile_dir = p
340 self.config_file_paths.append(p.location)
340 self.config_file_paths.append(p.location)
341 self._in_init_profile_dir = False
341 self._in_init_profile_dir = False
342
342
343 def init_config_files(self):
343 def init_config_files(self):
344 """[optionally] copy default config files into profile dir."""
344 """[optionally] copy default config files into profile dir."""
345 self.config_file_paths.extend(SYSTEM_CONFIG_DIRS)
345 self.config_file_paths.extend(SYSTEM_CONFIG_DIRS)
346 # copy config files
346 # copy config files
347 path = self.builtin_profile_dir
347 path = self.builtin_profile_dir
348 if self.copy_config_files:
348 if self.copy_config_files:
349 src = self.profile
349 src = self.profile
350
350
351 cfg = self.config_file_name
351 cfg = self.config_file_name
352 if path and os.path.exists(os.path.join(path, cfg)):
352 if path and os.path.exists(os.path.join(path, cfg)):
353 self.log.warn("Staging %r from %s into %r [overwrite=%s]"%(
353 self.log.warn("Staging %r from %s into %r [overwrite=%s]"%(
354 cfg, src, self.profile_dir.location, self.overwrite)
354 cfg, src, self.profile_dir.location, self.overwrite)
355 )
355 )
356 self.profile_dir.copy_config_file(cfg, path=path, overwrite=self.overwrite)
356 self.profile_dir.copy_config_file(cfg, path=path, overwrite=self.overwrite)
357 else:
357 else:
358 self.stage_default_config_file()
358 self.stage_default_config_file()
359 else:
359 else:
360 # Still stage *bundled* config files, but not generated ones
360 # Still stage *bundled* config files, but not generated ones
361 # This is necessary for `ipython profile=sympy` to load the profile
361 # This is necessary for `ipython profile=sympy` to load the profile
362 # on the first go
362 # on the first go
363 files = glob.glob(os.path.join(path, '*.py'))
363 files = glob.glob(os.path.join(path, '*.py'))
364 for fullpath in files:
364 for fullpath in files:
365 cfg = os.path.basename(fullpath)
365 cfg = os.path.basename(fullpath)
366 if self.profile_dir.copy_config_file(cfg, path=path, overwrite=False):
366 if self.profile_dir.copy_config_file(cfg, path=path, overwrite=False):
367 # file was copied
367 # file was copied
368 self.log.warn("Staging bundled %s from %s into %r"%(
368 self.log.warn("Staging bundled %s from %s into %r"%(
369 cfg, self.profile, self.profile_dir.location)
369 cfg, self.profile, self.profile_dir.location)
370 )
370 )
371
371
372
372
373 def stage_default_config_file(self):
373 def stage_default_config_file(self):
374 """auto generate default config file, and stage it into the profile."""
374 """auto generate default config file, and stage it into the profile."""
375 s = self.generate_config_file()
375 s = self.generate_config_file()
376 fname = os.path.join(self.profile_dir.location, self.config_file_name)
376 fname = os.path.join(self.profile_dir.location, self.config_file_name)
377 if self.overwrite or not os.path.exists(fname):
377 if self.overwrite or not os.path.exists(fname):
378 self.log.warn("Generating default config file: %r"%(fname))
378 self.log.warn("Generating default config file: %r"%(fname))
379 with open(fname, 'w') as f:
379 with open(fname, 'w') as f:
380 f.write(s)
380 f.write(s)
381
381
382 @catch_config_error
382 @catch_config_error
383 def initialize(self, argv=None):
383 def initialize(self, argv=None):
384 # don't hook up crash handler before parsing command-line
384 # don't hook up crash handler before parsing command-line
385 self.parse_command_line(argv)
385 self.parse_command_line(argv)
386 self.init_crash_handler()
386 self.init_crash_handler()
387 if self.subapp is not None:
387 if self.subapp is not None:
388 # stop here if subapp is taking over
388 # stop here if subapp is taking over
389 return
389 return
390 cl_config = self.config
390 cl_config = self.config
391 self.init_profile_dir()
391 self.init_profile_dir()
392 self.init_config_files()
392 self.init_config_files()
393 self.load_config_file()
393 self.load_config_file()
394 # enforce cl-opts override configfile opts:
394 # enforce cl-opts override configfile opts:
395 self.update_config(cl_config)
395 self.update_config(cl_config)
396
396
@@ -1,297 +1,296 b''
1 """Float class.
1 """Float class.
2
2
3 Represents an unbounded float using a widget.
3 Represents an unbounded float using a widget.
4 """
4 """
5 #-----------------------------------------------------------------------------
5 #-----------------------------------------------------------------------------
6 # Copyright (c) 2013, the IPython Development Team.
6 # Copyright (c) 2013, the IPython Development Team.
7 #
7 #
8 # Distributed under the terms of the Modified BSD License.
8 # Distributed under the terms of the Modified BSD License.
9 #
9 #
10 # The full license is in the file COPYING.txt, distributed with this software.
10 # The full license is in the file COPYING.txt, distributed with this software.
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12
12
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14 # Imports
14 # Imports
15 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
16 from .widget import DOMWidget, register
16 from .widget import DOMWidget, register
17 from .trait_types import Color
17 from .trait_types import Color
18 from IPython.utils.traitlets import Unicode, CFloat, Bool, CaselessStrEnum, Tuple
18 from IPython.utils.traitlets import (Unicode, CFloat, Bool, CaselessStrEnum,
19 Tuple, TraitError)
19 from IPython.utils.warn import DeprecatedClass
20 from IPython.utils.warn import DeprecatedClass
20
21
21 #-----------------------------------------------------------------------------
22 #-----------------------------------------------------------------------------
22 # Classes
23 # Classes
23 #-----------------------------------------------------------------------------
24 #-----------------------------------------------------------------------------
24 class _Float(DOMWidget):
25 class _Float(DOMWidget):
25 value = CFloat(0.0, help="Float value", sync=True)
26 value = CFloat(0.0, help="Float value", sync=True)
26 disabled = Bool(False, help="Enable or disable user changes", sync=True)
27 disabled = Bool(False, help="Enable or disable user changes", sync=True)
27 description = Unicode(help="Description of the value this widget represents", sync=True)
28 description = Unicode(help="Description of the value this widget represents", sync=True)
28
29
29 def __init__(self, value=None, **kwargs):
30 def __init__(self, value=None, **kwargs):
30 if value is not None:
31 if value is not None:
31 kwargs['value'] = value
32 kwargs['value'] = value
32 super(_Float, self).__init__(**kwargs)
33 super(_Float, self).__init__(**kwargs)
33
34
35
34 class _BoundedFloat(_Float):
36 class _BoundedFloat(_Float):
35 max = CFloat(100.0, help="Max value", sync=True)
37 max = CFloat(100.0, help="Max value", sync=True)
36 min = CFloat(0.0, help="Min value", sync=True)
38 min = CFloat(0.0, help="Min value", sync=True)
37 step = CFloat(0.1, help="Minimum step that the value can take (ignored by some views)", sync=True)
39 step = CFloat(0.1, help="Minimum step to increment the value (ignored by some views)", sync=True)
38
40
39 def __init__(self, *pargs, **kwargs):
41 def __init__(self, *pargs, **kwargs):
40 """Constructor"""
42 """Constructor"""
41 super(_BoundedFloat, self).__init__(*pargs, **kwargs)
43 super(_BoundedFloat, self).__init__(*pargs, **kwargs)
42 self._handle_value_changed('value', None, self.value)
43 self._handle_max_changed('max', None, self.max)
44 self._handle_min_changed('min', None, self.min)
45 self.on_trait_change(self._handle_value_changed, 'value')
46 self.on_trait_change(self._handle_max_changed, 'max')
47 self.on_trait_change(self._handle_min_changed, 'min')
48
44
49 def _handle_value_changed(self, name, old, new):
45 def _value_validate(self, value, trait):
50 """Validate value."""
46 """Cap and floor value"""
51 if self.min > new or new > self.max:
47 if self.min > value or self.max < value:
52 self.value = min(max(new, self.min), self.max)
48 value = min(max(value, self.min), self.max)
49 return value
53
50
54 def _handle_max_changed(self, name, old, new):
51 def _min_validate(self, min, trait):
55 """Make sure the min is always <= the max."""
52 """Enforce min <= value <= max"""
56 if new < self.min:
53 if min > self.max:
57 raise ValueError("setting max < min")
54 raise TraitError("Setting min > max")
58 if new < self.value:
55 if min > self.value:
59 self.value = new
56 self.value = min
57 return min
60
58
61 def _handle_min_changed(self, name, old, new):
59 def _max_validate(self, max, trait):
62 """Make sure the max is always >= the min."""
60 """Enforce min <= value <= max"""
63 if new > self.max:
61 if max < self.min:
64 raise ValueError("setting min > max")
62 raise TraitError("setting max < min")
65 if new > self.value:
63 if max < self.value:
66 self.value = new
64 self.value = max
65 return max
67
66
68
67
69 @register('IPython.FloatText')
68 @register('IPython.FloatText')
70 class FloatText(_Float):
69 class FloatText(_Float):
71 """ Displays a float value within a textbox. For a textbox in
70 """ Displays a float value within a textbox. For a textbox in
72 which the value must be within a specific range, use BoundedFloatText.
71 which the value must be within a specific range, use BoundedFloatText.
73
72
74 Parameters
73 Parameters
75 ----------
74 ----------
76 value : float
75 value : float
77 value displayed
76 value displayed
78 description : str
77 description : str
79 description displayed next to the textbox
78 description displayed next to the text box
80 color : str Unicode color code (eg. '#C13535'), optional
79 color : str Unicode color code (eg. '#C13535'), optional
81 color of the value displayed
80 color of the value displayed
82 """
81 """
83 _view_name = Unicode('FloatTextView', sync=True)
82 _view_name = Unicode('FloatTextView', sync=True)
84
83
85
84
86 @register('IPython.BoundedFloatText')
85 @register('IPython.BoundedFloatText')
87 class BoundedFloatText(_BoundedFloat):
86 class BoundedFloatText(_BoundedFloat):
88 """ Displays a float value within a textbox. Value must be within the range specified.
87 """ Displays a float value within a textbox. Value must be within the range specified.
89 For a textbox in which the value doesn't need to be within a specific range, use FloatText.
88 For a textbox in which the value doesn't need to be within a specific range, use FloatText.
90
89
91 Parameters
90 Parameters
92 ----------
91 ----------
93 value : float
92 value : float
94 value displayed
93 value displayed
95 min : float
94 min : float
96 minimal value of the range of possible values displayed
95 minimal value of the range of possible values displayed
97 max : float
96 max : float
98 maximal value of the range of possible values displayed
97 maximal value of the range of possible values displayed
99 description : str
98 description : str
100 description displayed next to the textbox
99 description displayed next to the textbox
101 color : str Unicode color code (eg. '#C13535'), optional
100 color : str Unicode color code (eg. '#C13535'), optional
102 color of the value displayed
101 color of the value displayed
103 """
102 """
104 _view_name = Unicode('FloatTextView', sync=True)
103 _view_name = Unicode('FloatTextView', sync=True)
105
104
106
105
107 @register('IPython.FloatSlider')
106 @register('IPython.FloatSlider')
108 class FloatSlider(_BoundedFloat):
107 class FloatSlider(_BoundedFloat):
109 """ Slider/trackbar of floating values with the specified range.
108 """ Slider/trackbar of floating values with the specified range.
110
109
111 Parameters
110 Parameters
112 ----------
111 ----------
113 value : float
112 value : float
114 position of the slider
113 position of the slider
115 min : float
114 min : float
116 minimal position of the slider
115 minimal position of the slider
117 max : float
116 max : float
118 maximal position of the slider
117 maximal position of the slider
119 step : float
118 step : float
120 step of the trackbar
119 step of the trackbar
121 description : str
120 description : str
122 name of the slider
121 name of the slider
123 orientation : {'vertical', 'horizontal}, optional
122 orientation : {'vertical', 'horizontal}, optional
124 default is horizontal
123 default is horizontal
125 readout : {True, False}, optional
124 readout : {True, False}, optional
126 default is True, display the current value of the slider next to it
125 default is True, display the current value of the slider next to it
127 slider_color : str Unicode color code (eg. '#C13535'), optional
126 slider_color : str Unicode color code (eg. '#C13535'), optional
128 color of the slider
127 color of the slider
129 color : str Unicode color code (eg. '#C13535'), optional
128 color : str Unicode color code (eg. '#C13535'), optional
130 color of the value displayed (if readout == True)
129 color of the value displayed (if readout == True)
131 """
130 """
132 _view_name = Unicode('FloatSliderView', sync=True)
131 _view_name = Unicode('FloatSliderView', sync=True)
133 orientation = CaselessStrEnum(values=['horizontal', 'vertical'],
132 orientation = CaselessStrEnum(values=['horizontal', 'vertical'],
134 default_value='horizontal', help="Vertical or horizontal.", sync=True)
133 default_value='horizontal', help="Vertical or horizontal.", sync=True)
135 _range = Bool(False, help="Display a range selector", sync=True)
134 _range = Bool(False, help="Display a range selector", sync=True)
136 readout = Bool(True, help="Display the current value of the slider next to it.", sync=True)
135 readout = Bool(True, help="Display the current value of the slider next to it.", sync=True)
137 slider_color = Color(None, allow_none=True, sync=True)
136 slider_color = Color(None, allow_none=True, sync=True)
138
137
139
138
140 @register('IPython.FloatProgress')
139 @register('IPython.FloatProgress')
141 class FloatProgress(_BoundedFloat):
140 class FloatProgress(_BoundedFloat):
142 """ Displays a progress bar.
141 """ Displays a progress bar.
143
142
144 Parameters
143 Parameters
145 -----------
144 -----------
146 value : float
145 value : float
147 position within the range of the progress bar
146 position within the range of the progress bar
148 min : float
147 min : float
149 minimal position of the slider
148 minimal position of the slider
150 max : float
149 max : float
151 maximal position of the slider
150 maximal position of the slider
152 step : float
151 step : float
153 step of the progress bar
152 step of the progress bar
154 description : str
153 description : str
155 name of the progress bar
154 name of the progress bar
156 bar_style: {'success', 'info', 'warning', 'danger', ''}, optional
155 bar_style: {'success', 'info', 'warning', 'danger', ''}, optional
157 color of the progress bar, default is '' (blue)
156 color of the progress bar, default is '' (blue)
158 colors are: 'success'-green, 'info'-light blue, 'warning'-orange, 'danger'-red
157 colors are: 'success'-green, 'info'-light blue, 'warning'-orange, 'danger'-red
159 """
158 """
160 _view_name = Unicode('ProgressView', sync=True)
159 _view_name = Unicode('ProgressView', sync=True)
161
160
162 bar_style = CaselessStrEnum(
161 bar_style = CaselessStrEnum(
163 values=['success', 'info', 'warning', 'danger', ''],
162 values=['success', 'info', 'warning', 'danger', ''],
164 default_value='', allow_none=True, sync=True, help="""Use a
163 default_value='', allow_none=True, sync=True, help="""Use a
165 predefined styling for the progess bar.""")
164 predefined styling for the progess bar.""")
166
165
167 class _FloatRange(_Float):
166 class _FloatRange(_Float):
168 value = Tuple(CFloat, CFloat, default_value=(0.0, 1.0), help="Tuple of (lower, upper) bounds", sync=True)
167 value = Tuple(CFloat, CFloat, default_value=(0.0, 1.0), help="Tuple of (lower, upper) bounds", sync=True)
169 lower = CFloat(0.0, help="Lower bound", sync=False)
168 lower = CFloat(0.0, help="Lower bound", sync=False)
170 upper = CFloat(1.0, help="Upper bound", sync=False)
169 upper = CFloat(1.0, help="Upper bound", sync=False)
171
170
172 def __init__(self, *pargs, **kwargs):
171 def __init__(self, *pargs, **kwargs):
173 value_given = 'value' in kwargs
172 value_given = 'value' in kwargs
174 lower_given = 'lower' in kwargs
173 lower_given = 'lower' in kwargs
175 upper_given = 'upper' in kwargs
174 upper_given = 'upper' in kwargs
176 if value_given and (lower_given or upper_given):
175 if value_given and (lower_given or upper_given):
177 raise ValueError("Cannot specify both 'value' and 'lower'/'upper' for range widget")
176 raise ValueError("Cannot specify both 'value' and 'lower'/'upper' for range widget")
178 if lower_given != upper_given:
177 if lower_given != upper_given:
179 raise ValueError("Must specify both 'lower' and 'upper' for range widget")
178 raise ValueError("Must specify both 'lower' and 'upper' for range widget")
180
179
181 DOMWidget.__init__(self, *pargs, **kwargs)
180 DOMWidget.__init__(self, *pargs, **kwargs)
182
181
183 # ensure the traits match, preferring whichever (if any) was given in kwargs
182 # ensure the traits match, preferring whichever (if any) was given in kwargs
184 if value_given:
183 if value_given:
185 self.lower, self.upper = self.value
184 self.lower, self.upper = self.value
186 else:
185 else:
187 self.value = (self.lower, self.upper)
186 self.value = (self.lower, self.upper)
188
187
189 self.on_trait_change(self._validate, ['value', 'upper', 'lower'])
188 self.on_trait_change(self._validate, ['value', 'upper', 'lower'])
190
189
191 def _validate(self, name, old, new):
190 def _validate(self, name, old, new):
192 if name == 'value':
191 if name == 'value':
193 self.lower, self.upper = min(new), max(new)
192 self.lower, self.upper = min(new), max(new)
194 elif name == 'lower':
193 elif name == 'lower':
195 self.value = (new, self.value[1])
194 self.value = (new, self.value[1])
196 elif name == 'upper':
195 elif name == 'upper':
197 self.value = (self.value[0], new)
196 self.value = (self.value[0], new)
198
197
199 class _BoundedFloatRange(_FloatRange):
198 class _BoundedFloatRange(_FloatRange):
200 step = CFloat(1.0, help="Minimum step that the value can take (ignored by some views)", sync=True)
199 step = CFloat(1.0, help="Minimum step that the value can take (ignored by some views)", sync=True)
201 max = CFloat(100.0, help="Max value", sync=True)
200 max = CFloat(100.0, help="Max value", sync=True)
202 min = CFloat(0.0, help="Min value", sync=True)
201 min = CFloat(0.0, help="Min value", sync=True)
203
202
204 def __init__(self, *pargs, **kwargs):
203 def __init__(self, *pargs, **kwargs):
205 any_value_given = 'value' in kwargs or 'upper' in kwargs or 'lower' in kwargs
204 any_value_given = 'value' in kwargs or 'upper' in kwargs or 'lower' in kwargs
206 _FloatRange.__init__(self, *pargs, **kwargs)
205 _FloatRange.__init__(self, *pargs, **kwargs)
207
206
208 # ensure a minimal amount of sanity
207 # ensure a minimal amount of sanity
209 if self.min > self.max:
208 if self.min > self.max:
210 raise ValueError("min must be <= max")
209 raise ValueError("min must be <= max")
211
210
212 if any_value_given:
211 if any_value_given:
213 # if a value was given, clamp it within (min, max)
212 # if a value was given, clamp it within (min, max)
214 self._validate("value", None, self.value)
213 self._validate("value", None, self.value)
215 else:
214 else:
216 # otherwise, set it to 25-75% to avoid the handles overlapping
215 # otherwise, set it to 25-75% to avoid the handles overlapping
217 self.value = (0.75*self.min + 0.25*self.max,
216 self.value = (0.75*self.min + 0.25*self.max,
218 0.25*self.min + 0.75*self.max)
217 0.25*self.min + 0.75*self.max)
219 # callback already set for 'value', 'lower', 'upper'
218 # callback already set for 'value', 'lower', 'upper'
220 self.on_trait_change(self._validate, ['min', 'max'])
219 self.on_trait_change(self._validate, ['min', 'max'])
221
220
222
221
223 def _validate(self, name, old, new):
222 def _validate(self, name, old, new):
224 if name == "min":
223 if name == "min":
225 if new > self.max:
224 if new > self.max:
226 raise ValueError("setting min > max")
225 raise ValueError("setting min > max")
227 self.min = new
226 self.min = new
228 elif name == "max":
227 elif name == "max":
229 if new < self.min:
228 if new < self.min:
230 raise ValueError("setting max < min")
229 raise ValueError("setting max < min")
231 self.max = new
230 self.max = new
232
231
233 low, high = self.value
232 low, high = self.value
234 if name == "value":
233 if name == "value":
235 low, high = min(new), max(new)
234 low, high = min(new), max(new)
236 elif name == "upper":
235 elif name == "upper":
237 if new < self.lower:
236 if new < self.lower:
238 raise ValueError("setting upper < lower")
237 raise ValueError("setting upper < lower")
239 high = new
238 high = new
240 elif name == "lower":
239 elif name == "lower":
241 if new > self.upper:
240 if new > self.upper:
242 raise ValueError("setting lower > upper")
241 raise ValueError("setting lower > upper")
243 low = new
242 low = new
244
243
245 low = max(self.min, min(low, self.max))
244 low = max(self.min, min(low, self.max))
246 high = min(self.max, max(high, self.min))
245 high = min(self.max, max(high, self.min))
247
246
248 # determine the order in which we should update the
247 # determine the order in which we should update the
249 # lower, upper traits to avoid a temporary inverted overlap
248 # lower, upper traits to avoid a temporary inverted overlap
250 lower_first = high < self.lower
249 lower_first = high < self.lower
251
250
252 self.value = (low, high)
251 self.value = (low, high)
253 if lower_first:
252 if lower_first:
254 self.lower = low
253 self.lower = low
255 self.upper = high
254 self.upper = high
256 else:
255 else:
257 self.upper = high
256 self.upper = high
258 self.lower = low
257 self.lower = low
259
258
260
259
261 @register('IPython.FloatRangeSlider')
260 @register('IPython.FloatRangeSlider')
262 class FloatRangeSlider(_BoundedFloatRange):
261 class FloatRangeSlider(_BoundedFloatRange):
263 """ Slider/trackbar for displaying a floating value range (within the specified range of values).
262 """ Slider/trackbar for displaying a floating value range (within the specified range of values).
264
263
265 Parameters
264 Parameters
266 ----------
265 ----------
267 value : float tuple
266 value : float tuple
268 range of the slider displayed
267 range of the slider displayed
269 min : float
268 min : float
270 minimal position of the slider
269 minimal position of the slider
271 max : float
270 max : float
272 maximal position of the slider
271 maximal position of the slider
273 step : float
272 step : float
274 step of the trackbar
273 step of the trackbar
275 description : str
274 description : str
276 name of the slider
275 name of the slider
277 orientation : {'vertical', 'horizontal}, optional
276 orientation : {'vertical', 'horizontal}, optional
278 default is horizontal
277 default is horizontal
279 readout : {True, False}, optional
278 readout : {True, False}, optional
280 default is True, display the current value of the slider next to it
279 default is True, display the current value of the slider next to it
281 slider_color : str Unicode color code (eg. '#C13535'), optional
280 slider_color : str Unicode color code (eg. '#C13535'), optional
282 color of the slider
281 color of the slider
283 color : str Unicode color code (eg. '#C13535'), optional
282 color : str Unicode color code (eg. '#C13535'), optional
284 color of the value displayed (if readout == True)
283 color of the value displayed (if readout == True)
285 """
284 """
286 _view_name = Unicode('FloatSliderView', sync=True)
285 _view_name = Unicode('FloatSliderView', sync=True)
287 orientation = CaselessStrEnum(values=['horizontal', 'vertical'],
286 orientation = CaselessStrEnum(values=['horizontal', 'vertical'],
288 default_value='horizontal', help="Vertical or horizontal.", sync=True)
287 default_value='horizontal', help="Vertical or horizontal.", sync=True)
289 _range = Bool(True, help="Display a range selector", sync=True)
288 _range = Bool(True, help="Display a range selector", sync=True)
290 readout = Bool(True, help="Display the current value of the slider next to it.", sync=True)
289 readout = Bool(True, help="Display the current value of the slider next to it.", sync=True)
291 slider_color = Color(None, allow_none=True, sync=True)
290 slider_color = Color(None, allow_none=True, sync=True)
292
291
293 # Remove in IPython 4.0
292 # Remove in IPython 4.0
294 FloatTextWidget = DeprecatedClass(FloatText, 'FloatTextWidget')
293 FloatTextWidget = DeprecatedClass(FloatText, 'FloatTextWidget')
295 BoundedFloatTextWidget = DeprecatedClass(BoundedFloatText, 'BoundedFloatTextWidget')
294 BoundedFloatTextWidget = DeprecatedClass(BoundedFloatText, 'BoundedFloatTextWidget')
296 FloatSliderWidget = DeprecatedClass(FloatSlider, 'FloatSliderWidget')
295 FloatSliderWidget = DeprecatedClass(FloatSlider, 'FloatSliderWidget')
297 FloatProgressWidget = DeprecatedClass(FloatProgress, 'FloatProgressWidget')
296 FloatProgressWidget = DeprecatedClass(FloatProgress, 'FloatProgressWidget')
@@ -1,208 +1,207 b''
1 """Int class.
1 """Int class.
2
2
3 Represents an unbounded int using a widget.
3 Represents an unbounded int using a widget.
4 """
4 """
5 #-----------------------------------------------------------------------------
5 #-----------------------------------------------------------------------------
6 # Copyright (c) 2013, the IPython Development Team.
6 # Copyright (c) 2013, the IPython Development Team.
7 #
7 #
8 # Distributed under the terms of the Modified BSD License.
8 # Distributed under the terms of the Modified BSD License.
9 #
9 #
10 # The full license is in the file COPYING.txt, distributed with this software.
10 # The full license is in the file COPYING.txt, distributed with this software.
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12
12
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14 # Imports
14 # Imports
15 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
16 from .widget import DOMWidget, register
16 from .widget import DOMWidget, register
17 from .trait_types import Color
17 from .trait_types import Color
18 from IPython.utils.traitlets import Unicode, CInt, Bool, CaselessStrEnum, Tuple
18 from IPython.utils.traitlets import (Unicode, CInt, Bool, CaselessStrEnum,
19 Tuple, TraitError)
19 from IPython.utils.warn import DeprecatedClass
20 from IPython.utils.warn import DeprecatedClass
20
21
21 #-----------------------------------------------------------------------------
22 #-----------------------------------------------------------------------------
22 # Classes
23 # Classes
23 #-----------------------------------------------------------------------------
24 #-----------------------------------------------------------------------------
24 class _Int(DOMWidget):
25 class _Int(DOMWidget):
25 """Base class used to create widgets that represent an int."""
26 """Base class used to create widgets that represent an int."""
26 value = CInt(0, help="Int value", sync=True)
27 value = CInt(0, help="Int value", sync=True)
27 disabled = Bool(False, help="Enable or disable user changes", sync=True)
28 disabled = Bool(False, help="Enable or disable user changes", sync=True)
28 description = Unicode(help="Description of the value this widget represents", sync=True)
29 description = Unicode(help="Description of the value this widget represents", sync=True)
29
30
30 def __init__(self, value=None, **kwargs):
31 def __init__(self, value=None, **kwargs):
31 if value is not None:
32 if value is not None:
32 kwargs['value'] = value
33 kwargs['value'] = value
33 super(_Int, self).__init__(**kwargs)
34 super(_Int, self).__init__(**kwargs)
34
35
36
35 class _BoundedInt(_Int):
37 class _BoundedInt(_Int):
36 """Base class used to create widgets that represent a int that is bounded
38 """Base class used to create widgets that represent a int that is bounded
37 by a minium and maximum."""
39 by a minium and maximum."""
38 step = CInt(1, help="Minimum step that the value can take (ignored by some views)", sync=True)
40 step = CInt(1, help="Minimum step to increment the value (ignored by some views)", sync=True)
39 max = CInt(100, help="Max value", sync=True)
41 max = CInt(100, help="Max value", sync=True)
40 min = CInt(0, help="Min value", sync=True)
42 min = CInt(0, help="Min value", sync=True)
41
43
42 def __init__(self, *pargs, **kwargs):
44 def __init__(self, *pargs, **kwargs):
43 """Constructor"""
45 """Constructor"""
44 super(_BoundedInt, self).__init__(*pargs, **kwargs)
46 super(_BoundedInt, self).__init__(*pargs, **kwargs)
45 self._handle_value_changed('value', None, self.value)
47
46 self._handle_max_changed('max', None, self.max)
48 def _value_validate(self, value, trait):
47 self._handle_min_changed('min', None, self.min)
49 """Cap and floor value"""
48 self.on_trait_change(self._handle_value_changed, 'value')
50 if self.min > value or self.max < value:
49 self.on_trait_change(self._handle_max_changed, 'max')
51 value = min(max(value, self.min), self.max)
50 self.on_trait_change(self._handle_min_changed, 'min')
52 return value
51
53
52 def _handle_value_changed(self, name, old, new):
54 def _min_validate(self, min, trait):
53 """Validate value."""
55 """Enforce min <= value <= max"""
54 if self.min > new or new > self.max:
56 if min > self.max:
55 self.value = min(max(new, self.min), self.max)
57 raise TraitError("Setting min > max")
56
58 if min > self.value:
57 def _handle_max_changed(self, name, old, new):
59 self.value = min
58 """Make sure the min is always <= the max."""
60 return min
59 if new < self.min:
61
60 raise ValueError("setting max < min")
62 def _max_validate(self, max, trait):
61 if new < self.value:
63 """Enforce min <= value <= max"""
62 self.value = new
64 if max < self.min:
63
65 raise TraitError("setting max < min")
64 def _handle_min_changed(self, name, old, new):
66 if max < self.value:
65 """Make sure the max is always >= the min."""
67 self.value = max
66 if new > self.max:
68 return max
67 raise ValueError("setting min > max")
68 if new > self.value:
69 self.value = new
70
69
71 @register('IPython.IntText')
70 @register('IPython.IntText')
72 class IntText(_Int):
71 class IntText(_Int):
73 """Textbox widget that represents a int."""
72 """Textbox widget that represents a int."""
74 _view_name = Unicode('IntTextView', sync=True)
73 _view_name = Unicode('IntTextView', sync=True)
75
74
76
75
77 @register('IPython.BoundedIntText')
76 @register('IPython.BoundedIntText')
78 class BoundedIntText(_BoundedInt):
77 class BoundedIntText(_BoundedInt):
79 """Textbox widget that represents a int bounded by a minimum and maximum value."""
78 """Textbox widget that represents a int bounded by a minimum and maximum value."""
80 _view_name = Unicode('IntTextView', sync=True)
79 _view_name = Unicode('IntTextView', sync=True)
81
80
82
81
83 @register('IPython.IntSlider')
82 @register('IPython.IntSlider')
84 class IntSlider(_BoundedInt):
83 class IntSlider(_BoundedInt):
85 """Slider widget that represents a int bounded by a minimum and maximum value."""
84 """Slider widget that represents a int bounded by a minimum and maximum value."""
86 _view_name = Unicode('IntSliderView', sync=True)
85 _view_name = Unicode('IntSliderView', sync=True)
87 orientation = CaselessStrEnum(values=['horizontal', 'vertical'],
86 orientation = CaselessStrEnum(values=['horizontal', 'vertical'],
88 default_value='horizontal', help="Vertical or horizontal.", sync=True)
87 default_value='horizontal', help="Vertical or horizontal.", sync=True)
89 _range = Bool(False, help="Display a range selector", sync=True)
88 _range = Bool(False, help="Display a range selector", sync=True)
90 readout = Bool(True, help="Display the current value of the slider next to it.", sync=True)
89 readout = Bool(True, help="Display the current value of the slider next to it.", sync=True)
91 slider_color = Color(None, allow_none=True, sync=True)
90 slider_color = Color(None, allow_none=True, sync=True)
92
91
93
92
94 @register('IPython.IntProgress')
93 @register('IPython.IntProgress')
95 class IntProgress(_BoundedInt):
94 class IntProgress(_BoundedInt):
96 """Progress bar that represents a int bounded by a minimum and maximum value."""
95 """Progress bar that represents a int bounded by a minimum and maximum value."""
97 _view_name = Unicode('ProgressView', sync=True)
96 _view_name = Unicode('ProgressView', sync=True)
98
97
99 bar_style = CaselessStrEnum(
98 bar_style = CaselessStrEnum(
100 values=['success', 'info', 'warning', 'danger', ''],
99 values=['success', 'info', 'warning', 'danger', ''],
101 default_value='', allow_none=True, sync=True, help="""Use a
100 default_value='', allow_none=True, sync=True, help="""Use a
102 predefined styling for the progess bar.""")
101 predefined styling for the progess bar.""")
103
102
104 class _IntRange(_Int):
103 class _IntRange(_Int):
105 value = Tuple(CInt, CInt, default_value=(0, 1), help="Tuple of (lower, upper) bounds", sync=True)
104 value = Tuple(CInt, CInt, default_value=(0, 1), help="Tuple of (lower, upper) bounds", sync=True)
106 lower = CInt(0, help="Lower bound", sync=False)
105 lower = CInt(0, help="Lower bound", sync=False)
107 upper = CInt(1, help="Upper bound", sync=False)
106 upper = CInt(1, help="Upper bound", sync=False)
108
107
109 def __init__(self, *pargs, **kwargs):
108 def __init__(self, *pargs, **kwargs):
110 value_given = 'value' in kwargs
109 value_given = 'value' in kwargs
111 lower_given = 'lower' in kwargs
110 lower_given = 'lower' in kwargs
112 upper_given = 'upper' in kwargs
111 upper_given = 'upper' in kwargs
113 if value_given and (lower_given or upper_given):
112 if value_given and (lower_given or upper_given):
114 raise ValueError("Cannot specify both 'value' and 'lower'/'upper' for range widget")
113 raise ValueError("Cannot specify both 'value' and 'lower'/'upper' for range widget")
115 if lower_given != upper_given:
114 if lower_given != upper_given:
116 raise ValueError("Must specify both 'lower' and 'upper' for range widget")
115 raise ValueError("Must specify both 'lower' and 'upper' for range widget")
117
116
118 super(_IntRange, self).__init__(*pargs, **kwargs)
117 super(_IntRange, self).__init__(*pargs, **kwargs)
119
118
120 # ensure the traits match, preferring whichever (if any) was given in kwargs
119 # ensure the traits match, preferring whichever (if any) was given in kwargs
121 if value_given:
120 if value_given:
122 self.lower, self.upper = self.value
121 self.lower, self.upper = self.value
123 else:
122 else:
124 self.value = (self.lower, self.upper)
123 self.value = (self.lower, self.upper)
125
124
126 self.on_trait_change(self._validate, ['value', 'upper', 'lower'])
125 self.on_trait_change(self._validate, ['value', 'upper', 'lower'])
127
126
128 def _validate(self, name, old, new):
127 def _validate(self, name, old, new):
129 if name == 'value':
128 if name == 'value':
130 self.lower, self.upper = min(new), max(new)
129 self.lower, self.upper = min(new), max(new)
131 elif name == 'lower':
130 elif name == 'lower':
132 self.value = (new, self.value[1])
131 self.value = (new, self.value[1])
133 elif name == 'upper':
132 elif name == 'upper':
134 self.value = (self.value[0], new)
133 self.value = (self.value[0], new)
135
134
136 class _BoundedIntRange(_IntRange):
135 class _BoundedIntRange(_IntRange):
137 step = CInt(1, help="Minimum step that the value can take (ignored by some views)", sync=True)
136 step = CInt(1, help="Minimum step that the value can take (ignored by some views)", sync=True)
138 max = CInt(100, help="Max value", sync=True)
137 max = CInt(100, help="Max value", sync=True)
139 min = CInt(0, help="Min value", sync=True)
138 min = CInt(0, help="Min value", sync=True)
140
139
141 def __init__(self, *pargs, **kwargs):
140 def __init__(self, *pargs, **kwargs):
142 any_value_given = 'value' in kwargs or 'upper' in kwargs or 'lower' in kwargs
141 any_value_given = 'value' in kwargs or 'upper' in kwargs or 'lower' in kwargs
143 _IntRange.__init__(self, *pargs, **kwargs)
142 _IntRange.__init__(self, *pargs, **kwargs)
144
143
145 # ensure a minimal amount of sanity
144 # ensure a minimal amount of sanity
146 if self.min > self.max:
145 if self.min > self.max:
147 raise ValueError("min must be <= max")
146 raise ValueError("min must be <= max")
148
147
149 if any_value_given:
148 if any_value_given:
150 # if a value was given, clamp it within (min, max)
149 # if a value was given, clamp it within (min, max)
151 self._validate("value", None, self.value)
150 self._validate("value", None, self.value)
152 else:
151 else:
153 # otherwise, set it to 25-75% to avoid the handles overlapping
152 # otherwise, set it to 25-75% to avoid the handles overlapping
154 self.value = (0.75*self.min + 0.25*self.max,
153 self.value = (0.75*self.min + 0.25*self.max,
155 0.25*self.min + 0.75*self.max)
154 0.25*self.min + 0.75*self.max)
156 # callback already set for 'value', 'lower', 'upper'
155 # callback already set for 'value', 'lower', 'upper'
157 self.on_trait_change(self._validate, ['min', 'max'])
156 self.on_trait_change(self._validate, ['min', 'max'])
158
157
159 def _validate(self, name, old, new):
158 def _validate(self, name, old, new):
160 if name == "min":
159 if name == "min":
161 if new > self.max:
160 if new > self.max:
162 raise ValueError("setting min > max")
161 raise ValueError("setting min > max")
163 elif name == "max":
162 elif name == "max":
164 if new < self.min:
163 if new < self.min:
165 raise ValueError("setting max < min")
164 raise ValueError("setting max < min")
166
165
167 low, high = self.value
166 low, high = self.value
168 if name == "value":
167 if name == "value":
169 low, high = min(new), max(new)
168 low, high = min(new), max(new)
170 elif name == "upper":
169 elif name == "upper":
171 if new < self.lower:
170 if new < self.lower:
172 raise ValueError("setting upper < lower")
171 raise ValueError("setting upper < lower")
173 high = new
172 high = new
174 elif name == "lower":
173 elif name == "lower":
175 if new > self.upper:
174 if new > self.upper:
176 raise ValueError("setting lower > upper")
175 raise ValueError("setting lower > upper")
177 low = new
176 low = new
178
177
179 low = max(self.min, min(low, self.max))
178 low = max(self.min, min(low, self.max))
180 high = min(self.max, max(high, self.min))
179 high = min(self.max, max(high, self.min))
181
180
182 # determine the order in which we should update the
181 # determine the order in which we should update the
183 # lower, upper traits to avoid a temporary inverted overlap
182 # lower, upper traits to avoid a temporary inverted overlap
184 lower_first = high < self.lower
183 lower_first = high < self.lower
185
184
186 self.value = (low, high)
185 self.value = (low, high)
187 if lower_first:
186 if lower_first:
188 self.lower = low
187 self.lower = low
189 self.upper = high
188 self.upper = high
190 else:
189 else:
191 self.upper = high
190 self.upper = high
192 self.lower = low
191 self.lower = low
193
192
194 @register('IPython.IntRangeSlider')
193 @register('IPython.IntRangeSlider')
195 class IntRangeSlider(_BoundedIntRange):
194 class IntRangeSlider(_BoundedIntRange):
196 """Slider widget that represents a pair of ints between a minimum and maximum value."""
195 """Slider widget that represents a pair of ints between a minimum and maximum value."""
197 _view_name = Unicode('IntSliderView', sync=True)
196 _view_name = Unicode('IntSliderView', sync=True)
198 orientation = CaselessStrEnum(values=['horizontal', 'vertical'],
197 orientation = CaselessStrEnum(values=['horizontal', 'vertical'],
199 default_value='horizontal', help="Vertical or horizontal.", sync=True)
198 default_value='horizontal', help="Vertical or horizontal.", sync=True)
200 _range = Bool(True, help="Display a range selector", sync=True)
199 _range = Bool(True, help="Display a range selector", sync=True)
201 readout = Bool(True, help="Display the current value of the slider next to it.", sync=True)
200 readout = Bool(True, help="Display the current value of the slider next to it.", sync=True)
202 slider_color = Color(None, allow_none=True, sync=True)
201 slider_color = Color(None, allow_none=True, sync=True)
203
202
204 # Remove in IPython 4.0
203 # Remove in IPython 4.0
205 IntTextWidget = DeprecatedClass(IntText, 'IntTextWidget')
204 IntTextWidget = DeprecatedClass(IntText, 'IntTextWidget')
206 BoundedIntTextWidget = DeprecatedClass(BoundedIntText, 'BoundedIntTextWidget')
205 BoundedIntTextWidget = DeprecatedClass(BoundedIntText, 'BoundedIntTextWidget')
207 IntSliderWidget = DeprecatedClass(IntSlider, 'IntSliderWidget')
206 IntSliderWidget = DeprecatedClass(IntSlider, 'IntSliderWidget')
208 IntProgressWidget = DeprecatedClass(IntProgress, 'IntProgressWidget')
207 IntProgressWidget = DeprecatedClass(IntProgress, 'IntProgressWidget')
@@ -1,1617 +1,1634 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """Tests for IPython.utils.traitlets."""
2 """Tests for IPython.utils.traitlets."""
3
3
4 # Copyright (c) IPython Development Team.
4 # Copyright (c) IPython Development Team.
5 # Distributed under the terms of the Modified BSD License.
5 # Distributed under the terms of the Modified BSD License.
6 #
6 #
7 # Adapted from enthought.traits, Copyright (c) Enthought, Inc.,
7 # Adapted from enthought.traits, Copyright (c) Enthought, Inc.,
8 # also under the terms of the Modified BSD License.
8 # also under the terms of the Modified BSD License.
9
9
10 import pickle
10 import pickle
11 import re
11 import re
12 import sys
12 import sys
13 from unittest import TestCase
13 from unittest import TestCase
14
14
15 import nose.tools as nt
15 import nose.tools as nt
16 from nose import SkipTest
16 from nose import SkipTest
17
17
18 from IPython.utils.traitlets import (
18 from IPython.utils.traitlets import (
19 HasTraits, MetaHasTraits, TraitType, Any, Bool, CBytes, Dict, Enum,
19 HasTraits, MetaHasTraits, TraitType, Any, Bool, CBytes, Dict, Enum,
20 Int, Long, Integer, Float, Complex, Bytes, Unicode, TraitError,
20 Int, Long, Integer, Float, Complex, Bytes, Unicode, TraitError,
21 Union, Undefined, Type, This, Instance, TCPAddress, List, Tuple,
21 Union, Undefined, Type, This, Instance, TCPAddress, List, Tuple,
22 ObjectName, DottedObjectName, CRegExp, link, directional_link,
22 ObjectName, DottedObjectName, CRegExp, link, directional_link,
23 EventfulList, EventfulDict, ForwardDeclaredType, ForwardDeclaredInstance,
23 EventfulList, EventfulDict, ForwardDeclaredType, ForwardDeclaredInstance,
24 )
24 )
25 from IPython.utils import py3compat
25 from IPython.utils import py3compat
26 from IPython.testing.decorators import skipif
26 from IPython.testing.decorators import skipif
27
27
28 #-----------------------------------------------------------------------------
28 #-----------------------------------------------------------------------------
29 # Helper classes for testing
29 # Helper classes for testing
30 #-----------------------------------------------------------------------------
30 #-----------------------------------------------------------------------------
31
31
32
32
33 class HasTraitsStub(HasTraits):
33 class HasTraitsStub(HasTraits):
34
34
35 def _notify_trait(self, name, old, new):
35 def _notify_trait(self, name, old, new):
36 self._notify_name = name
36 self._notify_name = name
37 self._notify_old = old
37 self._notify_old = old
38 self._notify_new = new
38 self._notify_new = new
39
39
40
40
41 #-----------------------------------------------------------------------------
41 #-----------------------------------------------------------------------------
42 # Test classes
42 # Test classes
43 #-----------------------------------------------------------------------------
43 #-----------------------------------------------------------------------------
44
44
45
45
46 class TestTraitType(TestCase):
46 class TestTraitType(TestCase):
47
47
48 def test_get_undefined(self):
48 def test_get_undefined(self):
49 class A(HasTraits):
49 class A(HasTraits):
50 a = TraitType
50 a = TraitType
51 a = A()
51 a = A()
52 self.assertEqual(a.a, Undefined)
52 self.assertEqual(a.a, Undefined)
53
53
54 def test_set(self):
54 def test_set(self):
55 class A(HasTraitsStub):
55 class A(HasTraitsStub):
56 a = TraitType
56 a = TraitType
57
57
58 a = A()
58 a = A()
59 a.a = 10
59 a.a = 10
60 self.assertEqual(a.a, 10)
60 self.assertEqual(a.a, 10)
61 self.assertEqual(a._notify_name, 'a')
61 self.assertEqual(a._notify_name, 'a')
62 self.assertEqual(a._notify_old, Undefined)
62 self.assertEqual(a._notify_old, Undefined)
63 self.assertEqual(a._notify_new, 10)
63 self.assertEqual(a._notify_new, 10)
64
64
65 def test_validate(self):
65 def test_validate(self):
66 class MyTT(TraitType):
66 class MyTT(TraitType):
67 def validate(self, inst, value):
67 def validate(self, inst, value):
68 return -1
68 return -1
69 class A(HasTraitsStub):
69 class A(HasTraitsStub):
70 tt = MyTT
70 tt = MyTT
71
71
72 a = A()
72 a = A()
73 a.tt = 10
73 a.tt = 10
74 self.assertEqual(a.tt, -1)
74 self.assertEqual(a.tt, -1)
75
75
76 def test_default_validate(self):
76 def test_default_validate(self):
77 class MyIntTT(TraitType):
77 class MyIntTT(TraitType):
78 def validate(self, obj, value):
78 def validate(self, obj, value):
79 if isinstance(value, int):
79 if isinstance(value, int):
80 return value
80 return value
81 self.error(obj, value)
81 self.error(obj, value)
82 class A(HasTraits):
82 class A(HasTraits):
83 tt = MyIntTT(10)
83 tt = MyIntTT(10)
84 a = A()
84 a = A()
85 self.assertEqual(a.tt, 10)
85 self.assertEqual(a.tt, 10)
86
86
87 # Defaults are validated when the HasTraits is instantiated
87 # Defaults are validated when the HasTraits is instantiated
88 class B(HasTraits):
88 class B(HasTraits):
89 tt = MyIntTT('bad default')
89 tt = MyIntTT('bad default')
90 self.assertRaises(TraitError, B)
90 self.assertRaises(TraitError, B)
91
91
92 def test_info(self):
92 def test_info(self):
93 class A(HasTraits):
93 class A(HasTraits):
94 tt = TraitType
94 tt = TraitType
95 a = A()
95 a = A()
96 self.assertEqual(A.tt.info(), 'any value')
96 self.assertEqual(A.tt.info(), 'any value')
97
97
98 def test_error(self):
98 def test_error(self):
99 class A(HasTraits):
99 class A(HasTraits):
100 tt = TraitType
100 tt = TraitType
101 a = A()
101 a = A()
102 self.assertRaises(TraitError, A.tt.error, a, 10)
102 self.assertRaises(TraitError, A.tt.error, a, 10)
103
103
104 def test_dynamic_initializer(self):
104 def test_dynamic_initializer(self):
105 class A(HasTraits):
105 class A(HasTraits):
106 x = Int(10)
106 x = Int(10)
107 def _x_default(self):
107 def _x_default(self):
108 return 11
108 return 11
109 class B(A):
109 class B(A):
110 x = Int(20)
110 x = Int(20)
111 class C(A):
111 class C(A):
112 def _x_default(self):
112 def _x_default(self):
113 return 21
113 return 21
114
114
115 a = A()
115 a = A()
116 self.assertEqual(a._trait_values, {})
116 self.assertEqual(a._trait_values, {})
117 self.assertEqual(list(a._trait_dyn_inits.keys()), ['x'])
117 self.assertEqual(list(a._trait_dyn_inits.keys()), ['x'])
118 self.assertEqual(a.x, 11)
118 self.assertEqual(a.x, 11)
119 self.assertEqual(a._trait_values, {'x': 11})
119 self.assertEqual(a._trait_values, {'x': 11})
120 b = B()
120 b = B()
121 self.assertEqual(b._trait_values, {'x': 20})
121 self.assertEqual(b._trait_values, {'x': 20})
122 self.assertEqual(list(a._trait_dyn_inits.keys()), ['x'])
122 self.assertEqual(list(a._trait_dyn_inits.keys()), ['x'])
123 self.assertEqual(b.x, 20)
123 self.assertEqual(b.x, 20)
124 c = C()
124 c = C()
125 self.assertEqual(c._trait_values, {})
125 self.assertEqual(c._trait_values, {})
126 self.assertEqual(list(a._trait_dyn_inits.keys()), ['x'])
126 self.assertEqual(list(a._trait_dyn_inits.keys()), ['x'])
127 self.assertEqual(c.x, 21)
127 self.assertEqual(c.x, 21)
128 self.assertEqual(c._trait_values, {'x': 21})
128 self.assertEqual(c._trait_values, {'x': 21})
129 # Ensure that the base class remains unmolested when the _default
129 # Ensure that the base class remains unmolested when the _default
130 # initializer gets overridden in a subclass.
130 # initializer gets overridden in a subclass.
131 a = A()
131 a = A()
132 c = C()
132 c = C()
133 self.assertEqual(a._trait_values, {})
133 self.assertEqual(a._trait_values, {})
134 self.assertEqual(list(a._trait_dyn_inits.keys()), ['x'])
134 self.assertEqual(list(a._trait_dyn_inits.keys()), ['x'])
135 self.assertEqual(a.x, 11)
135 self.assertEqual(a.x, 11)
136 self.assertEqual(a._trait_values, {'x': 11})
136 self.assertEqual(a._trait_values, {'x': 11})
137
137
138
138
139
139
140 class TestHasTraitsMeta(TestCase):
140 class TestHasTraitsMeta(TestCase):
141
141
142 def test_metaclass(self):
142 def test_metaclass(self):
143 self.assertEqual(type(HasTraits), MetaHasTraits)
143 self.assertEqual(type(HasTraits), MetaHasTraits)
144
144
145 class A(HasTraits):
145 class A(HasTraits):
146 a = Int
146 a = Int
147
147
148 a = A()
148 a = A()
149 self.assertEqual(type(a.__class__), MetaHasTraits)
149 self.assertEqual(type(a.__class__), MetaHasTraits)
150 self.assertEqual(a.a,0)
150 self.assertEqual(a.a,0)
151 a.a = 10
151 a.a = 10
152 self.assertEqual(a.a,10)
152 self.assertEqual(a.a,10)
153
153
154 class B(HasTraits):
154 class B(HasTraits):
155 b = Int()
155 b = Int()
156
156
157 b = B()
157 b = B()
158 self.assertEqual(b.b,0)
158 self.assertEqual(b.b,0)
159 b.b = 10
159 b.b = 10
160 self.assertEqual(b.b,10)
160 self.assertEqual(b.b,10)
161
161
162 class C(HasTraits):
162 class C(HasTraits):
163 c = Int(30)
163 c = Int(30)
164
164
165 c = C()
165 c = C()
166 self.assertEqual(c.c,30)
166 self.assertEqual(c.c,30)
167 c.c = 10
167 c.c = 10
168 self.assertEqual(c.c,10)
168 self.assertEqual(c.c,10)
169
169
170 def test_this_class(self):
170 def test_this_class(self):
171 class A(HasTraits):
171 class A(HasTraits):
172 t = This()
172 t = This()
173 tt = This()
173 tt = This()
174 class B(A):
174 class B(A):
175 tt = This()
175 tt = This()
176 ttt = This()
176 ttt = This()
177 self.assertEqual(A.t.this_class, A)
177 self.assertEqual(A.t.this_class, A)
178 self.assertEqual(B.t.this_class, A)
178 self.assertEqual(B.t.this_class, A)
179 self.assertEqual(B.tt.this_class, B)
179 self.assertEqual(B.tt.this_class, B)
180 self.assertEqual(B.ttt.this_class, B)
180 self.assertEqual(B.ttt.this_class, B)
181
181
182 class TestHasTraitsNotify(TestCase):
182 class TestHasTraitsNotify(TestCase):
183
183
184 def setUp(self):
184 def setUp(self):
185 self._notify1 = []
185 self._notify1 = []
186 self._notify2 = []
186 self._notify2 = []
187
187
188 def notify1(self, name, old, new):
188 def notify1(self, name, old, new):
189 self._notify1.append((name, old, new))
189 self._notify1.append((name, old, new))
190
190
191 def notify2(self, name, old, new):
191 def notify2(self, name, old, new):
192 self._notify2.append((name, old, new))
192 self._notify2.append((name, old, new))
193
193
194 def test_notify_all(self):
194 def test_notify_all(self):
195
195
196 class A(HasTraits):
196 class A(HasTraits):
197 a = Int
197 a = Int
198 b = Float
198 b = Float
199
199
200 a = A()
200 a = A()
201 a.on_trait_change(self.notify1)
201 a.on_trait_change(self.notify1)
202 a.a = 0
202 a.a = 0
203 self.assertEqual(len(self._notify1),0)
203 self.assertEqual(len(self._notify1),0)
204 a.b = 0.0
204 a.b = 0.0
205 self.assertEqual(len(self._notify1),0)
205 self.assertEqual(len(self._notify1),0)
206 a.a = 10
206 a.a = 10
207 self.assertTrue(('a',0,10) in self._notify1)
207 self.assertTrue(('a',0,10) in self._notify1)
208 a.b = 10.0
208 a.b = 10.0
209 self.assertTrue(('b',0.0,10.0) in self._notify1)
209 self.assertTrue(('b',0.0,10.0) in self._notify1)
210 self.assertRaises(TraitError,setattr,a,'a','bad string')
210 self.assertRaises(TraitError,setattr,a,'a','bad string')
211 self.assertRaises(TraitError,setattr,a,'b','bad string')
211 self.assertRaises(TraitError,setattr,a,'b','bad string')
212 self._notify1 = []
212 self._notify1 = []
213 a.on_trait_change(self.notify1,remove=True)
213 a.on_trait_change(self.notify1,remove=True)
214 a.a = 20
214 a.a = 20
215 a.b = 20.0
215 a.b = 20.0
216 self.assertEqual(len(self._notify1),0)
216 self.assertEqual(len(self._notify1),0)
217
217
218 def test_notify_one(self):
218 def test_notify_one(self):
219
219
220 class A(HasTraits):
220 class A(HasTraits):
221 a = Int
221 a = Int
222 b = Float
222 b = Float
223
223
224 a = A()
224 a = A()
225 a.on_trait_change(self.notify1, 'a')
225 a.on_trait_change(self.notify1, 'a')
226 a.a = 0
226 a.a = 0
227 self.assertEqual(len(self._notify1),0)
227 self.assertEqual(len(self._notify1),0)
228 a.a = 10
228 a.a = 10
229 self.assertTrue(('a',0,10) in self._notify1)
229 self.assertTrue(('a',0,10) in self._notify1)
230 self.assertRaises(TraitError,setattr,a,'a','bad string')
230 self.assertRaises(TraitError,setattr,a,'a','bad string')
231
231
232 def test_subclass(self):
232 def test_subclass(self):
233
233
234 class A(HasTraits):
234 class A(HasTraits):
235 a = Int
235 a = Int
236
236
237 class B(A):
237 class B(A):
238 b = Float
238 b = Float
239
239
240 b = B()
240 b = B()
241 self.assertEqual(b.a,0)
241 self.assertEqual(b.a,0)
242 self.assertEqual(b.b,0.0)
242 self.assertEqual(b.b,0.0)
243 b.a = 100
243 b.a = 100
244 b.b = 100.0
244 b.b = 100.0
245 self.assertEqual(b.a,100)
245 self.assertEqual(b.a,100)
246 self.assertEqual(b.b,100.0)
246 self.assertEqual(b.b,100.0)
247
247
248 def test_notify_subclass(self):
248 def test_notify_subclass(self):
249
249
250 class A(HasTraits):
250 class A(HasTraits):
251 a = Int
251 a = Int
252
252
253 class B(A):
253 class B(A):
254 b = Float
254 b = Float
255
255
256 b = B()
256 b = B()
257 b.on_trait_change(self.notify1, 'a')
257 b.on_trait_change(self.notify1, 'a')
258 b.on_trait_change(self.notify2, 'b')
258 b.on_trait_change(self.notify2, 'b')
259 b.a = 0
259 b.a = 0
260 b.b = 0.0
260 b.b = 0.0
261 self.assertEqual(len(self._notify1),0)
261 self.assertEqual(len(self._notify1),0)
262 self.assertEqual(len(self._notify2),0)
262 self.assertEqual(len(self._notify2),0)
263 b.a = 10
263 b.a = 10
264 b.b = 10.0
264 b.b = 10.0
265 self.assertTrue(('a',0,10) in self._notify1)
265 self.assertTrue(('a',0,10) in self._notify1)
266 self.assertTrue(('b',0.0,10.0) in self._notify2)
266 self.assertTrue(('b',0.0,10.0) in self._notify2)
267
267
268 def test_static_notify(self):
268 def test_static_notify(self):
269
269
270 class A(HasTraits):
270 class A(HasTraits):
271 a = Int
271 a = Int
272 _notify1 = []
272 _notify1 = []
273 def _a_changed(self, name, old, new):
273 def _a_changed(self, name, old, new):
274 self._notify1.append((name, old, new))
274 self._notify1.append((name, old, new))
275
275
276 a = A()
276 a = A()
277 a.a = 0
277 a.a = 0
278 # This is broken!!!
278 # This is broken!!!
279 self.assertEqual(len(a._notify1),0)
279 self.assertEqual(len(a._notify1),0)
280 a.a = 10
280 a.a = 10
281 self.assertTrue(('a',0,10) in a._notify1)
281 self.assertTrue(('a',0,10) in a._notify1)
282
282
283 class B(A):
283 class B(A):
284 b = Float
284 b = Float
285 _notify2 = []
285 _notify2 = []
286 def _b_changed(self, name, old, new):
286 def _b_changed(self, name, old, new):
287 self._notify2.append((name, old, new))
287 self._notify2.append((name, old, new))
288
288
289 b = B()
289 b = B()
290 b.a = 10
290 b.a = 10
291 b.b = 10.0
291 b.b = 10.0
292 self.assertTrue(('a',0,10) in b._notify1)
292 self.assertTrue(('a',0,10) in b._notify1)
293 self.assertTrue(('b',0.0,10.0) in b._notify2)
293 self.assertTrue(('b',0.0,10.0) in b._notify2)
294
294
295 def test_notify_args(self):
295 def test_notify_args(self):
296
296
297 def callback0():
297 def callback0():
298 self.cb = ()
298 self.cb = ()
299 def callback1(name):
299 def callback1(name):
300 self.cb = (name,)
300 self.cb = (name,)
301 def callback2(name, new):
301 def callback2(name, new):
302 self.cb = (name, new)
302 self.cb = (name, new)
303 def callback3(name, old, new):
303 def callback3(name, old, new):
304 self.cb = (name, old, new)
304 self.cb = (name, old, new)
305
305
306 class A(HasTraits):
306 class A(HasTraits):
307 a = Int
307 a = Int
308
308
309 a = A()
309 a = A()
310 a.on_trait_change(callback0, 'a')
310 a.on_trait_change(callback0, 'a')
311 a.a = 10
311 a.a = 10
312 self.assertEqual(self.cb,())
312 self.assertEqual(self.cb,())
313 a.on_trait_change(callback0, 'a', remove=True)
313 a.on_trait_change(callback0, 'a', remove=True)
314
314
315 a.on_trait_change(callback1, 'a')
315 a.on_trait_change(callback1, 'a')
316 a.a = 100
316 a.a = 100
317 self.assertEqual(self.cb,('a',))
317 self.assertEqual(self.cb,('a',))
318 a.on_trait_change(callback1, 'a', remove=True)
318 a.on_trait_change(callback1, 'a', remove=True)
319
319
320 a.on_trait_change(callback2, 'a')
320 a.on_trait_change(callback2, 'a')
321 a.a = 1000
321 a.a = 1000
322 self.assertEqual(self.cb,('a',1000))
322 self.assertEqual(self.cb,('a',1000))
323 a.on_trait_change(callback2, 'a', remove=True)
323 a.on_trait_change(callback2, 'a', remove=True)
324
324
325 a.on_trait_change(callback3, 'a')
325 a.on_trait_change(callback3, 'a')
326 a.a = 10000
326 a.a = 10000
327 self.assertEqual(self.cb,('a',1000,10000))
327 self.assertEqual(self.cb,('a',1000,10000))
328 a.on_trait_change(callback3, 'a', remove=True)
328 a.on_trait_change(callback3, 'a', remove=True)
329
329
330 self.assertEqual(len(a._trait_notifiers['a']),0)
330 self.assertEqual(len(a._trait_notifiers['a']),0)
331
331
332 def test_notify_only_once(self):
332 def test_notify_only_once(self):
333
333
334 class A(HasTraits):
334 class A(HasTraits):
335 listen_to = ['a']
335 listen_to = ['a']
336
336
337 a = Int(0)
337 a = Int(0)
338 b = 0
338 b = 0
339
339
340 def __init__(self, **kwargs):
340 def __init__(self, **kwargs):
341 super(A, self).__init__(**kwargs)
341 super(A, self).__init__(**kwargs)
342 self.on_trait_change(self.listener1, ['a'])
342 self.on_trait_change(self.listener1, ['a'])
343
343
344 def listener1(self, name, old, new):
344 def listener1(self, name, old, new):
345 self.b += 1
345 self.b += 1
346
346
347 class B(A):
347 class B(A):
348
348
349 c = 0
349 c = 0
350 d = 0
350 d = 0
351
351
352 def __init__(self, **kwargs):
352 def __init__(self, **kwargs):
353 super(B, self).__init__(**kwargs)
353 super(B, self).__init__(**kwargs)
354 self.on_trait_change(self.listener2)
354 self.on_trait_change(self.listener2)
355
355
356 def listener2(self, name, old, new):
356 def listener2(self, name, old, new):
357 self.c += 1
357 self.c += 1
358
358
359 def _a_changed(self, name, old, new):
359 def _a_changed(self, name, old, new):
360 self.d += 1
360 self.d += 1
361
361
362 b = B()
362 b = B()
363 b.a += 1
363 b.a += 1
364 self.assertEqual(b.b, b.c)
364 self.assertEqual(b.b, b.c)
365 self.assertEqual(b.b, b.d)
365 self.assertEqual(b.b, b.d)
366 b.a += 1
366 b.a += 1
367 self.assertEqual(b.b, b.c)
367 self.assertEqual(b.b, b.c)
368 self.assertEqual(b.b, b.d)
368 self.assertEqual(b.b, b.d)
369
369
370
370
371 class TestHasTraits(TestCase):
371 class TestHasTraits(TestCase):
372
372
373 def test_trait_names(self):
373 def test_trait_names(self):
374 class A(HasTraits):
374 class A(HasTraits):
375 i = Int
375 i = Int
376 f = Float
376 f = Float
377 a = A()
377 a = A()
378 self.assertEqual(sorted(a.trait_names()),['f','i'])
378 self.assertEqual(sorted(a.trait_names()),['f','i'])
379 self.assertEqual(sorted(A.class_trait_names()),['f','i'])
379 self.assertEqual(sorted(A.class_trait_names()),['f','i'])
380
380
381 def test_trait_metadata(self):
381 def test_trait_metadata(self):
382 class A(HasTraits):
382 class A(HasTraits):
383 i = Int(config_key='MY_VALUE')
383 i = Int(config_key='MY_VALUE')
384 a = A()
384 a = A()
385 self.assertEqual(a.trait_metadata('i','config_key'), 'MY_VALUE')
385 self.assertEqual(a.trait_metadata('i','config_key'), 'MY_VALUE')
386
386
387 def test_trait_metadata_default(self):
387 def test_trait_metadata_default(self):
388 class A(HasTraits):
388 class A(HasTraits):
389 i = Int()
389 i = Int()
390 a = A()
390 a = A()
391 self.assertEqual(a.trait_metadata('i', 'config_key'), None)
391 self.assertEqual(a.trait_metadata('i', 'config_key'), None)
392 self.assertEqual(a.trait_metadata('i', 'config_key', 'default'), 'default')
392 self.assertEqual(a.trait_metadata('i', 'config_key', 'default'), 'default')
393
393
394 def test_traits(self):
394 def test_traits(self):
395 class A(HasTraits):
395 class A(HasTraits):
396 i = Int
396 i = Int
397 f = Float
397 f = Float
398 a = A()
398 a = A()
399 self.assertEqual(a.traits(), dict(i=A.i, f=A.f))
399 self.assertEqual(a.traits(), dict(i=A.i, f=A.f))
400 self.assertEqual(A.class_traits(), dict(i=A.i, f=A.f))
400 self.assertEqual(A.class_traits(), dict(i=A.i, f=A.f))
401
401
402 def test_traits_metadata(self):
402 def test_traits_metadata(self):
403 class A(HasTraits):
403 class A(HasTraits):
404 i = Int(config_key='VALUE1', other_thing='VALUE2')
404 i = Int(config_key='VALUE1', other_thing='VALUE2')
405 f = Float(config_key='VALUE3', other_thing='VALUE2')
405 f = Float(config_key='VALUE3', other_thing='VALUE2')
406 j = Int(0)
406 j = Int(0)
407 a = A()
407 a = A()
408 self.assertEqual(a.traits(), dict(i=A.i, f=A.f, j=A.j))
408 self.assertEqual(a.traits(), dict(i=A.i, f=A.f, j=A.j))
409 traits = a.traits(config_key='VALUE1', other_thing='VALUE2')
409 traits = a.traits(config_key='VALUE1', other_thing='VALUE2')
410 self.assertEqual(traits, dict(i=A.i))
410 self.assertEqual(traits, dict(i=A.i))
411
411
412 # This passes, but it shouldn't because I am replicating a bug in
412 # This passes, but it shouldn't because I am replicating a bug in
413 # traits.
413 # traits.
414 traits = a.traits(config_key=lambda v: True)
414 traits = a.traits(config_key=lambda v: True)
415 self.assertEqual(traits, dict(i=A.i, f=A.f, j=A.j))
415 self.assertEqual(traits, dict(i=A.i, f=A.f, j=A.j))
416
416
417 def test_init(self):
417 def test_init(self):
418 class A(HasTraits):
418 class A(HasTraits):
419 i = Int()
419 i = Int()
420 x = Float()
420 x = Float()
421 a = A(i=1, x=10.0)
421 a = A(i=1, x=10.0)
422 self.assertEqual(a.i, 1)
422 self.assertEqual(a.i, 1)
423 self.assertEqual(a.x, 10.0)
423 self.assertEqual(a.x, 10.0)
424
424
425 def test_positional_args(self):
425 def test_positional_args(self):
426 class A(HasTraits):
426 class A(HasTraits):
427 i = Int(0)
427 i = Int(0)
428 def __init__(self, i):
428 def __init__(self, i):
429 super(A, self).__init__()
429 super(A, self).__init__()
430 self.i = i
430 self.i = i
431
431
432 a = A(5)
432 a = A(5)
433 self.assertEqual(a.i, 5)
433 self.assertEqual(a.i, 5)
434 # should raise TypeError if no positional arg given
434 # should raise TypeError if no positional arg given
435 self.assertRaises(TypeError, A)
435 self.assertRaises(TypeError, A)
436
436
437 #-----------------------------------------------------------------------------
437 #-----------------------------------------------------------------------------
438 # Tests for specific trait types
438 # Tests for specific trait types
439 #-----------------------------------------------------------------------------
439 #-----------------------------------------------------------------------------
440
440
441
441
442 class TestType(TestCase):
442 class TestType(TestCase):
443
443
444 def test_default(self):
444 def test_default(self):
445
445
446 class B(object): pass
446 class B(object): pass
447 class A(HasTraits):
447 class A(HasTraits):
448 klass = Type(allow_none=True)
448 klass = Type(allow_none=True)
449
449
450 a = A()
450 a = A()
451 self.assertEqual(a.klass, None)
451 self.assertEqual(a.klass, None)
452
452
453 a.klass = B
453 a.klass = B
454 self.assertEqual(a.klass, B)
454 self.assertEqual(a.klass, B)
455 self.assertRaises(TraitError, setattr, a, 'klass', 10)
455 self.assertRaises(TraitError, setattr, a, 'klass', 10)
456
456
457 def test_value(self):
457 def test_value(self):
458
458
459 class B(object): pass
459 class B(object): pass
460 class C(object): pass
460 class C(object): pass
461 class A(HasTraits):
461 class A(HasTraits):
462 klass = Type(B)
462 klass = Type(B)
463
463
464 a = A()
464 a = A()
465 self.assertEqual(a.klass, B)
465 self.assertEqual(a.klass, B)
466 self.assertRaises(TraitError, setattr, a, 'klass', C)
466 self.assertRaises(TraitError, setattr, a, 'klass', C)
467 self.assertRaises(TraitError, setattr, a, 'klass', object)
467 self.assertRaises(TraitError, setattr, a, 'klass', object)
468 a.klass = B
468 a.klass = B
469
469
470 def test_allow_none(self):
470 def test_allow_none(self):
471
471
472 class B(object): pass
472 class B(object): pass
473 class C(B): pass
473 class C(B): pass
474 class A(HasTraits):
474 class A(HasTraits):
475 klass = Type(B)
475 klass = Type(B)
476
476
477 a = A()
477 a = A()
478 self.assertEqual(a.klass, B)
478 self.assertEqual(a.klass, B)
479 self.assertRaises(TraitError, setattr, a, 'klass', None)
479 self.assertRaises(TraitError, setattr, a, 'klass', None)
480 a.klass = C
480 a.klass = C
481 self.assertEqual(a.klass, C)
481 self.assertEqual(a.klass, C)
482
482
483 def test_validate_klass(self):
483 def test_validate_klass(self):
484
484
485 class A(HasTraits):
485 class A(HasTraits):
486 klass = Type('no strings allowed')
486 klass = Type('no strings allowed')
487
487
488 self.assertRaises(ImportError, A)
488 self.assertRaises(ImportError, A)
489
489
490 class A(HasTraits):
490 class A(HasTraits):
491 klass = Type('rub.adub.Duck')
491 klass = Type('rub.adub.Duck')
492
492
493 self.assertRaises(ImportError, A)
493 self.assertRaises(ImportError, A)
494
494
495 def test_validate_default(self):
495 def test_validate_default(self):
496
496
497 class B(object): pass
497 class B(object): pass
498 class A(HasTraits):
498 class A(HasTraits):
499 klass = Type('bad default', B)
499 klass = Type('bad default', B)
500
500
501 self.assertRaises(ImportError, A)
501 self.assertRaises(ImportError, A)
502
502
503 class C(HasTraits):
503 class C(HasTraits):
504 klass = Type(None, B)
504 klass = Type(None, B)
505
505
506 self.assertRaises(TraitError, C)
506 self.assertRaises(TraitError, C)
507
507
508 def test_str_klass(self):
508 def test_str_klass(self):
509
509
510 class A(HasTraits):
510 class A(HasTraits):
511 klass = Type('IPython.utils.ipstruct.Struct')
511 klass = Type('IPython.utils.ipstruct.Struct')
512
512
513 from IPython.utils.ipstruct import Struct
513 from IPython.utils.ipstruct import Struct
514 a = A()
514 a = A()
515 a.klass = Struct
515 a.klass = Struct
516 self.assertEqual(a.klass, Struct)
516 self.assertEqual(a.klass, Struct)
517
517
518 self.assertRaises(TraitError, setattr, a, 'klass', 10)
518 self.assertRaises(TraitError, setattr, a, 'klass', 10)
519
519
520 def test_set_str_klass(self):
520 def test_set_str_klass(self):
521
521
522 class A(HasTraits):
522 class A(HasTraits):
523 klass = Type()
523 klass = Type()
524
524
525 a = A(klass='IPython.utils.ipstruct.Struct')
525 a = A(klass='IPython.utils.ipstruct.Struct')
526 from IPython.utils.ipstruct import Struct
526 from IPython.utils.ipstruct import Struct
527 self.assertEqual(a.klass, Struct)
527 self.assertEqual(a.klass, Struct)
528
528
529 class TestInstance(TestCase):
529 class TestInstance(TestCase):
530
530
531 def test_basic(self):
531 def test_basic(self):
532 class Foo(object): pass
532 class Foo(object): pass
533 class Bar(Foo): pass
533 class Bar(Foo): pass
534 class Bah(object): pass
534 class Bah(object): pass
535
535
536 class A(HasTraits):
536 class A(HasTraits):
537 inst = Instance(Foo, allow_none=True)
537 inst = Instance(Foo, allow_none=True)
538
538
539 a = A()
539 a = A()
540 self.assertTrue(a.inst is None)
540 self.assertTrue(a.inst is None)
541 a.inst = Foo()
541 a.inst = Foo()
542 self.assertTrue(isinstance(a.inst, Foo))
542 self.assertTrue(isinstance(a.inst, Foo))
543 a.inst = Bar()
543 a.inst = Bar()
544 self.assertTrue(isinstance(a.inst, Foo))
544 self.assertTrue(isinstance(a.inst, Foo))
545 self.assertRaises(TraitError, setattr, a, 'inst', Foo)
545 self.assertRaises(TraitError, setattr, a, 'inst', Foo)
546 self.assertRaises(TraitError, setattr, a, 'inst', Bar)
546 self.assertRaises(TraitError, setattr, a, 'inst', Bar)
547 self.assertRaises(TraitError, setattr, a, 'inst', Bah())
547 self.assertRaises(TraitError, setattr, a, 'inst', Bah())
548
548
549 def test_default_klass(self):
549 def test_default_klass(self):
550 class Foo(object): pass
550 class Foo(object): pass
551 class Bar(Foo): pass
551 class Bar(Foo): pass
552 class Bah(object): pass
552 class Bah(object): pass
553
553
554 class FooInstance(Instance):
554 class FooInstance(Instance):
555 klass = Foo
555 klass = Foo
556
556
557 class A(HasTraits):
557 class A(HasTraits):
558 inst = FooInstance(allow_none=True)
558 inst = FooInstance(allow_none=True)
559
559
560 a = A()
560 a = A()
561 self.assertTrue(a.inst is None)
561 self.assertTrue(a.inst is None)
562 a.inst = Foo()
562 a.inst = Foo()
563 self.assertTrue(isinstance(a.inst, Foo))
563 self.assertTrue(isinstance(a.inst, Foo))
564 a.inst = Bar()
564 a.inst = Bar()
565 self.assertTrue(isinstance(a.inst, Foo))
565 self.assertTrue(isinstance(a.inst, Foo))
566 self.assertRaises(TraitError, setattr, a, 'inst', Foo)
566 self.assertRaises(TraitError, setattr, a, 'inst', Foo)
567 self.assertRaises(TraitError, setattr, a, 'inst', Bar)
567 self.assertRaises(TraitError, setattr, a, 'inst', Bar)
568 self.assertRaises(TraitError, setattr, a, 'inst', Bah())
568 self.assertRaises(TraitError, setattr, a, 'inst', Bah())
569
569
570 def test_unique_default_value(self):
570 def test_unique_default_value(self):
571 class Foo(object): pass
571 class Foo(object): pass
572 class A(HasTraits):
572 class A(HasTraits):
573 inst = Instance(Foo,(),{})
573 inst = Instance(Foo,(),{})
574
574
575 a = A()
575 a = A()
576 b = A()
576 b = A()
577 self.assertTrue(a.inst is not b.inst)
577 self.assertTrue(a.inst is not b.inst)
578
578
579 def test_args_kw(self):
579 def test_args_kw(self):
580 class Foo(object):
580 class Foo(object):
581 def __init__(self, c): self.c = c
581 def __init__(self, c): self.c = c
582 class Bar(object): pass
582 class Bar(object): pass
583 class Bah(object):
583 class Bah(object):
584 def __init__(self, c, d):
584 def __init__(self, c, d):
585 self.c = c; self.d = d
585 self.c = c; self.d = d
586
586
587 class A(HasTraits):
587 class A(HasTraits):
588 inst = Instance(Foo, (10,))
588 inst = Instance(Foo, (10,))
589 a = A()
589 a = A()
590 self.assertEqual(a.inst.c, 10)
590 self.assertEqual(a.inst.c, 10)
591
591
592 class B(HasTraits):
592 class B(HasTraits):
593 inst = Instance(Bah, args=(10,), kw=dict(d=20))
593 inst = Instance(Bah, args=(10,), kw=dict(d=20))
594 b = B()
594 b = B()
595 self.assertEqual(b.inst.c, 10)
595 self.assertEqual(b.inst.c, 10)
596 self.assertEqual(b.inst.d, 20)
596 self.assertEqual(b.inst.d, 20)
597
597
598 class C(HasTraits):
598 class C(HasTraits):
599 inst = Instance(Foo, allow_none=True)
599 inst = Instance(Foo, allow_none=True)
600 c = C()
600 c = C()
601 self.assertTrue(c.inst is None)
601 self.assertTrue(c.inst is None)
602
602
603 def test_bad_default(self):
603 def test_bad_default(self):
604 class Foo(object): pass
604 class Foo(object): pass
605
605
606 class A(HasTraits):
606 class A(HasTraits):
607 inst = Instance(Foo)
607 inst = Instance(Foo)
608
608
609 self.assertRaises(TraitError, A)
609 self.assertRaises(TraitError, A)
610
610
611 def test_instance(self):
611 def test_instance(self):
612 class Foo(object): pass
612 class Foo(object): pass
613
613
614 def inner():
614 def inner():
615 class A(HasTraits):
615 class A(HasTraits):
616 inst = Instance(Foo())
616 inst = Instance(Foo())
617
617
618 self.assertRaises(TraitError, inner)
618 self.assertRaises(TraitError, inner)
619
619
620
620
621 class TestThis(TestCase):
621 class TestThis(TestCase):
622
622
623 def test_this_class(self):
623 def test_this_class(self):
624 class Foo(HasTraits):
624 class Foo(HasTraits):
625 this = This
625 this = This
626
626
627 f = Foo()
627 f = Foo()
628 self.assertEqual(f.this, None)
628 self.assertEqual(f.this, None)
629 g = Foo()
629 g = Foo()
630 f.this = g
630 f.this = g
631 self.assertEqual(f.this, g)
631 self.assertEqual(f.this, g)
632 self.assertRaises(TraitError, setattr, f, 'this', 10)
632 self.assertRaises(TraitError, setattr, f, 'this', 10)
633
633
634 def test_this_inst(self):
634 def test_this_inst(self):
635 class Foo(HasTraits):
635 class Foo(HasTraits):
636 this = This()
636 this = This()
637
637
638 f = Foo()
638 f = Foo()
639 f.this = Foo()
639 f.this = Foo()
640 self.assertTrue(isinstance(f.this, Foo))
640 self.assertTrue(isinstance(f.this, Foo))
641
641
642 def test_subclass(self):
642 def test_subclass(self):
643 class Foo(HasTraits):
643 class Foo(HasTraits):
644 t = This()
644 t = This()
645 class Bar(Foo):
645 class Bar(Foo):
646 pass
646 pass
647 f = Foo()
647 f = Foo()
648 b = Bar()
648 b = Bar()
649 f.t = b
649 f.t = b
650 b.t = f
650 b.t = f
651 self.assertEqual(f.t, b)
651 self.assertEqual(f.t, b)
652 self.assertEqual(b.t, f)
652 self.assertEqual(b.t, f)
653
653
654 def test_subclass_override(self):
654 def test_subclass_override(self):
655 class Foo(HasTraits):
655 class Foo(HasTraits):
656 t = This()
656 t = This()
657 class Bar(Foo):
657 class Bar(Foo):
658 t = This()
658 t = This()
659 f = Foo()
659 f = Foo()
660 b = Bar()
660 b = Bar()
661 f.t = b
661 f.t = b
662 self.assertEqual(f.t, b)
662 self.assertEqual(f.t, b)
663 self.assertRaises(TraitError, setattr, b, 't', f)
663 self.assertRaises(TraitError, setattr, b, 't', f)
664
664
665 def test_this_in_container(self):
665 def test_this_in_container(self):
666
666
667 class Tree(HasTraits):
667 class Tree(HasTraits):
668 value = Unicode()
668 value = Unicode()
669 leaves = List(This())
669 leaves = List(This())
670
670
671 tree = Tree(
671 tree = Tree(
672 value='foo',
672 value='foo',
673 leaves=[Tree('bar'), Tree('buzz')]
673 leaves=[Tree('bar'), Tree('buzz')]
674 )
674 )
675
675
676 with self.assertRaises(TraitError):
676 with self.assertRaises(TraitError):
677 tree.leaves = [1, 2]
677 tree.leaves = [1, 2]
678
678
679 class TraitTestBase(TestCase):
679 class TraitTestBase(TestCase):
680 """A best testing class for basic trait types."""
680 """A best testing class for basic trait types."""
681
681
682 def assign(self, value):
682 def assign(self, value):
683 self.obj.value = value
683 self.obj.value = value
684
684
685 def coerce(self, value):
685 def coerce(self, value):
686 return value
686 return value
687
687
688 def test_good_values(self):
688 def test_good_values(self):
689 if hasattr(self, '_good_values'):
689 if hasattr(self, '_good_values'):
690 for value in self._good_values:
690 for value in self._good_values:
691 self.assign(value)
691 self.assign(value)
692 self.assertEqual(self.obj.value, self.coerce(value))
692 self.assertEqual(self.obj.value, self.coerce(value))
693
693
694 def test_bad_values(self):
694 def test_bad_values(self):
695 if hasattr(self, '_bad_values'):
695 if hasattr(self, '_bad_values'):
696 for value in self._bad_values:
696 for value in self._bad_values:
697 try:
697 try:
698 self.assertRaises(TraitError, self.assign, value)
698 self.assertRaises(TraitError, self.assign, value)
699 except AssertionError:
699 except AssertionError:
700 assert False, value
700 assert False, value
701
701
702 def test_default_value(self):
702 def test_default_value(self):
703 if hasattr(self, '_default_value'):
703 if hasattr(self, '_default_value'):
704 self.assertEqual(self._default_value, self.obj.value)
704 self.assertEqual(self._default_value, self.obj.value)
705
705
706 def test_allow_none(self):
706 def test_allow_none(self):
707 if (hasattr(self, '_bad_values') and hasattr(self, '_good_values') and
707 if (hasattr(self, '_bad_values') and hasattr(self, '_good_values') and
708 None in self._bad_values):
708 None in self._bad_values):
709 trait=self.obj.traits()['value']
709 trait=self.obj.traits()['value']
710 try:
710 try:
711 trait.allow_none = True
711 trait.allow_none = True
712 self._bad_values.remove(None)
712 self._bad_values.remove(None)
713 #skip coerce. Allow None casts None to None.
713 #skip coerce. Allow None casts None to None.
714 self.assign(None)
714 self.assign(None)
715 self.assertEqual(self.obj.value,None)
715 self.assertEqual(self.obj.value,None)
716 self.test_good_values()
716 self.test_good_values()
717 self.test_bad_values()
717 self.test_bad_values()
718 finally:
718 finally:
719 #tear down
719 #tear down
720 trait.allow_none = False
720 trait.allow_none = False
721 self._bad_values.append(None)
721 self._bad_values.append(None)
722
722
723 def tearDown(self):
723 def tearDown(self):
724 # restore default value after tests, if set
724 # restore default value after tests, if set
725 if hasattr(self, '_default_value'):
725 if hasattr(self, '_default_value'):
726 self.obj.value = self._default_value
726 self.obj.value = self._default_value
727
727
728
728
729 class AnyTrait(HasTraits):
729 class AnyTrait(HasTraits):
730
730
731 value = Any
731 value = Any
732
732
733 class AnyTraitTest(TraitTestBase):
733 class AnyTraitTest(TraitTestBase):
734
734
735 obj = AnyTrait()
735 obj = AnyTrait()
736
736
737 _default_value = None
737 _default_value = None
738 _good_values = [10.0, 'ten', u'ten', [10], {'ten': 10},(10,), None, 1j]
738 _good_values = [10.0, 'ten', u'ten', [10], {'ten': 10},(10,), None, 1j]
739 _bad_values = []
739 _bad_values = []
740
740
741 class UnionTrait(HasTraits):
741 class UnionTrait(HasTraits):
742
742
743 value = Union([Type(), Bool()])
743 value = Union([Type(), Bool()])
744
744
745 class UnionTraitTest(TraitTestBase):
745 class UnionTraitTest(TraitTestBase):
746
746
747 obj = UnionTrait(value='IPython.utils.ipstruct.Struct')
747 obj = UnionTrait(value='IPython.utils.ipstruct.Struct')
748 _good_values = [int, float, True]
748 _good_values = [int, float, True]
749 _bad_values = [[], (0,), 1j]
749 _bad_values = [[], (0,), 1j]
750
750
751 class OrTrait(HasTraits):
751 class OrTrait(HasTraits):
752
752
753 value = Bool() | Unicode()
753 value = Bool() | Unicode()
754
754
755 class OrTraitTest(TraitTestBase):
755 class OrTraitTest(TraitTestBase):
756
756
757 obj = OrTrait()
757 obj = OrTrait()
758 _good_values = [True, False, 'ten']
758 _good_values = [True, False, 'ten']
759 _bad_values = [[], (0,), 1j]
759 _bad_values = [[], (0,), 1j]
760
760
761 class IntTrait(HasTraits):
761 class IntTrait(HasTraits):
762
762
763 value = Int(99)
763 value = Int(99)
764
764
765 class TestInt(TraitTestBase):
765 class TestInt(TraitTestBase):
766
766
767 obj = IntTrait()
767 obj = IntTrait()
768 _default_value = 99
768 _default_value = 99
769 _good_values = [10, -10]
769 _good_values = [10, -10]
770 _bad_values = ['ten', u'ten', [10], {'ten': 10},(10,), None, 1j,
770 _bad_values = ['ten', u'ten', [10], {'ten': 10},(10,), None, 1j,
771 10.1, -10.1, '10L', '-10L', '10.1', '-10.1', u'10L',
771 10.1, -10.1, '10L', '-10L', '10.1', '-10.1', u'10L',
772 u'-10L', u'10.1', u'-10.1', '10', '-10', u'10', u'-10']
772 u'-10L', u'10.1', u'-10.1', '10', '-10', u'10', u'-10']
773 if not py3compat.PY3:
773 if not py3compat.PY3:
774 _bad_values.extend([long(10), long(-10), 10*sys.maxint, -10*sys.maxint])
774 _bad_values.extend([long(10), long(-10), 10*sys.maxint, -10*sys.maxint])
775
775
776
776
777 class LongTrait(HasTraits):
777 class LongTrait(HasTraits):
778
778
779 value = Long(99 if py3compat.PY3 else long(99))
779 value = Long(99 if py3compat.PY3 else long(99))
780
780
781 class TestLong(TraitTestBase):
781 class TestLong(TraitTestBase):
782
782
783 obj = LongTrait()
783 obj = LongTrait()
784
784
785 _default_value = 99 if py3compat.PY3 else long(99)
785 _default_value = 99 if py3compat.PY3 else long(99)
786 _good_values = [10, -10]
786 _good_values = [10, -10]
787 _bad_values = ['ten', u'ten', [10], {'ten': 10},(10,),
787 _bad_values = ['ten', u'ten', [10], {'ten': 10},(10,),
788 None, 1j, 10.1, -10.1, '10', '-10', '10L', '-10L', '10.1',
788 None, 1j, 10.1, -10.1, '10', '-10', '10L', '-10L', '10.1',
789 '-10.1', u'10', u'-10', u'10L', u'-10L', u'10.1',
789 '-10.1', u'10', u'-10', u'10L', u'-10L', u'10.1',
790 u'-10.1']
790 u'-10.1']
791 if not py3compat.PY3:
791 if not py3compat.PY3:
792 # maxint undefined on py3, because int == long
792 # maxint undefined on py3, because int == long
793 _good_values.extend([long(10), long(-10), 10*sys.maxint, -10*sys.maxint])
793 _good_values.extend([long(10), long(-10), 10*sys.maxint, -10*sys.maxint])
794 _bad_values.extend([[long(10)], (long(10),)])
794 _bad_values.extend([[long(10)], (long(10),)])
795
795
796 @skipif(py3compat.PY3, "not relevant on py3")
796 @skipif(py3compat.PY3, "not relevant on py3")
797 def test_cast_small(self):
797 def test_cast_small(self):
798 """Long casts ints to long"""
798 """Long casts ints to long"""
799 self.obj.value = 10
799 self.obj.value = 10
800 self.assertEqual(type(self.obj.value), long)
800 self.assertEqual(type(self.obj.value), long)
801
801
802
802
803 class IntegerTrait(HasTraits):
803 class IntegerTrait(HasTraits):
804 value = Integer(1)
804 value = Integer(1)
805
805
806 class TestInteger(TestLong):
806 class TestInteger(TestLong):
807 obj = IntegerTrait()
807 obj = IntegerTrait()
808 _default_value = 1
808 _default_value = 1
809
809
810 def coerce(self, n):
810 def coerce(self, n):
811 return int(n)
811 return int(n)
812
812
813 @skipif(py3compat.PY3, "not relevant on py3")
813 @skipif(py3compat.PY3, "not relevant on py3")
814 def test_cast_small(self):
814 def test_cast_small(self):
815 """Integer casts small longs to int"""
815 """Integer casts small longs to int"""
816 if py3compat.PY3:
816 if py3compat.PY3:
817 raise SkipTest("not relevant on py3")
817 raise SkipTest("not relevant on py3")
818
818
819 self.obj.value = long(100)
819 self.obj.value = long(100)
820 self.assertEqual(type(self.obj.value), int)
820 self.assertEqual(type(self.obj.value), int)
821
821
822
822
823 class FloatTrait(HasTraits):
823 class FloatTrait(HasTraits):
824
824
825 value = Float(99.0)
825 value = Float(99.0)
826
826
827 class TestFloat(TraitTestBase):
827 class TestFloat(TraitTestBase):
828
828
829 obj = FloatTrait()
829 obj = FloatTrait()
830
830
831 _default_value = 99.0
831 _default_value = 99.0
832 _good_values = [10, -10, 10.1, -10.1]
832 _good_values = [10, -10, 10.1, -10.1]
833 _bad_values = ['ten', u'ten', [10], {'ten': 10},(10,), None,
833 _bad_values = ['ten', u'ten', [10], {'ten': 10},(10,), None,
834 1j, '10', '-10', '10L', '-10L', '10.1', '-10.1', u'10',
834 1j, '10', '-10', '10L', '-10L', '10.1', '-10.1', u'10',
835 u'-10', u'10L', u'-10L', u'10.1', u'-10.1']
835 u'-10', u'10L', u'-10L', u'10.1', u'-10.1']
836 if not py3compat.PY3:
836 if not py3compat.PY3:
837 _bad_values.extend([long(10), long(-10)])
837 _bad_values.extend([long(10), long(-10)])
838
838
839
839
840 class ComplexTrait(HasTraits):
840 class ComplexTrait(HasTraits):
841
841
842 value = Complex(99.0-99.0j)
842 value = Complex(99.0-99.0j)
843
843
844 class TestComplex(TraitTestBase):
844 class TestComplex(TraitTestBase):
845
845
846 obj = ComplexTrait()
846 obj = ComplexTrait()
847
847
848 _default_value = 99.0-99.0j
848 _default_value = 99.0-99.0j
849 _good_values = [10, -10, 10.1, -10.1, 10j, 10+10j, 10-10j,
849 _good_values = [10, -10, 10.1, -10.1, 10j, 10+10j, 10-10j,
850 10.1j, 10.1+10.1j, 10.1-10.1j]
850 10.1j, 10.1+10.1j, 10.1-10.1j]
851 _bad_values = [u'10L', u'-10L', 'ten', [10], {'ten': 10},(10,), None]
851 _bad_values = [u'10L', u'-10L', 'ten', [10], {'ten': 10},(10,), None]
852 if not py3compat.PY3:
852 if not py3compat.PY3:
853 _bad_values.extend([long(10), long(-10)])
853 _bad_values.extend([long(10), long(-10)])
854
854
855
855
856 class BytesTrait(HasTraits):
856 class BytesTrait(HasTraits):
857
857
858 value = Bytes(b'string')
858 value = Bytes(b'string')
859
859
860 class TestBytes(TraitTestBase):
860 class TestBytes(TraitTestBase):
861
861
862 obj = BytesTrait()
862 obj = BytesTrait()
863
863
864 _default_value = b'string'
864 _default_value = b'string'
865 _good_values = [b'10', b'-10', b'10L',
865 _good_values = [b'10', b'-10', b'10L',
866 b'-10L', b'10.1', b'-10.1', b'string']
866 b'-10L', b'10.1', b'-10.1', b'string']
867 _bad_values = [10, -10, 10.1, -10.1, 1j, [10],
867 _bad_values = [10, -10, 10.1, -10.1, 1j, [10],
868 ['ten'],{'ten': 10},(10,), None, u'string']
868 ['ten'],{'ten': 10},(10,), None, u'string']
869 if not py3compat.PY3:
869 if not py3compat.PY3:
870 _bad_values.extend([long(10), long(-10)])
870 _bad_values.extend([long(10), long(-10)])
871
871
872
872
873 class UnicodeTrait(HasTraits):
873 class UnicodeTrait(HasTraits):
874
874
875 value = Unicode(u'unicode')
875 value = Unicode(u'unicode')
876
876
877 class TestUnicode(TraitTestBase):
877 class TestUnicode(TraitTestBase):
878
878
879 obj = UnicodeTrait()
879 obj = UnicodeTrait()
880
880
881 _default_value = u'unicode'
881 _default_value = u'unicode'
882 _good_values = ['10', '-10', '10L', '-10L', '10.1',
882 _good_values = ['10', '-10', '10L', '-10L', '10.1',
883 '-10.1', '', u'', 'string', u'string', u"€"]
883 '-10.1', '', u'', 'string', u'string', u"€"]
884 _bad_values = [10, -10, 10.1, -10.1, 1j,
884 _bad_values = [10, -10, 10.1, -10.1, 1j,
885 [10], ['ten'], [u'ten'], {'ten': 10},(10,), None]
885 [10], ['ten'], [u'ten'], {'ten': 10},(10,), None]
886 if not py3compat.PY3:
886 if not py3compat.PY3:
887 _bad_values.extend([long(10), long(-10)])
887 _bad_values.extend([long(10), long(-10)])
888
888
889
889
890 class ObjectNameTrait(HasTraits):
890 class ObjectNameTrait(HasTraits):
891 value = ObjectName("abc")
891 value = ObjectName("abc")
892
892
893 class TestObjectName(TraitTestBase):
893 class TestObjectName(TraitTestBase):
894 obj = ObjectNameTrait()
894 obj = ObjectNameTrait()
895
895
896 _default_value = "abc"
896 _default_value = "abc"
897 _good_values = ["a", "gh", "g9", "g_", "_G", u"a345_"]
897 _good_values = ["a", "gh", "g9", "g_", "_G", u"a345_"]
898 _bad_values = [1, "", u"€", "9g", "!", "#abc", "aj@", "a.b", "a()", "a[0]",
898 _bad_values = [1, "", u"€", "9g", "!", "#abc", "aj@", "a.b", "a()", "a[0]",
899 None, object(), object]
899 None, object(), object]
900 if sys.version_info[0] < 3:
900 if sys.version_info[0] < 3:
901 _bad_values.append(u"ΓΎ")
901 _bad_values.append(u"ΓΎ")
902 else:
902 else:
903 _good_values.append(u"ΓΎ") # ΓΎ=1 is valid in Python 3 (PEP 3131).
903 _good_values.append(u"ΓΎ") # ΓΎ=1 is valid in Python 3 (PEP 3131).
904
904
905
905
906 class DottedObjectNameTrait(HasTraits):
906 class DottedObjectNameTrait(HasTraits):
907 value = DottedObjectName("a.b")
907 value = DottedObjectName("a.b")
908
908
909 class TestDottedObjectName(TraitTestBase):
909 class TestDottedObjectName(TraitTestBase):
910 obj = DottedObjectNameTrait()
910 obj = DottedObjectNameTrait()
911
911
912 _default_value = "a.b"
912 _default_value = "a.b"
913 _good_values = ["A", "y.t", "y765.__repr__", "os.path.join", u"os.path.join"]
913 _good_values = ["A", "y.t", "y765.__repr__", "os.path.join", u"os.path.join"]
914 _bad_values = [1, u"abc.€", "_.@", ".", ".abc", "abc.", ".abc.", None]
914 _bad_values = [1, u"abc.€", "_.@", ".", ".abc", "abc.", ".abc.", None]
915 if sys.version_info[0] < 3:
915 if sys.version_info[0] < 3:
916 _bad_values.append(u"t.ΓΎ")
916 _bad_values.append(u"t.ΓΎ")
917 else:
917 else:
918 _good_values.append(u"t.ΓΎ")
918 _good_values.append(u"t.ΓΎ")
919
919
920
920
921 class TCPAddressTrait(HasTraits):
921 class TCPAddressTrait(HasTraits):
922 value = TCPAddress()
922 value = TCPAddress()
923
923
924 class TestTCPAddress(TraitTestBase):
924 class TestTCPAddress(TraitTestBase):
925
925
926 obj = TCPAddressTrait()
926 obj = TCPAddressTrait()
927
927
928 _default_value = ('127.0.0.1',0)
928 _default_value = ('127.0.0.1',0)
929 _good_values = [('localhost',0),('192.168.0.1',1000),('www.google.com',80)]
929 _good_values = [('localhost',0),('192.168.0.1',1000),('www.google.com',80)]
930 _bad_values = [(0,0),('localhost',10.0),('localhost',-1), None]
930 _bad_values = [(0,0),('localhost',10.0),('localhost',-1), None]
931
931
932 class ListTrait(HasTraits):
932 class ListTrait(HasTraits):
933
933
934 value = List(Int)
934 value = List(Int)
935
935
936 class TestList(TraitTestBase):
936 class TestList(TraitTestBase):
937
937
938 obj = ListTrait()
938 obj = ListTrait()
939
939
940 _default_value = []
940 _default_value = []
941 _good_values = [[], [1], list(range(10)), (1,2)]
941 _good_values = [[], [1], list(range(10)), (1,2)]
942 _bad_values = [10, [1,'a'], 'a']
942 _bad_values = [10, [1,'a'], 'a']
943
943
944 def coerce(self, value):
944 def coerce(self, value):
945 if value is not None:
945 if value is not None:
946 value = list(value)
946 value = list(value)
947 return value
947 return value
948
948
949 class Foo(object):
949 class Foo(object):
950 pass
950 pass
951
951
952 class NoneInstanceListTrait(HasTraits):
952 class NoneInstanceListTrait(HasTraits):
953
953
954 value = List(Instance(Foo))
954 value = List(Instance(Foo))
955
955
956 class TestNoneInstanceList(TraitTestBase):
956 class TestNoneInstanceList(TraitTestBase):
957
957
958 obj = NoneInstanceListTrait()
958 obj = NoneInstanceListTrait()
959
959
960 _default_value = []
960 _default_value = []
961 _good_values = [[Foo(), Foo()], []]
961 _good_values = [[Foo(), Foo()], []]
962 _bad_values = [[None], [Foo(), None]]
962 _bad_values = [[None], [Foo(), None]]
963
963
964
964
965 class InstanceListTrait(HasTraits):
965 class InstanceListTrait(HasTraits):
966
966
967 value = List(Instance(__name__+'.Foo'))
967 value = List(Instance(__name__+'.Foo'))
968
968
969 class TestInstanceList(TraitTestBase):
969 class TestInstanceList(TraitTestBase):
970
970
971 obj = InstanceListTrait()
971 obj = InstanceListTrait()
972
972
973 def test_klass(self):
973 def test_klass(self):
974 """Test that the instance klass is properly assigned."""
974 """Test that the instance klass is properly assigned."""
975 self.assertIs(self.obj.traits()['value']._trait.klass, Foo)
975 self.assertIs(self.obj.traits()['value']._trait.klass, Foo)
976
976
977 _default_value = []
977 _default_value = []
978 _good_values = [[Foo(), Foo()], []]
978 _good_values = [[Foo(), Foo()], []]
979 _bad_values = [['1', 2,], '1', [Foo], None]
979 _bad_values = [['1', 2,], '1', [Foo], None]
980
980
981 class UnionListTrait(HasTraits):
981 class UnionListTrait(HasTraits):
982
982
983 value = List(Int() | Bool())
983 value = List(Int() | Bool())
984
984
985 class TestUnionListTrait(HasTraits):
985 class TestUnionListTrait(HasTraits):
986
986
987 obj = UnionListTrait()
987 obj = UnionListTrait()
988
988
989 _default_value = []
989 _default_value = []
990 _good_values = [[True, 1], [False, True]]
990 _good_values = [[True, 1], [False, True]]
991 _bad_values = [[1, 'True'], False]
991 _bad_values = [[1, 'True'], False]
992
992
993
993
994 class LenListTrait(HasTraits):
994 class LenListTrait(HasTraits):
995
995
996 value = List(Int, [0], minlen=1, maxlen=2)
996 value = List(Int, [0], minlen=1, maxlen=2)
997
997
998 class TestLenList(TraitTestBase):
998 class TestLenList(TraitTestBase):
999
999
1000 obj = LenListTrait()
1000 obj = LenListTrait()
1001
1001
1002 _default_value = [0]
1002 _default_value = [0]
1003 _good_values = [[1], [1,2], (1,2)]
1003 _good_values = [[1], [1,2], (1,2)]
1004 _bad_values = [10, [1,'a'], 'a', [], list(range(3))]
1004 _bad_values = [10, [1,'a'], 'a', [], list(range(3))]
1005
1005
1006 def coerce(self, value):
1006 def coerce(self, value):
1007 if value is not None:
1007 if value is not None:
1008 value = list(value)
1008 value = list(value)
1009 return value
1009 return value
1010
1010
1011 class TupleTrait(HasTraits):
1011 class TupleTrait(HasTraits):
1012
1012
1013 value = Tuple(Int(allow_none=True))
1013 value = Tuple(Int(allow_none=True))
1014
1014
1015 class TestTupleTrait(TraitTestBase):
1015 class TestTupleTrait(TraitTestBase):
1016
1016
1017 obj = TupleTrait()
1017 obj = TupleTrait()
1018
1018
1019 _default_value = None
1019 _default_value = None
1020 _good_values = [(1,), None, (0,), [1], (None,)]
1020 _good_values = [(1,), None, (0,), [1], (None,)]
1021 _bad_values = [10, (1,2), ('a'), ()]
1021 _bad_values = [10, (1,2), ('a'), ()]
1022
1022
1023 def coerce(self, value):
1023 def coerce(self, value):
1024 if value is not None:
1024 if value is not None:
1025 value = tuple(value)
1025 value = tuple(value)
1026 return value
1026 return value
1027
1027
1028 def test_invalid_args(self):
1028 def test_invalid_args(self):
1029 self.assertRaises(TypeError, Tuple, 5)
1029 self.assertRaises(TypeError, Tuple, 5)
1030 self.assertRaises(TypeError, Tuple, default_value='hello')
1030 self.assertRaises(TypeError, Tuple, default_value='hello')
1031 t = Tuple(Int, CBytes, default_value=(1,5))
1031 t = Tuple(Int, CBytes, default_value=(1,5))
1032
1032
1033 class LooseTupleTrait(HasTraits):
1033 class LooseTupleTrait(HasTraits):
1034
1034
1035 value = Tuple((1,2,3))
1035 value = Tuple((1,2,3))
1036
1036
1037 class TestLooseTupleTrait(TraitTestBase):
1037 class TestLooseTupleTrait(TraitTestBase):
1038
1038
1039 obj = LooseTupleTrait()
1039 obj = LooseTupleTrait()
1040
1040
1041 _default_value = (1,2,3)
1041 _default_value = (1,2,3)
1042 _good_values = [(1,), None, [1], (0,), tuple(range(5)), tuple('hello'), ('a',5), ()]
1042 _good_values = [(1,), None, [1], (0,), tuple(range(5)), tuple('hello'), ('a',5), ()]
1043 _bad_values = [10, 'hello', {}]
1043 _bad_values = [10, 'hello', {}]
1044
1044
1045 def coerce(self, value):
1045 def coerce(self, value):
1046 if value is not None:
1046 if value is not None:
1047 value = tuple(value)
1047 value = tuple(value)
1048 return value
1048 return value
1049
1049
1050 def test_invalid_args(self):
1050 def test_invalid_args(self):
1051 self.assertRaises(TypeError, Tuple, 5)
1051 self.assertRaises(TypeError, Tuple, 5)
1052 self.assertRaises(TypeError, Tuple, default_value='hello')
1052 self.assertRaises(TypeError, Tuple, default_value='hello')
1053 t = Tuple(Int, CBytes, default_value=(1,5))
1053 t = Tuple(Int, CBytes, default_value=(1,5))
1054
1054
1055
1055
1056 class MultiTupleTrait(HasTraits):
1056 class MultiTupleTrait(HasTraits):
1057
1057
1058 value = Tuple(Int, Bytes, default_value=[99,b'bottles'])
1058 value = Tuple(Int, Bytes, default_value=[99,b'bottles'])
1059
1059
1060 class TestMultiTuple(TraitTestBase):
1060 class TestMultiTuple(TraitTestBase):
1061
1061
1062 obj = MultiTupleTrait()
1062 obj = MultiTupleTrait()
1063
1063
1064 _default_value = (99,b'bottles')
1064 _default_value = (99,b'bottles')
1065 _good_values = [(1,b'a'), (2,b'b')]
1065 _good_values = [(1,b'a'), (2,b'b')]
1066 _bad_values = ((),10, b'a', (1,b'a',3), (b'a',1), (1, u'a'))
1066 _bad_values = ((),10, b'a', (1,b'a',3), (b'a',1), (1, u'a'))
1067
1067
1068 class CRegExpTrait(HasTraits):
1068 class CRegExpTrait(HasTraits):
1069
1069
1070 value = CRegExp(r'')
1070 value = CRegExp(r'')
1071
1071
1072 class TestCRegExp(TraitTestBase):
1072 class TestCRegExp(TraitTestBase):
1073
1073
1074 def coerce(self, value):
1074 def coerce(self, value):
1075 return re.compile(value)
1075 return re.compile(value)
1076
1076
1077 obj = CRegExpTrait()
1077 obj = CRegExpTrait()
1078
1078
1079 _default_value = re.compile(r'')
1079 _default_value = re.compile(r'')
1080 _good_values = [r'\d+', re.compile(r'\d+')]
1080 _good_values = [r'\d+', re.compile(r'\d+')]
1081 _bad_values = ['(', None, ()]
1081 _bad_values = ['(', None, ()]
1082
1082
1083 class DictTrait(HasTraits):
1083 class DictTrait(HasTraits):
1084 value = Dict()
1084 value = Dict()
1085
1085
1086 def test_dict_assignment():
1086 def test_dict_assignment():
1087 d = dict()
1087 d = dict()
1088 c = DictTrait()
1088 c = DictTrait()
1089 c.value = d
1089 c.value = d
1090 d['a'] = 5
1090 d['a'] = 5
1091 nt.assert_equal(d, c.value)
1091 nt.assert_equal(d, c.value)
1092 nt.assert_true(c.value is d)
1092 nt.assert_true(c.value is d)
1093
1093
1094 class ValidatedDictTrait(HasTraits):
1094 class ValidatedDictTrait(HasTraits):
1095
1095
1096 value = Dict(Unicode())
1096 value = Dict(Unicode())
1097
1097
1098 class TestInstanceDict(TraitTestBase):
1098 class TestInstanceDict(TraitTestBase):
1099
1099
1100 obj = ValidatedDictTrait()
1100 obj = ValidatedDictTrait()
1101
1101
1102 _default_value = {}
1102 _default_value = {}
1103 _good_values = [{'0': 'foo'}, {'1': 'bar'}]
1103 _good_values = [{'0': 'foo'}, {'1': 'bar'}]
1104 _bad_values = [{'0': 0}, {'1': 1}]
1104 _bad_values = [{'0': 0}, {'1': 1}]
1105
1105
1106
1106
1107 def test_dict_default_value():
1107 def test_dict_default_value():
1108 """Check that the `{}` default value of the Dict traitlet constructor is
1108 """Check that the `{}` default value of the Dict traitlet constructor is
1109 actually copied."""
1109 actually copied."""
1110
1110
1111 d1, d2 = Dict(), Dict()
1111 d1, d2 = Dict(), Dict()
1112 nt.assert_false(d1.get_default_value() is d2.get_default_value())
1112 nt.assert_false(d1.get_default_value() is d2.get_default_value())
1113
1113
1114
1114
1115 class TestValidationHook(TestCase):
1115 class TestValidationHook(TestCase):
1116
1116
1117 def test_parity_trait(self):
1117 def test_parity_trait(self):
1118 """Verify that the early validation hook is effective"""
1118 """Verify that the early validation hook is effective"""
1119
1119
1120 class Parity(HasTraits):
1120 class Parity(HasTraits):
1121
1121
1122 value = Int(0)
1122 value = Int(0)
1123 parity = Enum(['odd', 'even'], default_value='even')
1123 parity = Enum(['odd', 'even'], default_value='even')
1124
1124
1125 def _value_validate(self, value, trait):
1125 def _value_validate(self, value, trait):
1126 if self.parity == 'even' and value % 2:
1126 if self.parity == 'even' and value % 2:
1127 raise TraitError('Expected an even number')
1127 raise TraitError('Expected an even number')
1128 if self.parity == 'odd' and (value % 2 == 0):
1128 if self.parity == 'odd' and (value % 2 == 0):
1129 raise TraitError('Expected an odd number')
1129 raise TraitError('Expected an odd number')
1130 return value
1130 return value
1131
1131
1132 u = Parity()
1132 u = Parity()
1133 u.parity = 'odd'
1133 u.parity = 'odd'
1134 u.value = 1 # OK
1134 u.value = 1 # OK
1135 with self.assertRaises(TraitError):
1135 with self.assertRaises(TraitError):
1136 u.value = 2 # Trait Error
1136 u.value = 2 # Trait Error
1137
1137
1138 u.parity = 'even'
1138 u.parity = 'even'
1139 u.value = 2 # OK
1139 u.value = 2 # OK
1140
1140
1141
1141
1142 class TestLink(TestCase):
1142 class TestLink(TestCase):
1143
1143
1144 def test_connect_same(self):
1144 def test_connect_same(self):
1145 """Verify two traitlets of the same type can be linked together using link."""
1145 """Verify two traitlets of the same type can be linked together using link."""
1146
1146
1147 # Create two simple classes with Int traitlets.
1147 # Create two simple classes with Int traitlets.
1148 class A(HasTraits):
1148 class A(HasTraits):
1149 value = Int()
1149 value = Int()
1150 a = A(value=9)
1150 a = A(value=9)
1151 b = A(value=8)
1151 b = A(value=8)
1152
1152
1153 # Conenct the two classes.
1153 # Conenct the two classes.
1154 c = link((a, 'value'), (b, 'value'))
1154 c = link((a, 'value'), (b, 'value'))
1155
1155
1156 # Make sure the values are the same at the point of linking.
1156 # Make sure the values are the same at the point of linking.
1157 self.assertEqual(a.value, b.value)
1157 self.assertEqual(a.value, b.value)
1158
1158
1159 # Change one of the values to make sure they stay in sync.
1159 # Change one of the values to make sure they stay in sync.
1160 a.value = 5
1160 a.value = 5
1161 self.assertEqual(a.value, b.value)
1161 self.assertEqual(a.value, b.value)
1162 b.value = 6
1162 b.value = 6
1163 self.assertEqual(a.value, b.value)
1163 self.assertEqual(a.value, b.value)
1164
1164
1165 def test_link_different(self):
1165 def test_link_different(self):
1166 """Verify two traitlets of different types can be linked together using link."""
1166 """Verify two traitlets of different types can be linked together using link."""
1167
1167
1168 # Create two simple classes with Int traitlets.
1168 # Create two simple classes with Int traitlets.
1169 class A(HasTraits):
1169 class A(HasTraits):
1170 value = Int()
1170 value = Int()
1171 class B(HasTraits):
1171 class B(HasTraits):
1172 count = Int()
1172 count = Int()
1173 a = A(value=9)
1173 a = A(value=9)
1174 b = B(count=8)
1174 b = B(count=8)
1175
1175
1176 # Conenct the two classes.
1176 # Conenct the two classes.
1177 c = link((a, 'value'), (b, 'count'))
1177 c = link((a, 'value'), (b, 'count'))
1178
1178
1179 # Make sure the values are the same at the point of linking.
1179 # Make sure the values are the same at the point of linking.
1180 self.assertEqual(a.value, b.count)
1180 self.assertEqual(a.value, b.count)
1181
1181
1182 # Change one of the values to make sure they stay in sync.
1182 # Change one of the values to make sure they stay in sync.
1183 a.value = 5
1183 a.value = 5
1184 self.assertEqual(a.value, b.count)
1184 self.assertEqual(a.value, b.count)
1185 b.count = 4
1185 b.count = 4
1186 self.assertEqual(a.value, b.count)
1186 self.assertEqual(a.value, b.count)
1187
1187
1188 def test_unlink(self):
1188 def test_unlink(self):
1189 """Verify two linked traitlets can be unlinked."""
1189 """Verify two linked traitlets can be unlinked."""
1190
1190
1191 # Create two simple classes with Int traitlets.
1191 # Create two simple classes with Int traitlets.
1192 class A(HasTraits):
1192 class A(HasTraits):
1193 value = Int()
1193 value = Int()
1194 a = A(value=9)
1194 a = A(value=9)
1195 b = A(value=8)
1195 b = A(value=8)
1196
1196
1197 # Connect the two classes.
1197 # Connect the two classes.
1198 c = link((a, 'value'), (b, 'value'))
1198 c = link((a, 'value'), (b, 'value'))
1199 a.value = 4
1199 a.value = 4
1200 c.unlink()
1200 c.unlink()
1201
1201
1202 # Change one of the values to make sure they don't stay in sync.
1202 # Change one of the values to make sure they don't stay in sync.
1203 a.value = 5
1203 a.value = 5
1204 self.assertNotEqual(a.value, b.value)
1204 self.assertNotEqual(a.value, b.value)
1205
1205
1206 def test_callbacks(self):
1206 def test_callbacks(self):
1207 """Verify two linked traitlets have their callbacks called once."""
1207 """Verify two linked traitlets have their callbacks called once."""
1208
1208
1209 # Create two simple classes with Int traitlets.
1209 # Create two simple classes with Int traitlets.
1210 class A(HasTraits):
1210 class A(HasTraits):
1211 value = Int()
1211 value = Int()
1212 class B(HasTraits):
1212 class B(HasTraits):
1213 count = Int()
1213 count = Int()
1214 a = A(value=9)
1214 a = A(value=9)
1215 b = B(count=8)
1215 b = B(count=8)
1216
1216
1217 # Register callbacks that count.
1217 # Register callbacks that count.
1218 callback_count = []
1218 callback_count = []
1219 def a_callback(name, old, new):
1219 def a_callback(name, old, new):
1220 callback_count.append('a')
1220 callback_count.append('a')
1221 a.on_trait_change(a_callback, 'value')
1221 a.on_trait_change(a_callback, 'value')
1222 def b_callback(name, old, new):
1222 def b_callback(name, old, new):
1223 callback_count.append('b')
1223 callback_count.append('b')
1224 b.on_trait_change(b_callback, 'count')
1224 b.on_trait_change(b_callback, 'count')
1225
1225
1226 # Connect the two classes.
1226 # Connect the two classes.
1227 c = link((a, 'value'), (b, 'count'))
1227 c = link((a, 'value'), (b, 'count'))
1228
1228
1229 # Make sure b's count was set to a's value once.
1229 # Make sure b's count was set to a's value once.
1230 self.assertEqual(''.join(callback_count), 'b')
1230 self.assertEqual(''.join(callback_count), 'b')
1231 del callback_count[:]
1231 del callback_count[:]
1232
1232
1233 # Make sure a's value was set to b's count once.
1233 # Make sure a's value was set to b's count once.
1234 b.count = 5
1234 b.count = 5
1235 self.assertEqual(''.join(callback_count), 'ba')
1235 self.assertEqual(''.join(callback_count), 'ba')
1236 del callback_count[:]
1236 del callback_count[:]
1237
1237
1238 # Make sure b's count was set to a's value once.
1238 # Make sure b's count was set to a's value once.
1239 a.value = 4
1239 a.value = 4
1240 self.assertEqual(''.join(callback_count), 'ab')
1240 self.assertEqual(''.join(callback_count), 'ab')
1241 del callback_count[:]
1241 del callback_count[:]
1242
1242
1243 class TestDirectionalLink(TestCase):
1243 class TestDirectionalLink(TestCase):
1244 def test_connect_same(self):
1244 def test_connect_same(self):
1245 """Verify two traitlets of the same type can be linked together using directional_link."""
1245 """Verify two traitlets of the same type can be linked together using directional_link."""
1246
1246
1247 # Create two simple classes with Int traitlets.
1247 # Create two simple classes with Int traitlets.
1248 class A(HasTraits):
1248 class A(HasTraits):
1249 value = Int()
1249 value = Int()
1250 a = A(value=9)
1250 a = A(value=9)
1251 b = A(value=8)
1251 b = A(value=8)
1252
1252
1253 # Conenct the two classes.
1253 # Conenct the two classes.
1254 c = directional_link((a, 'value'), (b, 'value'))
1254 c = directional_link((a, 'value'), (b, 'value'))
1255
1255
1256 # Make sure the values are the same at the point of linking.
1256 # Make sure the values are the same at the point of linking.
1257 self.assertEqual(a.value, b.value)
1257 self.assertEqual(a.value, b.value)
1258
1258
1259 # Change one the value of the source and check that it synchronizes the target.
1259 # Change one the value of the source and check that it synchronizes the target.
1260 a.value = 5
1260 a.value = 5
1261 self.assertEqual(b.value, 5)
1261 self.assertEqual(b.value, 5)
1262 # Change one the value of the target and check that it has no impact on the source
1262 # Change one the value of the target and check that it has no impact on the source
1263 b.value = 6
1263 b.value = 6
1264 self.assertEqual(a.value, 5)
1264 self.assertEqual(a.value, 5)
1265
1265
1266 def test_link_different(self):
1266 def test_link_different(self):
1267 """Verify two traitlets of different types can be linked together using link."""
1267 """Verify two traitlets of different types can be linked together using link."""
1268
1268
1269 # Create two simple classes with Int traitlets.
1269 # Create two simple classes with Int traitlets.
1270 class A(HasTraits):
1270 class A(HasTraits):
1271 value = Int()
1271 value = Int()
1272 class B(HasTraits):
1272 class B(HasTraits):
1273 count = Int()
1273 count = Int()
1274 a = A(value=9)
1274 a = A(value=9)
1275 b = B(count=8)
1275 b = B(count=8)
1276
1276
1277 # Conenct the two classes.
1277 # Conenct the two classes.
1278 c = directional_link((a, 'value'), (b, 'count'))
1278 c = directional_link((a, 'value'), (b, 'count'))
1279
1279
1280 # Make sure the values are the same at the point of linking.
1280 # Make sure the values are the same at the point of linking.
1281 self.assertEqual(a.value, b.count)
1281 self.assertEqual(a.value, b.count)
1282
1282
1283 # Change one the value of the source and check that it synchronizes the target.
1283 # Change one the value of the source and check that it synchronizes the target.
1284 a.value = 5
1284 a.value = 5
1285 self.assertEqual(b.count, 5)
1285 self.assertEqual(b.count, 5)
1286 # Change one the value of the target and check that it has no impact on the source
1286 # Change one the value of the target and check that it has no impact on the source
1287 b.value = 6
1287 b.value = 6
1288 self.assertEqual(a.value, 5)
1288 self.assertEqual(a.value, 5)
1289
1289
1290 def test_unlink(self):
1290 def test_unlink(self):
1291 """Verify two linked traitlets can be unlinked."""
1291 """Verify two linked traitlets can be unlinked."""
1292
1292
1293 # Create two simple classes with Int traitlets.
1293 # Create two simple classes with Int traitlets.
1294 class A(HasTraits):
1294 class A(HasTraits):
1295 value = Int()
1295 value = Int()
1296 a = A(value=9)
1296 a = A(value=9)
1297 b = A(value=8)
1297 b = A(value=8)
1298
1298
1299 # Connect the two classes.
1299 # Connect the two classes.
1300 c = directional_link((a, 'value'), (b, 'value'))
1300 c = directional_link((a, 'value'), (b, 'value'))
1301 a.value = 4
1301 a.value = 4
1302 c.unlink()
1302 c.unlink()
1303
1303
1304 # Change one of the values to make sure they don't stay in sync.
1304 # Change one of the values to make sure they don't stay in sync.
1305 a.value = 5
1305 a.value = 5
1306 self.assertNotEqual(a.value, b.value)
1306 self.assertNotEqual(a.value, b.value)
1307
1307
1308 class Pickleable(HasTraits):
1308 class Pickleable(HasTraits):
1309 i = Int()
1309 i = Int()
1310 j = Int()
1310 j = Int()
1311
1311
1312 def _i_default(self):
1312 def _i_default(self):
1313 return 1
1313 return 1
1314
1314
1315 def _i_changed(self, name, old, new):
1315 def _i_changed(self, name, old, new):
1316 self.j = new
1316 self.j = new
1317
1317
1318 def test_pickle_hastraits():
1318 def test_pickle_hastraits():
1319 c = Pickleable()
1319 c = Pickleable()
1320 for protocol in range(pickle.HIGHEST_PROTOCOL + 1):
1320 for protocol in range(pickle.HIGHEST_PROTOCOL + 1):
1321 p = pickle.dumps(c, protocol)
1321 p = pickle.dumps(c, protocol)
1322 c2 = pickle.loads(p)
1322 c2 = pickle.loads(p)
1323 nt.assert_equal(c2.i, c.i)
1323 nt.assert_equal(c2.i, c.i)
1324 nt.assert_equal(c2.j, c.j)
1324 nt.assert_equal(c2.j, c.j)
1325
1325
1326 c.i = 5
1326 c.i = 5
1327 for protocol in range(pickle.HIGHEST_PROTOCOL + 1):
1327 for protocol in range(pickle.HIGHEST_PROTOCOL + 1):
1328 p = pickle.dumps(c, protocol)
1328 p = pickle.dumps(c, protocol)
1329 c2 = pickle.loads(p)
1329 c2 = pickle.loads(p)
1330 nt.assert_equal(c2.i, c.i)
1330 nt.assert_equal(c2.i, c.i)
1331 nt.assert_equal(c2.j, c.j)
1331 nt.assert_equal(c2.j, c.j)
1332
1332
1333
1333
1334 def test_hold_trait_notifications():
1334 def test_hold_trait_notifications():
1335 changes = []
1335 changes = []
1336
1336 class Test(HasTraits):
1337 class Test(HasTraits):
1337 a = Integer(0)
1338 a = Integer(0)
1339 b = Integer(0)
1340
1338 def _a_changed(self, name, old, new):
1341 def _a_changed(self, name, old, new):
1339 changes.append((old, new))
1342 changes.append((old, new))
1340
1343
1344 def _b_validate(self, value, trait):
1345 if value != 0:
1346 raise TraitError('Only 0 is a valid value')
1347 return value
1348
1349 # Test context manager and nesting
1341 t = Test()
1350 t = Test()
1342 with t.hold_trait_notifications():
1351 with t.hold_trait_notifications():
1343 with t.hold_trait_notifications():
1352 with t.hold_trait_notifications():
1344 t.a = 1
1353 t.a = 1
1345 nt.assert_equal(t.a, 1)
1354 nt.assert_equal(t.a, 1)
1346 nt.assert_equal(changes, [])
1355 nt.assert_equal(changes, [])
1347 t.a = 2
1356 t.a = 2
1348 nt.assert_equal(t.a, 2)
1357 nt.assert_equal(t.a, 2)
1349 with t.hold_trait_notifications():
1358 with t.hold_trait_notifications():
1350 t.a = 3
1359 t.a = 3
1351 nt.assert_equal(t.a, 3)
1360 nt.assert_equal(t.a, 3)
1352 nt.assert_equal(changes, [])
1361 nt.assert_equal(changes, [])
1353 t.a = 4
1362 t.a = 4
1354 nt.assert_equal(t.a, 4)
1363 nt.assert_equal(t.a, 4)
1355 nt.assert_equal(changes, [])
1364 nt.assert_equal(changes, [])
1356 t.a = 4
1365 t.a = 4
1357 nt.assert_equal(t.a, 4)
1366 nt.assert_equal(t.a, 4)
1358 nt.assert_equal(changes, [])
1367 nt.assert_equal(changes, [])
1359 nt.assert_equal(changes, [(0,1), (1,2), (2,3), (3,4)])
1360
1368
1369 nt.assert_equal(changes, [(3,4)])
1370 # Test roll-back
1371 try:
1372 with t.hold_trait_notifications():
1373 t.b = 1 # raises a Trait error
1374 except:
1375 pass
1376 nt.assert_equal(t.b, 0)
1377
1361
1378
1362 class OrderTraits(HasTraits):
1379 class OrderTraits(HasTraits):
1363 notified = Dict()
1380 notified = Dict()
1364
1381
1365 a = Unicode()
1382 a = Unicode()
1366 b = Unicode()
1383 b = Unicode()
1367 c = Unicode()
1384 c = Unicode()
1368 d = Unicode()
1385 d = Unicode()
1369 e = Unicode()
1386 e = Unicode()
1370 f = Unicode()
1387 f = Unicode()
1371 g = Unicode()
1388 g = Unicode()
1372 h = Unicode()
1389 h = Unicode()
1373 i = Unicode()
1390 i = Unicode()
1374 j = Unicode()
1391 j = Unicode()
1375 k = Unicode()
1392 k = Unicode()
1376 l = Unicode()
1393 l = Unicode()
1377
1394
1378 def _notify(self, name, old, new):
1395 def _notify(self, name, old, new):
1379 """check the value of all traits when each trait change is triggered
1396 """check the value of all traits when each trait change is triggered
1380
1397
1381 This verifies that the values are not sensitive
1398 This verifies that the values are not sensitive
1382 to dict ordering when loaded from kwargs
1399 to dict ordering when loaded from kwargs
1383 """
1400 """
1384 # check the value of the other traits
1401 # check the value of the other traits
1385 # when a given trait change notification fires
1402 # when a given trait change notification fires
1386 self.notified[name] = {
1403 self.notified[name] = {
1387 c: getattr(self, c) for c in 'abcdefghijkl'
1404 c: getattr(self, c) for c in 'abcdefghijkl'
1388 }
1405 }
1389
1406
1390 def __init__(self, **kwargs):
1407 def __init__(self, **kwargs):
1391 self.on_trait_change(self._notify)
1408 self.on_trait_change(self._notify)
1392 super(OrderTraits, self).__init__(**kwargs)
1409 super(OrderTraits, self).__init__(**kwargs)
1393
1410
1394 def test_notification_order():
1411 def test_notification_order():
1395 d = {c:c for c in 'abcdefghijkl'}
1412 d = {c:c for c in 'abcdefghijkl'}
1396 obj = OrderTraits()
1413 obj = OrderTraits()
1397 nt.assert_equal(obj.notified, {})
1414 nt.assert_equal(obj.notified, {})
1398 obj = OrderTraits(**d)
1415 obj = OrderTraits(**d)
1399 notifications = {
1416 notifications = {
1400 c: d for c in 'abcdefghijkl'
1417 c: d for c in 'abcdefghijkl'
1401 }
1418 }
1402 nt.assert_equal(obj.notified, notifications)
1419 nt.assert_equal(obj.notified, notifications)
1403
1420
1404
1421
1405 class TestEventful(TestCase):
1422 class TestEventful(TestCase):
1406
1423
1407 def test_list(self):
1424 def test_list(self):
1408 """Does the EventfulList work?"""
1425 """Does the EventfulList work?"""
1409 event_cache = []
1426 event_cache = []
1410
1427
1411 class A(HasTraits):
1428 class A(HasTraits):
1412 x = EventfulList([c for c in 'abc'])
1429 x = EventfulList([c for c in 'abc'])
1413 a = A()
1430 a = A()
1414 a.x.on_events(lambda i, x: event_cache.append('insert'), \
1431 a.x.on_events(lambda i, x: event_cache.append('insert'), \
1415 lambda i, x: event_cache.append('set'), \
1432 lambda i, x: event_cache.append('set'), \
1416 lambda i: event_cache.append('del'), \
1433 lambda i: event_cache.append('del'), \
1417 lambda: event_cache.append('reverse'), \
1434 lambda: event_cache.append('reverse'), \
1418 lambda *p, **k: event_cache.append('sort'))
1435 lambda *p, **k: event_cache.append('sort'))
1419
1436
1420 a.x.remove('c')
1437 a.x.remove('c')
1421 # ab
1438 # ab
1422 a.x.insert(0, 'z')
1439 a.x.insert(0, 'z')
1423 # zab
1440 # zab
1424 del a.x[1]
1441 del a.x[1]
1425 # zb
1442 # zb
1426 a.x.reverse()
1443 a.x.reverse()
1427 # bz
1444 # bz
1428 a.x[1] = 'o'
1445 a.x[1] = 'o'
1429 # bo
1446 # bo
1430 a.x.append('a')
1447 a.x.append('a')
1431 # boa
1448 # boa
1432 a.x.sort()
1449 a.x.sort()
1433 # abo
1450 # abo
1434
1451
1435 # Were the correct events captured?
1452 # Were the correct events captured?
1436 self.assertEqual(event_cache, ['del', 'insert', 'del', 'reverse', 'set', 'set', 'sort'])
1453 self.assertEqual(event_cache, ['del', 'insert', 'del', 'reverse', 'set', 'set', 'sort'])
1437
1454
1438 # Is the output correct?
1455 # Is the output correct?
1439 self.assertEqual(a.x, [c for c in 'abo'])
1456 self.assertEqual(a.x, [c for c in 'abo'])
1440
1457
1441 def test_dict(self):
1458 def test_dict(self):
1442 """Does the EventfulDict work?"""
1459 """Does the EventfulDict work?"""
1443 event_cache = []
1460 event_cache = []
1444
1461
1445 class A(HasTraits):
1462 class A(HasTraits):
1446 x = EventfulDict({c: c for c in 'abc'})
1463 x = EventfulDict({c: c for c in 'abc'})
1447 a = A()
1464 a = A()
1448 a.x.on_events(lambda k, v: event_cache.append('add'), \
1465 a.x.on_events(lambda k, v: event_cache.append('add'), \
1449 lambda k, v: event_cache.append('set'), \
1466 lambda k, v: event_cache.append('set'), \
1450 lambda k: event_cache.append('del'))
1467 lambda k: event_cache.append('del'))
1451
1468
1452 del a.x['c']
1469 del a.x['c']
1453 # ab
1470 # ab
1454 a.x['z'] = 1
1471 a.x['z'] = 1
1455 # abz
1472 # abz
1456 a.x['z'] = 'z'
1473 a.x['z'] = 'z'
1457 # abz
1474 # abz
1458 a.x.pop('a')
1475 a.x.pop('a')
1459 # bz
1476 # bz
1460
1477
1461 # Were the correct events captured?
1478 # Were the correct events captured?
1462 self.assertEqual(event_cache, ['del', 'add', 'set', 'del'])
1479 self.assertEqual(event_cache, ['del', 'add', 'set', 'del'])
1463
1480
1464 # Is the output correct?
1481 # Is the output correct?
1465 self.assertEqual(a.x, {c: c for c in 'bz'})
1482 self.assertEqual(a.x, {c: c for c in 'bz'})
1466
1483
1467 ###
1484 ###
1468 # Traits for Forward Declaration Tests
1485 # Traits for Forward Declaration Tests
1469 ###
1486 ###
1470 class ForwardDeclaredInstanceTrait(HasTraits):
1487 class ForwardDeclaredInstanceTrait(HasTraits):
1471
1488
1472 value = ForwardDeclaredInstance('ForwardDeclaredBar', allow_none=True)
1489 value = ForwardDeclaredInstance('ForwardDeclaredBar', allow_none=True)
1473
1490
1474 class ForwardDeclaredTypeTrait(HasTraits):
1491 class ForwardDeclaredTypeTrait(HasTraits):
1475
1492
1476 value = ForwardDeclaredType('ForwardDeclaredBar', allow_none=True)
1493 value = ForwardDeclaredType('ForwardDeclaredBar', allow_none=True)
1477
1494
1478 class ForwardDeclaredInstanceListTrait(HasTraits):
1495 class ForwardDeclaredInstanceListTrait(HasTraits):
1479
1496
1480 value = List(ForwardDeclaredInstance('ForwardDeclaredBar'))
1497 value = List(ForwardDeclaredInstance('ForwardDeclaredBar'))
1481
1498
1482 class ForwardDeclaredTypeListTrait(HasTraits):
1499 class ForwardDeclaredTypeListTrait(HasTraits):
1483
1500
1484 value = List(ForwardDeclaredType('ForwardDeclaredBar'))
1501 value = List(ForwardDeclaredType('ForwardDeclaredBar'))
1485 ###
1502 ###
1486 # End Traits for Forward Declaration Tests
1503 # End Traits for Forward Declaration Tests
1487 ###
1504 ###
1488
1505
1489 ###
1506 ###
1490 # Classes for Forward Declaration Tests
1507 # Classes for Forward Declaration Tests
1491 ###
1508 ###
1492 class ForwardDeclaredBar(object):
1509 class ForwardDeclaredBar(object):
1493 pass
1510 pass
1494
1511
1495 class ForwardDeclaredBarSub(ForwardDeclaredBar):
1512 class ForwardDeclaredBarSub(ForwardDeclaredBar):
1496 pass
1513 pass
1497 ###
1514 ###
1498 # End Classes for Forward Declaration Tests
1515 # End Classes for Forward Declaration Tests
1499 ###
1516 ###
1500
1517
1501 ###
1518 ###
1502 # Forward Declaration Tests
1519 # Forward Declaration Tests
1503 ###
1520 ###
1504 class TestForwardDeclaredInstanceTrait(TraitTestBase):
1521 class TestForwardDeclaredInstanceTrait(TraitTestBase):
1505
1522
1506 obj = ForwardDeclaredInstanceTrait()
1523 obj = ForwardDeclaredInstanceTrait()
1507 _default_value = None
1524 _default_value = None
1508 _good_values = [None, ForwardDeclaredBar(), ForwardDeclaredBarSub()]
1525 _good_values = [None, ForwardDeclaredBar(), ForwardDeclaredBarSub()]
1509 _bad_values = ['foo', 3, ForwardDeclaredBar, ForwardDeclaredBarSub]
1526 _bad_values = ['foo', 3, ForwardDeclaredBar, ForwardDeclaredBarSub]
1510
1527
1511 class TestForwardDeclaredTypeTrait(TraitTestBase):
1528 class TestForwardDeclaredTypeTrait(TraitTestBase):
1512
1529
1513 obj = ForwardDeclaredTypeTrait()
1530 obj = ForwardDeclaredTypeTrait()
1514 _default_value = None
1531 _default_value = None
1515 _good_values = [None, ForwardDeclaredBar, ForwardDeclaredBarSub]
1532 _good_values = [None, ForwardDeclaredBar, ForwardDeclaredBarSub]
1516 _bad_values = ['foo', 3, ForwardDeclaredBar(), ForwardDeclaredBarSub()]
1533 _bad_values = ['foo', 3, ForwardDeclaredBar(), ForwardDeclaredBarSub()]
1517
1534
1518 class TestForwardDeclaredInstanceList(TraitTestBase):
1535 class TestForwardDeclaredInstanceList(TraitTestBase):
1519
1536
1520 obj = ForwardDeclaredInstanceListTrait()
1537 obj = ForwardDeclaredInstanceListTrait()
1521
1538
1522 def test_klass(self):
1539 def test_klass(self):
1523 """Test that the instance klass is properly assigned."""
1540 """Test that the instance klass is properly assigned."""
1524 self.assertIs(self.obj.traits()['value']._trait.klass, ForwardDeclaredBar)
1541 self.assertIs(self.obj.traits()['value']._trait.klass, ForwardDeclaredBar)
1525
1542
1526 _default_value = []
1543 _default_value = []
1527 _good_values = [
1544 _good_values = [
1528 [ForwardDeclaredBar(), ForwardDeclaredBarSub()],
1545 [ForwardDeclaredBar(), ForwardDeclaredBarSub()],
1529 [],
1546 [],
1530 ]
1547 ]
1531 _bad_values = [
1548 _bad_values = [
1532 ForwardDeclaredBar(),
1549 ForwardDeclaredBar(),
1533 [ForwardDeclaredBar(), 3, None],
1550 [ForwardDeclaredBar(), 3, None],
1534 '1',
1551 '1',
1535 # Note that this is the type, not an instance.
1552 # Note that this is the type, not an instance.
1536 [ForwardDeclaredBar],
1553 [ForwardDeclaredBar],
1537 [None],
1554 [None],
1538 None,
1555 None,
1539 ]
1556 ]
1540
1557
1541 class TestForwardDeclaredTypeList(TraitTestBase):
1558 class TestForwardDeclaredTypeList(TraitTestBase):
1542
1559
1543 obj = ForwardDeclaredTypeListTrait()
1560 obj = ForwardDeclaredTypeListTrait()
1544
1561
1545 def test_klass(self):
1562 def test_klass(self):
1546 """Test that the instance klass is properly assigned."""
1563 """Test that the instance klass is properly assigned."""
1547 self.assertIs(self.obj.traits()['value']._trait.klass, ForwardDeclaredBar)
1564 self.assertIs(self.obj.traits()['value']._trait.klass, ForwardDeclaredBar)
1548
1565
1549 _default_value = []
1566 _default_value = []
1550 _good_values = [
1567 _good_values = [
1551 [ForwardDeclaredBar, ForwardDeclaredBarSub],
1568 [ForwardDeclaredBar, ForwardDeclaredBarSub],
1552 [],
1569 [],
1553 ]
1570 ]
1554 _bad_values = [
1571 _bad_values = [
1555 ForwardDeclaredBar,
1572 ForwardDeclaredBar,
1556 [ForwardDeclaredBar, 3],
1573 [ForwardDeclaredBar, 3],
1557 '1',
1574 '1',
1558 # Note that this is an instance, not the type.
1575 # Note that this is an instance, not the type.
1559 [ForwardDeclaredBar()],
1576 [ForwardDeclaredBar()],
1560 [None],
1577 [None],
1561 None,
1578 None,
1562 ]
1579 ]
1563 ###
1580 ###
1564 # End Forward Declaration Tests
1581 # End Forward Declaration Tests
1565 ###
1582 ###
1566
1583
1567 class TestDynamicTraits(TestCase):
1584 class TestDynamicTraits(TestCase):
1568
1585
1569 def setUp(self):
1586 def setUp(self):
1570 self._notify1 = []
1587 self._notify1 = []
1571
1588
1572 def notify1(self, name, old, new):
1589 def notify1(self, name, old, new):
1573 self._notify1.append((name, old, new))
1590 self._notify1.append((name, old, new))
1574
1591
1575 def test_notify_all(self):
1592 def test_notify_all(self):
1576
1593
1577 class A(HasTraits):
1594 class A(HasTraits):
1578 pass
1595 pass
1579
1596
1580 a = A()
1597 a = A()
1581 self.assertTrue(not hasattr(a, 'x'))
1598 self.assertTrue(not hasattr(a, 'x'))
1582 self.assertTrue(not hasattr(a, 'y'))
1599 self.assertTrue(not hasattr(a, 'y'))
1583
1600
1584 # Dynamically add trait x.
1601 # Dynamically add trait x.
1585 a.add_trait('x', Int())
1602 a.add_trait('x', Int())
1586 self.assertTrue(hasattr(a, 'x'))
1603 self.assertTrue(hasattr(a, 'x'))
1587 self.assertTrue(isinstance(a, (A, )))
1604 self.assertTrue(isinstance(a, (A, )))
1588
1605
1589 # Dynamically add trait y.
1606 # Dynamically add trait y.
1590 a.add_trait('y', Float())
1607 a.add_trait('y', Float())
1591 self.assertTrue(hasattr(a, 'y'))
1608 self.assertTrue(hasattr(a, 'y'))
1592 self.assertTrue(isinstance(a, (A, )))
1609 self.assertTrue(isinstance(a, (A, )))
1593 self.assertEqual(a.__class__.__name__, A.__name__)
1610 self.assertEqual(a.__class__.__name__, A.__name__)
1594
1611
1595 # Create a new instance and verify that x and y
1612 # Create a new instance and verify that x and y
1596 # aren't defined.
1613 # aren't defined.
1597 b = A()
1614 b = A()
1598 self.assertTrue(not hasattr(b, 'x'))
1615 self.assertTrue(not hasattr(b, 'x'))
1599 self.assertTrue(not hasattr(b, 'y'))
1616 self.assertTrue(not hasattr(b, 'y'))
1600
1617
1601 # Verify that notification works like normal.
1618 # Verify that notification works like normal.
1602 a.on_trait_change(self.notify1)
1619 a.on_trait_change(self.notify1)
1603 a.x = 0
1620 a.x = 0
1604 self.assertEqual(len(self._notify1), 0)
1621 self.assertEqual(len(self._notify1), 0)
1605 a.y = 0.0
1622 a.y = 0.0
1606 self.assertEqual(len(self._notify1), 0)
1623 self.assertEqual(len(self._notify1), 0)
1607 a.x = 10
1624 a.x = 10
1608 self.assertTrue(('x', 0, 10) in self._notify1)
1625 self.assertTrue(('x', 0, 10) in self._notify1)
1609 a.y = 10.0
1626 a.y = 10.0
1610 self.assertTrue(('y', 0.0, 10.0) in self._notify1)
1627 self.assertTrue(('y', 0.0, 10.0) in self._notify1)
1611 self.assertRaises(TraitError, setattr, a, 'x', 'bad string')
1628 self.assertRaises(TraitError, setattr, a, 'x', 'bad string')
1612 self.assertRaises(TraitError, setattr, a, 'y', 'bad string')
1629 self.assertRaises(TraitError, setattr, a, 'y', 'bad string')
1613 self._notify1 = []
1630 self._notify1 = []
1614 a.on_trait_change(self.notify1, remove=True)
1631 a.on_trait_change(self.notify1, remove=True)
1615 a.x = 20
1632 a.x = 20
1616 a.y = 20.0
1633 a.y = 20.0
1617 self.assertEqual(len(self._notify1), 0)
1634 self.assertEqual(len(self._notify1), 0)
@@ -1,1841 +1,1874 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """
2 """
3 A lightweight Traits like module.
3 A lightweight Traits like module.
4
4
5 This is designed to provide a lightweight, simple, pure Python version of
5 This is designed to provide a lightweight, simple, pure Python version of
6 many of the capabilities of enthought.traits. This includes:
6 many of the capabilities of enthought.traits. This includes:
7
7
8 * Validation
8 * Validation
9 * Type specification with defaults
9 * Type specification with defaults
10 * Static and dynamic notification
10 * Static and dynamic notification
11 * Basic predefined types
11 * Basic predefined types
12 * An API that is similar to enthought.traits
12 * An API that is similar to enthought.traits
13
13
14 We don't support:
14 We don't support:
15
15
16 * Delegation
16 * Delegation
17 * Automatic GUI generation
17 * Automatic GUI generation
18 * A full set of trait types. Most importantly, we don't provide container
18 * A full set of trait types. Most importantly, we don't provide container
19 traits (list, dict, tuple) that can trigger notifications if their
19 traits (list, dict, tuple) that can trigger notifications if their
20 contents change.
20 contents change.
21 * API compatibility with enthought.traits
21 * API compatibility with enthought.traits
22
22
23 There are also some important difference in our design:
23 There are also some important difference in our design:
24
24
25 * enthought.traits does not validate default values. We do.
25 * enthought.traits does not validate default values. We do.
26
26
27 We choose to create this module because we need these capabilities, but
27 We choose to create this module because we need these capabilities, but
28 we need them to be pure Python so they work in all Python implementations,
28 we need them to be pure Python so they work in all Python implementations,
29 including Jython and IronPython.
29 including Jython and IronPython.
30
30
31 Inheritance diagram:
31 Inheritance diagram:
32
32
33 .. inheritance-diagram:: IPython.utils.traitlets
33 .. inheritance-diagram:: IPython.utils.traitlets
34 :parts: 3
34 :parts: 3
35 """
35 """
36
36
37 # Copyright (c) IPython Development Team.
37 # Copyright (c) IPython Development Team.
38 # Distributed under the terms of the Modified BSD License.
38 # Distributed under the terms of the Modified BSD License.
39 #
39 #
40 # Adapted from enthought.traits, Copyright (c) Enthought, Inc.,
40 # Adapted from enthought.traits, Copyright (c) Enthought, Inc.,
41 # also under the terms of the Modified BSD License.
41 # also under the terms of the Modified BSD License.
42
42
43 import contextlib
43 import contextlib
44 import inspect
44 import inspect
45 import re
45 import re
46 import sys
46 import sys
47 import types
47 import types
48 from types import FunctionType
48 from types import FunctionType
49 try:
49 try:
50 from types import ClassType, InstanceType
50 from types import ClassType, InstanceType
51 ClassTypes = (ClassType, type)
51 ClassTypes = (ClassType, type)
52 except:
52 except:
53 ClassTypes = (type,)
53 ClassTypes = (type,)
54 from warnings import warn
54 from warnings import warn
55
55
56 from .getargspec import getargspec
56 from .getargspec import getargspec
57 from .importstring import import_item
57 from .importstring import import_item
58 from IPython.utils import py3compat
58 from IPython.utils import py3compat
59 from IPython.utils import eventful
59 from IPython.utils import eventful
60 from IPython.utils.py3compat import iteritems, string_types
60 from IPython.utils.py3compat import iteritems, string_types
61 from IPython.testing.skipdoctest import skip_doctest
61 from IPython.testing.skipdoctest import skip_doctest
62
62
63 SequenceTypes = (list, tuple, set, frozenset)
63 SequenceTypes = (list, tuple, set, frozenset)
64
64
65 #-----------------------------------------------------------------------------
65 #-----------------------------------------------------------------------------
66 # Basic classes
66 # Basic classes
67 #-----------------------------------------------------------------------------
67 #-----------------------------------------------------------------------------
68
68
69
69
70 class NoDefaultSpecified ( object ): pass
70 class NoDefaultSpecified ( object ): pass
71 NoDefaultSpecified = NoDefaultSpecified()
71 NoDefaultSpecified = NoDefaultSpecified()
72
72
73
73
74 class Undefined ( object ): pass
74 class Undefined ( object ): pass
75 Undefined = Undefined()
75 Undefined = Undefined()
76
76
77 class TraitError(Exception):
77 class TraitError(Exception):
78 pass
78 pass
79
79
80 #-----------------------------------------------------------------------------
80 #-----------------------------------------------------------------------------
81 # Utilities
81 # Utilities
82 #-----------------------------------------------------------------------------
82 #-----------------------------------------------------------------------------
83
83
84
84
85 def class_of ( object ):
85 def class_of ( object ):
86 """ Returns a string containing the class name of an object with the
86 """ Returns a string containing the class name of an object with the
87 correct indefinite article ('a' or 'an') preceding it (e.g., 'an Image',
87 correct indefinite article ('a' or 'an') preceding it (e.g., 'an Image',
88 'a PlotValue').
88 'a PlotValue').
89 """
89 """
90 if isinstance( object, py3compat.string_types ):
90 if isinstance( object, py3compat.string_types ):
91 return add_article( object )
91 return add_article( object )
92
92
93 return add_article( object.__class__.__name__ )
93 return add_article( object.__class__.__name__ )
94
94
95
95
96 def add_article ( name ):
96 def add_article ( name ):
97 """ Returns a string containing the correct indefinite article ('a' or 'an')
97 """ Returns a string containing the correct indefinite article ('a' or 'an')
98 prefixed to the specified string.
98 prefixed to the specified string.
99 """
99 """
100 if name[:1].lower() in 'aeiou':
100 if name[:1].lower() in 'aeiou':
101 return 'an ' + name
101 return 'an ' + name
102
102
103 return 'a ' + name
103 return 'a ' + name
104
104
105
105
106 def repr_type(obj):
106 def repr_type(obj):
107 """ Return a string representation of a value and its type for readable
107 """ Return a string representation of a value and its type for readable
108 error messages.
108 error messages.
109 """
109 """
110 the_type = type(obj)
110 the_type = type(obj)
111 if (not py3compat.PY3) and the_type is InstanceType:
111 if (not py3compat.PY3) and the_type is InstanceType:
112 # Old-style class.
112 # Old-style class.
113 the_type = obj.__class__
113 the_type = obj.__class__
114 msg = '%r %r' % (obj, the_type)
114 msg = '%r %r' % (obj, the_type)
115 return msg
115 return msg
116
116
117
117
118 def is_trait(t):
118 def is_trait(t):
119 """ Returns whether the given value is an instance or subclass of TraitType.
119 """ Returns whether the given value is an instance or subclass of TraitType.
120 """
120 """
121 return (isinstance(t, TraitType) or
121 return (isinstance(t, TraitType) or
122 (isinstance(t, type) and issubclass(t, TraitType)))
122 (isinstance(t, type) and issubclass(t, TraitType)))
123
123
124
124
125 def parse_notifier_name(name):
125 def parse_notifier_name(name):
126 """Convert the name argument to a list of names.
126 """Convert the name argument to a list of names.
127
127
128 Examples
128 Examples
129 --------
129 --------
130
130
131 >>> parse_notifier_name('a')
131 >>> parse_notifier_name('a')
132 ['a']
132 ['a']
133 >>> parse_notifier_name(['a','b'])
133 >>> parse_notifier_name(['a','b'])
134 ['a', 'b']
134 ['a', 'b']
135 >>> parse_notifier_name(None)
135 >>> parse_notifier_name(None)
136 ['anytrait']
136 ['anytrait']
137 """
137 """
138 if isinstance(name, string_types):
138 if isinstance(name, string_types):
139 return [name]
139 return [name]
140 elif name is None:
140 elif name is None:
141 return ['anytrait']
141 return ['anytrait']
142 elif isinstance(name, (list, tuple)):
142 elif isinstance(name, (list, tuple)):
143 for n in name:
143 for n in name:
144 assert isinstance(n, string_types), "names must be strings"
144 assert isinstance(n, string_types), "names must be strings"
145 return name
145 return name
146
146
147
147
148 class _SimpleTest:
148 class _SimpleTest:
149 def __init__ ( self, value ): self.value = value
149 def __init__ ( self, value ): self.value = value
150 def __call__ ( self, test ):
150 def __call__ ( self, test ):
151 return test == self.value
151 return test == self.value
152 def __repr__(self):
152 def __repr__(self):
153 return "<SimpleTest(%r)" % self.value
153 return "<SimpleTest(%r)" % self.value
154 def __str__(self):
154 def __str__(self):
155 return self.__repr__()
155 return self.__repr__()
156
156
157
157
158 def getmembers(object, predicate=None):
158 def getmembers(object, predicate=None):
159 """A safe version of inspect.getmembers that handles missing attributes.
159 """A safe version of inspect.getmembers that handles missing attributes.
160
160
161 This is useful when there are descriptor based attributes that for
161 This is useful when there are descriptor based attributes that for
162 some reason raise AttributeError even though they exist. This happens
162 some reason raise AttributeError even though they exist. This happens
163 in zope.inteface with the __provides__ attribute.
163 in zope.inteface with the __provides__ attribute.
164 """
164 """
165 results = []
165 results = []
166 for key in dir(object):
166 for key in dir(object):
167 try:
167 try:
168 value = getattr(object, key)
168 value = getattr(object, key)
169 except AttributeError:
169 except AttributeError:
170 pass
170 pass
171 else:
171 else:
172 if not predicate or predicate(value):
172 if not predicate or predicate(value):
173 results.append((key, value))
173 results.append((key, value))
174 results.sort()
174 results.sort()
175 return results
175 return results
176
176
177 def _validate_link(*tuples):
177 def _validate_link(*tuples):
178 """Validate arguments for traitlet link functions"""
178 """Validate arguments for traitlet link functions"""
179 for t in tuples:
179 for t in tuples:
180 if not len(t) == 2:
180 if not len(t) == 2:
181 raise TypeError("Each linked traitlet must be specified as (HasTraits, 'trait_name'), not %r" % t)
181 raise TypeError("Each linked traitlet must be specified as (HasTraits, 'trait_name'), not %r" % t)
182 obj, trait_name = t
182 obj, trait_name = t
183 if not isinstance(obj, HasTraits):
183 if not isinstance(obj, HasTraits):
184 raise TypeError("Each object must be HasTraits, not %r" % type(obj))
184 raise TypeError("Each object must be HasTraits, not %r" % type(obj))
185 if not trait_name in obj.traits():
185 if not trait_name in obj.traits():
186 raise TypeError("%r has no trait %r" % (obj, trait_name))
186 raise TypeError("%r has no trait %r" % (obj, trait_name))
187
187
188 @skip_doctest
188 @skip_doctest
189 class link(object):
189 class link(object):
190 """Link traits from different objects together so they remain in sync.
190 """Link traits from different objects together so they remain in sync.
191
191
192 Parameters
192 Parameters
193 ----------
193 ----------
194 *args : pairs of objects/attributes
194 *args : pairs of objects/attributes
195
195
196 Examples
196 Examples
197 --------
197 --------
198
198
199 >>> c = link((obj1, 'value'), (obj2, 'value'), (obj3, 'value'))
199 >>> c = link((obj1, 'value'), (obj2, 'value'), (obj3, 'value'))
200 >>> obj1.value = 5 # updates other objects as well
200 >>> obj1.value = 5 # updates other objects as well
201 """
201 """
202 updating = False
202 updating = False
203 def __init__(self, *args):
203 def __init__(self, *args):
204 if len(args) < 2:
204 if len(args) < 2:
205 raise TypeError('At least two traitlets must be provided.')
205 raise TypeError('At least two traitlets must be provided.')
206 _validate_link(*args)
206 _validate_link(*args)
207
207
208 self.objects = {}
208 self.objects = {}
209
209
210 initial = getattr(args[0][0], args[0][1])
210 initial = getattr(args[0][0], args[0][1])
211 for obj, attr in args:
211 for obj, attr in args:
212 setattr(obj, attr, initial)
212 setattr(obj, attr, initial)
213
213
214 callback = self._make_closure(obj, attr)
214 callback = self._make_closure(obj, attr)
215 obj.on_trait_change(callback, attr)
215 obj.on_trait_change(callback, attr)
216 self.objects[(obj, attr)] = callback
216 self.objects[(obj, attr)] = callback
217
217
218 @contextlib.contextmanager
218 @contextlib.contextmanager
219 def _busy_updating(self):
219 def _busy_updating(self):
220 self.updating = True
220 self.updating = True
221 try:
221 try:
222 yield
222 yield
223 finally:
223 finally:
224 self.updating = False
224 self.updating = False
225
225
226 def _make_closure(self, sending_obj, sending_attr):
226 def _make_closure(self, sending_obj, sending_attr):
227 def update(name, old, new):
227 def update(name, old, new):
228 self._update(sending_obj, sending_attr, new)
228 self._update(sending_obj, sending_attr, new)
229 return update
229 return update
230
230
231 def _update(self, sending_obj, sending_attr, new):
231 def _update(self, sending_obj, sending_attr, new):
232 if self.updating:
232 if self.updating:
233 return
233 return
234 with self._busy_updating():
234 with self._busy_updating():
235 for obj, attr in self.objects.keys():
235 for obj, attr in self.objects.keys():
236 setattr(obj, attr, new)
236 setattr(obj, attr, new)
237
237
238 def unlink(self):
238 def unlink(self):
239 for key, callback in self.objects.items():
239 for key, callback in self.objects.items():
240 (obj, attr) = key
240 (obj, attr) = key
241 obj.on_trait_change(callback, attr, remove=True)
241 obj.on_trait_change(callback, attr, remove=True)
242
242
243 @skip_doctest
243 @skip_doctest
244 class directional_link(object):
244 class directional_link(object):
245 """Link the trait of a source object with traits of target objects.
245 """Link the trait of a source object with traits of target objects.
246
246
247 Parameters
247 Parameters
248 ----------
248 ----------
249 source : pair of object, name
249 source : pair of object, name
250 targets : pairs of objects/attributes
250 targets : pairs of objects/attributes
251
251
252 Examples
252 Examples
253 --------
253 --------
254
254
255 >>> c = directional_link((src, 'value'), (tgt1, 'value'), (tgt2, 'value'))
255 >>> c = directional_link((src, 'value'), (tgt1, 'value'), (tgt2, 'value'))
256 >>> src.value = 5 # updates target objects
256 >>> src.value = 5 # updates target objects
257 >>> tgt1.value = 6 # does not update other objects
257 >>> tgt1.value = 6 # does not update other objects
258 """
258 """
259 updating = False
259 updating = False
260
260
261 def __init__(self, source, *targets):
261 def __init__(self, source, *targets):
262 if len(targets) < 1:
262 if len(targets) < 1:
263 raise TypeError('At least two traitlets must be provided.')
263 raise TypeError('At least two traitlets must be provided.')
264 _validate_link(source, *targets)
264 _validate_link(source, *targets)
265 self.source = source
265 self.source = source
266 self.targets = targets
266 self.targets = targets
267
267
268 # Update current value
268 # Update current value
269 src_attr_value = getattr(source[0], source[1])
269 src_attr_value = getattr(source[0], source[1])
270 for obj, attr in targets:
270 for obj, attr in targets:
271 setattr(obj, attr, src_attr_value)
271 setattr(obj, attr, src_attr_value)
272
272
273 # Wire
273 # Wire
274 self.source[0].on_trait_change(self._update, self.source[1])
274 self.source[0].on_trait_change(self._update, self.source[1])
275
275
276 @contextlib.contextmanager
276 @contextlib.contextmanager
277 def _busy_updating(self):
277 def _busy_updating(self):
278 self.updating = True
278 self.updating = True
279 try:
279 try:
280 yield
280 yield
281 finally:
281 finally:
282 self.updating = False
282 self.updating = False
283
283
284 def _update(self, name, old, new):
284 def _update(self, name, old, new):
285 if self.updating:
285 if self.updating:
286 return
286 return
287 with self._busy_updating():
287 with self._busy_updating():
288 for obj, attr in self.targets:
288 for obj, attr in self.targets:
289 setattr(obj, attr, new)
289 setattr(obj, attr, new)
290
290
291 def unlink(self):
291 def unlink(self):
292 self.source[0].on_trait_change(self._update, self.source[1], remove=True)
292 self.source[0].on_trait_change(self._update, self.source[1], remove=True)
293 self.source = None
293 self.source = None
294 self.targets = []
294 self.targets = []
295
295
296 dlink = directional_link
296 dlink = directional_link
297
297
298
298
299 #-----------------------------------------------------------------------------
299 #-----------------------------------------------------------------------------
300 # Base TraitType for all traits
300 # Base TraitType for all traits
301 #-----------------------------------------------------------------------------
301 #-----------------------------------------------------------------------------
302
302
303
303
304 class TraitType(object):
304 class TraitType(object):
305 """A base class for all trait descriptors.
305 """A base class for all trait descriptors.
306
306
307 Notes
307 Notes
308 -----
308 -----
309 Our implementation of traits is based on Python's descriptor
309 Our implementation of traits is based on Python's descriptor
310 prototol. This class is the base class for all such descriptors. The
310 prototol. This class is the base class for all such descriptors. The
311 only magic we use is a custom metaclass for the main :class:`HasTraits`
311 only magic we use is a custom metaclass for the main :class:`HasTraits`
312 class that does the following:
312 class that does the following:
313
313
314 1. Sets the :attr:`name` attribute of every :class:`TraitType`
314 1. Sets the :attr:`name` attribute of every :class:`TraitType`
315 instance in the class dict to the name of the attribute.
315 instance in the class dict to the name of the attribute.
316 2. Sets the :attr:`this_class` attribute of every :class:`TraitType`
316 2. Sets the :attr:`this_class` attribute of every :class:`TraitType`
317 instance in the class dict to the *class* that declared the trait.
317 instance in the class dict to the *class* that declared the trait.
318 This is used by the :class:`This` trait to allow subclasses to
318 This is used by the :class:`This` trait to allow subclasses to
319 accept superclasses for :class:`This` values.
319 accept superclasses for :class:`This` values.
320 """
320 """
321
321
322
323 metadata = {}
322 metadata = {}
324 default_value = Undefined
323 default_value = Undefined
325 allow_none = False
324 allow_none = False
326 info_text = 'any value'
325 info_text = 'any value'
327
326
328 def __init__(self, default_value=NoDefaultSpecified, allow_none=None, **metadata):
327 def __init__(self, default_value=NoDefaultSpecified, allow_none=None, **metadata):
329 """Create a TraitType.
328 """Create a TraitType.
330 """
329 """
331 if default_value is not NoDefaultSpecified:
330 if default_value is not NoDefaultSpecified:
332 self.default_value = default_value
331 self.default_value = default_value
333 if allow_none is not None:
332 if allow_none is not None:
334 self.allow_none = allow_none
333 self.allow_none = allow_none
335
334
336 if 'default' in metadata:
335 if 'default' in metadata:
337 # Warn the user that they probably meant default_value.
336 # Warn the user that they probably meant default_value.
338 warn(
337 warn(
339 "Parameter 'default' passed to TraitType. "
338 "Parameter 'default' passed to TraitType. "
340 "Did you mean 'default_value'?"
339 "Did you mean 'default_value'?"
341 )
340 )
342
341
343 if len(metadata) > 0:
342 if len(metadata) > 0:
344 if len(self.metadata) > 0:
343 if len(self.metadata) > 0:
345 self._metadata = self.metadata.copy()
344 self._metadata = self.metadata.copy()
346 self._metadata.update(metadata)
345 self._metadata.update(metadata)
347 else:
346 else:
348 self._metadata = metadata
347 self._metadata = metadata
349 else:
348 else:
350 self._metadata = self.metadata
349 self._metadata = self.metadata
351
350
352 self.init()
351 self.init()
353
352
354 def init(self):
353 def init(self):
355 pass
354 pass
356
355
357 def get_default_value(self):
356 def get_default_value(self):
358 """Create a new instance of the default value."""
357 """Create a new instance of the default value."""
359 return self.default_value
358 return self.default_value
360
359
361 def instance_init(self):
360 def instance_init(self):
362 """Part of the initialization which may depends on the underlying
361 """Part of the initialization which may depends on the underlying
363 HasTraits instance.
362 HasTraits instance.
364
363
365 It is typically overloaded for specific trait types.
364 It is typically overloaded for specific trait types.
366
365
367 This method is called by :meth:`HasTraits.__new__` and in the
366 This method is called by :meth:`HasTraits.__new__` and in the
368 :meth:`TraitType.instance_init` method of trait types holding
367 :meth:`TraitType.instance_init` method of trait types holding
369 other trait types.
368 other trait types.
370 """
369 """
371 pass
370 pass
372
371
373 def init_default_value(self, obj):
372 def init_default_value(self, obj):
374 """Instantiate the default value for the trait type.
373 """Instantiate the default value for the trait type.
375
374
376 This method is called by :meth:`TraitType.set_default_value` in the
375 This method is called by :meth:`TraitType.set_default_value` in the
377 case a default value is provided at construction time or later when
376 case a default value is provided at construction time or later when
378 accessing the trait value for the first time in
377 accessing the trait value for the first time in
379 :meth:`HasTraits.__get__`.
378 :meth:`HasTraits.__get__`.
380 """
379 """
381 value = self.get_default_value()
380 value = self.get_default_value()
382 value = self._validate(obj, value)
381 value = self._validate(obj, value)
383 obj._trait_values[self.name] = value
382 obj._trait_values[self.name] = value
384 return value
383 return value
385
384
386 def set_default_value(self, obj):
385 def set_default_value(self, obj):
387 """Set the default value on a per instance basis.
386 """Set the default value on a per instance basis.
388
387
389 This method is called by :meth:`HasTraits.__new__` to instantiate and
388 This method is called by :meth:`HasTraits.__new__` to instantiate and
390 validate the default value. The creation and validation of
389 validate the default value. The creation and validation of
391 default values must be delayed until the parent :class:`HasTraits`
390 default values must be delayed until the parent :class:`HasTraits`
392 class has been instantiated.
391 class has been instantiated.
393 Parameters
392 Parameters
394 ----------
393 ----------
395 obj : :class:`HasTraits` instance
394 obj : :class:`HasTraits` instance
396 The parent :class:`HasTraits` instance that has just been
395 The parent :class:`HasTraits` instance that has just been
397 created.
396 created.
398 """
397 """
399 # Check for a deferred initializer defined in the same class as the
398 # Check for a deferred initializer defined in the same class as the
400 # trait declaration or above.
399 # trait declaration or above.
401 mro = type(obj).mro()
400 mro = type(obj).mro()
402 meth_name = '_%s_default' % self.name
401 meth_name = '_%s_default' % self.name
403 for cls in mro[:mro.index(self.this_class)+1]:
402 for cls in mro[:mro.index(self.this_class)+1]:
404 if meth_name in cls.__dict__:
403 if meth_name in cls.__dict__:
405 break
404 break
406 else:
405 else:
407 # We didn't find one. Do static initialization.
406 # We didn't find one. Do static initialization.
408 self.init_default_value(obj)
407 self.init_default_value(obj)
409 return
408 return
410 # Complete the dynamic initialization.
409 # Complete the dynamic initialization.
411 obj._trait_dyn_inits[self.name] = meth_name
410 obj._trait_dyn_inits[self.name] = meth_name
412
411
413 def __get__(self, obj, cls=None):
412 def __get__(self, obj, cls=None):
414 """Get the value of the trait by self.name for the instance.
413 """Get the value of the trait by self.name for the instance.
415
414
416 Default values are instantiated when :meth:`HasTraits.__new__`
415 Default values are instantiated when :meth:`HasTraits.__new__`
417 is called. Thus by the time this method gets called either the
416 is called. Thus by the time this method gets called either the
418 default value or a user defined value (they called :meth:`__set__`)
417 default value or a user defined value (they called :meth:`__set__`)
419 is in the :class:`HasTraits` instance.
418 is in the :class:`HasTraits` instance.
420 """
419 """
421 if obj is None:
420 if obj is None:
422 return self
421 return self
423 else:
422 else:
424 try:
423 try:
425 value = obj._trait_values[self.name]
424 value = obj._trait_values[self.name]
426 except KeyError:
425 except KeyError:
427 # Check for a dynamic initializer.
426 # Check for a dynamic initializer.
428 if self.name in obj._trait_dyn_inits:
427 if self.name in obj._trait_dyn_inits:
429 method = getattr(obj, obj._trait_dyn_inits[self.name])
428 method = getattr(obj, obj._trait_dyn_inits[self.name])
430 value = method()
429 value = method()
431 # FIXME: Do we really validate here?
430 # FIXME: Do we really validate here?
432 value = self._validate(obj, value)
431 value = self._validate(obj, value)
433 obj._trait_values[self.name] = value
432 obj._trait_values[self.name] = value
434 return value
433 return value
435 else:
434 else:
436 return self.init_default_value(obj)
435 return self.init_default_value(obj)
437 except Exception:
436 except Exception:
438 # HasTraits should call set_default_value to populate
437 # HasTraits should call set_default_value to populate
439 # this. So this should never be reached.
438 # this. So this should never be reached.
440 raise TraitError('Unexpected error in TraitType: '
439 raise TraitError('Unexpected error in TraitType: '
441 'default value not set properly')
440 'default value not set properly')
442 else:
441 else:
443 return value
442 return value
444
443
445 def __set__(self, obj, value):
444 def __set__(self, obj, value):
446 new_value = self._validate(obj, value)
445 new_value = self._validate(obj, value)
447 try:
446 try:
448 old_value = obj._trait_values[self.name]
447 old_value = obj._trait_values[self.name]
449 except KeyError:
448 except KeyError:
450 old_value = None
449 old_value = Undefined
451
450
452 obj._trait_values[self.name] = new_value
451 obj._trait_values[self.name] = new_value
453 try:
452 try:
454 silent = bool(old_value == new_value)
453 silent = bool(old_value == new_value)
455 except:
454 except:
456 # if there is an error in comparing, default to notify
455 # if there is an error in comparing, default to notify
457 silent = False
456 silent = False
458 if silent is not True:
457 if silent is not True:
459 # we explicitly compare silent to True just in case the equality
458 # we explicitly compare silent to True just in case the equality
460 # comparison above returns something other than True/False
459 # comparison above returns something other than True/False
461 obj._notify_trait(self.name, old_value, new_value)
460 obj._notify_trait(self.name, old_value, new_value)
462
461
463 def _validate(self, obj, value):
462 def _validate(self, obj, value):
464 if value is None and self.allow_none:
463 if value is None and self.allow_none:
465 return value
464 return value
466 if hasattr(self, 'validate'):
465 if hasattr(self, 'validate'):
467 value = self.validate(obj, value)
466 value = self.validate(obj, value)
468 try:
467 if obj._cross_validation_lock is False:
469 obj_validate = getattr(obj, '_%s_validate' % self.name)
468 value = self._cross_validate(obj, value)
470 except (AttributeError, RuntimeError):
469 return value
471 # Qt mixins raise RuntimeError on missing attrs accessed before __init__
470
472 pass
471 def _cross_validate(self, obj, value):
473 else:
472 if hasattr(obj, '_%s_validate' % self.name):
474 value = obj_validate(value, self)
473 cross_validate = getattr(obj, '_%s_validate' % self.name)
474 value = cross_validate(value, self)
475 return value
475 return value
476
476
477 def __or__(self, other):
477 def __or__(self, other):
478 if isinstance(other, Union):
478 if isinstance(other, Union):
479 return Union([self] + other.trait_types)
479 return Union([self] + other.trait_types)
480 else:
480 else:
481 return Union([self, other])
481 return Union([self, other])
482
482
483 def info(self):
483 def info(self):
484 return self.info_text
484 return self.info_text
485
485
486 def error(self, obj, value):
486 def error(self, obj, value):
487 if obj is not None:
487 if obj is not None:
488 e = "The '%s' trait of %s instance must be %s, but a value of %s was specified." \
488 e = "The '%s' trait of %s instance must be %s, but a value of %s was specified." \
489 % (self.name, class_of(obj),
489 % (self.name, class_of(obj),
490 self.info(), repr_type(value))
490 self.info(), repr_type(value))
491 else:
491 else:
492 e = "The '%s' trait must be %s, but a value of %r was specified." \
492 e = "The '%s' trait must be %s, but a value of %r was specified." \
493 % (self.name, self.info(), repr_type(value))
493 % (self.name, self.info(), repr_type(value))
494 raise TraitError(e)
494 raise TraitError(e)
495
495
496 def get_metadata(self, key, default=None):
496 def get_metadata(self, key, default=None):
497 return getattr(self, '_metadata', {}).get(key, default)
497 return getattr(self, '_metadata', {}).get(key, default)
498
498
499 def set_metadata(self, key, value):
499 def set_metadata(self, key, value):
500 getattr(self, '_metadata', {})[key] = value
500 getattr(self, '_metadata', {})[key] = value
501
501
502
502
503 #-----------------------------------------------------------------------------
503 #-----------------------------------------------------------------------------
504 # The HasTraits implementation
504 # The HasTraits implementation
505 #-----------------------------------------------------------------------------
505 #-----------------------------------------------------------------------------
506
506
507
507
508 class MetaHasTraits(type):
508 class MetaHasTraits(type):
509 """A metaclass for HasTraits.
509 """A metaclass for HasTraits.
510
510
511 This metaclass makes sure that any TraitType class attributes are
511 This metaclass makes sure that any TraitType class attributes are
512 instantiated and sets their name attribute.
512 instantiated and sets their name attribute.
513 """
513 """
514
514
515 def __new__(mcls, name, bases, classdict):
515 def __new__(mcls, name, bases, classdict):
516 """Create the HasTraits class.
516 """Create the HasTraits class.
517
517
518 This instantiates all TraitTypes in the class dict and sets their
518 This instantiates all TraitTypes in the class dict and sets their
519 :attr:`name` attribute.
519 :attr:`name` attribute.
520 """
520 """
521 # print "MetaHasTraitlets (mcls, name): ", mcls, name
521 # print "MetaHasTraitlets (mcls, name): ", mcls, name
522 # print "MetaHasTraitlets (bases): ", bases
522 # print "MetaHasTraitlets (bases): ", bases
523 # print "MetaHasTraitlets (classdict): ", classdict
523 # print "MetaHasTraitlets (classdict): ", classdict
524 for k,v in iteritems(classdict):
524 for k,v in iteritems(classdict):
525 if isinstance(v, TraitType):
525 if isinstance(v, TraitType):
526 v.name = k
526 v.name = k
527 elif inspect.isclass(v):
527 elif inspect.isclass(v):
528 if issubclass(v, TraitType):
528 if issubclass(v, TraitType):
529 vinst = v()
529 vinst = v()
530 vinst.name = k
530 vinst.name = k
531 classdict[k] = vinst
531 classdict[k] = vinst
532 return super(MetaHasTraits, mcls).__new__(mcls, name, bases, classdict)
532 return super(MetaHasTraits, mcls).__new__(mcls, name, bases, classdict)
533
533
534 def __init__(cls, name, bases, classdict):
534 def __init__(cls, name, bases, classdict):
535 """Finish initializing the HasTraits class.
535 """Finish initializing the HasTraits class.
536
536
537 This sets the :attr:`this_class` attribute of each TraitType in the
537 This sets the :attr:`this_class` attribute of each TraitType in the
538 class dict to the newly created class ``cls``.
538 class dict to the newly created class ``cls``.
539 """
539 """
540 for k, v in iteritems(classdict):
540 for k, v in iteritems(classdict):
541 if isinstance(v, TraitType):
541 if isinstance(v, TraitType):
542 v.this_class = cls
542 v.this_class = cls
543 super(MetaHasTraits, cls).__init__(name, bases, classdict)
543 super(MetaHasTraits, cls).__init__(name, bases, classdict)
544
544
545
545 class HasTraits(py3compat.with_metaclass(MetaHasTraits, object)):
546 class HasTraits(py3compat.with_metaclass(MetaHasTraits, object)):
546
547
547 def __new__(cls, *args, **kw):
548 def __new__(cls, *args, **kw):
548 # This is needed because object.__new__ only accepts
549 # This is needed because object.__new__ only accepts
549 # the cls argument.
550 # the cls argument.
550 new_meth = super(HasTraits, cls).__new__
551 new_meth = super(HasTraits, cls).__new__
551 if new_meth is object.__new__:
552 if new_meth is object.__new__:
552 inst = new_meth(cls)
553 inst = new_meth(cls)
553 else:
554 else:
554 inst = new_meth(cls, **kw)
555 inst = new_meth(cls, **kw)
555 inst._trait_values = {}
556 inst._trait_values = {}
556 inst._trait_notifiers = {}
557 inst._trait_notifiers = {}
557 inst._trait_dyn_inits = {}
558 inst._trait_dyn_inits = {}
559 inst._cross_validation_lock = True
558 # Here we tell all the TraitType instances to set their default
560 # Here we tell all the TraitType instances to set their default
559 # values on the instance.
561 # values on the instance.
560 for key in dir(cls):
562 for key in dir(cls):
561 # Some descriptors raise AttributeError like zope.interface's
563 # Some descriptors raise AttributeError like zope.interface's
562 # __provides__ attributes even though they exist. This causes
564 # __provides__ attributes even though they exist. This causes
563 # AttributeErrors even though they are listed in dir(cls).
565 # AttributeErrors even though they are listed in dir(cls).
564 try:
566 try:
565 value = getattr(cls, key)
567 value = getattr(cls, key)
566 except AttributeError:
568 except AttributeError:
567 pass
569 pass
568 else:
570 else:
569 if isinstance(value, TraitType):
571 if isinstance(value, TraitType):
570 value.instance_init()
572 value.instance_init()
571 if key not in kw:
573 if key not in kw:
572 value.set_default_value(inst)
574 value.set_default_value(inst)
573
575 inst._cross_validation_lock = False
574 return inst
576 return inst
575
577
576 def __init__(self, *args, **kw):
578 def __init__(self, *args, **kw):
577 # Allow trait values to be set using keyword arguments.
579 # Allow trait values to be set using keyword arguments.
578 # We need to use setattr for this to trigger validation and
580 # We need to use setattr for this to trigger validation and
579 # notifications.
581 # notifications.
580
581 with self.hold_trait_notifications():
582 with self.hold_trait_notifications():
582 for key, value in iteritems(kw):
583 for key, value in iteritems(kw):
583 setattr(self, key, value)
584 setattr(self, key, value)
584
585
585 @contextlib.contextmanager
586 @contextlib.contextmanager
586 def hold_trait_notifications(self):
587 def hold_trait_notifications(self):
587 """Context manager for bundling trait change notifications
588 """Context manager for bundling trait change notifications and cross
588
589 validation.
589 Use this when doing multiple trait assignments (init, config),
590
590 to avoid race conditions in trait notifiers requesting other trait values.
591 Use this when doing multiple trait assignments (init, config), to avoid
592 race conditions in trait notifiers requesting other trait values.
591 All trait notifications will fire after all values have been assigned.
593 All trait notifications will fire after all values have been assigned.
592 """
594 """
593 _notify_trait = self._notify_trait
595 if self._cross_validation_lock is True:
594 notifications = []
595 self._notify_trait = lambda *a: notifications.append(a)
596
597 try:
598 yield
596 yield
599 finally:
597 return
600 self._notify_trait = _notify_trait
598 else:
601 if isinstance(_notify_trait, types.MethodType):
599 self._cross_validation_lock = True
602 # FIXME: remove when support is bumped to 3.4.
600 cache = {}
603 # when original method is restored,
601 notifications = {}
604 # remove the redundant value from __dict__
602 _notify_trait = self._notify_trait
605 # (only used to preserve pickleability on Python < 3.4)
603
606 self.__dict__.pop('_notify_trait', None)
604 def cache_values(*a):
607 # trigger delayed notifications
605 cache[a[0]] = a
608 for args in notifications:
606
609 self._notify_trait(*args)
607 def hold_notifications(*a):
608 notifications[a[0]] = a
609
610 self._notify_trait = cache_values
611
612 try:
613 yield
614 finally:
615 try:
616 self._notify_trait = hold_notifications
617 for name in cache:
618 if hasattr(self, '_%s_validate' % name):
619 cross_validate = getattr(self, '_%s_validate' % name)
620 setattr(self, name, cross_validate(getattr(self, name), self))
621 except TraitError as e:
622 self._notify_trait = lambda *x: None
623 for name in cache:
624 if cache[name][1] is not Undefined:
625 setattr(self, name, cache[name][1])
626 else:
627 delattr(self, name)
628 cache = {}
629 notifications = {}
630 raise e
631 finally:
632 self._notify_trait = _notify_trait
633 self._cross_validation_lock = False
634 if isinstance(_notify_trait, types.MethodType):
635 # FIXME: remove when support is bumped to 3.4.
636 # when original method is restored,
637 # remove the redundant value from __dict__
638 # (only used to preserve pickleability on Python < 3.4)
639 self.__dict__.pop('_notify_trait', None)
640 # trigger delayed notifications
641 for v in dict(cache, **notifications).values():
642 self._notify_trait(*v)
610
643
611 def _notify_trait(self, name, old_value, new_value):
644 def _notify_trait(self, name, old_value, new_value):
612
645
613 # First dynamic ones
646 # First dynamic ones
614 callables = []
647 callables = []
615 callables.extend(self._trait_notifiers.get(name,[]))
648 callables.extend(self._trait_notifiers.get(name,[]))
616 callables.extend(self._trait_notifiers.get('anytrait',[]))
649 callables.extend(self._trait_notifiers.get('anytrait',[]))
617
650
618 # Now static ones
651 # Now static ones
619 try:
652 try:
620 cb = getattr(self, '_%s_changed' % name)
653 cb = getattr(self, '_%s_changed' % name)
621 except:
654 except:
622 pass
655 pass
623 else:
656 else:
624 callables.append(cb)
657 callables.append(cb)
625
658
626 # Call them all now
659 # Call them all now
627 for c in callables:
660 for c in callables:
628 # Traits catches and logs errors here. I allow them to raise
661 # Traits catches and logs errors here. I allow them to raise
629 if callable(c):
662 if callable(c):
630 argspec = getargspec(c)
663 argspec = getargspec(c)
631
664
632 nargs = len(argspec[0])
665 nargs = len(argspec[0])
633 # Bound methods have an additional 'self' argument
666 # Bound methods have an additional 'self' argument
634 # I don't know how to treat unbound methods, but they
667 # I don't know how to treat unbound methods, but they
635 # can't really be used for callbacks.
668 # can't really be used for callbacks.
636 if isinstance(c, types.MethodType):
669 if isinstance(c, types.MethodType):
637 offset = -1
670 offset = -1
638 else:
671 else:
639 offset = 0
672 offset = 0
640 if nargs + offset == 0:
673 if nargs + offset == 0:
641 c()
674 c()
642 elif nargs + offset == 1:
675 elif nargs + offset == 1:
643 c(name)
676 c(name)
644 elif nargs + offset == 2:
677 elif nargs + offset == 2:
645 c(name, new_value)
678 c(name, new_value)
646 elif nargs + offset == 3:
679 elif nargs + offset == 3:
647 c(name, old_value, new_value)
680 c(name, old_value, new_value)
648 else:
681 else:
649 raise TraitError('a trait changed callback '
682 raise TraitError('a trait changed callback '
650 'must have 0-3 arguments.')
683 'must have 0-3 arguments.')
651 else:
684 else:
652 raise TraitError('a trait changed callback '
685 raise TraitError('a trait changed callback '
653 'must be callable.')
686 'must be callable.')
654
687
655
688
656 def _add_notifiers(self, handler, name):
689 def _add_notifiers(self, handler, name):
657 if name not in self._trait_notifiers:
690 if name not in self._trait_notifiers:
658 nlist = []
691 nlist = []
659 self._trait_notifiers[name] = nlist
692 self._trait_notifiers[name] = nlist
660 else:
693 else:
661 nlist = self._trait_notifiers[name]
694 nlist = self._trait_notifiers[name]
662 if handler not in nlist:
695 if handler not in nlist:
663 nlist.append(handler)
696 nlist.append(handler)
664
697
665 def _remove_notifiers(self, handler, name):
698 def _remove_notifiers(self, handler, name):
666 if name in self._trait_notifiers:
699 if name in self._trait_notifiers:
667 nlist = self._trait_notifiers[name]
700 nlist = self._trait_notifiers[name]
668 try:
701 try:
669 index = nlist.index(handler)
702 index = nlist.index(handler)
670 except ValueError:
703 except ValueError:
671 pass
704 pass
672 else:
705 else:
673 del nlist[index]
706 del nlist[index]
674
707
675 def on_trait_change(self, handler, name=None, remove=False):
708 def on_trait_change(self, handler, name=None, remove=False):
676 """Setup a handler to be called when a trait changes.
709 """Setup a handler to be called when a trait changes.
677
710
678 This is used to setup dynamic notifications of trait changes.
711 This is used to setup dynamic notifications of trait changes.
679
712
680 Static handlers can be created by creating methods on a HasTraits
713 Static handlers can be created by creating methods on a HasTraits
681 subclass with the naming convention '_[traitname]_changed'. Thus,
714 subclass with the naming convention '_[traitname]_changed'. Thus,
682 to create static handler for the trait 'a', create the method
715 to create static handler for the trait 'a', create the method
683 _a_changed(self, name, old, new) (fewer arguments can be used, see
716 _a_changed(self, name, old, new) (fewer arguments can be used, see
684 below).
717 below).
685
718
686 Parameters
719 Parameters
687 ----------
720 ----------
688 handler : callable
721 handler : callable
689 A callable that is called when a trait changes. Its
722 A callable that is called when a trait changes. Its
690 signature can be handler(), handler(name), handler(name, new)
723 signature can be handler(), handler(name), handler(name, new)
691 or handler(name, old, new).
724 or handler(name, old, new).
692 name : list, str, None
725 name : list, str, None
693 If None, the handler will apply to all traits. If a list
726 If None, the handler will apply to all traits. If a list
694 of str, handler will apply to all names in the list. If a
727 of str, handler will apply to all names in the list. If a
695 str, the handler will apply just to that name.
728 str, the handler will apply just to that name.
696 remove : bool
729 remove : bool
697 If False (the default), then install the handler. If True
730 If False (the default), then install the handler. If True
698 then unintall it.
731 then unintall it.
699 """
732 """
700 if remove:
733 if remove:
701 names = parse_notifier_name(name)
734 names = parse_notifier_name(name)
702 for n in names:
735 for n in names:
703 self._remove_notifiers(handler, n)
736 self._remove_notifiers(handler, n)
704 else:
737 else:
705 names = parse_notifier_name(name)
738 names = parse_notifier_name(name)
706 for n in names:
739 for n in names:
707 self._add_notifiers(handler, n)
740 self._add_notifiers(handler, n)
708
741
709 @classmethod
742 @classmethod
710 def class_trait_names(cls, **metadata):
743 def class_trait_names(cls, **metadata):
711 """Get a list of all the names of this class' traits.
744 """Get a list of all the names of this class' traits.
712
745
713 This method is just like the :meth:`trait_names` method,
746 This method is just like the :meth:`trait_names` method,
714 but is unbound.
747 but is unbound.
715 """
748 """
716 return cls.class_traits(**metadata).keys()
749 return cls.class_traits(**metadata).keys()
717
750
718 @classmethod
751 @classmethod
719 def class_traits(cls, **metadata):
752 def class_traits(cls, **metadata):
720 """Get a `dict` of all the traits of this class. The dictionary
753 """Get a `dict` of all the traits of this class. The dictionary
721 is keyed on the name and the values are the TraitType objects.
754 is keyed on the name and the values are the TraitType objects.
722
755
723 This method is just like the :meth:`traits` method, but is unbound.
756 This method is just like the :meth:`traits` method, but is unbound.
724
757
725 The TraitTypes returned don't know anything about the values
758 The TraitTypes returned don't know anything about the values
726 that the various HasTrait's instances are holding.
759 that the various HasTrait's instances are holding.
727
760
728 The metadata kwargs allow functions to be passed in which
761 The metadata kwargs allow functions to be passed in which
729 filter traits based on metadata values. The functions should
762 filter traits based on metadata values. The functions should
730 take a single value as an argument and return a boolean. If
763 take a single value as an argument and return a boolean. If
731 any function returns False, then the trait is not included in
764 any function returns False, then the trait is not included in
732 the output. This does not allow for any simple way of
765 the output. This does not allow for any simple way of
733 testing that a metadata name exists and has any
766 testing that a metadata name exists and has any
734 value because get_metadata returns None if a metadata key
767 value because get_metadata returns None if a metadata key
735 doesn't exist.
768 doesn't exist.
736 """
769 """
737 traits = dict([memb for memb in getmembers(cls) if
770 traits = dict([memb for memb in getmembers(cls) if
738 isinstance(memb[1], TraitType)])
771 isinstance(memb[1], TraitType)])
739
772
740 if len(metadata) == 0:
773 if len(metadata) == 0:
741 return traits
774 return traits
742
775
743 for meta_name, meta_eval in metadata.items():
776 for meta_name, meta_eval in metadata.items():
744 if type(meta_eval) is not FunctionType:
777 if type(meta_eval) is not FunctionType:
745 metadata[meta_name] = _SimpleTest(meta_eval)
778 metadata[meta_name] = _SimpleTest(meta_eval)
746
779
747 result = {}
780 result = {}
748 for name, trait in traits.items():
781 for name, trait in traits.items():
749 for meta_name, meta_eval in metadata.items():
782 for meta_name, meta_eval in metadata.items():
750 if not meta_eval(trait.get_metadata(meta_name)):
783 if not meta_eval(trait.get_metadata(meta_name)):
751 break
784 break
752 else:
785 else:
753 result[name] = trait
786 result[name] = trait
754
787
755 return result
788 return result
756
789
757 def trait_names(self, **metadata):
790 def trait_names(self, **metadata):
758 """Get a list of all the names of this class' traits."""
791 """Get a list of all the names of this class' traits."""
759 return self.traits(**metadata).keys()
792 return self.traits(**metadata).keys()
760
793
761 def traits(self, **metadata):
794 def traits(self, **metadata):
762 """Get a `dict` of all the traits of this class. The dictionary
795 """Get a `dict` of all the traits of this class. The dictionary
763 is keyed on the name and the values are the TraitType objects.
796 is keyed on the name and the values are the TraitType objects.
764
797
765 The TraitTypes returned don't know anything about the values
798 The TraitTypes returned don't know anything about the values
766 that the various HasTrait's instances are holding.
799 that the various HasTrait's instances are holding.
767
800
768 The metadata kwargs allow functions to be passed in which
801 The metadata kwargs allow functions to be passed in which
769 filter traits based on metadata values. The functions should
802 filter traits based on metadata values. The functions should
770 take a single value as an argument and return a boolean. If
803 take a single value as an argument and return a boolean. If
771 any function returns False, then the trait is not included in
804 any function returns False, then the trait is not included in
772 the output. This does not allow for any simple way of
805 the output. This does not allow for any simple way of
773 testing that a metadata name exists and has any
806 testing that a metadata name exists and has any
774 value because get_metadata returns None if a metadata key
807 value because get_metadata returns None if a metadata key
775 doesn't exist.
808 doesn't exist.
776 """
809 """
777 traits = dict([memb for memb in getmembers(self.__class__) if
810 traits = dict([memb for memb in getmembers(self.__class__) if
778 isinstance(memb[1], TraitType)])
811 isinstance(memb[1], TraitType)])
779
812
780 if len(metadata) == 0:
813 if len(metadata) == 0:
781 return traits
814 return traits
782
815
783 for meta_name, meta_eval in metadata.items():
816 for meta_name, meta_eval in metadata.items():
784 if type(meta_eval) is not FunctionType:
817 if type(meta_eval) is not FunctionType:
785 metadata[meta_name] = _SimpleTest(meta_eval)
818 metadata[meta_name] = _SimpleTest(meta_eval)
786
819
787 result = {}
820 result = {}
788 for name, trait in traits.items():
821 for name, trait in traits.items():
789 for meta_name, meta_eval in metadata.items():
822 for meta_name, meta_eval in metadata.items():
790 if not meta_eval(trait.get_metadata(meta_name)):
823 if not meta_eval(trait.get_metadata(meta_name)):
791 break
824 break
792 else:
825 else:
793 result[name] = trait
826 result[name] = trait
794
827
795 return result
828 return result
796
829
797 def trait_metadata(self, traitname, key, default=None):
830 def trait_metadata(self, traitname, key, default=None):
798 """Get metadata values for trait by key."""
831 """Get metadata values for trait by key."""
799 try:
832 try:
800 trait = getattr(self.__class__, traitname)
833 trait = getattr(self.__class__, traitname)
801 except AttributeError:
834 except AttributeError:
802 raise TraitError("Class %s does not have a trait named %s" %
835 raise TraitError("Class %s does not have a trait named %s" %
803 (self.__class__.__name__, traitname))
836 (self.__class__.__name__, traitname))
804 else:
837 else:
805 return trait.get_metadata(key, default)
838 return trait.get_metadata(key, default)
806
839
807 def add_trait(self, traitname, trait):
840 def add_trait(self, traitname, trait):
808 """Dynamically add a trait attribute to the HasTraits instance."""
841 """Dynamically add a trait attribute to the HasTraits instance."""
809 self.__class__ = type(self.__class__.__name__, (self.__class__,),
842 self.__class__ = type(self.__class__.__name__, (self.__class__,),
810 {traitname: trait})
843 {traitname: trait})
811 trait.set_default_value(self)
844 trait.set_default_value(self)
812
845
813 #-----------------------------------------------------------------------------
846 #-----------------------------------------------------------------------------
814 # Actual TraitTypes implementations/subclasses
847 # Actual TraitTypes implementations/subclasses
815 #-----------------------------------------------------------------------------
848 #-----------------------------------------------------------------------------
816
849
817 #-----------------------------------------------------------------------------
850 #-----------------------------------------------------------------------------
818 # TraitTypes subclasses for handling classes and instances of classes
851 # TraitTypes subclasses for handling classes and instances of classes
819 #-----------------------------------------------------------------------------
852 #-----------------------------------------------------------------------------
820
853
821
854
822 class ClassBasedTraitType(TraitType):
855 class ClassBasedTraitType(TraitType):
823 """
856 """
824 A trait with error reporting and string -> type resolution for Type,
857 A trait with error reporting and string -> type resolution for Type,
825 Instance and This.
858 Instance and This.
826 """
859 """
827
860
828 def _resolve_string(self, string):
861 def _resolve_string(self, string):
829 """
862 """
830 Resolve a string supplied for a type into an actual object.
863 Resolve a string supplied for a type into an actual object.
831 """
864 """
832 return import_item(string)
865 return import_item(string)
833
866
834 def error(self, obj, value):
867 def error(self, obj, value):
835 kind = type(value)
868 kind = type(value)
836 if (not py3compat.PY3) and kind is InstanceType:
869 if (not py3compat.PY3) and kind is InstanceType:
837 msg = 'class %s' % value.__class__.__name__
870 msg = 'class %s' % value.__class__.__name__
838 else:
871 else:
839 msg = '%s (i.e. %s)' % ( str( kind )[1:-1], repr( value ) )
872 msg = '%s (i.e. %s)' % ( str( kind )[1:-1], repr( value ) )
840
873
841 if obj is not None:
874 if obj is not None:
842 e = "The '%s' trait of %s instance must be %s, but a value of %s was specified." \
875 e = "The '%s' trait of %s instance must be %s, but a value of %s was specified." \
843 % (self.name, class_of(obj),
876 % (self.name, class_of(obj),
844 self.info(), msg)
877 self.info(), msg)
845 else:
878 else:
846 e = "The '%s' trait must be %s, but a value of %r was specified." \
879 e = "The '%s' trait must be %s, but a value of %r was specified." \
847 % (self.name, self.info(), msg)
880 % (self.name, self.info(), msg)
848
881
849 raise TraitError(e)
882 raise TraitError(e)
850
883
851
884
852 class Type(ClassBasedTraitType):
885 class Type(ClassBasedTraitType):
853 """A trait whose value must be a subclass of a specified class."""
886 """A trait whose value must be a subclass of a specified class."""
854
887
855 def __init__ (self, default_value=None, klass=None, allow_none=False,
888 def __init__ (self, default_value=None, klass=None, allow_none=False,
856 **metadata):
889 **metadata):
857 """Construct a Type trait
890 """Construct a Type trait
858
891
859 A Type trait specifies that its values must be subclasses of
892 A Type trait specifies that its values must be subclasses of
860 a particular class.
893 a particular class.
861
894
862 If only ``default_value`` is given, it is used for the ``klass`` as
895 If only ``default_value`` is given, it is used for the ``klass`` as
863 well.
896 well.
864
897
865 Parameters
898 Parameters
866 ----------
899 ----------
867 default_value : class, str or None
900 default_value : class, str or None
868 The default value must be a subclass of klass. If an str,
901 The default value must be a subclass of klass. If an str,
869 the str must be a fully specified class name, like 'foo.bar.Bah'.
902 the str must be a fully specified class name, like 'foo.bar.Bah'.
870 The string is resolved into real class, when the parent
903 The string is resolved into real class, when the parent
871 :class:`HasTraits` class is instantiated.
904 :class:`HasTraits` class is instantiated.
872 klass : class, str, None
905 klass : class, str, None
873 Values of this trait must be a subclass of klass. The klass
906 Values of this trait must be a subclass of klass. The klass
874 may be specified in a string like: 'foo.bar.MyClass'.
907 may be specified in a string like: 'foo.bar.MyClass'.
875 The string is resolved into real class, when the parent
908 The string is resolved into real class, when the parent
876 :class:`HasTraits` class is instantiated.
909 :class:`HasTraits` class is instantiated.
877 allow_none : bool [ default True ]
910 allow_none : bool [ default True ]
878 Indicates whether None is allowed as an assignable value. Even if
911 Indicates whether None is allowed as an assignable value. Even if
879 ``False``, the default value may be ``None``.
912 ``False``, the default value may be ``None``.
880 """
913 """
881 if default_value is None:
914 if default_value is None:
882 if klass is None:
915 if klass is None:
883 klass = object
916 klass = object
884 elif klass is None:
917 elif klass is None:
885 klass = default_value
918 klass = default_value
886
919
887 if not (inspect.isclass(klass) or isinstance(klass, py3compat.string_types)):
920 if not (inspect.isclass(klass) or isinstance(klass, py3compat.string_types)):
888 raise TraitError("A Type trait must specify a class.")
921 raise TraitError("A Type trait must specify a class.")
889
922
890 self.klass = klass
923 self.klass = klass
891
924
892 super(Type, self).__init__(default_value, allow_none=allow_none, **metadata)
925 super(Type, self).__init__(default_value, allow_none=allow_none, **metadata)
893
926
894 def validate(self, obj, value):
927 def validate(self, obj, value):
895 """Validates that the value is a valid object instance."""
928 """Validates that the value is a valid object instance."""
896 if isinstance(value, py3compat.string_types):
929 if isinstance(value, py3compat.string_types):
897 try:
930 try:
898 value = self._resolve_string(value)
931 value = self._resolve_string(value)
899 except ImportError:
932 except ImportError:
900 raise TraitError("The '%s' trait of %s instance must be a type, but "
933 raise TraitError("The '%s' trait of %s instance must be a type, but "
901 "%r could not be imported" % (self.name, obj, value))
934 "%r could not be imported" % (self.name, obj, value))
902 try:
935 try:
903 if issubclass(value, self.klass):
936 if issubclass(value, self.klass):
904 return value
937 return value
905 except:
938 except:
906 pass
939 pass
907
940
908 self.error(obj, value)
941 self.error(obj, value)
909
942
910 def info(self):
943 def info(self):
911 """ Returns a description of the trait."""
944 """ Returns a description of the trait."""
912 if isinstance(self.klass, py3compat.string_types):
945 if isinstance(self.klass, py3compat.string_types):
913 klass = self.klass
946 klass = self.klass
914 else:
947 else:
915 klass = self.klass.__name__
948 klass = self.klass.__name__
916 result = 'a subclass of ' + klass
949 result = 'a subclass of ' + klass
917 if self.allow_none:
950 if self.allow_none:
918 return result + ' or None'
951 return result + ' or None'
919 return result
952 return result
920
953
921 def instance_init(self):
954 def instance_init(self):
922 self._resolve_classes()
955 self._resolve_classes()
923 super(Type, self).instance_init()
956 super(Type, self).instance_init()
924
957
925 def _resolve_classes(self):
958 def _resolve_classes(self):
926 if isinstance(self.klass, py3compat.string_types):
959 if isinstance(self.klass, py3compat.string_types):
927 self.klass = self._resolve_string(self.klass)
960 self.klass = self._resolve_string(self.klass)
928 if isinstance(self.default_value, py3compat.string_types):
961 if isinstance(self.default_value, py3compat.string_types):
929 self.default_value = self._resolve_string(self.default_value)
962 self.default_value = self._resolve_string(self.default_value)
930
963
931 def get_default_value(self):
964 def get_default_value(self):
932 return self.default_value
965 return self.default_value
933
966
934
967
935 class DefaultValueGenerator(object):
968 class DefaultValueGenerator(object):
936 """A class for generating new default value instances."""
969 """A class for generating new default value instances."""
937
970
938 def __init__(self, *args, **kw):
971 def __init__(self, *args, **kw):
939 self.args = args
972 self.args = args
940 self.kw = kw
973 self.kw = kw
941
974
942 def generate(self, klass):
975 def generate(self, klass):
943 return klass(*self.args, **self.kw)
976 return klass(*self.args, **self.kw)
944
977
945
978
946 class Instance(ClassBasedTraitType):
979 class Instance(ClassBasedTraitType):
947 """A trait whose value must be an instance of a specified class.
980 """A trait whose value must be an instance of a specified class.
948
981
949 The value can also be an instance of a subclass of the specified class.
982 The value can also be an instance of a subclass of the specified class.
950
983
951 Subclasses can declare default classes by overriding the klass attribute
984 Subclasses can declare default classes by overriding the klass attribute
952 """
985 """
953
986
954 klass = None
987 klass = None
955
988
956 def __init__(self, klass=None, args=None, kw=None, allow_none=False,
989 def __init__(self, klass=None, args=None, kw=None, allow_none=False,
957 **metadata ):
990 **metadata ):
958 """Construct an Instance trait.
991 """Construct an Instance trait.
959
992
960 This trait allows values that are instances of a particular
993 This trait allows values that are instances of a particular
961 class or its subclasses. Our implementation is quite different
994 class or its subclasses. Our implementation is quite different
962 from that of enthough.traits as we don't allow instances to be used
995 from that of enthough.traits as we don't allow instances to be used
963 for klass and we handle the ``args`` and ``kw`` arguments differently.
996 for klass and we handle the ``args`` and ``kw`` arguments differently.
964
997
965 Parameters
998 Parameters
966 ----------
999 ----------
967 klass : class, str
1000 klass : class, str
968 The class that forms the basis for the trait. Class names
1001 The class that forms the basis for the trait. Class names
969 can also be specified as strings, like 'foo.bar.Bar'.
1002 can also be specified as strings, like 'foo.bar.Bar'.
970 args : tuple
1003 args : tuple
971 Positional arguments for generating the default value.
1004 Positional arguments for generating the default value.
972 kw : dict
1005 kw : dict
973 Keyword arguments for generating the default value.
1006 Keyword arguments for generating the default value.
974 allow_none : bool [default True]
1007 allow_none : bool [default True]
975 Indicates whether None is allowed as a value.
1008 Indicates whether None is allowed as a value.
976
1009
977 Notes
1010 Notes
978 -----
1011 -----
979 If both ``args`` and ``kw`` are None, then the default value is None.
1012 If both ``args`` and ``kw`` are None, then the default value is None.
980 If ``args`` is a tuple and ``kw`` is a dict, then the default is
1013 If ``args`` is a tuple and ``kw`` is a dict, then the default is
981 created as ``klass(*args, **kw)``. If exactly one of ``args`` or ``kw`` is
1014 created as ``klass(*args, **kw)``. If exactly one of ``args`` or ``kw`` is
982 None, the None is replaced by ``()`` or ``{}``, respectively.
1015 None, the None is replaced by ``()`` or ``{}``, respectively.
983 """
1016 """
984 if klass is None:
1017 if klass is None:
985 klass = self.klass
1018 klass = self.klass
986
1019
987 if (klass is not None) and (inspect.isclass(klass) or isinstance(klass, py3compat.string_types)):
1020 if (klass is not None) and (inspect.isclass(klass) or isinstance(klass, py3compat.string_types)):
988 self.klass = klass
1021 self.klass = klass
989 else:
1022 else:
990 raise TraitError('The klass attribute must be a class'
1023 raise TraitError('The klass attribute must be a class'
991 ' not: %r' % klass)
1024 ' not: %r' % klass)
992
1025
993 # self.klass is a class, so handle default_value
1026 # self.klass is a class, so handle default_value
994 if args is None and kw is None:
1027 if args is None and kw is None:
995 default_value = None
1028 default_value = None
996 else:
1029 else:
997 if args is None:
1030 if args is None:
998 # kw is not None
1031 # kw is not None
999 args = ()
1032 args = ()
1000 elif kw is None:
1033 elif kw is None:
1001 # args is not None
1034 # args is not None
1002 kw = {}
1035 kw = {}
1003
1036
1004 if not isinstance(kw, dict):
1037 if not isinstance(kw, dict):
1005 raise TraitError("The 'kw' argument must be a dict or None.")
1038 raise TraitError("The 'kw' argument must be a dict or None.")
1006 if not isinstance(args, tuple):
1039 if not isinstance(args, tuple):
1007 raise TraitError("The 'args' argument must be a tuple or None.")
1040 raise TraitError("The 'args' argument must be a tuple or None.")
1008
1041
1009 default_value = DefaultValueGenerator(*args, **kw)
1042 default_value = DefaultValueGenerator(*args, **kw)
1010
1043
1011 super(Instance, self).__init__(default_value, allow_none=allow_none, **metadata)
1044 super(Instance, self).__init__(default_value, allow_none=allow_none, **metadata)
1012
1045
1013 def validate(self, obj, value):
1046 def validate(self, obj, value):
1014 if isinstance(value, self.klass):
1047 if isinstance(value, self.klass):
1015 return value
1048 return value
1016 else:
1049 else:
1017 self.error(obj, value)
1050 self.error(obj, value)
1018
1051
1019 def info(self):
1052 def info(self):
1020 if isinstance(self.klass, py3compat.string_types):
1053 if isinstance(self.klass, py3compat.string_types):
1021 klass = self.klass
1054 klass = self.klass
1022 else:
1055 else:
1023 klass = self.klass.__name__
1056 klass = self.klass.__name__
1024 result = class_of(klass)
1057 result = class_of(klass)
1025 if self.allow_none:
1058 if self.allow_none:
1026 return result + ' or None'
1059 return result + ' or None'
1027
1060
1028 return result
1061 return result
1029
1062
1030 def instance_init(self):
1063 def instance_init(self):
1031 self._resolve_classes()
1064 self._resolve_classes()
1032 super(Instance, self).instance_init()
1065 super(Instance, self).instance_init()
1033
1066
1034 def _resolve_classes(self):
1067 def _resolve_classes(self):
1035 if isinstance(self.klass, py3compat.string_types):
1068 if isinstance(self.klass, py3compat.string_types):
1036 self.klass = self._resolve_string(self.klass)
1069 self.klass = self._resolve_string(self.klass)
1037
1070
1038 def get_default_value(self):
1071 def get_default_value(self):
1039 """Instantiate a default value instance.
1072 """Instantiate a default value instance.
1040
1073
1041 This is called when the containing HasTraits classes'
1074 This is called when the containing HasTraits classes'
1042 :meth:`__new__` method is called to ensure that a unique instance
1075 :meth:`__new__` method is called to ensure that a unique instance
1043 is created for each HasTraits instance.
1076 is created for each HasTraits instance.
1044 """
1077 """
1045 dv = self.default_value
1078 dv = self.default_value
1046 if isinstance(dv, DefaultValueGenerator):
1079 if isinstance(dv, DefaultValueGenerator):
1047 return dv.generate(self.klass)
1080 return dv.generate(self.klass)
1048 else:
1081 else:
1049 return dv
1082 return dv
1050
1083
1051
1084
1052 class ForwardDeclaredMixin(object):
1085 class ForwardDeclaredMixin(object):
1053 """
1086 """
1054 Mixin for forward-declared versions of Instance and Type.
1087 Mixin for forward-declared versions of Instance and Type.
1055 """
1088 """
1056 def _resolve_string(self, string):
1089 def _resolve_string(self, string):
1057 """
1090 """
1058 Find the specified class name by looking for it in the module in which
1091 Find the specified class name by looking for it in the module in which
1059 our this_class attribute was defined.
1092 our this_class attribute was defined.
1060 """
1093 """
1061 modname = self.this_class.__module__
1094 modname = self.this_class.__module__
1062 return import_item('.'.join([modname, string]))
1095 return import_item('.'.join([modname, string]))
1063
1096
1064
1097
1065 class ForwardDeclaredType(ForwardDeclaredMixin, Type):
1098 class ForwardDeclaredType(ForwardDeclaredMixin, Type):
1066 """
1099 """
1067 Forward-declared version of Type.
1100 Forward-declared version of Type.
1068 """
1101 """
1069 pass
1102 pass
1070
1103
1071
1104
1072 class ForwardDeclaredInstance(ForwardDeclaredMixin, Instance):
1105 class ForwardDeclaredInstance(ForwardDeclaredMixin, Instance):
1073 """
1106 """
1074 Forward-declared version of Instance.
1107 Forward-declared version of Instance.
1075 """
1108 """
1076 pass
1109 pass
1077
1110
1078
1111
1079 class This(ClassBasedTraitType):
1112 class This(ClassBasedTraitType):
1080 """A trait for instances of the class containing this trait.
1113 """A trait for instances of the class containing this trait.
1081
1114
1082 Because how how and when class bodies are executed, the ``This``
1115 Because how how and when class bodies are executed, the ``This``
1083 trait can only have a default value of None. This, and because we
1116 trait can only have a default value of None. This, and because we
1084 always validate default values, ``allow_none`` is *always* true.
1117 always validate default values, ``allow_none`` is *always* true.
1085 """
1118 """
1086
1119
1087 info_text = 'an instance of the same type as the receiver or None'
1120 info_text = 'an instance of the same type as the receiver or None'
1088
1121
1089 def __init__(self, **metadata):
1122 def __init__(self, **metadata):
1090 super(This, self).__init__(None, **metadata)
1123 super(This, self).__init__(None, **metadata)
1091
1124
1092 def validate(self, obj, value):
1125 def validate(self, obj, value):
1093 # What if value is a superclass of obj.__class__? This is
1126 # What if value is a superclass of obj.__class__? This is
1094 # complicated if it was the superclass that defined the This
1127 # complicated if it was the superclass that defined the This
1095 # trait.
1128 # trait.
1096 if isinstance(value, self.this_class) or (value is None):
1129 if isinstance(value, self.this_class) or (value is None):
1097 return value
1130 return value
1098 else:
1131 else:
1099 self.error(obj, value)
1132 self.error(obj, value)
1100
1133
1101
1134
1102 class Union(TraitType):
1135 class Union(TraitType):
1103 """A trait type representing a Union type."""
1136 """A trait type representing a Union type."""
1104
1137
1105 def __init__(self, trait_types, **metadata):
1138 def __init__(self, trait_types, **metadata):
1106 """Construct a Union trait.
1139 """Construct a Union trait.
1107
1140
1108 This trait allows values that are allowed by at least one of the
1141 This trait allows values that are allowed by at least one of the
1109 specified trait types. A Union traitlet cannot have metadata on
1142 specified trait types. A Union traitlet cannot have metadata on
1110 its own, besides the metadata of the listed types.
1143 its own, besides the metadata of the listed types.
1111
1144
1112 Parameters
1145 Parameters
1113 ----------
1146 ----------
1114 trait_types: sequence
1147 trait_types: sequence
1115 The list of trait types of length at least 1.
1148 The list of trait types of length at least 1.
1116
1149
1117 Notes
1150 Notes
1118 -----
1151 -----
1119 Union([Float(), Bool(), Int()]) attempts to validate the provided values
1152 Union([Float(), Bool(), Int()]) attempts to validate the provided values
1120 with the validation function of Float, then Bool, and finally Int.
1153 with the validation function of Float, then Bool, and finally Int.
1121 """
1154 """
1122 self.trait_types = trait_types
1155 self.trait_types = trait_types
1123 self.info_text = " or ".join([tt.info_text for tt in self.trait_types])
1156 self.info_text = " or ".join([tt.info_text for tt in self.trait_types])
1124 self.default_value = self.trait_types[0].get_default_value()
1157 self.default_value = self.trait_types[0].get_default_value()
1125 super(Union, self).__init__(**metadata)
1158 super(Union, self).__init__(**metadata)
1126
1159
1127 def instance_init(self):
1160 def instance_init(self):
1128 for trait_type in self.trait_types:
1161 for trait_type in self.trait_types:
1129 trait_type.name = self.name
1162 trait_type.name = self.name
1130 trait_type.this_class = self.this_class
1163 trait_type.this_class = self.this_class
1131 trait_type.instance_init()
1164 trait_type.instance_init()
1132 super(Union, self).instance_init()
1165 super(Union, self).instance_init()
1133
1166
1134 def validate(self, obj, value):
1167 def validate(self, obj, value):
1135 for trait_type in self.trait_types:
1168 for trait_type in self.trait_types:
1136 try:
1169 try:
1137 v = trait_type._validate(obj, value)
1170 v = trait_type._validate(obj, value)
1138 self._metadata = trait_type._metadata
1171 self._metadata = trait_type._metadata
1139 return v
1172 return v
1140 except TraitError:
1173 except TraitError:
1141 continue
1174 continue
1142 self.error(obj, value)
1175 self.error(obj, value)
1143
1176
1144 def __or__(self, other):
1177 def __or__(self, other):
1145 if isinstance(other, Union):
1178 if isinstance(other, Union):
1146 return Union(self.trait_types + other.trait_types)
1179 return Union(self.trait_types + other.trait_types)
1147 else:
1180 else:
1148 return Union(self.trait_types + [other])
1181 return Union(self.trait_types + [other])
1149
1182
1150 #-----------------------------------------------------------------------------
1183 #-----------------------------------------------------------------------------
1151 # Basic TraitTypes implementations/subclasses
1184 # Basic TraitTypes implementations/subclasses
1152 #-----------------------------------------------------------------------------
1185 #-----------------------------------------------------------------------------
1153
1186
1154
1187
1155 class Any(TraitType):
1188 class Any(TraitType):
1156 default_value = None
1189 default_value = None
1157 info_text = 'any value'
1190 info_text = 'any value'
1158
1191
1159
1192
1160 class Int(TraitType):
1193 class Int(TraitType):
1161 """An int trait."""
1194 """An int trait."""
1162
1195
1163 default_value = 0
1196 default_value = 0
1164 info_text = 'an int'
1197 info_text = 'an int'
1165
1198
1166 def validate(self, obj, value):
1199 def validate(self, obj, value):
1167 if isinstance(value, int):
1200 if isinstance(value, int):
1168 return value
1201 return value
1169 self.error(obj, value)
1202 self.error(obj, value)
1170
1203
1171 class CInt(Int):
1204 class CInt(Int):
1172 """A casting version of the int trait."""
1205 """A casting version of the int trait."""
1173
1206
1174 def validate(self, obj, value):
1207 def validate(self, obj, value):
1175 try:
1208 try:
1176 return int(value)
1209 return int(value)
1177 except:
1210 except:
1178 self.error(obj, value)
1211 self.error(obj, value)
1179
1212
1180 if py3compat.PY3:
1213 if py3compat.PY3:
1181 Long, CLong = Int, CInt
1214 Long, CLong = Int, CInt
1182 Integer = Int
1215 Integer = Int
1183 else:
1216 else:
1184 class Long(TraitType):
1217 class Long(TraitType):
1185 """A long integer trait."""
1218 """A long integer trait."""
1186
1219
1187 default_value = 0
1220 default_value = 0
1188 info_text = 'a long'
1221 info_text = 'a long'
1189
1222
1190 def validate(self, obj, value):
1223 def validate(self, obj, value):
1191 if isinstance(value, long):
1224 if isinstance(value, long):
1192 return value
1225 return value
1193 if isinstance(value, int):
1226 if isinstance(value, int):
1194 return long(value)
1227 return long(value)
1195 self.error(obj, value)
1228 self.error(obj, value)
1196
1229
1197
1230
1198 class CLong(Long):
1231 class CLong(Long):
1199 """A casting version of the long integer trait."""
1232 """A casting version of the long integer trait."""
1200
1233
1201 def validate(self, obj, value):
1234 def validate(self, obj, value):
1202 try:
1235 try:
1203 return long(value)
1236 return long(value)
1204 except:
1237 except:
1205 self.error(obj, value)
1238 self.error(obj, value)
1206
1239
1207 class Integer(TraitType):
1240 class Integer(TraitType):
1208 """An integer trait.
1241 """An integer trait.
1209
1242
1210 Longs that are unnecessary (<= sys.maxint) are cast to ints."""
1243 Longs that are unnecessary (<= sys.maxint) are cast to ints."""
1211
1244
1212 default_value = 0
1245 default_value = 0
1213 info_text = 'an integer'
1246 info_text = 'an integer'
1214
1247
1215 def validate(self, obj, value):
1248 def validate(self, obj, value):
1216 if isinstance(value, int):
1249 if isinstance(value, int):
1217 return value
1250 return value
1218 if isinstance(value, long):
1251 if isinstance(value, long):
1219 # downcast longs that fit in int:
1252 # downcast longs that fit in int:
1220 # note that int(n > sys.maxint) returns a long, so
1253 # note that int(n > sys.maxint) returns a long, so
1221 # we don't need a condition on this cast
1254 # we don't need a condition on this cast
1222 return int(value)
1255 return int(value)
1223 if sys.platform == "cli":
1256 if sys.platform == "cli":
1224 from System import Int64
1257 from System import Int64
1225 if isinstance(value, Int64):
1258 if isinstance(value, Int64):
1226 return int(value)
1259 return int(value)
1227 self.error(obj, value)
1260 self.error(obj, value)
1228
1261
1229
1262
1230 class Float(TraitType):
1263 class Float(TraitType):
1231 """A float trait."""
1264 """A float trait."""
1232
1265
1233 default_value = 0.0
1266 default_value = 0.0
1234 info_text = 'a float'
1267 info_text = 'a float'
1235
1268
1236 def validate(self, obj, value):
1269 def validate(self, obj, value):
1237 if isinstance(value, float):
1270 if isinstance(value, float):
1238 return value
1271 return value
1239 if isinstance(value, int):
1272 if isinstance(value, int):
1240 return float(value)
1273 return float(value)
1241 self.error(obj, value)
1274 self.error(obj, value)
1242
1275
1243
1276
1244 class CFloat(Float):
1277 class CFloat(Float):
1245 """A casting version of the float trait."""
1278 """A casting version of the float trait."""
1246
1279
1247 def validate(self, obj, value):
1280 def validate(self, obj, value):
1248 try:
1281 try:
1249 return float(value)
1282 return float(value)
1250 except:
1283 except:
1251 self.error(obj, value)
1284 self.error(obj, value)
1252
1285
1253 class Complex(TraitType):
1286 class Complex(TraitType):
1254 """A trait for complex numbers."""
1287 """A trait for complex numbers."""
1255
1288
1256 default_value = 0.0 + 0.0j
1289 default_value = 0.0 + 0.0j
1257 info_text = 'a complex number'
1290 info_text = 'a complex number'
1258
1291
1259 def validate(self, obj, value):
1292 def validate(self, obj, value):
1260 if isinstance(value, complex):
1293 if isinstance(value, complex):
1261 return value
1294 return value
1262 if isinstance(value, (float, int)):
1295 if isinstance(value, (float, int)):
1263 return complex(value)
1296 return complex(value)
1264 self.error(obj, value)
1297 self.error(obj, value)
1265
1298
1266
1299
1267 class CComplex(Complex):
1300 class CComplex(Complex):
1268 """A casting version of the complex number trait."""
1301 """A casting version of the complex number trait."""
1269
1302
1270 def validate (self, obj, value):
1303 def validate (self, obj, value):
1271 try:
1304 try:
1272 return complex(value)
1305 return complex(value)
1273 except:
1306 except:
1274 self.error(obj, value)
1307 self.error(obj, value)
1275
1308
1276 # We should always be explicit about whether we're using bytes or unicode, both
1309 # We should always be explicit about whether we're using bytes or unicode, both
1277 # for Python 3 conversion and for reliable unicode behaviour on Python 2. So
1310 # for Python 3 conversion and for reliable unicode behaviour on Python 2. So
1278 # we don't have a Str type.
1311 # we don't have a Str type.
1279 class Bytes(TraitType):
1312 class Bytes(TraitType):
1280 """A trait for byte strings."""
1313 """A trait for byte strings."""
1281
1314
1282 default_value = b''
1315 default_value = b''
1283 info_text = 'a bytes object'
1316 info_text = 'a bytes object'
1284
1317
1285 def validate(self, obj, value):
1318 def validate(self, obj, value):
1286 if isinstance(value, bytes):
1319 if isinstance(value, bytes):
1287 return value
1320 return value
1288 self.error(obj, value)
1321 self.error(obj, value)
1289
1322
1290
1323
1291 class CBytes(Bytes):
1324 class CBytes(Bytes):
1292 """A casting version of the byte string trait."""
1325 """A casting version of the byte string trait."""
1293
1326
1294 def validate(self, obj, value):
1327 def validate(self, obj, value):
1295 try:
1328 try:
1296 return bytes(value)
1329 return bytes(value)
1297 except:
1330 except:
1298 self.error(obj, value)
1331 self.error(obj, value)
1299
1332
1300
1333
1301 class Unicode(TraitType):
1334 class Unicode(TraitType):
1302 """A trait for unicode strings."""
1335 """A trait for unicode strings."""
1303
1336
1304 default_value = u''
1337 default_value = u''
1305 info_text = 'a unicode string'
1338 info_text = 'a unicode string'
1306
1339
1307 def validate(self, obj, value):
1340 def validate(self, obj, value):
1308 if isinstance(value, py3compat.unicode_type):
1341 if isinstance(value, py3compat.unicode_type):
1309 return value
1342 return value
1310 if isinstance(value, bytes):
1343 if isinstance(value, bytes):
1311 try:
1344 try:
1312 return value.decode('ascii', 'strict')
1345 return value.decode('ascii', 'strict')
1313 except UnicodeDecodeError:
1346 except UnicodeDecodeError:
1314 msg = "Could not decode {!r} for unicode trait '{}' of {} instance."
1347 msg = "Could not decode {!r} for unicode trait '{}' of {} instance."
1315 raise TraitError(msg.format(value, self.name, class_of(obj)))
1348 raise TraitError(msg.format(value, self.name, class_of(obj)))
1316 self.error(obj, value)
1349 self.error(obj, value)
1317
1350
1318
1351
1319 class CUnicode(Unicode):
1352 class CUnicode(Unicode):
1320 """A casting version of the unicode trait."""
1353 """A casting version of the unicode trait."""
1321
1354
1322 def validate(self, obj, value):
1355 def validate(self, obj, value):
1323 try:
1356 try:
1324 return py3compat.unicode_type(value)
1357 return py3compat.unicode_type(value)
1325 except:
1358 except:
1326 self.error(obj, value)
1359 self.error(obj, value)
1327
1360
1328
1361
1329 class ObjectName(TraitType):
1362 class ObjectName(TraitType):
1330 """A string holding a valid object name in this version of Python.
1363 """A string holding a valid object name in this version of Python.
1331
1364
1332 This does not check that the name exists in any scope."""
1365 This does not check that the name exists in any scope."""
1333 info_text = "a valid object identifier in Python"
1366 info_text = "a valid object identifier in Python"
1334
1367
1335 if py3compat.PY3:
1368 if py3compat.PY3:
1336 # Python 3:
1369 # Python 3:
1337 coerce_str = staticmethod(lambda _,s: s)
1370 coerce_str = staticmethod(lambda _,s: s)
1338
1371
1339 else:
1372 else:
1340 # Python 2:
1373 # Python 2:
1341 def coerce_str(self, obj, value):
1374 def coerce_str(self, obj, value):
1342 "In Python 2, coerce ascii-only unicode to str"
1375 "In Python 2, coerce ascii-only unicode to str"
1343 if isinstance(value, unicode):
1376 if isinstance(value, unicode):
1344 try:
1377 try:
1345 return str(value)
1378 return str(value)
1346 except UnicodeEncodeError:
1379 except UnicodeEncodeError:
1347 self.error(obj, value)
1380 self.error(obj, value)
1348 return value
1381 return value
1349
1382
1350 def validate(self, obj, value):
1383 def validate(self, obj, value):
1351 value = self.coerce_str(obj, value)
1384 value = self.coerce_str(obj, value)
1352
1385
1353 if isinstance(value, string_types) and py3compat.isidentifier(value):
1386 if isinstance(value, string_types) and py3compat.isidentifier(value):
1354 return value
1387 return value
1355 self.error(obj, value)
1388 self.error(obj, value)
1356
1389
1357 class DottedObjectName(ObjectName):
1390 class DottedObjectName(ObjectName):
1358 """A string holding a valid dotted object name in Python, such as A.b3._c"""
1391 """A string holding a valid dotted object name in Python, such as A.b3._c"""
1359 def validate(self, obj, value):
1392 def validate(self, obj, value):
1360 value = self.coerce_str(obj, value)
1393 value = self.coerce_str(obj, value)
1361
1394
1362 if isinstance(value, string_types) and py3compat.isidentifier(value, dotted=True):
1395 if isinstance(value, string_types) and py3compat.isidentifier(value, dotted=True):
1363 return value
1396 return value
1364 self.error(obj, value)
1397 self.error(obj, value)
1365
1398
1366
1399
1367 class Bool(TraitType):
1400 class Bool(TraitType):
1368 """A boolean (True, False) trait."""
1401 """A boolean (True, False) trait."""
1369
1402
1370 default_value = False
1403 default_value = False
1371 info_text = 'a boolean'
1404 info_text = 'a boolean'
1372
1405
1373 def validate(self, obj, value):
1406 def validate(self, obj, value):
1374 if isinstance(value, bool):
1407 if isinstance(value, bool):
1375 return value
1408 return value
1376 self.error(obj, value)
1409 self.error(obj, value)
1377
1410
1378
1411
1379 class CBool(Bool):
1412 class CBool(Bool):
1380 """A casting version of the boolean trait."""
1413 """A casting version of the boolean trait."""
1381
1414
1382 def validate(self, obj, value):
1415 def validate(self, obj, value):
1383 try:
1416 try:
1384 return bool(value)
1417 return bool(value)
1385 except:
1418 except:
1386 self.error(obj, value)
1419 self.error(obj, value)
1387
1420
1388
1421
1389 class Enum(TraitType):
1422 class Enum(TraitType):
1390 """An enum that whose value must be in a given sequence."""
1423 """An enum that whose value must be in a given sequence."""
1391
1424
1392 def __init__(self, values, default_value=None, **metadata):
1425 def __init__(self, values, default_value=None, **metadata):
1393 self.values = values
1426 self.values = values
1394 super(Enum, self).__init__(default_value, **metadata)
1427 super(Enum, self).__init__(default_value, **metadata)
1395
1428
1396 def validate(self, obj, value):
1429 def validate(self, obj, value):
1397 if value in self.values:
1430 if value in self.values:
1398 return value
1431 return value
1399 self.error(obj, value)
1432 self.error(obj, value)
1400
1433
1401 def info(self):
1434 def info(self):
1402 """ Returns a description of the trait."""
1435 """ Returns a description of the trait."""
1403 result = 'any of ' + repr(self.values)
1436 result = 'any of ' + repr(self.values)
1404 if self.allow_none:
1437 if self.allow_none:
1405 return result + ' or None'
1438 return result + ' or None'
1406 return result
1439 return result
1407
1440
1408 class CaselessStrEnum(Enum):
1441 class CaselessStrEnum(Enum):
1409 """An enum of strings that are caseless in validate."""
1442 """An enum of strings that are caseless in validate."""
1410
1443
1411 def validate(self, obj, value):
1444 def validate(self, obj, value):
1412 if not isinstance(value, py3compat.string_types):
1445 if not isinstance(value, py3compat.string_types):
1413 self.error(obj, value)
1446 self.error(obj, value)
1414
1447
1415 for v in self.values:
1448 for v in self.values:
1416 if v.lower() == value.lower():
1449 if v.lower() == value.lower():
1417 return v
1450 return v
1418 self.error(obj, value)
1451 self.error(obj, value)
1419
1452
1420 class Container(Instance):
1453 class Container(Instance):
1421 """An instance of a container (list, set, etc.)
1454 """An instance of a container (list, set, etc.)
1422
1455
1423 To be subclassed by overriding klass.
1456 To be subclassed by overriding klass.
1424 """
1457 """
1425 klass = None
1458 klass = None
1426 _cast_types = ()
1459 _cast_types = ()
1427 _valid_defaults = SequenceTypes
1460 _valid_defaults = SequenceTypes
1428 _trait = None
1461 _trait = None
1429
1462
1430 def __init__(self, trait=None, default_value=None, allow_none=False,
1463 def __init__(self, trait=None, default_value=None, allow_none=False,
1431 **metadata):
1464 **metadata):
1432 """Create a container trait type from a list, set, or tuple.
1465 """Create a container trait type from a list, set, or tuple.
1433
1466
1434 The default value is created by doing ``List(default_value)``,
1467 The default value is created by doing ``List(default_value)``,
1435 which creates a copy of the ``default_value``.
1468 which creates a copy of the ``default_value``.
1436
1469
1437 ``trait`` can be specified, which restricts the type of elements
1470 ``trait`` can be specified, which restricts the type of elements
1438 in the container to that TraitType.
1471 in the container to that TraitType.
1439
1472
1440 If only one arg is given and it is not a Trait, it is taken as
1473 If only one arg is given and it is not a Trait, it is taken as
1441 ``default_value``:
1474 ``default_value``:
1442
1475
1443 ``c = List([1,2,3])``
1476 ``c = List([1,2,3])``
1444
1477
1445 Parameters
1478 Parameters
1446 ----------
1479 ----------
1447
1480
1448 trait : TraitType [ optional ]
1481 trait : TraitType [ optional ]
1449 the type for restricting the contents of the Container. If unspecified,
1482 the type for restricting the contents of the Container. If unspecified,
1450 types are not checked.
1483 types are not checked.
1451
1484
1452 default_value : SequenceType [ optional ]
1485 default_value : SequenceType [ optional ]
1453 The default value for the Trait. Must be list/tuple/set, and
1486 The default value for the Trait. Must be list/tuple/set, and
1454 will be cast to the container type.
1487 will be cast to the container type.
1455
1488
1456 allow_none : bool [ default False ]
1489 allow_none : bool [ default False ]
1457 Whether to allow the value to be None
1490 Whether to allow the value to be None
1458
1491
1459 **metadata : any
1492 **metadata : any
1460 further keys for extensions to the Trait (e.g. config)
1493 further keys for extensions to the Trait (e.g. config)
1461
1494
1462 """
1495 """
1463 # allow List([values]):
1496 # allow List([values]):
1464 if default_value is None and not is_trait(trait):
1497 if default_value is None and not is_trait(trait):
1465 default_value = trait
1498 default_value = trait
1466 trait = None
1499 trait = None
1467
1500
1468 if default_value is None:
1501 if default_value is None:
1469 args = ()
1502 args = ()
1470 elif isinstance(default_value, self._valid_defaults):
1503 elif isinstance(default_value, self._valid_defaults):
1471 args = (default_value,)
1504 args = (default_value,)
1472 else:
1505 else:
1473 raise TypeError('default value of %s was %s' %(self.__class__.__name__, default_value))
1506 raise TypeError('default value of %s was %s' %(self.__class__.__name__, default_value))
1474
1507
1475 if is_trait(trait):
1508 if is_trait(trait):
1476 self._trait = trait() if isinstance(trait, type) else trait
1509 self._trait = trait() if isinstance(trait, type) else trait
1477 self._trait.name = 'element'
1510 self._trait.name = 'element'
1478 elif trait is not None:
1511 elif trait is not None:
1479 raise TypeError("`trait` must be a Trait or None, got %s"%repr_type(trait))
1512 raise TypeError("`trait` must be a Trait or None, got %s"%repr_type(trait))
1480
1513
1481 super(Container,self).__init__(klass=self.klass, args=args,
1514 super(Container,self).__init__(klass=self.klass, args=args,
1482 allow_none=allow_none, **metadata)
1515 allow_none=allow_none, **metadata)
1483
1516
1484 def element_error(self, obj, element, validator):
1517 def element_error(self, obj, element, validator):
1485 e = "Element of the '%s' trait of %s instance must be %s, but a value of %s was specified." \
1518 e = "Element of the '%s' trait of %s instance must be %s, but a value of %s was specified." \
1486 % (self.name, class_of(obj), validator.info(), repr_type(element))
1519 % (self.name, class_of(obj), validator.info(), repr_type(element))
1487 raise TraitError(e)
1520 raise TraitError(e)
1488
1521
1489 def validate(self, obj, value):
1522 def validate(self, obj, value):
1490 if isinstance(value, self._cast_types):
1523 if isinstance(value, self._cast_types):
1491 value = self.klass(value)
1524 value = self.klass(value)
1492 value = super(Container, self).validate(obj, value)
1525 value = super(Container, self).validate(obj, value)
1493 if value is None:
1526 if value is None:
1494 return value
1527 return value
1495
1528
1496 value = self.validate_elements(obj, value)
1529 value = self.validate_elements(obj, value)
1497
1530
1498 return value
1531 return value
1499
1532
1500 def validate_elements(self, obj, value):
1533 def validate_elements(self, obj, value):
1501 validated = []
1534 validated = []
1502 if self._trait is None or isinstance(self._trait, Any):
1535 if self._trait is None or isinstance(self._trait, Any):
1503 return value
1536 return value
1504 for v in value:
1537 for v in value:
1505 try:
1538 try:
1506 v = self._trait._validate(obj, v)
1539 v = self._trait._validate(obj, v)
1507 except TraitError:
1540 except TraitError:
1508 self.element_error(obj, v, self._trait)
1541 self.element_error(obj, v, self._trait)
1509 else:
1542 else:
1510 validated.append(v)
1543 validated.append(v)
1511 return self.klass(validated)
1544 return self.klass(validated)
1512
1545
1513 def instance_init(self):
1546 def instance_init(self):
1514 if isinstance(self._trait, TraitType):
1547 if isinstance(self._trait, TraitType):
1515 self._trait.this_class = self.this_class
1548 self._trait.this_class = self.this_class
1516 self._trait.instance_init()
1549 self._trait.instance_init()
1517 super(Container, self).instance_init()
1550 super(Container, self).instance_init()
1518
1551
1519
1552
1520 class List(Container):
1553 class List(Container):
1521 """An instance of a Python list."""
1554 """An instance of a Python list."""
1522 klass = list
1555 klass = list
1523 _cast_types = (tuple,)
1556 _cast_types = (tuple,)
1524
1557
1525 def __init__(self, trait=None, default_value=None, minlen=0, maxlen=sys.maxsize, **metadata):
1558 def __init__(self, trait=None, default_value=None, minlen=0, maxlen=sys.maxsize, **metadata):
1526 """Create a List trait type from a list, set, or tuple.
1559 """Create a List trait type from a list, set, or tuple.
1527
1560
1528 The default value is created by doing ``List(default_value)``,
1561 The default value is created by doing ``List(default_value)``,
1529 which creates a copy of the ``default_value``.
1562 which creates a copy of the ``default_value``.
1530
1563
1531 ``trait`` can be specified, which restricts the type of elements
1564 ``trait`` can be specified, which restricts the type of elements
1532 in the container to that TraitType.
1565 in the container to that TraitType.
1533
1566
1534 If only one arg is given and it is not a Trait, it is taken as
1567 If only one arg is given and it is not a Trait, it is taken as
1535 ``default_value``:
1568 ``default_value``:
1536
1569
1537 ``c = List([1,2,3])``
1570 ``c = List([1,2,3])``
1538
1571
1539 Parameters
1572 Parameters
1540 ----------
1573 ----------
1541
1574
1542 trait : TraitType [ optional ]
1575 trait : TraitType [ optional ]
1543 the type for restricting the contents of the Container. If unspecified,
1576 the type for restricting the contents of the Container. If unspecified,
1544 types are not checked.
1577 types are not checked.
1545
1578
1546 default_value : SequenceType [ optional ]
1579 default_value : SequenceType [ optional ]
1547 The default value for the Trait. Must be list/tuple/set, and
1580 The default value for the Trait. Must be list/tuple/set, and
1548 will be cast to the container type.
1581 will be cast to the container type.
1549
1582
1550 minlen : Int [ default 0 ]
1583 minlen : Int [ default 0 ]
1551 The minimum length of the input list
1584 The minimum length of the input list
1552
1585
1553 maxlen : Int [ default sys.maxsize ]
1586 maxlen : Int [ default sys.maxsize ]
1554 The maximum length of the input list
1587 The maximum length of the input list
1555
1588
1556 allow_none : bool [ default False ]
1589 allow_none : bool [ default False ]
1557 Whether to allow the value to be None
1590 Whether to allow the value to be None
1558
1591
1559 **metadata : any
1592 **metadata : any
1560 further keys for extensions to the Trait (e.g. config)
1593 further keys for extensions to the Trait (e.g. config)
1561
1594
1562 """
1595 """
1563 self._minlen = minlen
1596 self._minlen = minlen
1564 self._maxlen = maxlen
1597 self._maxlen = maxlen
1565 super(List, self).__init__(trait=trait, default_value=default_value,
1598 super(List, self).__init__(trait=trait, default_value=default_value,
1566 **metadata)
1599 **metadata)
1567
1600
1568 def length_error(self, obj, value):
1601 def length_error(self, obj, value):
1569 e = "The '%s' trait of %s instance must be of length %i <= L <= %i, but a value of %s was specified." \
1602 e = "The '%s' trait of %s instance must be of length %i <= L <= %i, but a value of %s was specified." \
1570 % (self.name, class_of(obj), self._minlen, self._maxlen, value)
1603 % (self.name, class_of(obj), self._minlen, self._maxlen, value)
1571 raise TraitError(e)
1604 raise TraitError(e)
1572
1605
1573 def validate_elements(self, obj, value):
1606 def validate_elements(self, obj, value):
1574 length = len(value)
1607 length = len(value)
1575 if length < self._minlen or length > self._maxlen:
1608 if length < self._minlen or length > self._maxlen:
1576 self.length_error(obj, value)
1609 self.length_error(obj, value)
1577
1610
1578 return super(List, self).validate_elements(obj, value)
1611 return super(List, self).validate_elements(obj, value)
1579
1612
1580 def validate(self, obj, value):
1613 def validate(self, obj, value):
1581 value = super(List, self).validate(obj, value)
1614 value = super(List, self).validate(obj, value)
1582 value = self.validate_elements(obj, value)
1615 value = self.validate_elements(obj, value)
1583 return value
1616 return value
1584
1617
1585
1618
1586 class Set(List):
1619 class Set(List):
1587 """An instance of a Python set."""
1620 """An instance of a Python set."""
1588 klass = set
1621 klass = set
1589 _cast_types = (tuple, list)
1622 _cast_types = (tuple, list)
1590
1623
1591
1624
1592 class Tuple(Container):
1625 class Tuple(Container):
1593 """An instance of a Python tuple."""
1626 """An instance of a Python tuple."""
1594 klass = tuple
1627 klass = tuple
1595 _cast_types = (list,)
1628 _cast_types = (list,)
1596
1629
1597 def __init__(self, *traits, **metadata):
1630 def __init__(self, *traits, **metadata):
1598 """Tuple(*traits, default_value=None, **medatata)
1631 """Tuple(*traits, default_value=None, **medatata)
1599
1632
1600 Create a tuple from a list, set, or tuple.
1633 Create a tuple from a list, set, or tuple.
1601
1634
1602 Create a fixed-type tuple with Traits:
1635 Create a fixed-type tuple with Traits:
1603
1636
1604 ``t = Tuple(Int, Str, CStr)``
1637 ``t = Tuple(Int, Str, CStr)``
1605
1638
1606 would be length 3, with Int,Str,CStr for each element.
1639 would be length 3, with Int,Str,CStr for each element.
1607
1640
1608 If only one arg is given and it is not a Trait, it is taken as
1641 If only one arg is given and it is not a Trait, it is taken as
1609 default_value:
1642 default_value:
1610
1643
1611 ``t = Tuple((1,2,3))``
1644 ``t = Tuple((1,2,3))``
1612
1645
1613 Otherwise, ``default_value`` *must* be specified by keyword.
1646 Otherwise, ``default_value`` *must* be specified by keyword.
1614
1647
1615 Parameters
1648 Parameters
1616 ----------
1649 ----------
1617
1650
1618 *traits : TraitTypes [ optional ]
1651 *traits : TraitTypes [ optional ]
1619 the types for restricting the contents of the Tuple. If unspecified,
1652 the types for restricting the contents of the Tuple. If unspecified,
1620 types are not checked. If specified, then each positional argument
1653 types are not checked. If specified, then each positional argument
1621 corresponds to an element of the tuple. Tuples defined with traits
1654 corresponds to an element of the tuple. Tuples defined with traits
1622 are of fixed length.
1655 are of fixed length.
1623
1656
1624 default_value : SequenceType [ optional ]
1657 default_value : SequenceType [ optional ]
1625 The default value for the Tuple. Must be list/tuple/set, and
1658 The default value for the Tuple. Must be list/tuple/set, and
1626 will be cast to a tuple. If `traits` are specified, the
1659 will be cast to a tuple. If `traits` are specified, the
1627 `default_value` must conform to the shape and type they specify.
1660 `default_value` must conform to the shape and type they specify.
1628
1661
1629 allow_none : bool [ default False ]
1662 allow_none : bool [ default False ]
1630 Whether to allow the value to be None
1663 Whether to allow the value to be None
1631
1664
1632 **metadata : any
1665 **metadata : any
1633 further keys for extensions to the Trait (e.g. config)
1666 further keys for extensions to the Trait (e.g. config)
1634
1667
1635 """
1668 """
1636 default_value = metadata.pop('default_value', None)
1669 default_value = metadata.pop('default_value', None)
1637 allow_none = metadata.pop('allow_none', True)
1670 allow_none = metadata.pop('allow_none', True)
1638
1671
1639 # allow Tuple((values,)):
1672 # allow Tuple((values,)):
1640 if len(traits) == 1 and default_value is None and not is_trait(traits[0]):
1673 if len(traits) == 1 and default_value is None and not is_trait(traits[0]):
1641 default_value = traits[0]
1674 default_value = traits[0]
1642 traits = ()
1675 traits = ()
1643
1676
1644 if default_value is None:
1677 if default_value is None:
1645 args = ()
1678 args = ()
1646 elif isinstance(default_value, self._valid_defaults):
1679 elif isinstance(default_value, self._valid_defaults):
1647 args = (default_value,)
1680 args = (default_value,)
1648 else:
1681 else:
1649 raise TypeError('default value of %s was %s' %(self.__class__.__name__, default_value))
1682 raise TypeError('default value of %s was %s' %(self.__class__.__name__, default_value))
1650
1683
1651 self._traits = []
1684 self._traits = []
1652 for trait in traits:
1685 for trait in traits:
1653 t = trait() if isinstance(trait, type) else trait
1686 t = trait() if isinstance(trait, type) else trait
1654 t.name = 'element'
1687 t.name = 'element'
1655 self._traits.append(t)
1688 self._traits.append(t)
1656
1689
1657 if self._traits and default_value is None:
1690 if self._traits and default_value is None:
1658 # don't allow default to be an empty container if length is specified
1691 # don't allow default to be an empty container if length is specified
1659 args = None
1692 args = None
1660 super(Container,self).__init__(klass=self.klass, args=args, allow_none=allow_none, **metadata)
1693 super(Container,self).__init__(klass=self.klass, args=args, allow_none=allow_none, **metadata)
1661
1694
1662 def validate_elements(self, obj, value):
1695 def validate_elements(self, obj, value):
1663 if not self._traits:
1696 if not self._traits:
1664 # nothing to validate
1697 # nothing to validate
1665 return value
1698 return value
1666 if len(value) != len(self._traits):
1699 if len(value) != len(self._traits):
1667 e = "The '%s' trait of %s instance requires %i elements, but a value of %s was specified." \
1700 e = "The '%s' trait of %s instance requires %i elements, but a value of %s was specified." \
1668 % (self.name, class_of(obj), len(self._traits), repr_type(value))
1701 % (self.name, class_of(obj), len(self._traits), repr_type(value))
1669 raise TraitError(e)
1702 raise TraitError(e)
1670
1703
1671 validated = []
1704 validated = []
1672 for t, v in zip(self._traits, value):
1705 for t, v in zip(self._traits, value):
1673 try:
1706 try:
1674 v = t._validate(obj, v)
1707 v = t._validate(obj, v)
1675 except TraitError:
1708 except TraitError:
1676 self.element_error(obj, v, t)
1709 self.element_error(obj, v, t)
1677 else:
1710 else:
1678 validated.append(v)
1711 validated.append(v)
1679 return tuple(validated)
1712 return tuple(validated)
1680
1713
1681 def instance_init(self):
1714 def instance_init(self):
1682 for trait in self._traits:
1715 for trait in self._traits:
1683 if isinstance(trait, TraitType):
1716 if isinstance(trait, TraitType):
1684 trait.this_class = self.this_class
1717 trait.this_class = self.this_class
1685 trait.instance_init()
1718 trait.instance_init()
1686 super(Container, self).instance_init()
1719 super(Container, self).instance_init()
1687
1720
1688
1721
1689 class Dict(Instance):
1722 class Dict(Instance):
1690 """An instance of a Python dict."""
1723 """An instance of a Python dict."""
1691 _trait = None
1724 _trait = None
1692
1725
1693 def __init__(self, trait=None, default_value=NoDefaultSpecified, allow_none=False, **metadata):
1726 def __init__(self, trait=None, default_value=NoDefaultSpecified, allow_none=False, **metadata):
1694 """Create a dict trait type from a dict.
1727 """Create a dict trait type from a dict.
1695
1728
1696 The default value is created by doing ``dict(default_value)``,
1729 The default value is created by doing ``dict(default_value)``,
1697 which creates a copy of the ``default_value``.
1730 which creates a copy of the ``default_value``.
1698
1731
1699 trait : TraitType [ optional ]
1732 trait : TraitType [ optional ]
1700 the type for restricting the contents of the Container. If unspecified,
1733 the type for restricting the contents of the Container. If unspecified,
1701 types are not checked.
1734 types are not checked.
1702
1735
1703 default_value : SequenceType [ optional ]
1736 default_value : SequenceType [ optional ]
1704 The default value for the Dict. Must be dict, tuple, or None, and
1737 The default value for the Dict. Must be dict, tuple, or None, and
1705 will be cast to a dict if not None. If `trait` is specified, the
1738 will be cast to a dict if not None. If `trait` is specified, the
1706 `default_value` must conform to the constraints it specifies.
1739 `default_value` must conform to the constraints it specifies.
1707
1740
1708 allow_none : bool [ default False ]
1741 allow_none : bool [ default False ]
1709 Whether to allow the value to be None
1742 Whether to allow the value to be None
1710
1743
1711 """
1744 """
1712 if default_value is NoDefaultSpecified and trait is not None:
1745 if default_value is NoDefaultSpecified and trait is not None:
1713 if not is_trait(trait):
1746 if not is_trait(trait):
1714 default_value = trait
1747 default_value = trait
1715 trait = None
1748 trait = None
1716 if default_value is NoDefaultSpecified:
1749 if default_value is NoDefaultSpecified:
1717 default_value = {}
1750 default_value = {}
1718 if default_value is None:
1751 if default_value is None:
1719 args = None
1752 args = None
1720 elif isinstance(default_value, dict):
1753 elif isinstance(default_value, dict):
1721 args = (default_value,)
1754 args = (default_value,)
1722 elif isinstance(default_value, SequenceTypes):
1755 elif isinstance(default_value, SequenceTypes):
1723 args = (default_value,)
1756 args = (default_value,)
1724 else:
1757 else:
1725 raise TypeError('default value of Dict was %s' % default_value)
1758 raise TypeError('default value of Dict was %s' % default_value)
1726
1759
1727 if is_trait(trait):
1760 if is_trait(trait):
1728 self._trait = trait() if isinstance(trait, type) else trait
1761 self._trait = trait() if isinstance(trait, type) else trait
1729 self._trait.name = 'element'
1762 self._trait.name = 'element'
1730 elif trait is not None:
1763 elif trait is not None:
1731 raise TypeError("`trait` must be a Trait or None, got %s"%repr_type(trait))
1764 raise TypeError("`trait` must be a Trait or None, got %s"%repr_type(trait))
1732
1765
1733 super(Dict,self).__init__(klass=dict, args=args,
1766 super(Dict,self).__init__(klass=dict, args=args,
1734 allow_none=allow_none, **metadata)
1767 allow_none=allow_none, **metadata)
1735
1768
1736 def element_error(self, obj, element, validator):
1769 def element_error(self, obj, element, validator):
1737 e = "Element of the '%s' trait of %s instance must be %s, but a value of %s was specified." \
1770 e = "Element of the '%s' trait of %s instance must be %s, but a value of %s was specified." \
1738 % (self.name, class_of(obj), validator.info(), repr_type(element))
1771 % (self.name, class_of(obj), validator.info(), repr_type(element))
1739 raise TraitError(e)
1772 raise TraitError(e)
1740
1773
1741 def validate(self, obj, value):
1774 def validate(self, obj, value):
1742 value = super(Dict, self).validate(obj, value)
1775 value = super(Dict, self).validate(obj, value)
1743 if value is None:
1776 if value is None:
1744 return value
1777 return value
1745 value = self.validate_elements(obj, value)
1778 value = self.validate_elements(obj, value)
1746 return value
1779 return value
1747
1780
1748 def validate_elements(self, obj, value):
1781 def validate_elements(self, obj, value):
1749 if self._trait is None or isinstance(self._trait, Any):
1782 if self._trait is None or isinstance(self._trait, Any):
1750 return value
1783 return value
1751 validated = {}
1784 validated = {}
1752 for key in value:
1785 for key in value:
1753 v = value[key]
1786 v = value[key]
1754 try:
1787 try:
1755 v = self._trait._validate(obj, v)
1788 v = self._trait._validate(obj, v)
1756 except TraitError:
1789 except TraitError:
1757 self.element_error(obj, v, self._trait)
1790 self.element_error(obj, v, self._trait)
1758 else:
1791 else:
1759 validated[key] = v
1792 validated[key] = v
1760 return self.klass(validated)
1793 return self.klass(validated)
1761
1794
1762 def instance_init(self):
1795 def instance_init(self):
1763 if isinstance(self._trait, TraitType):
1796 if isinstance(self._trait, TraitType):
1764 self._trait.this_class = self.this_class
1797 self._trait.this_class = self.this_class
1765 self._trait.instance_init()
1798 self._trait.instance_init()
1766 super(Dict, self).instance_init()
1799 super(Dict, self).instance_init()
1767
1800
1768
1801
1769 class EventfulDict(Instance):
1802 class EventfulDict(Instance):
1770 """An instance of an EventfulDict."""
1803 """An instance of an EventfulDict."""
1771
1804
1772 def __init__(self, default_value={}, allow_none=False, **metadata):
1805 def __init__(self, default_value={}, allow_none=False, **metadata):
1773 """Create a EventfulDict trait type from a dict.
1806 """Create a EventfulDict trait type from a dict.
1774
1807
1775 The default value is created by doing
1808 The default value is created by doing
1776 ``eventful.EvenfulDict(default_value)``, which creates a copy of the
1809 ``eventful.EvenfulDict(default_value)``, which creates a copy of the
1777 ``default_value``.
1810 ``default_value``.
1778 """
1811 """
1779 if default_value is None:
1812 if default_value is None:
1780 args = None
1813 args = None
1781 elif isinstance(default_value, dict):
1814 elif isinstance(default_value, dict):
1782 args = (default_value,)
1815 args = (default_value,)
1783 elif isinstance(default_value, SequenceTypes):
1816 elif isinstance(default_value, SequenceTypes):
1784 args = (default_value,)
1817 args = (default_value,)
1785 else:
1818 else:
1786 raise TypeError('default value of EventfulDict was %s' % default_value)
1819 raise TypeError('default value of EventfulDict was %s' % default_value)
1787
1820
1788 super(EventfulDict, self).__init__(klass=eventful.EventfulDict, args=args,
1821 super(EventfulDict, self).__init__(klass=eventful.EventfulDict, args=args,
1789 allow_none=allow_none, **metadata)
1822 allow_none=allow_none, **metadata)
1790
1823
1791
1824
1792 class EventfulList(Instance):
1825 class EventfulList(Instance):
1793 """An instance of an EventfulList."""
1826 """An instance of an EventfulList."""
1794
1827
1795 def __init__(self, default_value=None, allow_none=False, **metadata):
1828 def __init__(self, default_value=None, allow_none=False, **metadata):
1796 """Create a EventfulList trait type from a dict.
1829 """Create a EventfulList trait type from a dict.
1797
1830
1798 The default value is created by doing
1831 The default value is created by doing
1799 ``eventful.EvenfulList(default_value)``, which creates a copy of the
1832 ``eventful.EvenfulList(default_value)``, which creates a copy of the
1800 ``default_value``.
1833 ``default_value``.
1801 """
1834 """
1802 if default_value is None:
1835 if default_value is None:
1803 args = ((),)
1836 args = ((),)
1804 else:
1837 else:
1805 args = (default_value,)
1838 args = (default_value,)
1806
1839
1807 super(EventfulList, self).__init__(klass=eventful.EventfulList, args=args,
1840 super(EventfulList, self).__init__(klass=eventful.EventfulList, args=args,
1808 allow_none=allow_none, **metadata)
1841 allow_none=allow_none, **metadata)
1809
1842
1810
1843
1811 class TCPAddress(TraitType):
1844 class TCPAddress(TraitType):
1812 """A trait for an (ip, port) tuple.
1845 """A trait for an (ip, port) tuple.
1813
1846
1814 This allows for both IPv4 IP addresses as well as hostnames.
1847 This allows for both IPv4 IP addresses as well as hostnames.
1815 """
1848 """
1816
1849
1817 default_value = ('127.0.0.1', 0)
1850 default_value = ('127.0.0.1', 0)
1818 info_text = 'an (ip, port) tuple'
1851 info_text = 'an (ip, port) tuple'
1819
1852
1820 def validate(self, obj, value):
1853 def validate(self, obj, value):
1821 if isinstance(value, tuple):
1854 if isinstance(value, tuple):
1822 if len(value) == 2:
1855 if len(value) == 2:
1823 if isinstance(value[0], py3compat.string_types) and isinstance(value[1], int):
1856 if isinstance(value[0], py3compat.string_types) and isinstance(value[1], int):
1824 port = value[1]
1857 port = value[1]
1825 if port >= 0 and port <= 65535:
1858 if port >= 0 and port <= 65535:
1826 return value
1859 return value
1827 self.error(obj, value)
1860 self.error(obj, value)
1828
1861
1829 class CRegExp(TraitType):
1862 class CRegExp(TraitType):
1830 """A casting compiled regular expression trait.
1863 """A casting compiled regular expression trait.
1831
1864
1832 Accepts both strings and compiled regular expressions. The resulting
1865 Accepts both strings and compiled regular expressions. The resulting
1833 attribute will be a compiled regular expression."""
1866 attribute will be a compiled regular expression."""
1834
1867
1835 info_text = 'a regular expression'
1868 info_text = 'a regular expression'
1836
1869
1837 def validate(self, obj, value):
1870 def validate(self, obj, value):
1838 try:
1871 try:
1839 return re.compile(value)
1872 return re.compile(value)
1840 except:
1873 except:
1841 self.error(obj, value)
1874 self.error(obj, value)
General Comments 0
You need to be logged in to leave comments. Login now