##// END OF EJS Templates
Merge pull request #7708 from SylvainCorlay/allow_none...
Min RK -
r20763:454aa2cb merge
parent child Browse files
Show More
@@ -1,432 +1,432 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """
2 """
3 A mixin for :class:`~IPython.core.application.Application` classes that
3 A mixin for :class:`~IPython.core.application.Application` classes that
4 launch InteractiveShell instances, load extensions, etc.
4 launch InteractiveShell instances, load extensions, etc.
5 """
5 """
6
6
7 # Copyright (c) IPython Development Team.
7 # Copyright (c) IPython Development Team.
8 # Distributed under the terms of the Modified BSD License.
8 # Distributed under the terms of the Modified BSD License.
9
9
10 from __future__ import absolute_import
10 from __future__ import absolute_import
11 from __future__ import print_function
11 from __future__ import print_function
12
12
13 import glob
13 import glob
14 import os
14 import os
15 import sys
15 import sys
16
16
17 from IPython.config.application import boolean_flag
17 from IPython.config.application import boolean_flag
18 from IPython.config.configurable import Configurable
18 from IPython.config.configurable import Configurable
19 from IPython.config.loader import Config
19 from IPython.config.loader import Config
20 from IPython.core import pylabtools
20 from IPython.core import pylabtools
21 from IPython.utils import py3compat
21 from IPython.utils import py3compat
22 from IPython.utils.contexts import preserve_keys
22 from IPython.utils.contexts import preserve_keys
23 from IPython.utils.path import filefind
23 from IPython.utils.path import filefind
24 from IPython.utils.traitlets import (
24 from IPython.utils.traitlets import (
25 Unicode, Instance, List, Bool, CaselessStrEnum
25 Unicode, Instance, List, Bool, CaselessStrEnum
26 )
26 )
27 from IPython.lib.inputhook import guis
27 from IPython.lib.inputhook import guis
28
28
29 #-----------------------------------------------------------------------------
29 #-----------------------------------------------------------------------------
30 # Aliases and Flags
30 # Aliases and Flags
31 #-----------------------------------------------------------------------------
31 #-----------------------------------------------------------------------------
32
32
33 gui_keys = tuple(sorted([ key for key in guis if key is not None ]))
33 gui_keys = tuple(sorted([ key for key in guis if key is not None ]))
34
34
35 backend_keys = sorted(pylabtools.backends.keys())
35 backend_keys = sorted(pylabtools.backends.keys())
36 backend_keys.insert(0, 'auto')
36 backend_keys.insert(0, 'auto')
37
37
38 shell_flags = {}
38 shell_flags = {}
39
39
40 addflag = lambda *args: shell_flags.update(boolean_flag(*args))
40 addflag = lambda *args: shell_flags.update(boolean_flag(*args))
41 addflag('autoindent', 'InteractiveShell.autoindent',
41 addflag('autoindent', 'InteractiveShell.autoindent',
42 'Turn on autoindenting.', 'Turn off autoindenting.'
42 'Turn on autoindenting.', 'Turn off autoindenting.'
43 )
43 )
44 addflag('automagic', 'InteractiveShell.automagic',
44 addflag('automagic', 'InteractiveShell.automagic',
45 """Turn on the auto calling of magic commands. Type %%magic at the
45 """Turn on the auto calling of magic commands. Type %%magic at the
46 IPython prompt for more information.""",
46 IPython prompt for more information.""",
47 'Turn off the auto calling of magic commands.'
47 'Turn off the auto calling of magic commands.'
48 )
48 )
49 addflag('pdb', 'InteractiveShell.pdb',
49 addflag('pdb', 'InteractiveShell.pdb',
50 "Enable auto calling the pdb debugger after every exception.",
50 "Enable auto calling the pdb debugger after every exception.",
51 "Disable auto calling the pdb debugger after every exception."
51 "Disable auto calling the pdb debugger after every exception."
52 )
52 )
53 # pydb flag doesn't do any config, as core.debugger switches on import,
53 # pydb flag doesn't do any config, as core.debugger switches on import,
54 # which is before parsing. This just allows the flag to be passed.
54 # which is before parsing. This just allows the flag to be passed.
55 shell_flags.update(dict(
55 shell_flags.update(dict(
56 pydb = ({},
56 pydb = ({},
57 """Use the third party 'pydb' package as debugger, instead of pdb.
57 """Use the third party 'pydb' package as debugger, instead of pdb.
58 Requires that pydb is installed."""
58 Requires that pydb is installed."""
59 )
59 )
60 ))
60 ))
61 addflag('pprint', 'PlainTextFormatter.pprint',
61 addflag('pprint', 'PlainTextFormatter.pprint',
62 "Enable auto pretty printing of results.",
62 "Enable auto pretty printing of results.",
63 "Disable auto pretty printing of results."
63 "Disable auto pretty printing of results."
64 )
64 )
65 addflag('color-info', 'InteractiveShell.color_info',
65 addflag('color-info', 'InteractiveShell.color_info',
66 """IPython can display information about objects via a set of functions,
66 """IPython can display information about objects via a set of functions,
67 and optionally can use colors for this, syntax highlighting
67 and optionally can use colors for this, syntax highlighting
68 source code and various other elements. This is on by default, but can cause
68 source code and various other elements. This is on by default, but can cause
69 problems with some pagers. If you see such problems, you can disable the
69 problems with some pagers. If you see such problems, you can disable the
70 colours.""",
70 colours.""",
71 "Disable using colors for info related things."
71 "Disable using colors for info related things."
72 )
72 )
73 addflag('deep-reload', 'InteractiveShell.deep_reload',
73 addflag('deep-reload', 'InteractiveShell.deep_reload',
74 """Enable deep (recursive) reloading by default. IPython can use the
74 """Enable deep (recursive) reloading by default. IPython can use the
75 deep_reload module which reloads changes in modules recursively (it
75 deep_reload module which reloads changes in modules recursively (it
76 replaces the reload() function, so you don't need to change anything to
76 replaces the reload() function, so you don't need to change anything to
77 use it). deep_reload() forces a full reload of modules whose code may
77 use it). deep_reload() forces a full reload of modules whose code may
78 have changed, which the default reload() function does not. When
78 have changed, which the default reload() function does not. When
79 deep_reload is off, IPython will use the normal reload(), but
79 deep_reload is off, IPython will use the normal reload(), but
80 deep_reload will still be available as dreload(). This feature is off
80 deep_reload will still be available as dreload(). This feature is off
81 by default [which means that you have both normal reload() and
81 by default [which means that you have both normal reload() and
82 dreload()].""",
82 dreload()].""",
83 "Disable deep (recursive) reloading by default."
83 "Disable deep (recursive) reloading by default."
84 )
84 )
85 nosep_config = Config()
85 nosep_config = Config()
86 nosep_config.InteractiveShell.separate_in = ''
86 nosep_config.InteractiveShell.separate_in = ''
87 nosep_config.InteractiveShell.separate_out = ''
87 nosep_config.InteractiveShell.separate_out = ''
88 nosep_config.InteractiveShell.separate_out2 = ''
88 nosep_config.InteractiveShell.separate_out2 = ''
89
89
90 shell_flags['nosep']=(nosep_config, "Eliminate all spacing between prompts.")
90 shell_flags['nosep']=(nosep_config, "Eliminate all spacing between prompts.")
91 shell_flags['pylab'] = (
91 shell_flags['pylab'] = (
92 {'InteractiveShellApp' : {'pylab' : 'auto'}},
92 {'InteractiveShellApp' : {'pylab' : 'auto'}},
93 """Pre-load matplotlib and numpy for interactive use with
93 """Pre-load matplotlib and numpy for interactive use with
94 the default matplotlib backend."""
94 the default matplotlib backend."""
95 )
95 )
96 shell_flags['matplotlib'] = (
96 shell_flags['matplotlib'] = (
97 {'InteractiveShellApp' : {'matplotlib' : 'auto'}},
97 {'InteractiveShellApp' : {'matplotlib' : 'auto'}},
98 """Configure matplotlib for interactive use with
98 """Configure matplotlib for interactive use with
99 the default matplotlib backend."""
99 the default matplotlib backend."""
100 )
100 )
101
101
102 # it's possible we don't want short aliases for *all* of these:
102 # it's possible we don't want short aliases for *all* of these:
103 shell_aliases = dict(
103 shell_aliases = dict(
104 autocall='InteractiveShell.autocall',
104 autocall='InteractiveShell.autocall',
105 colors='InteractiveShell.colors',
105 colors='InteractiveShell.colors',
106 logfile='InteractiveShell.logfile',
106 logfile='InteractiveShell.logfile',
107 logappend='InteractiveShell.logappend',
107 logappend='InteractiveShell.logappend',
108 c='InteractiveShellApp.code_to_run',
108 c='InteractiveShellApp.code_to_run',
109 m='InteractiveShellApp.module_to_run',
109 m='InteractiveShellApp.module_to_run',
110 ext='InteractiveShellApp.extra_extension',
110 ext='InteractiveShellApp.extra_extension',
111 gui='InteractiveShellApp.gui',
111 gui='InteractiveShellApp.gui',
112 pylab='InteractiveShellApp.pylab',
112 pylab='InteractiveShellApp.pylab',
113 matplotlib='InteractiveShellApp.matplotlib',
113 matplotlib='InteractiveShellApp.matplotlib',
114 )
114 )
115 shell_aliases['cache-size'] = 'InteractiveShell.cache_size'
115 shell_aliases['cache-size'] = 'InteractiveShell.cache_size'
116
116
117 #-----------------------------------------------------------------------------
117 #-----------------------------------------------------------------------------
118 # Main classes and functions
118 # Main classes and functions
119 #-----------------------------------------------------------------------------
119 #-----------------------------------------------------------------------------
120
120
121 class InteractiveShellApp(Configurable):
121 class InteractiveShellApp(Configurable):
122 """A Mixin for applications that start InteractiveShell instances.
122 """A Mixin for applications that start InteractiveShell instances.
123
123
124 Provides configurables for loading extensions and executing files
124 Provides configurables for loading extensions and executing files
125 as part of configuring a Shell environment.
125 as part of configuring a Shell environment.
126
126
127 The following methods should be called by the :meth:`initialize` method
127 The following methods should be called by the :meth:`initialize` method
128 of the subclass:
128 of the subclass:
129
129
130 - :meth:`init_path`
130 - :meth:`init_path`
131 - :meth:`init_shell` (to be implemented by the subclass)
131 - :meth:`init_shell` (to be implemented by the subclass)
132 - :meth:`init_gui_pylab`
132 - :meth:`init_gui_pylab`
133 - :meth:`init_extensions`
133 - :meth:`init_extensions`
134 - :meth:`init_code`
134 - :meth:`init_code`
135 """
135 """
136 extensions = List(Unicode, config=True,
136 extensions = List(Unicode, config=True,
137 help="A list of dotted module names of IPython extensions to load."
137 help="A list of dotted module names of IPython extensions to load."
138 )
138 )
139 extra_extension = Unicode('', config=True,
139 extra_extension = Unicode('', config=True,
140 help="dotted module name of an IPython extension to load."
140 help="dotted module name of an IPython extension to load."
141 )
141 )
142
142
143 reraise_ipython_extension_failures = Bool(
143 reraise_ipython_extension_failures = Bool(
144 False,
144 False,
145 config=True,
145 config=True,
146 help="Reraise exceptions encountered loading IPython extensions?",
146 help="Reraise exceptions encountered loading IPython extensions?",
147 )
147 )
148
148
149 # Extensions that are always loaded (not configurable)
149 # Extensions that are always loaded (not configurable)
150 default_extensions = List(Unicode, [u'storemagic'], config=False)
150 default_extensions = List(Unicode, [u'storemagic'], config=False)
151
151
152 hide_initial_ns = Bool(True, config=True,
152 hide_initial_ns = Bool(True, config=True,
153 help="""Should variables loaded at startup (by startup files, exec_lines, etc.)
153 help="""Should variables loaded at startup (by startup files, exec_lines, etc.)
154 be hidden from tools like %who?"""
154 be hidden from tools like %who?"""
155 )
155 )
156
156
157 exec_files = List(Unicode, config=True,
157 exec_files = List(Unicode, config=True,
158 help="""List of files to run at IPython startup."""
158 help="""List of files to run at IPython startup."""
159 )
159 )
160 exec_PYTHONSTARTUP = Bool(True, config=True,
160 exec_PYTHONSTARTUP = Bool(True, config=True,
161 help="""Run the file referenced by the PYTHONSTARTUP environment
161 help="""Run the file referenced by the PYTHONSTARTUP environment
162 variable at IPython startup."""
162 variable at IPython startup."""
163 )
163 )
164 file_to_run = Unicode('', config=True,
164 file_to_run = Unicode('', config=True,
165 help="""A file to be run""")
165 help="""A file to be run""")
166
166
167 exec_lines = List(Unicode, config=True,
167 exec_lines = List(Unicode, config=True,
168 help="""lines of code to run at IPython startup."""
168 help="""lines of code to run at IPython startup."""
169 )
169 )
170 code_to_run = Unicode('', config=True,
170 code_to_run = Unicode('', config=True,
171 help="Execute the given command string."
171 help="Execute the given command string."
172 )
172 )
173 module_to_run = Unicode('', config=True,
173 module_to_run = Unicode('', config=True,
174 help="Run the module as a script."
174 help="Run the module as a script."
175 )
175 )
176 gui = CaselessStrEnum(gui_keys, config=True,
176 gui = CaselessStrEnum(gui_keys, config=True, allow_none=True,
177 help="Enable GUI event loop integration with any of {0}.".format(gui_keys)
177 help="Enable GUI event loop integration with any of {0}.".format(gui_keys)
178 )
178 )
179 matplotlib = CaselessStrEnum(backend_keys,
179 matplotlib = CaselessStrEnum(backend_keys, allow_none=True,
180 config=True,
180 config=True,
181 help="""Configure matplotlib for interactive use with
181 help="""Configure matplotlib for interactive use with
182 the default matplotlib backend."""
182 the default matplotlib backend."""
183 )
183 )
184 pylab = CaselessStrEnum(backend_keys,
184 pylab = CaselessStrEnum(backend_keys, allow_none=True,
185 config=True,
185 config=True,
186 help="""Pre-load matplotlib and numpy for interactive use,
186 help="""Pre-load matplotlib and numpy for interactive use,
187 selecting a particular matplotlib backend and loop integration.
187 selecting a particular matplotlib backend and loop integration.
188 """
188 """
189 )
189 )
190 pylab_import_all = Bool(True, config=True,
190 pylab_import_all = Bool(True, config=True,
191 help="""If true, IPython will populate the user namespace with numpy, pylab, etc.
191 help="""If true, IPython will populate the user namespace with numpy, pylab, etc.
192 and an ``import *`` is done from numpy and pylab, when using pylab mode.
192 and an ``import *`` is done from numpy and pylab, when using pylab mode.
193
193
194 When False, pylab mode should not import any names into the user namespace.
194 When False, pylab mode should not import any names into the user namespace.
195 """
195 """
196 )
196 )
197 shell = Instance('IPython.core.interactiveshell.InteractiveShellABC')
197 shell = Instance('IPython.core.interactiveshell.InteractiveShellABC')
198
198
199 user_ns = Instance(dict, args=None, allow_none=True)
199 user_ns = Instance(dict, args=None, allow_none=True)
200 def _user_ns_changed(self, name, old, new):
200 def _user_ns_changed(self, name, old, new):
201 if self.shell is not None:
201 if self.shell is not None:
202 self.shell.user_ns = new
202 self.shell.user_ns = new
203 self.shell.init_user_ns()
203 self.shell.init_user_ns()
204
204
205 def init_path(self):
205 def init_path(self):
206 """Add current working directory, '', to sys.path"""
206 """Add current working directory, '', to sys.path"""
207 if sys.path[0] != '':
207 if sys.path[0] != '':
208 sys.path.insert(0, '')
208 sys.path.insert(0, '')
209
209
210 def init_shell(self):
210 def init_shell(self):
211 raise NotImplementedError("Override in subclasses")
211 raise NotImplementedError("Override in subclasses")
212
212
213 def init_gui_pylab(self):
213 def init_gui_pylab(self):
214 """Enable GUI event loop integration, taking pylab into account."""
214 """Enable GUI event loop integration, taking pylab into account."""
215 enable = False
215 enable = False
216 shell = self.shell
216 shell = self.shell
217 if self.pylab:
217 if self.pylab:
218 enable = lambda key: shell.enable_pylab(key, import_all=self.pylab_import_all)
218 enable = lambda key: shell.enable_pylab(key, import_all=self.pylab_import_all)
219 key = self.pylab
219 key = self.pylab
220 elif self.matplotlib:
220 elif self.matplotlib:
221 enable = shell.enable_matplotlib
221 enable = shell.enable_matplotlib
222 key = self.matplotlib
222 key = self.matplotlib
223 elif self.gui:
223 elif self.gui:
224 enable = shell.enable_gui
224 enable = shell.enable_gui
225 key = self.gui
225 key = self.gui
226
226
227 if not enable:
227 if not enable:
228 return
228 return
229
229
230 try:
230 try:
231 r = enable(key)
231 r = enable(key)
232 except ImportError:
232 except ImportError:
233 self.log.warn("Eventloop or matplotlib integration failed. Is matplotlib installed?")
233 self.log.warn("Eventloop or matplotlib integration failed. Is matplotlib installed?")
234 self.shell.showtraceback()
234 self.shell.showtraceback()
235 return
235 return
236 except Exception:
236 except Exception:
237 self.log.warn("GUI event loop or pylab initialization failed")
237 self.log.warn("GUI event loop or pylab initialization failed")
238 self.shell.showtraceback()
238 self.shell.showtraceback()
239 return
239 return
240
240
241 if isinstance(r, tuple):
241 if isinstance(r, tuple):
242 gui, backend = r[:2]
242 gui, backend = r[:2]
243 self.log.info("Enabling GUI event loop integration, "
243 self.log.info("Enabling GUI event loop integration, "
244 "eventloop=%s, matplotlib=%s", gui, backend)
244 "eventloop=%s, matplotlib=%s", gui, backend)
245 if key == "auto":
245 if key == "auto":
246 print("Using matplotlib backend: %s" % backend)
246 print("Using matplotlib backend: %s" % backend)
247 else:
247 else:
248 gui = r
248 gui = r
249 self.log.info("Enabling GUI event loop integration, "
249 self.log.info("Enabling GUI event loop integration, "
250 "eventloop=%s", gui)
250 "eventloop=%s", gui)
251
251
252 def init_extensions(self):
252 def init_extensions(self):
253 """Load all IPython extensions in IPythonApp.extensions.
253 """Load all IPython extensions in IPythonApp.extensions.
254
254
255 This uses the :meth:`ExtensionManager.load_extensions` to load all
255 This uses the :meth:`ExtensionManager.load_extensions` to load all
256 the extensions listed in ``self.extensions``.
256 the extensions listed in ``self.extensions``.
257 """
257 """
258 try:
258 try:
259 self.log.debug("Loading IPython extensions...")
259 self.log.debug("Loading IPython extensions...")
260 extensions = self.default_extensions + self.extensions
260 extensions = self.default_extensions + self.extensions
261 if self.extra_extension:
261 if self.extra_extension:
262 extensions.append(self.extra_extension)
262 extensions.append(self.extra_extension)
263 for ext in extensions:
263 for ext in extensions:
264 try:
264 try:
265 self.log.info("Loading IPython extension: %s" % ext)
265 self.log.info("Loading IPython extension: %s" % ext)
266 self.shell.extension_manager.load_extension(ext)
266 self.shell.extension_manager.load_extension(ext)
267 except:
267 except:
268 if self.reraise_ipython_extension_failures:
268 if self.reraise_ipython_extension_failures:
269 raise
269 raise
270 msg = ("Error in loading extension: {ext}\n"
270 msg = ("Error in loading extension: {ext}\n"
271 "Check your config files in {location}".format(
271 "Check your config files in {location}".format(
272 ext=ext,
272 ext=ext,
273 location=self.profile_dir.location
273 location=self.profile_dir.location
274 ))
274 ))
275 self.log.warn(msg, exc_info=True)
275 self.log.warn(msg, exc_info=True)
276 except:
276 except:
277 if self.reraise_ipython_extension_failures:
277 if self.reraise_ipython_extension_failures:
278 raise
278 raise
279 self.log.warn("Unknown error in loading extensions:", exc_info=True)
279 self.log.warn("Unknown error in loading extensions:", exc_info=True)
280
280
281 def init_code(self):
281 def init_code(self):
282 """run the pre-flight code, specified via exec_lines"""
282 """run the pre-flight code, specified via exec_lines"""
283 self._run_startup_files()
283 self._run_startup_files()
284 self._run_exec_lines()
284 self._run_exec_lines()
285 self._run_exec_files()
285 self._run_exec_files()
286
286
287 # Hide variables defined here from %who etc.
287 # Hide variables defined here from %who etc.
288 if self.hide_initial_ns:
288 if self.hide_initial_ns:
289 self.shell.user_ns_hidden.update(self.shell.user_ns)
289 self.shell.user_ns_hidden.update(self.shell.user_ns)
290
290
291 # command-line execution (ipython -i script.py, ipython -m module)
291 # command-line execution (ipython -i script.py, ipython -m module)
292 # should *not* be excluded from %whos
292 # should *not* be excluded from %whos
293 self._run_cmd_line_code()
293 self._run_cmd_line_code()
294 self._run_module()
294 self._run_module()
295
295
296 # flush output, so itwon't be attached to the first cell
296 # flush output, so itwon't be attached to the first cell
297 sys.stdout.flush()
297 sys.stdout.flush()
298 sys.stderr.flush()
298 sys.stderr.flush()
299
299
300 def _run_exec_lines(self):
300 def _run_exec_lines(self):
301 """Run lines of code in IPythonApp.exec_lines in the user's namespace."""
301 """Run lines of code in IPythonApp.exec_lines in the user's namespace."""
302 if not self.exec_lines:
302 if not self.exec_lines:
303 return
303 return
304 try:
304 try:
305 self.log.debug("Running code from IPythonApp.exec_lines...")
305 self.log.debug("Running code from IPythonApp.exec_lines...")
306 for line in self.exec_lines:
306 for line in self.exec_lines:
307 try:
307 try:
308 self.log.info("Running code in user namespace: %s" %
308 self.log.info("Running code in user namespace: %s" %
309 line)
309 line)
310 self.shell.run_cell(line, store_history=False)
310 self.shell.run_cell(line, store_history=False)
311 except:
311 except:
312 self.log.warn("Error in executing line in user "
312 self.log.warn("Error in executing line in user "
313 "namespace: %s" % line)
313 "namespace: %s" % line)
314 self.shell.showtraceback()
314 self.shell.showtraceback()
315 except:
315 except:
316 self.log.warn("Unknown error in handling IPythonApp.exec_lines:")
316 self.log.warn("Unknown error in handling IPythonApp.exec_lines:")
317 self.shell.showtraceback()
317 self.shell.showtraceback()
318
318
319 def _exec_file(self, fname, shell_futures=False):
319 def _exec_file(self, fname, shell_futures=False):
320 try:
320 try:
321 full_filename = filefind(fname, [u'.', self.ipython_dir])
321 full_filename = filefind(fname, [u'.', self.ipython_dir])
322 except IOError as e:
322 except IOError as e:
323 self.log.warn("File not found: %r"%fname)
323 self.log.warn("File not found: %r"%fname)
324 return
324 return
325 # Make sure that the running script gets a proper sys.argv as if it
325 # Make sure that the running script gets a proper sys.argv as if it
326 # were run from a system shell.
326 # were run from a system shell.
327 save_argv = sys.argv
327 save_argv = sys.argv
328 sys.argv = [full_filename] + self.extra_args[1:]
328 sys.argv = [full_filename] + self.extra_args[1:]
329 # protect sys.argv from potential unicode strings on Python 2:
329 # protect sys.argv from potential unicode strings on Python 2:
330 if not py3compat.PY3:
330 if not py3compat.PY3:
331 sys.argv = [ py3compat.cast_bytes(a) for a in sys.argv ]
331 sys.argv = [ py3compat.cast_bytes(a) for a in sys.argv ]
332 try:
332 try:
333 if os.path.isfile(full_filename):
333 if os.path.isfile(full_filename):
334 self.log.info("Running file in user namespace: %s" %
334 self.log.info("Running file in user namespace: %s" %
335 full_filename)
335 full_filename)
336 # Ensure that __file__ is always defined to match Python
336 # Ensure that __file__ is always defined to match Python
337 # behavior.
337 # behavior.
338 with preserve_keys(self.shell.user_ns, '__file__'):
338 with preserve_keys(self.shell.user_ns, '__file__'):
339 self.shell.user_ns['__file__'] = fname
339 self.shell.user_ns['__file__'] = fname
340 if full_filename.endswith('.ipy'):
340 if full_filename.endswith('.ipy'):
341 self.shell.safe_execfile_ipy(full_filename,
341 self.shell.safe_execfile_ipy(full_filename,
342 shell_futures=shell_futures)
342 shell_futures=shell_futures)
343 else:
343 else:
344 # default to python, even without extension
344 # default to python, even without extension
345 self.shell.safe_execfile(full_filename,
345 self.shell.safe_execfile(full_filename,
346 self.shell.user_ns,
346 self.shell.user_ns,
347 shell_futures=shell_futures)
347 shell_futures=shell_futures)
348 finally:
348 finally:
349 sys.argv = save_argv
349 sys.argv = save_argv
350
350
351 def _run_startup_files(self):
351 def _run_startup_files(self):
352 """Run files from profile startup directory"""
352 """Run files from profile startup directory"""
353 startup_dir = self.profile_dir.startup_dir
353 startup_dir = self.profile_dir.startup_dir
354 startup_files = []
354 startup_files = []
355
355
356 if self.exec_PYTHONSTARTUP and os.environ.get('PYTHONSTARTUP', False) and \
356 if self.exec_PYTHONSTARTUP and os.environ.get('PYTHONSTARTUP', False) and \
357 not (self.file_to_run or self.code_to_run or self.module_to_run):
357 not (self.file_to_run or self.code_to_run or self.module_to_run):
358 python_startup = os.environ['PYTHONSTARTUP']
358 python_startup = os.environ['PYTHONSTARTUP']
359 self.log.debug("Running PYTHONSTARTUP file %s...", python_startup)
359 self.log.debug("Running PYTHONSTARTUP file %s...", python_startup)
360 try:
360 try:
361 self._exec_file(python_startup)
361 self._exec_file(python_startup)
362 except:
362 except:
363 self.log.warn("Unknown error in handling PYTHONSTARTUP file %s:", python_startup)
363 self.log.warn("Unknown error in handling PYTHONSTARTUP file %s:", python_startup)
364 self.shell.showtraceback()
364 self.shell.showtraceback()
365 finally:
365 finally:
366 # Many PYTHONSTARTUP files set up the readline completions,
366 # Many PYTHONSTARTUP files set up the readline completions,
367 # but this is often at odds with IPython's own completions.
367 # but this is often at odds with IPython's own completions.
368 # Do not allow PYTHONSTARTUP to set up readline.
368 # Do not allow PYTHONSTARTUP to set up readline.
369 if self.shell.has_readline:
369 if self.shell.has_readline:
370 self.shell.set_readline_completer()
370 self.shell.set_readline_completer()
371
371
372 startup_files += glob.glob(os.path.join(startup_dir, '*.py'))
372 startup_files += glob.glob(os.path.join(startup_dir, '*.py'))
373 startup_files += glob.glob(os.path.join(startup_dir, '*.ipy'))
373 startup_files += glob.glob(os.path.join(startup_dir, '*.ipy'))
374 if not startup_files:
374 if not startup_files:
375 return
375 return
376
376
377 self.log.debug("Running startup files from %s...", startup_dir)
377 self.log.debug("Running startup files from %s...", startup_dir)
378 try:
378 try:
379 for fname in sorted(startup_files):
379 for fname in sorted(startup_files):
380 self._exec_file(fname)
380 self._exec_file(fname)
381 except:
381 except:
382 self.log.warn("Unknown error in handling startup files:")
382 self.log.warn("Unknown error in handling startup files:")
383 self.shell.showtraceback()
383 self.shell.showtraceback()
384
384
385 def _run_exec_files(self):
385 def _run_exec_files(self):
386 """Run files from IPythonApp.exec_files"""
386 """Run files from IPythonApp.exec_files"""
387 if not self.exec_files:
387 if not self.exec_files:
388 return
388 return
389
389
390 self.log.debug("Running files in IPythonApp.exec_files...")
390 self.log.debug("Running files in IPythonApp.exec_files...")
391 try:
391 try:
392 for fname in self.exec_files:
392 for fname in self.exec_files:
393 self._exec_file(fname)
393 self._exec_file(fname)
394 except:
394 except:
395 self.log.warn("Unknown error in handling IPythonApp.exec_files:")
395 self.log.warn("Unknown error in handling IPythonApp.exec_files:")
396 self.shell.showtraceback()
396 self.shell.showtraceback()
397
397
398 def _run_cmd_line_code(self):
398 def _run_cmd_line_code(self):
399 """Run code or file specified at the command-line"""
399 """Run code or file specified at the command-line"""
400 if self.code_to_run:
400 if self.code_to_run:
401 line = self.code_to_run
401 line = self.code_to_run
402 try:
402 try:
403 self.log.info("Running code given at command line (c=): %s" %
403 self.log.info("Running code given at command line (c=): %s" %
404 line)
404 line)
405 self.shell.run_cell(line, store_history=False)
405 self.shell.run_cell(line, store_history=False)
406 except:
406 except:
407 self.log.warn("Error in executing line in user namespace: %s" %
407 self.log.warn("Error in executing line in user namespace: %s" %
408 line)
408 line)
409 self.shell.showtraceback()
409 self.shell.showtraceback()
410
410
411 # Like Python itself, ignore the second if the first of these is present
411 # Like Python itself, ignore the second if the first of these is present
412 elif self.file_to_run:
412 elif self.file_to_run:
413 fname = self.file_to_run
413 fname = self.file_to_run
414 try:
414 try:
415 self._exec_file(fname, shell_futures=True)
415 self._exec_file(fname, shell_futures=True)
416 except:
416 except:
417 self.log.warn("Error in executing file in user namespace: %s" %
417 self.log.warn("Error in executing file in user namespace: %s" %
418 fname)
418 fname)
419 self.shell.showtraceback()
419 self.shell.showtraceback()
420
420
421 def _run_module(self):
421 def _run_module(self):
422 """Run module specified at the command-line."""
422 """Run module specified at the command-line."""
423 if self.module_to_run:
423 if self.module_to_run:
424 # Make sure that the module gets a proper sys.argv as if it were
424 # Make sure that the module gets a proper sys.argv as if it were
425 # run using `python -m`.
425 # run using `python -m`.
426 save_argv = sys.argv
426 save_argv = sys.argv
427 sys.argv = [sys.executable] + self.extra_args
427 sys.argv = [sys.executable] + self.extra_args
428 try:
428 try:
429 self.shell.safe_run_module(self.module_to_run,
429 self.shell.safe_run_module(self.module_to_run,
430 self.shell.user_ns)
430 self.shell.user_ns)
431 finally:
431 finally:
432 sys.argv = save_argv
432 sys.argv = save_argv
@@ -1,468 +1,468 b''
1 """A base class for contents managers."""
1 """A base class for contents managers."""
2
2
3 # Copyright (c) IPython Development Team.
3 # Copyright (c) IPython Development Team.
4 # Distributed under the terms of the Modified BSD License.
4 # Distributed under the terms of the Modified BSD License.
5
5
6 from fnmatch import fnmatch
6 from fnmatch import fnmatch
7 import itertools
7 import itertools
8 import json
8 import json
9 import os
9 import os
10 import re
10 import re
11
11
12 from tornado.web import HTTPError
12 from tornado.web import HTTPError
13
13
14 from .checkpoints import Checkpoints
14 from .checkpoints import Checkpoints
15 from IPython.config.configurable import LoggingConfigurable
15 from IPython.config.configurable import LoggingConfigurable
16 from IPython.nbformat import sign, validate, ValidationError
16 from IPython.nbformat import sign, validate, ValidationError
17 from IPython.nbformat.v4 import new_notebook
17 from IPython.nbformat.v4 import new_notebook
18 from IPython.utils.importstring import import_item
18 from IPython.utils.importstring import import_item
19 from IPython.utils.traitlets import (
19 from IPython.utils.traitlets import (
20 Any,
20 Any,
21 Dict,
21 Dict,
22 Instance,
22 Instance,
23 List,
23 List,
24 TraitError,
24 TraitError,
25 Type,
25 Type,
26 Unicode,
26 Unicode,
27 )
27 )
28 from IPython.utils.py3compat import string_types
28 from IPython.utils.py3compat import string_types
29
29
30 copy_pat = re.compile(r'\-Copy\d*\.')
30 copy_pat = re.compile(r'\-Copy\d*\.')
31
31
32
32
33 class ContentsManager(LoggingConfigurable):
33 class ContentsManager(LoggingConfigurable):
34 """Base class for serving files and directories.
34 """Base class for serving files and directories.
35
35
36 This serves any text or binary file,
36 This serves any text or binary file,
37 as well as directories,
37 as well as directories,
38 with special handling for JSON notebook documents.
38 with special handling for JSON notebook documents.
39
39
40 Most APIs take a path argument,
40 Most APIs take a path argument,
41 which is always an API-style unicode path,
41 which is always an API-style unicode path,
42 and always refers to a directory.
42 and always refers to a directory.
43
43
44 - unicode, not url-escaped
44 - unicode, not url-escaped
45 - '/'-separated
45 - '/'-separated
46 - leading and trailing '/' will be stripped
46 - leading and trailing '/' will be stripped
47 - if unspecified, path defaults to '',
47 - if unspecified, path defaults to '',
48 indicating the root path.
48 indicating the root path.
49
49
50 """
50 """
51
51
52 notary = Instance(sign.NotebookNotary)
52 notary = Instance(sign.NotebookNotary)
53 def _notary_default(self):
53 def _notary_default(self):
54 return sign.NotebookNotary(parent=self)
54 return sign.NotebookNotary(parent=self)
55
55
56 hide_globs = List(Unicode, [
56 hide_globs = List(Unicode, [
57 u'__pycache__', '*.pyc', '*.pyo',
57 u'__pycache__', '*.pyc', '*.pyo',
58 '.DS_Store', '*.so', '*.dylib', '*~',
58 '.DS_Store', '*.so', '*.dylib', '*~',
59 ], config=True, help="""
59 ], config=True, help="""
60 Glob patterns to hide in file and directory listings.
60 Glob patterns to hide in file and directory listings.
61 """)
61 """)
62
62
63 untitled_notebook = Unicode("Untitled", config=True,
63 untitled_notebook = Unicode("Untitled", config=True,
64 help="The base name used when creating untitled notebooks."
64 help="The base name used when creating untitled notebooks."
65 )
65 )
66
66
67 untitled_file = Unicode("untitled", config=True,
67 untitled_file = Unicode("untitled", config=True,
68 help="The base name used when creating untitled files."
68 help="The base name used when creating untitled files."
69 )
69 )
70
70
71 untitled_directory = Unicode("Untitled Folder", config=True,
71 untitled_directory = Unicode("Untitled Folder", config=True,
72 help="The base name used when creating untitled directories."
72 help="The base name used when creating untitled directories."
73 )
73 )
74
74
75 pre_save_hook = Any(None, config=True,
75 pre_save_hook = Any(None, config=True,
76 help="""Python callable or importstring thereof
76 help="""Python callable or importstring thereof
77
77
78 To be called on a contents model prior to save.
78 To be called on a contents model prior to save.
79
79
80 This can be used to process the structure,
80 This can be used to process the structure,
81 such as removing notebook outputs or other side effects that
81 such as removing notebook outputs or other side effects that
82 should not be saved.
82 should not be saved.
83
83
84 It will be called as (all arguments passed by keyword)::
84 It will be called as (all arguments passed by keyword)::
85
85
86 hook(path=path, model=model, contents_manager=self)
86 hook(path=path, model=model, contents_manager=self)
87
87
88 - model: the model to be saved. Includes file contents.
88 - model: the model to be saved. Includes file contents.
89 Modifying this dict will affect the file that is stored.
89 Modifying this dict will affect the file that is stored.
90 - path: the API path of the save destination
90 - path: the API path of the save destination
91 - contents_manager: this ContentsManager instance
91 - contents_manager: this ContentsManager instance
92 """
92 """
93 )
93 )
94 def _pre_save_hook_changed(self, name, old, new):
94 def _pre_save_hook_changed(self, name, old, new):
95 if new and isinstance(new, string_types):
95 if new and isinstance(new, string_types):
96 self.pre_save_hook = import_item(self.pre_save_hook)
96 self.pre_save_hook = import_item(self.pre_save_hook)
97 elif new:
97 elif new:
98 if not callable(new):
98 if not callable(new):
99 raise TraitError("pre_save_hook must be callable")
99 raise TraitError("pre_save_hook must be callable")
100
100
101 def run_pre_save_hook(self, model, path, **kwargs):
101 def run_pre_save_hook(self, model, path, **kwargs):
102 """Run the pre-save hook if defined, and log errors"""
102 """Run the pre-save hook if defined, and log errors"""
103 if self.pre_save_hook:
103 if self.pre_save_hook:
104 try:
104 try:
105 self.log.debug("Running pre-save hook on %s", path)
105 self.log.debug("Running pre-save hook on %s", path)
106 self.pre_save_hook(model=model, path=path, contents_manager=self, **kwargs)
106 self.pre_save_hook(model=model, path=path, contents_manager=self, **kwargs)
107 except Exception:
107 except Exception:
108 self.log.error("Pre-save hook failed on %s", path, exc_info=True)
108 self.log.error("Pre-save hook failed on %s", path, exc_info=True)
109
109
110 checkpoints_class = Type(Checkpoints, config=True)
110 checkpoints_class = Type(Checkpoints, config=True)
111 checkpoints = Instance(Checkpoints, config=True)
111 checkpoints = Instance(Checkpoints, config=True)
112 checkpoints_kwargs = Dict(allow_none=False, config=True)
112 checkpoints_kwargs = Dict(config=True)
113
113
114 def _checkpoints_default(self):
114 def _checkpoints_default(self):
115 return self.checkpoints_class(**self.checkpoints_kwargs)
115 return self.checkpoints_class(**self.checkpoints_kwargs)
116
116
117 def _checkpoints_kwargs_default(self):
117 def _checkpoints_kwargs_default(self):
118 return dict(
118 return dict(
119 parent=self,
119 parent=self,
120 log=self.log,
120 log=self.log,
121 )
121 )
122
122
123 # ContentsManager API part 1: methods that must be
123 # ContentsManager API part 1: methods that must be
124 # implemented in subclasses.
124 # implemented in subclasses.
125
125
126 def dir_exists(self, path):
126 def dir_exists(self, path):
127 """Does the API-style path (directory) actually exist?
127 """Does the API-style path (directory) actually exist?
128
128
129 Like os.path.isdir
129 Like os.path.isdir
130
130
131 Override this method in subclasses.
131 Override this method in subclasses.
132
132
133 Parameters
133 Parameters
134 ----------
134 ----------
135 path : string
135 path : string
136 The path to check
136 The path to check
137
137
138 Returns
138 Returns
139 -------
139 -------
140 exists : bool
140 exists : bool
141 Whether the path does indeed exist.
141 Whether the path does indeed exist.
142 """
142 """
143 raise NotImplementedError
143 raise NotImplementedError
144
144
145 def is_hidden(self, path):
145 def is_hidden(self, path):
146 """Does the API style path correspond to a hidden directory or file?
146 """Does the API style path correspond to a hidden directory or file?
147
147
148 Parameters
148 Parameters
149 ----------
149 ----------
150 path : string
150 path : string
151 The path to check. This is an API path (`/` separated,
151 The path to check. This is an API path (`/` separated,
152 relative to root dir).
152 relative to root dir).
153
153
154 Returns
154 Returns
155 -------
155 -------
156 hidden : bool
156 hidden : bool
157 Whether the path is hidden.
157 Whether the path is hidden.
158
158
159 """
159 """
160 raise NotImplementedError
160 raise NotImplementedError
161
161
162 def file_exists(self, path=''):
162 def file_exists(self, path=''):
163 """Does a file exist at the given path?
163 """Does a file exist at the given path?
164
164
165 Like os.path.isfile
165 Like os.path.isfile
166
166
167 Override this method in subclasses.
167 Override this method in subclasses.
168
168
169 Parameters
169 Parameters
170 ----------
170 ----------
171 name : string
171 name : string
172 The name of the file you are checking.
172 The name of the file you are checking.
173 path : string
173 path : string
174 The relative path to the file's directory (with '/' as separator)
174 The relative path to the file's directory (with '/' as separator)
175
175
176 Returns
176 Returns
177 -------
177 -------
178 exists : bool
178 exists : bool
179 Whether the file exists.
179 Whether the file exists.
180 """
180 """
181 raise NotImplementedError('must be implemented in a subclass')
181 raise NotImplementedError('must be implemented in a subclass')
182
182
183 def exists(self, path):
183 def exists(self, path):
184 """Does a file or directory exist at the given path?
184 """Does a file or directory exist at the given path?
185
185
186 Like os.path.exists
186 Like os.path.exists
187
187
188 Parameters
188 Parameters
189 ----------
189 ----------
190 path : string
190 path : string
191 The relative path to the file's directory (with '/' as separator)
191 The relative path to the file's directory (with '/' as separator)
192
192
193 Returns
193 Returns
194 -------
194 -------
195 exists : bool
195 exists : bool
196 Whether the target exists.
196 Whether the target exists.
197 """
197 """
198 return self.file_exists(path) or self.dir_exists(path)
198 return self.file_exists(path) or self.dir_exists(path)
199
199
200 def get(self, path, content=True, type=None, format=None):
200 def get(self, path, content=True, type=None, format=None):
201 """Get the model of a file or directory with or without content."""
201 """Get the model of a file or directory with or without content."""
202 raise NotImplementedError('must be implemented in a subclass')
202 raise NotImplementedError('must be implemented in a subclass')
203
203
204 def save(self, model, path):
204 def save(self, model, path):
205 """Save the file or directory and return the model with no content.
205 """Save the file or directory and return the model with no content.
206
206
207 Save implementations should call self.run_pre_save_hook(model=model, path=path)
207 Save implementations should call self.run_pre_save_hook(model=model, path=path)
208 prior to writing any data.
208 prior to writing any data.
209 """
209 """
210 raise NotImplementedError('must be implemented in a subclass')
210 raise NotImplementedError('must be implemented in a subclass')
211
211
212 def delete_file(self, path):
212 def delete_file(self, path):
213 """Delete file or directory by path."""
213 """Delete file or directory by path."""
214 raise NotImplementedError('must be implemented in a subclass')
214 raise NotImplementedError('must be implemented in a subclass')
215
215
216 def rename_file(self, old_path, new_path):
216 def rename_file(self, old_path, new_path):
217 """Rename a file."""
217 """Rename a file."""
218 raise NotImplementedError('must be implemented in a subclass')
218 raise NotImplementedError('must be implemented in a subclass')
219
219
220 # ContentsManager API part 2: methods that have useable default
220 # ContentsManager API part 2: methods that have useable default
221 # implementations, but can be overridden in subclasses.
221 # implementations, but can be overridden in subclasses.
222
222
223 def delete(self, path):
223 def delete(self, path):
224 """Delete a file/directory and any associated checkpoints."""
224 """Delete a file/directory and any associated checkpoints."""
225 self.delete_file(path)
225 self.delete_file(path)
226 self.checkpoints.delete_all_checkpoints(path)
226 self.checkpoints.delete_all_checkpoints(path)
227
227
228 def rename(self, old_path, new_path):
228 def rename(self, old_path, new_path):
229 """Rename a file and any checkpoints associated with that file."""
229 """Rename a file and any checkpoints associated with that file."""
230 self.rename_file(old_path, new_path)
230 self.rename_file(old_path, new_path)
231 self.checkpoints.rename_all_checkpoints(old_path, new_path)
231 self.checkpoints.rename_all_checkpoints(old_path, new_path)
232
232
233 def update(self, model, path):
233 def update(self, model, path):
234 """Update the file's path
234 """Update the file's path
235
235
236 For use in PATCH requests, to enable renaming a file without
236 For use in PATCH requests, to enable renaming a file without
237 re-uploading its contents. Only used for renaming at the moment.
237 re-uploading its contents. Only used for renaming at the moment.
238 """
238 """
239 path = path.strip('/')
239 path = path.strip('/')
240 new_path = model.get('path', path).strip('/')
240 new_path = model.get('path', path).strip('/')
241 if path != new_path:
241 if path != new_path:
242 self.rename(path, new_path)
242 self.rename(path, new_path)
243 model = self.get(new_path, content=False)
243 model = self.get(new_path, content=False)
244 return model
244 return model
245
245
246 def info_string(self):
246 def info_string(self):
247 return "Serving contents"
247 return "Serving contents"
248
248
249 def get_kernel_path(self, path, model=None):
249 def get_kernel_path(self, path, model=None):
250 """Return the API path for the kernel
250 """Return the API path for the kernel
251
251
252 KernelManagers can turn this value into a filesystem path,
252 KernelManagers can turn this value into a filesystem path,
253 or ignore it altogether.
253 or ignore it altogether.
254
254
255 The default value here will start kernels in the directory of the
255 The default value here will start kernels in the directory of the
256 notebook server. FileContentsManager overrides this to use the
256 notebook server. FileContentsManager overrides this to use the
257 directory containing the notebook.
257 directory containing the notebook.
258 """
258 """
259 return ''
259 return ''
260
260
261 def increment_filename(self, filename, path='', insert=''):
261 def increment_filename(self, filename, path='', insert=''):
262 """Increment a filename until it is unique.
262 """Increment a filename until it is unique.
263
263
264 Parameters
264 Parameters
265 ----------
265 ----------
266 filename : unicode
266 filename : unicode
267 The name of a file, including extension
267 The name of a file, including extension
268 path : unicode
268 path : unicode
269 The API path of the target's directory
269 The API path of the target's directory
270
270
271 Returns
271 Returns
272 -------
272 -------
273 name : unicode
273 name : unicode
274 A filename that is unique, based on the input filename.
274 A filename that is unique, based on the input filename.
275 """
275 """
276 path = path.strip('/')
276 path = path.strip('/')
277 basename, ext = os.path.splitext(filename)
277 basename, ext = os.path.splitext(filename)
278 for i in itertools.count():
278 for i in itertools.count():
279 if i:
279 if i:
280 insert_i = '{}{}'.format(insert, i)
280 insert_i = '{}{}'.format(insert, i)
281 else:
281 else:
282 insert_i = ''
282 insert_i = ''
283 name = u'{basename}{insert}{ext}'.format(basename=basename,
283 name = u'{basename}{insert}{ext}'.format(basename=basename,
284 insert=insert_i, ext=ext)
284 insert=insert_i, ext=ext)
285 if not self.exists(u'{}/{}'.format(path, name)):
285 if not self.exists(u'{}/{}'.format(path, name)):
286 break
286 break
287 return name
287 return name
288
288
289 def validate_notebook_model(self, model):
289 def validate_notebook_model(self, model):
290 """Add failed-validation message to model"""
290 """Add failed-validation message to model"""
291 try:
291 try:
292 validate(model['content'])
292 validate(model['content'])
293 except ValidationError as e:
293 except ValidationError as e:
294 model['message'] = u'Notebook Validation failed: {}:\n{}'.format(
294 model['message'] = u'Notebook Validation failed: {}:\n{}'.format(
295 e.message, json.dumps(e.instance, indent=1, default=lambda obj: '<UNKNOWN>'),
295 e.message, json.dumps(e.instance, indent=1, default=lambda obj: '<UNKNOWN>'),
296 )
296 )
297 return model
297 return model
298
298
299 def new_untitled(self, path='', type='', ext=''):
299 def new_untitled(self, path='', type='', ext=''):
300 """Create a new untitled file or directory in path
300 """Create a new untitled file or directory in path
301
301
302 path must be a directory
302 path must be a directory
303
303
304 File extension can be specified.
304 File extension can be specified.
305
305
306 Use `new` to create files with a fully specified path (including filename).
306 Use `new` to create files with a fully specified path (including filename).
307 """
307 """
308 path = path.strip('/')
308 path = path.strip('/')
309 if not self.dir_exists(path):
309 if not self.dir_exists(path):
310 raise HTTPError(404, 'No such directory: %s' % path)
310 raise HTTPError(404, 'No such directory: %s' % path)
311
311
312 model = {}
312 model = {}
313 if type:
313 if type:
314 model['type'] = type
314 model['type'] = type
315
315
316 if ext == '.ipynb':
316 if ext == '.ipynb':
317 model.setdefault('type', 'notebook')
317 model.setdefault('type', 'notebook')
318 else:
318 else:
319 model.setdefault('type', 'file')
319 model.setdefault('type', 'file')
320
320
321 insert = ''
321 insert = ''
322 if model['type'] == 'directory':
322 if model['type'] == 'directory':
323 untitled = self.untitled_directory
323 untitled = self.untitled_directory
324 insert = ' '
324 insert = ' '
325 elif model['type'] == 'notebook':
325 elif model['type'] == 'notebook':
326 untitled = self.untitled_notebook
326 untitled = self.untitled_notebook
327 ext = '.ipynb'
327 ext = '.ipynb'
328 elif model['type'] == 'file':
328 elif model['type'] == 'file':
329 untitled = self.untitled_file
329 untitled = self.untitled_file
330 else:
330 else:
331 raise HTTPError(400, "Unexpected model type: %r" % model['type'])
331 raise HTTPError(400, "Unexpected model type: %r" % model['type'])
332
332
333 name = self.increment_filename(untitled + ext, path, insert=insert)
333 name = self.increment_filename(untitled + ext, path, insert=insert)
334 path = u'{0}/{1}'.format(path, name)
334 path = u'{0}/{1}'.format(path, name)
335 return self.new(model, path)
335 return self.new(model, path)
336
336
337 def new(self, model=None, path=''):
337 def new(self, model=None, path=''):
338 """Create a new file or directory and return its model with no content.
338 """Create a new file or directory and return its model with no content.
339
339
340 To create a new untitled entity in a directory, use `new_untitled`.
340 To create a new untitled entity in a directory, use `new_untitled`.
341 """
341 """
342 path = path.strip('/')
342 path = path.strip('/')
343 if model is None:
343 if model is None:
344 model = {}
344 model = {}
345
345
346 if path.endswith('.ipynb'):
346 if path.endswith('.ipynb'):
347 model.setdefault('type', 'notebook')
347 model.setdefault('type', 'notebook')
348 else:
348 else:
349 model.setdefault('type', 'file')
349 model.setdefault('type', 'file')
350
350
351 # no content, not a directory, so fill out new-file model
351 # no content, not a directory, so fill out new-file model
352 if 'content' not in model and model['type'] != 'directory':
352 if 'content' not in model and model['type'] != 'directory':
353 if model['type'] == 'notebook':
353 if model['type'] == 'notebook':
354 model['content'] = new_notebook()
354 model['content'] = new_notebook()
355 model['format'] = 'json'
355 model['format'] = 'json'
356 else:
356 else:
357 model['content'] = ''
357 model['content'] = ''
358 model['type'] = 'file'
358 model['type'] = 'file'
359 model['format'] = 'text'
359 model['format'] = 'text'
360
360
361 model = self.save(model, path)
361 model = self.save(model, path)
362 return model
362 return model
363
363
364 def copy(self, from_path, to_path=None):
364 def copy(self, from_path, to_path=None):
365 """Copy an existing file and return its new model.
365 """Copy an existing file and return its new model.
366
366
367 If to_path not specified, it will be the parent directory of from_path.
367 If to_path not specified, it will be the parent directory of from_path.
368 If to_path is a directory, filename will increment `from_path-Copy#.ext`.
368 If to_path is a directory, filename will increment `from_path-Copy#.ext`.
369
369
370 from_path must be a full path to a file.
370 from_path must be a full path to a file.
371 """
371 """
372 path = from_path.strip('/')
372 path = from_path.strip('/')
373 if to_path is not None:
373 if to_path is not None:
374 to_path = to_path.strip('/')
374 to_path = to_path.strip('/')
375
375
376 if '/' in path:
376 if '/' in path:
377 from_dir, from_name = path.rsplit('/', 1)
377 from_dir, from_name = path.rsplit('/', 1)
378 else:
378 else:
379 from_dir = ''
379 from_dir = ''
380 from_name = path
380 from_name = path
381
381
382 model = self.get(path)
382 model = self.get(path)
383 model.pop('path', None)
383 model.pop('path', None)
384 model.pop('name', None)
384 model.pop('name', None)
385 if model['type'] == 'directory':
385 if model['type'] == 'directory':
386 raise HTTPError(400, "Can't copy directories")
386 raise HTTPError(400, "Can't copy directories")
387
387
388 if to_path is None:
388 if to_path is None:
389 to_path = from_dir
389 to_path = from_dir
390 if self.dir_exists(to_path):
390 if self.dir_exists(to_path):
391 name = copy_pat.sub(u'.', from_name)
391 name = copy_pat.sub(u'.', from_name)
392 to_name = self.increment_filename(name, to_path, insert='-Copy')
392 to_name = self.increment_filename(name, to_path, insert='-Copy')
393 to_path = u'{0}/{1}'.format(to_path, to_name)
393 to_path = u'{0}/{1}'.format(to_path, to_name)
394
394
395 model = self.save(model, to_path)
395 model = self.save(model, to_path)
396 return model
396 return model
397
397
398 def log_info(self):
398 def log_info(self):
399 self.log.info(self.info_string())
399 self.log.info(self.info_string())
400
400
401 def trust_notebook(self, path):
401 def trust_notebook(self, path):
402 """Explicitly trust a notebook
402 """Explicitly trust a notebook
403
403
404 Parameters
404 Parameters
405 ----------
405 ----------
406 path : string
406 path : string
407 The path of a notebook
407 The path of a notebook
408 """
408 """
409 model = self.get(path)
409 model = self.get(path)
410 nb = model['content']
410 nb = model['content']
411 self.log.warn("Trusting notebook %s", path)
411 self.log.warn("Trusting notebook %s", path)
412 self.notary.mark_cells(nb, True)
412 self.notary.mark_cells(nb, True)
413 self.save(model, path)
413 self.save(model, path)
414
414
415 def check_and_sign(self, nb, path=''):
415 def check_and_sign(self, nb, path=''):
416 """Check for trusted cells, and sign the notebook.
416 """Check for trusted cells, and sign the notebook.
417
417
418 Called as a part of saving notebooks.
418 Called as a part of saving notebooks.
419
419
420 Parameters
420 Parameters
421 ----------
421 ----------
422 nb : dict
422 nb : dict
423 The notebook dict
423 The notebook dict
424 path : string
424 path : string
425 The notebook's path (for logging)
425 The notebook's path (for logging)
426 """
426 """
427 if self.notary.check_cells(nb):
427 if self.notary.check_cells(nb):
428 self.notary.sign(nb)
428 self.notary.sign(nb)
429 else:
429 else:
430 self.log.warn("Saving untrusted notebook %s", path)
430 self.log.warn("Saving untrusted notebook %s", path)
431
431
432 def mark_trusted_cells(self, nb, path=''):
432 def mark_trusted_cells(self, nb, path=''):
433 """Mark cells as trusted if the notebook signature matches.
433 """Mark cells as trusted if the notebook signature matches.
434
434
435 Called as a part of loading notebooks.
435 Called as a part of loading notebooks.
436
436
437 Parameters
437 Parameters
438 ----------
438 ----------
439 nb : dict
439 nb : dict
440 The notebook object (in current nbformat)
440 The notebook object (in current nbformat)
441 path : string
441 path : string
442 The notebook's path (for logging)
442 The notebook's path (for logging)
443 """
443 """
444 trusted = self.notary.check_signature(nb)
444 trusted = self.notary.check_signature(nb)
445 if not trusted:
445 if not trusted:
446 self.log.warn("Notebook %s is not trusted", path)
446 self.log.warn("Notebook %s is not trusted", path)
447 self.notary.mark_cells(nb, trusted)
447 self.notary.mark_cells(nb, trusted)
448
448
449 def should_list(self, name):
449 def should_list(self, name):
450 """Should this file/directory name be displayed in a listing?"""
450 """Should this file/directory name be displayed in a listing?"""
451 return not any(fnmatch(name, glob) for glob in self.hide_globs)
451 return not any(fnmatch(name, glob) for glob in self.hide_globs)
452
452
453 # Part 3: Checkpoints API
453 # Part 3: Checkpoints API
454 def create_checkpoint(self, path):
454 def create_checkpoint(self, path):
455 """Create a checkpoint."""
455 """Create a checkpoint."""
456 return self.checkpoints.create_checkpoint(self, path)
456 return self.checkpoints.create_checkpoint(self, path)
457
457
458 def restore_checkpoint(self, checkpoint_id, path):
458 def restore_checkpoint(self, checkpoint_id, path):
459 """
459 """
460 Restore a checkpoint.
460 Restore a checkpoint.
461 """
461 """
462 self.checkpoints.restore_checkpoint(self, checkpoint_id, path)
462 self.checkpoints.restore_checkpoint(self, checkpoint_id, path)
463
463
464 def list_checkpoints(self, path):
464 def list_checkpoints(self, path):
465 return self.checkpoints.list_checkpoints(path)
465 return self.checkpoints.list_checkpoints(path)
466
466
467 def delete_checkpoint(self, checkpoint_id, path):
467 def delete_checkpoint(self, checkpoint_id, path):
468 return self.checkpoints.delete_checkpoint(checkpoint_id, path)
468 return self.checkpoints.delete_checkpoint(checkpoint_id, path)
@@ -1,489 +1,489 b''
1 """Base Widget class. Allows user to create widgets in the back-end that render
1 """Base Widget class. Allows user to create widgets in the back-end that render
2 in the IPython notebook front-end.
2 in the IPython notebook front-end.
3 """
3 """
4 #-----------------------------------------------------------------------------
4 #-----------------------------------------------------------------------------
5 # Copyright (c) 2013, the IPython Development Team.
5 # Copyright (c) 2013, the IPython Development Team.
6 #
6 #
7 # Distributed under the terms of the Modified BSD License.
7 # Distributed under the terms of the Modified BSD License.
8 #
8 #
9 # The full license is in the file COPYING.txt, distributed with this software.
9 # The full license is in the file COPYING.txt, distributed with this software.
10 #-----------------------------------------------------------------------------
10 #-----------------------------------------------------------------------------
11
11
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13 # Imports
13 # Imports
14 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
15 from contextlib import contextmanager
15 from contextlib import contextmanager
16 import collections
16 import collections
17
17
18 from IPython.core.getipython import get_ipython
18 from IPython.core.getipython import get_ipython
19 from IPython.kernel.comm import Comm
19 from IPython.kernel.comm import Comm
20 from IPython.config import LoggingConfigurable
20 from IPython.config import LoggingConfigurable
21 from IPython.utils.importstring import import_item
21 from IPython.utils.importstring import import_item
22 from IPython.utils.traitlets import Unicode, Dict, Instance, Bool, List, \
22 from IPython.utils.traitlets import Unicode, Dict, Instance, Bool, List, \
23 CaselessStrEnum, Tuple, CUnicode, Int, Set
23 CaselessStrEnum, Tuple, CUnicode, Int, Set
24 from IPython.utils.py3compat import string_types
24 from IPython.utils.py3compat import string_types
25
25
26 #-----------------------------------------------------------------------------
26 #-----------------------------------------------------------------------------
27 # Classes
27 # Classes
28 #-----------------------------------------------------------------------------
28 #-----------------------------------------------------------------------------
29 class CallbackDispatcher(LoggingConfigurable):
29 class CallbackDispatcher(LoggingConfigurable):
30 """A structure for registering and running callbacks"""
30 """A structure for registering and running callbacks"""
31 callbacks = List()
31 callbacks = List()
32
32
33 def __call__(self, *args, **kwargs):
33 def __call__(self, *args, **kwargs):
34 """Call all of the registered callbacks."""
34 """Call all of the registered callbacks."""
35 value = None
35 value = None
36 for callback in self.callbacks:
36 for callback in self.callbacks:
37 try:
37 try:
38 local_value = callback(*args, **kwargs)
38 local_value = callback(*args, **kwargs)
39 except Exception as e:
39 except Exception as e:
40 ip = get_ipython()
40 ip = get_ipython()
41 if ip is None:
41 if ip is None:
42 self.log.warn("Exception in callback %s: %s", callback, e, exc_info=True)
42 self.log.warn("Exception in callback %s: %s", callback, e, exc_info=True)
43 else:
43 else:
44 ip.showtraceback()
44 ip.showtraceback()
45 else:
45 else:
46 value = local_value if local_value is not None else value
46 value = local_value if local_value is not None else value
47 return value
47 return value
48
48
49 def register_callback(self, callback, remove=False):
49 def register_callback(self, callback, remove=False):
50 """(Un)Register a callback
50 """(Un)Register a callback
51
51
52 Parameters
52 Parameters
53 ----------
53 ----------
54 callback: method handle
54 callback: method handle
55 Method to be registered or unregistered.
55 Method to be registered or unregistered.
56 remove=False: bool
56 remove=False: bool
57 Whether to unregister the callback."""
57 Whether to unregister the callback."""
58
58
59 # (Un)Register the callback.
59 # (Un)Register the callback.
60 if remove and callback in self.callbacks:
60 if remove and callback in self.callbacks:
61 self.callbacks.remove(callback)
61 self.callbacks.remove(callback)
62 elif not remove and callback not in self.callbacks:
62 elif not remove and callback not in self.callbacks:
63 self.callbacks.append(callback)
63 self.callbacks.append(callback)
64
64
65 def _show_traceback(method):
65 def _show_traceback(method):
66 """decorator for showing tracebacks in IPython"""
66 """decorator for showing tracebacks in IPython"""
67 def m(self, *args, **kwargs):
67 def m(self, *args, **kwargs):
68 try:
68 try:
69 return(method(self, *args, **kwargs))
69 return(method(self, *args, **kwargs))
70 except Exception as e:
70 except Exception as e:
71 ip = get_ipython()
71 ip = get_ipython()
72 if ip is None:
72 if ip is None:
73 self.log.warn("Exception in widget method %s: %s", method, e, exc_info=True)
73 self.log.warn("Exception in widget method %s: %s", method, e, exc_info=True)
74 else:
74 else:
75 ip.showtraceback()
75 ip.showtraceback()
76 return m
76 return m
77
77
78
78
79 def register(key=None):
79 def register(key=None):
80 """Returns a decorator registering a widget class in the widget registry.
80 """Returns a decorator registering a widget class in the widget registry.
81 If no key is provided, the class name is used as a key. A key is
81 If no key is provided, the class name is used as a key. A key is
82 provided for each core IPython widget so that the frontend can use
82 provided for each core IPython widget so that the frontend can use
83 this key regardless of the language of the kernel"""
83 this key regardless of the language of the kernel"""
84 def wrap(widget):
84 def wrap(widget):
85 l = key if key is not None else widget.__module__ + widget.__name__
85 l = key if key is not None else widget.__module__ + widget.__name__
86 Widget.widget_types[l] = widget
86 Widget.widget_types[l] = widget
87 return widget
87 return widget
88 return wrap
88 return wrap
89
89
90
90
91 class Widget(LoggingConfigurable):
91 class Widget(LoggingConfigurable):
92 #-------------------------------------------------------------------------
92 #-------------------------------------------------------------------------
93 # Class attributes
93 # Class attributes
94 #-------------------------------------------------------------------------
94 #-------------------------------------------------------------------------
95 _widget_construction_callback = None
95 _widget_construction_callback = None
96 widgets = {}
96 widgets = {}
97 widget_types = {}
97 widget_types = {}
98
98
99 @staticmethod
99 @staticmethod
100 def on_widget_constructed(callback):
100 def on_widget_constructed(callback):
101 """Registers a callback to be called when a widget is constructed.
101 """Registers a callback to be called when a widget is constructed.
102
102
103 The callback must have the following signature:
103 The callback must have the following signature:
104 callback(widget)"""
104 callback(widget)"""
105 Widget._widget_construction_callback = callback
105 Widget._widget_construction_callback = callback
106
106
107 @staticmethod
107 @staticmethod
108 def _call_widget_constructed(widget):
108 def _call_widget_constructed(widget):
109 """Static method, called when a widget is constructed."""
109 """Static method, called when a widget is constructed."""
110 if Widget._widget_construction_callback is not None and callable(Widget._widget_construction_callback):
110 if Widget._widget_construction_callback is not None and callable(Widget._widget_construction_callback):
111 Widget._widget_construction_callback(widget)
111 Widget._widget_construction_callback(widget)
112
112
113 @staticmethod
113 @staticmethod
114 def handle_comm_opened(comm, msg):
114 def handle_comm_opened(comm, msg):
115 """Static method, called when a widget is constructed."""
115 """Static method, called when a widget is constructed."""
116 widget_class = import_item(msg['content']['data']['widget_class'])
116 widget_class = import_item(msg['content']['data']['widget_class'])
117 widget = widget_class(comm=comm)
117 widget = widget_class(comm=comm)
118
118
119
119
120 #-------------------------------------------------------------------------
120 #-------------------------------------------------------------------------
121 # Traits
121 # Traits
122 #-------------------------------------------------------------------------
122 #-------------------------------------------------------------------------
123 _model_module = Unicode(None, allow_none=True, help="""A requirejs module name
123 _model_module = Unicode(None, allow_none=True, help="""A requirejs module name
124 in which to find _model_name. If empty, look in the global registry.""")
124 in which to find _model_name. If empty, look in the global registry.""")
125 _model_name = Unicode('WidgetModel', help="""Name of the backbone model
125 _model_name = Unicode('WidgetModel', help="""Name of the backbone model
126 registered in the front-end to create and sync this widget with.""")
126 registered in the front-end to create and sync this widget with.""")
127 _view_module = Unicode(help="""A requirejs module in which to find _view_name.
127 _view_module = Unicode(help="""A requirejs module in which to find _view_name.
128 If empty, look in the global registry.""", sync=True)
128 If empty, look in the global registry.""", sync=True)
129 _view_name = Unicode(None, allow_none=True, help="""Default view registered in the front-end
129 _view_name = Unicode(None, allow_none=True, help="""Default view registered in the front-end
130 to use to represent the widget.""", sync=True)
130 to use to represent the widget.""", sync=True)
131 comm = Instance('IPython.kernel.comm.Comm')
131 comm = Instance('IPython.kernel.comm.Comm')
132
132
133 msg_throttle = Int(3, sync=True, help="""Maximum number of msgs the
133 msg_throttle = Int(3, sync=True, help="""Maximum number of msgs the
134 front-end can send before receiving an idle msg from the back-end.""")
134 front-end can send before receiving an idle msg from the back-end.""")
135
135
136 version = Int(0, sync=True, help="""Widget's version""")
136 version = Int(0, sync=True, help="""Widget's version""")
137 keys = List()
137 keys = List()
138 def _keys_default(self):
138 def _keys_default(self):
139 return [name for name in self.traits(sync=True)]
139 return [name for name in self.traits(sync=True)]
140
140
141 _property_lock = Tuple((None, None))
141 _property_lock = Tuple((None, None))
142 _send_state_lock = Int(0)
142 _send_state_lock = Int(0)
143 _states_to_send = Set(allow_none=False)
143 _states_to_send = Set()
144 _display_callbacks = Instance(CallbackDispatcher, ())
144 _display_callbacks = Instance(CallbackDispatcher, ())
145 _msg_callbacks = Instance(CallbackDispatcher, ())
145 _msg_callbacks = Instance(CallbackDispatcher, ())
146
146
147 #-------------------------------------------------------------------------
147 #-------------------------------------------------------------------------
148 # (Con/de)structor
148 # (Con/de)structor
149 #-------------------------------------------------------------------------
149 #-------------------------------------------------------------------------
150 def __init__(self, **kwargs):
150 def __init__(self, **kwargs):
151 """Public constructor"""
151 """Public constructor"""
152 self._model_id = kwargs.pop('model_id', None)
152 self._model_id = kwargs.pop('model_id', None)
153 super(Widget, self).__init__(**kwargs)
153 super(Widget, self).__init__(**kwargs)
154
154
155 Widget._call_widget_constructed(self)
155 Widget._call_widget_constructed(self)
156 self.open()
156 self.open()
157
157
158 def __del__(self):
158 def __del__(self):
159 """Object disposal"""
159 """Object disposal"""
160 self.close()
160 self.close()
161
161
162 #-------------------------------------------------------------------------
162 #-------------------------------------------------------------------------
163 # Properties
163 # Properties
164 #-------------------------------------------------------------------------
164 #-------------------------------------------------------------------------
165
165
166 def open(self):
166 def open(self):
167 """Open a comm to the frontend if one isn't already open."""
167 """Open a comm to the frontend if one isn't already open."""
168 if self.comm is None:
168 if self.comm is None:
169 args = dict(target_name='ipython.widget',
169 args = dict(target_name='ipython.widget',
170 data={'model_name': self._model_name,
170 data={'model_name': self._model_name,
171 'model_module': self._model_module})
171 'model_module': self._model_module})
172 if self._model_id is not None:
172 if self._model_id is not None:
173 args['comm_id'] = self._model_id
173 args['comm_id'] = self._model_id
174 self.comm = Comm(**args)
174 self.comm = Comm(**args)
175
175
176 def _comm_changed(self, name, new):
176 def _comm_changed(self, name, new):
177 """Called when the comm is changed."""
177 """Called when the comm is changed."""
178 if new is None:
178 if new is None:
179 return
179 return
180 self._model_id = self.model_id
180 self._model_id = self.model_id
181
181
182 self.comm.on_msg(self._handle_msg)
182 self.comm.on_msg(self._handle_msg)
183 Widget.widgets[self.model_id] = self
183 Widget.widgets[self.model_id] = self
184
184
185 # first update
185 # first update
186 self.send_state()
186 self.send_state()
187
187
188 @property
188 @property
189 def model_id(self):
189 def model_id(self):
190 """Gets the model id of this widget.
190 """Gets the model id of this widget.
191
191
192 If a Comm doesn't exist yet, a Comm will be created automagically."""
192 If a Comm doesn't exist yet, a Comm will be created automagically."""
193 return self.comm.comm_id
193 return self.comm.comm_id
194
194
195 #-------------------------------------------------------------------------
195 #-------------------------------------------------------------------------
196 # Methods
196 # Methods
197 #-------------------------------------------------------------------------
197 #-------------------------------------------------------------------------
198
198
199 def close(self):
199 def close(self):
200 """Close method.
200 """Close method.
201
201
202 Closes the underlying comm.
202 Closes the underlying comm.
203 When the comm is closed, all of the widget views are automatically
203 When the comm is closed, all of the widget views are automatically
204 removed from the front-end."""
204 removed from the front-end."""
205 if self.comm is not None:
205 if self.comm is not None:
206 Widget.widgets.pop(self.model_id, None)
206 Widget.widgets.pop(self.model_id, None)
207 self.comm.close()
207 self.comm.close()
208 self.comm = None
208 self.comm = None
209
209
210 def send_state(self, key=None):
210 def send_state(self, key=None):
211 """Sends the widget state, or a piece of it, to the front-end.
211 """Sends the widget state, or a piece of it, to the front-end.
212
212
213 Parameters
213 Parameters
214 ----------
214 ----------
215 key : unicode, or iterable (optional)
215 key : unicode, or iterable (optional)
216 A single property's name or iterable of property names to sync with the front-end.
216 A single property's name or iterable of property names to sync with the front-end.
217 """
217 """
218 self._send({
218 self._send({
219 "method" : "update",
219 "method" : "update",
220 "state" : self.get_state(key=key)
220 "state" : self.get_state(key=key)
221 })
221 })
222
222
223 def get_state(self, key=None):
223 def get_state(self, key=None):
224 """Gets the widget state, or a piece of it.
224 """Gets the widget state, or a piece of it.
225
225
226 Parameters
226 Parameters
227 ----------
227 ----------
228 key : unicode or iterable (optional)
228 key : unicode or iterable (optional)
229 A single property's name or iterable of property names to get.
229 A single property's name or iterable of property names to get.
230 """
230 """
231 if key is None:
231 if key is None:
232 keys = self.keys
232 keys = self.keys
233 elif isinstance(key, string_types):
233 elif isinstance(key, string_types):
234 keys = [key]
234 keys = [key]
235 elif isinstance(key, collections.Iterable):
235 elif isinstance(key, collections.Iterable):
236 keys = key
236 keys = key
237 else:
237 else:
238 raise ValueError("key must be a string, an iterable of keys, or None")
238 raise ValueError("key must be a string, an iterable of keys, or None")
239 state = {}
239 state = {}
240 for k in keys:
240 for k in keys:
241 f = self.trait_metadata(k, 'to_json', self._trait_to_json)
241 f = self.trait_metadata(k, 'to_json', self._trait_to_json)
242 value = getattr(self, k)
242 value = getattr(self, k)
243 state[k] = f(value)
243 state[k] = f(value)
244 return state
244 return state
245
245
246 def set_state(self, sync_data):
246 def set_state(self, sync_data):
247 """Called when a state is received from the front-end."""
247 """Called when a state is received from the front-end."""
248 for name in self.keys:
248 for name in self.keys:
249 if name in sync_data:
249 if name in sync_data:
250 json_value = sync_data[name]
250 json_value = sync_data[name]
251 from_json = self.trait_metadata(name, 'from_json', self._trait_from_json)
251 from_json = self.trait_metadata(name, 'from_json', self._trait_from_json)
252 with self._lock_property(name, json_value):
252 with self._lock_property(name, json_value):
253 setattr(self, name, from_json(json_value))
253 setattr(self, name, from_json(json_value))
254
254
255 def send(self, content):
255 def send(self, content):
256 """Sends a custom msg to the widget model in the front-end.
256 """Sends a custom msg to the widget model in the front-end.
257
257
258 Parameters
258 Parameters
259 ----------
259 ----------
260 content : dict
260 content : dict
261 Content of the message to send.
261 Content of the message to send.
262 """
262 """
263 self._send({"method": "custom", "content": content})
263 self._send({"method": "custom", "content": content})
264
264
265 def on_msg(self, callback, remove=False):
265 def on_msg(self, callback, remove=False):
266 """(Un)Register a custom msg receive callback.
266 """(Un)Register a custom msg receive callback.
267
267
268 Parameters
268 Parameters
269 ----------
269 ----------
270 callback: callable
270 callback: callable
271 callback will be passed two arguments when a message arrives::
271 callback will be passed two arguments when a message arrives::
272
272
273 callback(widget, content)
273 callback(widget, content)
274
274
275 remove: bool
275 remove: bool
276 True if the callback should be unregistered."""
276 True if the callback should be unregistered."""
277 self._msg_callbacks.register_callback(callback, remove=remove)
277 self._msg_callbacks.register_callback(callback, remove=remove)
278
278
279 def on_displayed(self, callback, remove=False):
279 def on_displayed(self, callback, remove=False):
280 """(Un)Register a widget displayed callback.
280 """(Un)Register a widget displayed callback.
281
281
282 Parameters
282 Parameters
283 ----------
283 ----------
284 callback: method handler
284 callback: method handler
285 Must have a signature of::
285 Must have a signature of::
286
286
287 callback(widget, **kwargs)
287 callback(widget, **kwargs)
288
288
289 kwargs from display are passed through without modification.
289 kwargs from display are passed through without modification.
290 remove: bool
290 remove: bool
291 True if the callback should be unregistered."""
291 True if the callback should be unregistered."""
292 self._display_callbacks.register_callback(callback, remove=remove)
292 self._display_callbacks.register_callback(callback, remove=remove)
293
293
294 #-------------------------------------------------------------------------
294 #-------------------------------------------------------------------------
295 # Support methods
295 # Support methods
296 #-------------------------------------------------------------------------
296 #-------------------------------------------------------------------------
297 @contextmanager
297 @contextmanager
298 def _lock_property(self, key, value):
298 def _lock_property(self, key, value):
299 """Lock a property-value pair.
299 """Lock a property-value pair.
300
300
301 The value should be the JSON state of the property.
301 The value should be the JSON state of the property.
302
302
303 NOTE: This, in addition to the single lock for all state changes, is
303 NOTE: This, in addition to the single lock for all state changes, is
304 flawed. In the future we may want to look into buffering state changes
304 flawed. In the future we may want to look into buffering state changes
305 back to the front-end."""
305 back to the front-end."""
306 self._property_lock = (key, value)
306 self._property_lock = (key, value)
307 try:
307 try:
308 yield
308 yield
309 finally:
309 finally:
310 self._property_lock = (None, None)
310 self._property_lock = (None, None)
311
311
312 @contextmanager
312 @contextmanager
313 def hold_sync(self):
313 def hold_sync(self):
314 """Hold syncing any state until the context manager is released"""
314 """Hold syncing any state until the context manager is released"""
315 # We increment a value so that this can be nested. Syncing will happen when
315 # We increment a value so that this can be nested. Syncing will happen when
316 # all levels have been released.
316 # all levels have been released.
317 self._send_state_lock += 1
317 self._send_state_lock += 1
318 try:
318 try:
319 yield
319 yield
320 finally:
320 finally:
321 self._send_state_lock -=1
321 self._send_state_lock -=1
322 if self._send_state_lock == 0:
322 if self._send_state_lock == 0:
323 self.send_state(self._states_to_send)
323 self.send_state(self._states_to_send)
324 self._states_to_send.clear()
324 self._states_to_send.clear()
325
325
326 def _should_send_property(self, key, value):
326 def _should_send_property(self, key, value):
327 """Check the property lock (property_lock)"""
327 """Check the property lock (property_lock)"""
328 to_json = self.trait_metadata(key, 'to_json', self._trait_to_json)
328 to_json = self.trait_metadata(key, 'to_json', self._trait_to_json)
329 if (key == self._property_lock[0]
329 if (key == self._property_lock[0]
330 and to_json(value) == self._property_lock[1]):
330 and to_json(value) == self._property_lock[1]):
331 return False
331 return False
332 elif self._send_state_lock > 0:
332 elif self._send_state_lock > 0:
333 self._states_to_send.add(key)
333 self._states_to_send.add(key)
334 return False
334 return False
335 else:
335 else:
336 return True
336 return True
337
337
338 # Event handlers
338 # Event handlers
339 @_show_traceback
339 @_show_traceback
340 def _handle_msg(self, msg):
340 def _handle_msg(self, msg):
341 """Called when a msg is received from the front-end"""
341 """Called when a msg is received from the front-end"""
342 data = msg['content']['data']
342 data = msg['content']['data']
343 method = data['method']
343 method = data['method']
344
344
345 # Handle backbone sync methods CREATE, PATCH, and UPDATE all in one.
345 # Handle backbone sync methods CREATE, PATCH, and UPDATE all in one.
346 if method == 'backbone':
346 if method == 'backbone':
347 if 'sync_data' in data:
347 if 'sync_data' in data:
348 sync_data = data['sync_data']
348 sync_data = data['sync_data']
349 self.set_state(sync_data) # handles all methods
349 self.set_state(sync_data) # handles all methods
350
350
351 # Handle a state request.
351 # Handle a state request.
352 elif method == 'request_state':
352 elif method == 'request_state':
353 self.send_state()
353 self.send_state()
354
354
355 # Handle a custom msg from the front-end.
355 # Handle a custom msg from the front-end.
356 elif method == 'custom':
356 elif method == 'custom':
357 if 'content' in data:
357 if 'content' in data:
358 self._handle_custom_msg(data['content'])
358 self._handle_custom_msg(data['content'])
359
359
360 # Catch remainder.
360 # Catch remainder.
361 else:
361 else:
362 self.log.error('Unknown front-end to back-end widget msg with method "%s"' % method)
362 self.log.error('Unknown front-end to back-end widget msg with method "%s"' % method)
363
363
364 def _handle_custom_msg(self, content):
364 def _handle_custom_msg(self, content):
365 """Called when a custom msg is received."""
365 """Called when a custom msg is received."""
366 self._msg_callbacks(self, content)
366 self._msg_callbacks(self, content)
367
367
368 def _notify_trait(self, name, old_value, new_value):
368 def _notify_trait(self, name, old_value, new_value):
369 """Called when a property has been changed."""
369 """Called when a property has been changed."""
370 # Trigger default traitlet callback machinery. This allows any user
370 # Trigger default traitlet callback machinery. This allows any user
371 # registered validation to be processed prior to allowing the widget
371 # registered validation to be processed prior to allowing the widget
372 # machinery to handle the state.
372 # machinery to handle the state.
373 LoggingConfigurable._notify_trait(self, name, old_value, new_value)
373 LoggingConfigurable._notify_trait(self, name, old_value, new_value)
374
374
375 # Send the state after the user registered callbacks for trait changes
375 # Send the state after the user registered callbacks for trait changes
376 # have all fired (allows for user to validate values).
376 # have all fired (allows for user to validate values).
377 if self.comm is not None and name in self.keys:
377 if self.comm is not None and name in self.keys:
378 # Make sure this isn't information that the front-end just sent us.
378 # Make sure this isn't information that the front-end just sent us.
379 if self._should_send_property(name, new_value):
379 if self._should_send_property(name, new_value):
380 # Send new state to front-end
380 # Send new state to front-end
381 self.send_state(key=name)
381 self.send_state(key=name)
382
382
383 def _handle_displayed(self, **kwargs):
383 def _handle_displayed(self, **kwargs):
384 """Called when a view has been displayed for this widget instance"""
384 """Called when a view has been displayed for this widget instance"""
385 self._display_callbacks(self, **kwargs)
385 self._display_callbacks(self, **kwargs)
386
386
387 def _trait_to_json(self, x):
387 def _trait_to_json(self, x):
388 """Convert a trait value to json
388 """Convert a trait value to json
389
389
390 Traverse lists/tuples and dicts and serialize their values as well.
390 Traverse lists/tuples and dicts and serialize their values as well.
391 Replace any widgets with their model_id
391 Replace any widgets with their model_id
392 """
392 """
393 if isinstance(x, dict):
393 if isinstance(x, dict):
394 return {k: self._trait_to_json(v) for k, v in x.items()}
394 return {k: self._trait_to_json(v) for k, v in x.items()}
395 elif isinstance(x, (list, tuple)):
395 elif isinstance(x, (list, tuple)):
396 return [self._trait_to_json(v) for v in x]
396 return [self._trait_to_json(v) for v in x]
397 elif isinstance(x, Widget):
397 elif isinstance(x, Widget):
398 return "IPY_MODEL_" + x.model_id
398 return "IPY_MODEL_" + x.model_id
399 else:
399 else:
400 return x # Value must be JSON-able
400 return x # Value must be JSON-able
401
401
402 def _trait_from_json(self, x):
402 def _trait_from_json(self, x):
403 """Convert json values to objects
403 """Convert json values to objects
404
404
405 Replace any strings representing valid model id values to Widget references.
405 Replace any strings representing valid model id values to Widget references.
406 """
406 """
407 if isinstance(x, dict):
407 if isinstance(x, dict):
408 return {k: self._trait_from_json(v) for k, v in x.items()}
408 return {k: self._trait_from_json(v) for k, v in x.items()}
409 elif isinstance(x, (list, tuple)):
409 elif isinstance(x, (list, tuple)):
410 return [self._trait_from_json(v) for v in x]
410 return [self._trait_from_json(v) for v in x]
411 elif isinstance(x, string_types) and x.startswith('IPY_MODEL_') and x[10:] in Widget.widgets:
411 elif isinstance(x, string_types) and x.startswith('IPY_MODEL_') and x[10:] in Widget.widgets:
412 # we want to support having child widgets at any level in a hierarchy
412 # we want to support having child widgets at any level in a hierarchy
413 # trusting that a widget UUID will not appear out in the wild
413 # trusting that a widget UUID will not appear out in the wild
414 return Widget.widgets[x[10:]]
414 return Widget.widgets[x[10:]]
415 else:
415 else:
416 return x
416 return x
417
417
418 def _ipython_display_(self, **kwargs):
418 def _ipython_display_(self, **kwargs):
419 """Called when `IPython.display.display` is called on the widget."""
419 """Called when `IPython.display.display` is called on the widget."""
420 # Show view.
420 # Show view.
421 if self._view_name is not None:
421 if self._view_name is not None:
422 self._send({"method": "display"})
422 self._send({"method": "display"})
423 self._handle_displayed(**kwargs)
423 self._handle_displayed(**kwargs)
424
424
425 def _send(self, msg):
425 def _send(self, msg):
426 """Sends a message to the model in the front-end."""
426 """Sends a message to the model in the front-end."""
427 self.comm.send(msg)
427 self.comm.send(msg)
428
428
429
429
430 class DOMWidget(Widget):
430 class DOMWidget(Widget):
431 visible = Bool(True, allow_none=True, help="Whether the widget is visible. False collapses the empty space, while None preserves the empty space.", sync=True)
431 visible = Bool(True, allow_none=True, help="Whether the widget is visible. False collapses the empty space, while None preserves the empty space.", sync=True)
432 _css = Tuple(sync=True, help="CSS property list: (selector, key, value)")
432 _css = Tuple(sync=True, help="CSS property list: (selector, key, value)")
433 _dom_classes = Tuple(sync=True, help="DOM classes applied to widget.$el.")
433 _dom_classes = Tuple(sync=True, help="DOM classes applied to widget.$el.")
434
434
435 width = CUnicode(sync=True)
435 width = CUnicode(sync=True)
436 height = CUnicode(sync=True)
436 height = CUnicode(sync=True)
437 # A default padding of 2.5 px makes the widgets look nice when displayed inline.
437 # A default padding of 2.5 px makes the widgets look nice when displayed inline.
438 padding = CUnicode(sync=True)
438 padding = CUnicode(sync=True)
439 margin = CUnicode(sync=True)
439 margin = CUnicode(sync=True)
440
440
441 color = Unicode(sync=True)
441 color = Unicode(sync=True)
442 background_color = Unicode(sync=True)
442 background_color = Unicode(sync=True)
443 border_color = Unicode(sync=True)
443 border_color = Unicode(sync=True)
444
444
445 border_width = CUnicode(sync=True)
445 border_width = CUnicode(sync=True)
446 border_radius = CUnicode(sync=True)
446 border_radius = CUnicode(sync=True)
447 border_style = CaselessStrEnum(values=[ # http://www.w3schools.com/cssref/pr_border-style.asp
447 border_style = CaselessStrEnum(values=[ # http://www.w3schools.com/cssref/pr_border-style.asp
448 'none',
448 'none',
449 'hidden',
449 'hidden',
450 'dotted',
450 'dotted',
451 'dashed',
451 'dashed',
452 'solid',
452 'solid',
453 'double',
453 'double',
454 'groove',
454 'groove',
455 'ridge',
455 'ridge',
456 'inset',
456 'inset',
457 'outset',
457 'outset',
458 'initial',
458 'initial',
459 'inherit', ''],
459 'inherit', ''],
460 default_value='', sync=True)
460 default_value='', sync=True)
461
461
462 font_style = CaselessStrEnum(values=[ # http://www.w3schools.com/cssref/pr_font_font-style.asp
462 font_style = CaselessStrEnum(values=[ # http://www.w3schools.com/cssref/pr_font_font-style.asp
463 'normal',
463 'normal',
464 'italic',
464 'italic',
465 'oblique',
465 'oblique',
466 'initial',
466 'initial',
467 'inherit', ''],
467 'inherit', ''],
468 default_value='', sync=True)
468 default_value='', sync=True)
469 font_weight = CaselessStrEnum(values=[ # http://www.w3schools.com/cssref/pr_font_weight.asp
469 font_weight = CaselessStrEnum(values=[ # http://www.w3schools.com/cssref/pr_font_weight.asp
470 'normal',
470 'normal',
471 'bold',
471 'bold',
472 'bolder',
472 'bolder',
473 'lighter',
473 'lighter',
474 'initial',
474 'initial',
475 'inherit', ''] + list(map(str, range(100,1000,100))),
475 'inherit', ''] + list(map(str, range(100,1000,100))),
476 default_value='', sync=True)
476 default_value='', sync=True)
477 font_size = CUnicode(sync=True)
477 font_size = CUnicode(sync=True)
478 font_family = Unicode(sync=True)
478 font_family = Unicode(sync=True)
479
479
480 def __init__(self, *pargs, **kwargs):
480 def __init__(self, *pargs, **kwargs):
481 super(DOMWidget, self).__init__(*pargs, **kwargs)
481 super(DOMWidget, self).__init__(*pargs, **kwargs)
482
482
483 def _validate_border(name, old, new):
483 def _validate_border(name, old, new):
484 if new is not None and new != '':
484 if new is not None and new != '':
485 if name != 'border_width' and not self.border_width:
485 if name != 'border_width' and not self.border_width:
486 self.border_width = 1
486 self.border_width = 1
487 if name != 'border_style' and self.border_style == '':
487 if name != 'border_style' and self.border_style == '':
488 self.border_style = 'solid'
488 self.border_style = 'solid'
489 self.on_trait_change(_validate_border, ['border_width', 'border_style', 'border_color'])
489 self.on_trait_change(_validate_border, ['border_width', 'border_style', 'border_color'])
@@ -1,80 +1,80 b''
1 """Box class.
1 """Box class.
2
2
3 Represents a container that can be used to group other widgets.
3 Represents a container that can be used to group other widgets.
4 """
4 """
5
5
6 # Copyright (c) IPython Development Team.
6 # Copyright (c) IPython Development Team.
7 # Distributed under the terms of the Modified BSD License.
7 # Distributed under the terms of the Modified BSD License.
8
8
9 from .widget import DOMWidget, register
9 from .widget import DOMWidget, register
10 from IPython.utils.traitlets import Unicode, Tuple, TraitError, Int, CaselessStrEnum
10 from IPython.utils.traitlets import Unicode, Tuple, TraitError, Int, CaselessStrEnum
11 from IPython.utils.warn import DeprecatedClass
11 from IPython.utils.warn import DeprecatedClass
12
12
13 @register('IPython.Box')
13 @register('IPython.Box')
14 class Box(DOMWidget):
14 class Box(DOMWidget):
15 """Displays multiple widgets in a group."""
15 """Displays multiple widgets in a group."""
16 _view_name = Unicode('BoxView', sync=True)
16 _view_name = Unicode('BoxView', sync=True)
17
17
18 # Child widgets in the container.
18 # Child widgets in the container.
19 # Using a tuple here to force reassignment to update the list.
19 # Using a tuple here to force reassignment to update the list.
20 # When a proper notifying-list trait exists, that is what should be used here.
20 # When a proper notifying-list trait exists, that is what should be used here.
21 children = Tuple(sync=True, allow_none=False)
21 children = Tuple(sync=True)
22
22
23 _overflow_values = ['visible', 'hidden', 'scroll', 'auto', 'initial', 'inherit', '']
23 _overflow_values = ['visible', 'hidden', 'scroll', 'auto', 'initial', 'inherit', '']
24 overflow_x = CaselessStrEnum(
24 overflow_x = CaselessStrEnum(
25 values=_overflow_values,
25 values=_overflow_values,
26 default_value='', allow_none=False, sync=True, help="""Specifies what
26 default_value='', sync=True, help="""Specifies what
27 happens to content that is too large for the rendered region.""")
27 happens to content that is too large for the rendered region.""")
28 overflow_y = CaselessStrEnum(
28 overflow_y = CaselessStrEnum(
29 values=_overflow_values,
29 values=_overflow_values,
30 default_value='', allow_none=False, sync=True, help="""Specifies what
30 default_value='', sync=True, help="""Specifies what
31 happens to content that is too large for the rendered region.""")
31 happens to content that is too large for the rendered region.""")
32
32
33 box_style = CaselessStrEnum(
33 box_style = CaselessStrEnum(
34 values=['success', 'info', 'warning', 'danger', ''],
34 values=['success', 'info', 'warning', 'danger', ''],
35 default_value='', allow_none=True, sync=True, help="""Use a
35 default_value='', allow_none=True, sync=True, help="""Use a
36 predefined styling for the box.""")
36 predefined styling for the box.""")
37
37
38 def __init__(self, children = (), **kwargs):
38 def __init__(self, children = (), **kwargs):
39 kwargs['children'] = children
39 kwargs['children'] = children
40 super(Box, self).__init__(**kwargs)
40 super(Box, self).__init__(**kwargs)
41 self.on_displayed(Box._fire_children_displayed)
41 self.on_displayed(Box._fire_children_displayed)
42
42
43 def _fire_children_displayed(self):
43 def _fire_children_displayed(self):
44 for child in self.children:
44 for child in self.children:
45 child._handle_displayed()
45 child._handle_displayed()
46
46
47
47
48 @register('IPython.FlexBox')
48 @register('IPython.FlexBox')
49 class FlexBox(Box):
49 class FlexBox(Box):
50 """Displays multiple widgets using the flexible box model."""
50 """Displays multiple widgets using the flexible box model."""
51 _view_name = Unicode('FlexBoxView', sync=True)
51 _view_name = Unicode('FlexBoxView', sync=True)
52 orientation = CaselessStrEnum(values=['vertical', 'horizontal'], default_value='vertical', sync=True)
52 orientation = CaselessStrEnum(values=['vertical', 'horizontal'], default_value='vertical', sync=True)
53 flex = Int(0, sync=True, help="""Specify the flexible-ness of the model.""")
53 flex = Int(0, sync=True, help="""Specify the flexible-ness of the model.""")
54 def _flex_changed(self, name, old, new):
54 def _flex_changed(self, name, old, new):
55 new = min(max(0, new), 2)
55 new = min(max(0, new), 2)
56 if self.flex != new:
56 if self.flex != new:
57 self.flex = new
57 self.flex = new
58
58
59 _locations = ['start', 'center', 'end', 'baseline', 'stretch']
59 _locations = ['start', 'center', 'end', 'baseline', 'stretch']
60 pack = CaselessStrEnum(
60 pack = CaselessStrEnum(
61 values=_locations,
61 values=_locations,
62 default_value='start', allow_none=False, sync=True)
62 default_value='start', sync=True)
63 align = CaselessStrEnum(
63 align = CaselessStrEnum(
64 values=_locations,
64 values=_locations,
65 default_value='start', allow_none=False, sync=True)
65 default_value='start', sync=True)
66
66
67
67
68 def VBox(*pargs, **kwargs):
68 def VBox(*pargs, **kwargs):
69 """Displays multiple widgets vertically using the flexible box model."""
69 """Displays multiple widgets vertically using the flexible box model."""
70 kwargs['orientation'] = 'vertical'
70 kwargs['orientation'] = 'vertical'
71 return FlexBox(*pargs, **kwargs)
71 return FlexBox(*pargs, **kwargs)
72
72
73 def HBox(*pargs, **kwargs):
73 def HBox(*pargs, **kwargs):
74 """Displays multiple widgets horizontally using the flexible box model."""
74 """Displays multiple widgets horizontally using the flexible box model."""
75 kwargs['orientation'] = 'horizontal'
75 kwargs['orientation'] = 'horizontal'
76 return FlexBox(*pargs, **kwargs)
76 return FlexBox(*pargs, **kwargs)
77
77
78
78
79 # Remove in IPython 4.0
79 # Remove in IPython 4.0
80 ContainerWidget = DeprecatedClass(Box, 'ContainerWidget')
80 ContainerWidget = DeprecatedClass(Box, 'ContainerWidget')
@@ -1,298 +1,296 b''
1 """Float class.
1 """Float class.
2
2
3 Represents an unbounded float using a widget.
3 Represents an unbounded float using a widget.
4 """
4 """
5 #-----------------------------------------------------------------------------
5 #-----------------------------------------------------------------------------
6 # Copyright (c) 2013, the IPython Development Team.
6 # Copyright (c) 2013, the IPython Development Team.
7 #
7 #
8 # Distributed under the terms of the Modified BSD License.
8 # Distributed under the terms of the Modified BSD License.
9 #
9 #
10 # The full license is in the file COPYING.txt, distributed with this software.
10 # The full license is in the file COPYING.txt, distributed with this software.
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12
12
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14 # Imports
14 # Imports
15 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
16 from .widget import DOMWidget, register
16 from .widget import DOMWidget, register
17 from IPython.utils.traitlets import Unicode, CFloat, Bool, CaselessStrEnum, Tuple
17 from IPython.utils.traitlets import Unicode, CFloat, Bool, CaselessStrEnum, Tuple
18 from IPython.utils.warn import DeprecatedClass
18 from IPython.utils.warn import DeprecatedClass
19
19
20 #-----------------------------------------------------------------------------
20 #-----------------------------------------------------------------------------
21 # Classes
21 # Classes
22 #-----------------------------------------------------------------------------
22 #-----------------------------------------------------------------------------
23 class _Float(DOMWidget):
23 class _Float(DOMWidget):
24 value = CFloat(0.0, help="Float value", sync=True)
24 value = CFloat(0.0, help="Float value", sync=True)
25 disabled = Bool(False, help="Enable or disable user changes", sync=True)
25 disabled = Bool(False, help="Enable or disable user changes", sync=True)
26 description = Unicode(help="Description of the value this widget represents", sync=True)
26 description = Unicode(help="Description of the value this widget represents", sync=True)
27
27
28 def __init__(self, value=None, **kwargs):
28 def __init__(self, value=None, **kwargs):
29 if value is not None:
29 if value is not None:
30 kwargs['value'] = value
30 kwargs['value'] = value
31 super(_Float, self).__init__(**kwargs)
31 super(_Float, self).__init__(**kwargs)
32
32
33 class _BoundedFloat(_Float):
33 class _BoundedFloat(_Float):
34 max = CFloat(100.0, help="Max value", sync=True)
34 max = CFloat(100.0, help="Max value", sync=True)
35 min = CFloat(0.0, help="Min value", sync=True)
35 min = CFloat(0.0, help="Min value", sync=True)
36 step = CFloat(0.1, help="Minimum step that the value can take (ignored by some views)", sync=True)
36 step = CFloat(0.1, help="Minimum step that the value can take (ignored by some views)", sync=True)
37
37
38 def __init__(self, *pargs, **kwargs):
38 def __init__(self, *pargs, **kwargs):
39 """Constructor"""
39 """Constructor"""
40 super(_BoundedFloat, self).__init__(*pargs, **kwargs)
40 super(_BoundedFloat, self).__init__(*pargs, **kwargs)
41 self._handle_value_changed('value', None, self.value)
41 self._handle_value_changed('value', None, self.value)
42 self._handle_max_changed('max', None, self.max)
42 self._handle_max_changed('max', None, self.max)
43 self._handle_min_changed('min', None, self.min)
43 self._handle_min_changed('min', None, self.min)
44 self.on_trait_change(self._handle_value_changed, 'value')
44 self.on_trait_change(self._handle_value_changed, 'value')
45 self.on_trait_change(self._handle_max_changed, 'max')
45 self.on_trait_change(self._handle_max_changed, 'max')
46 self.on_trait_change(self._handle_min_changed, 'min')
46 self.on_trait_change(self._handle_min_changed, 'min')
47
47
48 def _handle_value_changed(self, name, old, new):
48 def _handle_value_changed(self, name, old, new):
49 """Validate value."""
49 """Validate value."""
50 if self.min > new or new > self.max:
50 if self.min > new or new > self.max:
51 self.value = min(max(new, self.min), self.max)
51 self.value = min(max(new, self.min), self.max)
52
52
53 def _handle_max_changed(self, name, old, new):
53 def _handle_max_changed(self, name, old, new):
54 """Make sure the min is always <= the max."""
54 """Make sure the min is always <= the max."""
55 if new < self.min:
55 if new < self.min:
56 raise ValueError("setting max < min")
56 raise ValueError("setting max < min")
57 if new < self.value:
57 if new < self.value:
58 self.value = new
58 self.value = new
59
59
60 def _handle_min_changed(self, name, old, new):
60 def _handle_min_changed(self, name, old, new):
61 """Make sure the max is always >= the min."""
61 """Make sure the max is always >= the min."""
62 if new > self.max:
62 if new > self.max:
63 raise ValueError("setting min > max")
63 raise ValueError("setting min > max")
64 if new > self.value:
64 if new > self.value:
65 self.value = new
65 self.value = new
66
66
67
67
68 @register('IPython.FloatText')
68 @register('IPython.FloatText')
69 class FloatText(_Float):
69 class FloatText(_Float):
70 """ Displays a float value within a textbox. For a textbox in
70 """ Displays a float value within a textbox. For a textbox in
71 which the value must be within a specific range, use BoundedFloatText.
71 which the value must be within a specific range, use BoundedFloatText.
72
72
73 Parameters
73 Parameters
74 ----------
74 ----------
75 value : float
75 value : float
76 value displayed
76 value displayed
77 description : str
77 description : str
78 description displayed next to the textbox
78 description displayed next to the textbox
79 color : str Unicode color code (eg. '#C13535'), optional
79 color : str Unicode color code (eg. '#C13535'), optional
80 color of the value displayed
80 color of the value displayed
81 """
81 """
82 _view_name = Unicode('FloatTextView', sync=True)
82 _view_name = Unicode('FloatTextView', sync=True)
83
83
84
84
85 @register('IPython.BoundedFloatText')
85 @register('IPython.BoundedFloatText')
86 class BoundedFloatText(_BoundedFloat):
86 class BoundedFloatText(_BoundedFloat):
87 """ Displays a float value within a textbox. Value must be within the range specified.
87 """ Displays a float value within a textbox. Value must be within the range specified.
88 For a textbox in which the value doesn't need to be within a specific range, use FloatText.
88 For a textbox in which the value doesn't need to be within a specific range, use FloatText.
89
89
90 Parameters
90 Parameters
91 ----------
91 ----------
92 value : float
92 value : float
93 value displayed
93 value displayed
94 min : float
94 min : float
95 minimal value of the range of possible values displayed
95 minimal value of the range of possible values displayed
96 max : float
96 max : float
97 maximal value of the range of possible values displayed
97 maximal value of the range of possible values displayed
98 description : str
98 description : str
99 description displayed next to the textbox
99 description displayed next to the textbox
100 color : str Unicode color code (eg. '#C13535'), optional
100 color : str Unicode color code (eg. '#C13535'), optional
101 color of the value displayed
101 color of the value displayed
102 """
102 """
103 _view_name = Unicode('FloatTextView', sync=True)
103 _view_name = Unicode('FloatTextView', sync=True)
104
104
105
105
106 @register('IPython.FloatSlider')
106 @register('IPython.FloatSlider')
107 class FloatSlider(_BoundedFloat):
107 class FloatSlider(_BoundedFloat):
108 """ Slider/trackbar of floating values with the specified range.
108 """ Slider/trackbar of floating values with the specified range.
109
109
110 Parameters
110 Parameters
111 ----------
111 ----------
112 value : float
112 value : float
113 position of the slider
113 position of the slider
114 min : float
114 min : float
115 minimal position of the slider
115 minimal position of the slider
116 max : float
116 max : float
117 maximal position of the slider
117 maximal position of the slider
118 step : float
118 step : float
119 step of the trackbar
119 step of the trackbar
120 description : str
120 description : str
121 name of the slider
121 name of the slider
122 orientation : {'vertical', 'horizontal}, optional
122 orientation : {'vertical', 'horizontal}, optional
123 default is horizontal
123 default is horizontal
124 readout : {True, False}, optional
124 readout : {True, False}, optional
125 default is True, display the current value of the slider next to it
125 default is True, display the current value of the slider next to it
126 slider_color : str Unicode color code (eg. '#C13535'), optional
126 slider_color : str Unicode color code (eg. '#C13535'), optional
127 color of the slider
127 color of the slider
128 color : str Unicode color code (eg. '#C13535'), optional
128 color : str Unicode color code (eg. '#C13535'), optional
129 color of the value displayed (if readout == True)
129 color of the value displayed (if readout == True)
130 """
130 """
131 _view_name = Unicode('FloatSliderView', sync=True)
131 _view_name = Unicode('FloatSliderView', sync=True)
132 orientation = CaselessStrEnum(values=['horizontal', 'vertical'],
132 orientation = CaselessStrEnum(values=['horizontal', 'vertical'],
133 default_value='horizontal',
133 default_value='horizontal', help="Vertical or horizontal.", sync=True)
134 help="Vertical or horizontal.", allow_none=False, sync=True)
135 _range = Bool(False, help="Display a range selector", sync=True)
134 _range = Bool(False, help="Display a range selector", sync=True)
136 readout = Bool(True, help="Display the current value of the slider next to it.", sync=True)
135 readout = Bool(True, help="Display the current value of the slider next to it.", sync=True)
137 slider_color = Unicode(sync=True)
136 slider_color = Unicode(sync=True)
138
137
139
138
140 @register('IPython.FloatProgress')
139 @register('IPython.FloatProgress')
141 class FloatProgress(_BoundedFloat):
140 class FloatProgress(_BoundedFloat):
142 """ Displays a progress bar.
141 """ Displays a progress bar.
143
142
144 Parameters
143 Parameters
145 -----------
144 -----------
146 value : float
145 value : float
147 position within the range of the progress bar
146 position within the range of the progress bar
148 min : float
147 min : float
149 minimal position of the slider
148 minimal position of the slider
150 max : float
149 max : float
151 maximal position of the slider
150 maximal position of the slider
152 step : float
151 step : float
153 step of the progress bar
152 step of the progress bar
154 description : str
153 description : str
155 name of the progress bar
154 name of the progress bar
156 bar_style: {'success', 'info', 'warning', 'danger', ''}, optional
155 bar_style: {'success', 'info', 'warning', 'danger', ''}, optional
157 color of the progress bar, default is '' (blue)
156 color of the progress bar, default is '' (blue)
158 colors are: 'success'-green, 'info'-light blue, 'warning'-orange, 'danger'-red
157 colors are: 'success'-green, 'info'-light blue, 'warning'-orange, 'danger'-red
159 """
158 """
160 _view_name = Unicode('ProgressView', sync=True)
159 _view_name = Unicode('ProgressView', sync=True)
161
160
162 bar_style = CaselessStrEnum(
161 bar_style = CaselessStrEnum(
163 values=['success', 'info', 'warning', 'danger', ''],
162 values=['success', 'info', 'warning', 'danger', ''],
164 default_value='', allow_none=True, sync=True, help="""Use a
163 default_value='', allow_none=True, sync=True, help="""Use a
165 predefined styling for the progess bar.""")
164 predefined styling for the progess bar.""")
166
165
167 class _FloatRange(_Float):
166 class _FloatRange(_Float):
168 value = Tuple(CFloat, CFloat, default_value=(0.0, 1.0), help="Tuple of (lower, upper) bounds", sync=True)
167 value = Tuple(CFloat, CFloat, default_value=(0.0, 1.0), help="Tuple of (lower, upper) bounds", sync=True)
169 lower = CFloat(0.0, help="Lower bound", sync=False)
168 lower = CFloat(0.0, help="Lower bound", sync=False)
170 upper = CFloat(1.0, help="Upper bound", sync=False)
169 upper = CFloat(1.0, help="Upper bound", sync=False)
171
170
172 def __init__(self, *pargs, **kwargs):
171 def __init__(self, *pargs, **kwargs):
173 value_given = 'value' in kwargs
172 value_given = 'value' in kwargs
174 lower_given = 'lower' in kwargs
173 lower_given = 'lower' in kwargs
175 upper_given = 'upper' in kwargs
174 upper_given = 'upper' in kwargs
176 if value_given and (lower_given or upper_given):
175 if value_given and (lower_given or upper_given):
177 raise ValueError("Cannot specify both 'value' and 'lower'/'upper' for range widget")
176 raise ValueError("Cannot specify both 'value' and 'lower'/'upper' for range widget")
178 if lower_given != upper_given:
177 if lower_given != upper_given:
179 raise ValueError("Must specify both 'lower' and 'upper' for range widget")
178 raise ValueError("Must specify both 'lower' and 'upper' for range widget")
180
179
181 DOMWidget.__init__(self, *pargs, **kwargs)
180 DOMWidget.__init__(self, *pargs, **kwargs)
182
181
183 # ensure the traits match, preferring whichever (if any) was given in kwargs
182 # ensure the traits match, preferring whichever (if any) was given in kwargs
184 if value_given:
183 if value_given:
185 self.lower, self.upper = self.value
184 self.lower, self.upper = self.value
186 else:
185 else:
187 self.value = (self.lower, self.upper)
186 self.value = (self.lower, self.upper)
188
187
189 self.on_trait_change(self._validate, ['value', 'upper', 'lower'])
188 self.on_trait_change(self._validate, ['value', 'upper', 'lower'])
190
189
191 def _validate(self, name, old, new):
190 def _validate(self, name, old, new):
192 if name == 'value':
191 if name == 'value':
193 self.lower, self.upper = min(new), max(new)
192 self.lower, self.upper = min(new), max(new)
194 elif name == 'lower':
193 elif name == 'lower':
195 self.value = (new, self.value[1])
194 self.value = (new, self.value[1])
196 elif name == 'upper':
195 elif name == 'upper':
197 self.value = (self.value[0], new)
196 self.value = (self.value[0], new)
198
197
199 class _BoundedFloatRange(_FloatRange):
198 class _BoundedFloatRange(_FloatRange):
200 step = CFloat(1.0, help="Minimum step that the value can take (ignored by some views)", sync=True)
199 step = CFloat(1.0, help="Minimum step that the value can take (ignored by some views)", sync=True)
201 max = CFloat(100.0, help="Max value", sync=True)
200 max = CFloat(100.0, help="Max value", sync=True)
202 min = CFloat(0.0, help="Min value", sync=True)
201 min = CFloat(0.0, help="Min value", sync=True)
203
202
204 def __init__(self, *pargs, **kwargs):
203 def __init__(self, *pargs, **kwargs):
205 any_value_given = 'value' in kwargs or 'upper' in kwargs or 'lower' in kwargs
204 any_value_given = 'value' in kwargs or 'upper' in kwargs or 'lower' in kwargs
206 _FloatRange.__init__(self, *pargs, **kwargs)
205 _FloatRange.__init__(self, *pargs, **kwargs)
207
206
208 # ensure a minimal amount of sanity
207 # ensure a minimal amount of sanity
209 if self.min > self.max:
208 if self.min > self.max:
210 raise ValueError("min must be <= max")
209 raise ValueError("min must be <= max")
211
210
212 if any_value_given:
211 if any_value_given:
213 # if a value was given, clamp it within (min, max)
212 # if a value was given, clamp it within (min, max)
214 self._validate("value", None, self.value)
213 self._validate("value", None, self.value)
215 else:
214 else:
216 # otherwise, set it to 25-75% to avoid the handles overlapping
215 # otherwise, set it to 25-75% to avoid the handles overlapping
217 self.value = (0.75*self.min + 0.25*self.max,
216 self.value = (0.75*self.min + 0.25*self.max,
218 0.25*self.min + 0.75*self.max)
217 0.25*self.min + 0.75*self.max)
219 # callback already set for 'value', 'lower', 'upper'
218 # callback already set for 'value', 'lower', 'upper'
220 self.on_trait_change(self._validate, ['min', 'max'])
219 self.on_trait_change(self._validate, ['min', 'max'])
221
220
222
221
223 def _validate(self, name, old, new):
222 def _validate(self, name, old, new):
224 if name == "min":
223 if name == "min":
225 if new > self.max:
224 if new > self.max:
226 raise ValueError("setting min > max")
225 raise ValueError("setting min > max")
227 self.min = new
226 self.min = new
228 elif name == "max":
227 elif name == "max":
229 if new < self.min:
228 if new < self.min:
230 raise ValueError("setting max < min")
229 raise ValueError("setting max < min")
231 self.max = new
230 self.max = new
232
231
233 low, high = self.value
232 low, high = self.value
234 if name == "value":
233 if name == "value":
235 low, high = min(new), max(new)
234 low, high = min(new), max(new)
236 elif name == "upper":
235 elif name == "upper":
237 if new < self.lower:
236 if new < self.lower:
238 raise ValueError("setting upper < lower")
237 raise ValueError("setting upper < lower")
239 high = new
238 high = new
240 elif name == "lower":
239 elif name == "lower":
241 if new > self.upper:
240 if new > self.upper:
242 raise ValueError("setting lower > upper")
241 raise ValueError("setting lower > upper")
243 low = new
242 low = new
244
243
245 low = max(self.min, min(low, self.max))
244 low = max(self.min, min(low, self.max))
246 high = min(self.max, max(high, self.min))
245 high = min(self.max, max(high, self.min))
247
246
248 # determine the order in which we should update the
247 # determine the order in which we should update the
249 # lower, upper traits to avoid a temporary inverted overlap
248 # lower, upper traits to avoid a temporary inverted overlap
250 lower_first = high < self.lower
249 lower_first = high < self.lower
251
250
252 self.value = (low, high)
251 self.value = (low, high)
253 if lower_first:
252 if lower_first:
254 self.lower = low
253 self.lower = low
255 self.upper = high
254 self.upper = high
256 else:
255 else:
257 self.upper = high
256 self.upper = high
258 self.lower = low
257 self.lower = low
259
258
260
259
261 @register('IPython.FloatRangeSlider')
260 @register('IPython.FloatRangeSlider')
262 class FloatRangeSlider(_BoundedFloatRange):
261 class FloatRangeSlider(_BoundedFloatRange):
263 """ Slider/trackbar for displaying a floating value range (within the specified range of values).
262 """ Slider/trackbar for displaying a floating value range (within the specified range of values).
264
263
265 Parameters
264 Parameters
266 ----------
265 ----------
267 value : float tuple
266 value : float tuple
268 range of the slider displayed
267 range of the slider displayed
269 min : float
268 min : float
270 minimal position of the slider
269 minimal position of the slider
271 max : float
270 max : float
272 maximal position of the slider
271 maximal position of the slider
273 step : float
272 step : float
274 step of the trackbar
273 step of the trackbar
275 description : str
274 description : str
276 name of the slider
275 name of the slider
277 orientation : {'vertical', 'horizontal}, optional
276 orientation : {'vertical', 'horizontal}, optional
278 default is horizontal
277 default is horizontal
279 readout : {True, False}, optional
278 readout : {True, False}, optional
280 default is True, display the current value of the slider next to it
279 default is True, display the current value of the slider next to it
281 slider_color : str Unicode color code (eg. '#C13535'), optional
280 slider_color : str Unicode color code (eg. '#C13535'), optional
282 color of the slider
281 color of the slider
283 color : str Unicode color code (eg. '#C13535'), optional
282 color : str Unicode color code (eg. '#C13535'), optional
284 color of the value displayed (if readout == True)
283 color of the value displayed (if readout == True)
285 """
284 """
286 _view_name = Unicode('FloatSliderView', sync=True)
285 _view_name = Unicode('FloatSliderView', sync=True)
287 orientation = CaselessStrEnum(values=['horizontal', 'vertical'],
286 orientation = CaselessStrEnum(values=['horizontal', 'vertical'],
288 default_value='horizontal', allow_none=False,
287 default_value='horizontal', help="Vertical or horizontal.", sync=True)
289 help="Vertical or horizontal.", sync=True)
290 _range = Bool(True, help="Display a range selector", sync=True)
288 _range = Bool(True, help="Display a range selector", sync=True)
291 readout = Bool(True, help="Display the current value of the slider next to it.", sync=True)
289 readout = Bool(True, help="Display the current value of the slider next to it.", sync=True)
292 slider_color = Unicode(sync=True)
290 slider_color = Unicode(sync=True)
293
291
294 # Remove in IPython 4.0
292 # Remove in IPython 4.0
295 FloatTextWidget = DeprecatedClass(FloatText, 'FloatTextWidget')
293 FloatTextWidget = DeprecatedClass(FloatText, 'FloatTextWidget')
296 BoundedFloatTextWidget = DeprecatedClass(BoundedFloatText, 'BoundedFloatTextWidget')
294 BoundedFloatTextWidget = DeprecatedClass(BoundedFloatText, 'BoundedFloatTextWidget')
297 FloatSliderWidget = DeprecatedClass(FloatSlider, 'FloatSliderWidget')
295 FloatSliderWidget = DeprecatedClass(FloatSlider, 'FloatSliderWidget')
298 FloatProgressWidget = DeprecatedClass(FloatProgress, 'FloatProgressWidget')
296 FloatProgressWidget = DeprecatedClass(FloatProgress, 'FloatProgressWidget')
@@ -1,209 +1,207 b''
1 """Int class.
1 """Int class.
2
2
3 Represents an unbounded int using a widget.
3 Represents an unbounded int using a widget.
4 """
4 """
5 #-----------------------------------------------------------------------------
5 #-----------------------------------------------------------------------------
6 # Copyright (c) 2013, the IPython Development Team.
6 # Copyright (c) 2013, the IPython Development Team.
7 #
7 #
8 # Distributed under the terms of the Modified BSD License.
8 # Distributed under the terms of the Modified BSD License.
9 #
9 #
10 # The full license is in the file COPYING.txt, distributed with this software.
10 # The full license is in the file COPYING.txt, distributed with this software.
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12
12
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14 # Imports
14 # Imports
15 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
16 from .widget import DOMWidget, register
16 from .widget import DOMWidget, register
17 from IPython.utils.traitlets import Unicode, CInt, Bool, CaselessStrEnum, Tuple
17 from IPython.utils.traitlets import Unicode, CInt, Bool, CaselessStrEnum, Tuple
18 from IPython.utils.warn import DeprecatedClass
18 from IPython.utils.warn import DeprecatedClass
19
19
20 #-----------------------------------------------------------------------------
20 #-----------------------------------------------------------------------------
21 # Classes
21 # Classes
22 #-----------------------------------------------------------------------------
22 #-----------------------------------------------------------------------------
23 class _Int(DOMWidget):
23 class _Int(DOMWidget):
24 """Base class used to create widgets that represent an int."""
24 """Base class used to create widgets that represent an int."""
25 value = CInt(0, help="Int value", sync=True)
25 value = CInt(0, help="Int value", sync=True)
26 disabled = Bool(False, help="Enable or disable user changes", sync=True)
26 disabled = Bool(False, help="Enable or disable user changes", sync=True)
27 description = Unicode(help="Description of the value this widget represents", sync=True)
27 description = Unicode(help="Description of the value this widget represents", sync=True)
28
28
29 def __init__(self, value=None, **kwargs):
29 def __init__(self, value=None, **kwargs):
30 if value is not None:
30 if value is not None:
31 kwargs['value'] = value
31 kwargs['value'] = value
32 super(_Int, self).__init__(**kwargs)
32 super(_Int, self).__init__(**kwargs)
33
33
34 class _BoundedInt(_Int):
34 class _BoundedInt(_Int):
35 """Base class used to create widgets that represent a int that is bounded
35 """Base class used to create widgets that represent a int that is bounded
36 by a minium and maximum."""
36 by a minium and maximum."""
37 step = CInt(1, help="Minimum step that the value can take (ignored by some views)", sync=True)
37 step = CInt(1, help="Minimum step that the value can take (ignored by some views)", sync=True)
38 max = CInt(100, help="Max value", sync=True)
38 max = CInt(100, help="Max value", sync=True)
39 min = CInt(0, help="Min value", sync=True)
39 min = CInt(0, help="Min value", sync=True)
40
40
41 def __init__(self, *pargs, **kwargs):
41 def __init__(self, *pargs, **kwargs):
42 """Constructor"""
42 """Constructor"""
43 super(_BoundedInt, self).__init__(*pargs, **kwargs)
43 super(_BoundedInt, self).__init__(*pargs, **kwargs)
44 self._handle_value_changed('value', None, self.value)
44 self._handle_value_changed('value', None, self.value)
45 self._handle_max_changed('max', None, self.max)
45 self._handle_max_changed('max', None, self.max)
46 self._handle_min_changed('min', None, self.min)
46 self._handle_min_changed('min', None, self.min)
47 self.on_trait_change(self._handle_value_changed, 'value')
47 self.on_trait_change(self._handle_value_changed, 'value')
48 self.on_trait_change(self._handle_max_changed, 'max')
48 self.on_trait_change(self._handle_max_changed, 'max')
49 self.on_trait_change(self._handle_min_changed, 'min')
49 self.on_trait_change(self._handle_min_changed, 'min')
50
50
51 def _handle_value_changed(self, name, old, new):
51 def _handle_value_changed(self, name, old, new):
52 """Validate value."""
52 """Validate value."""
53 if self.min > new or new > self.max:
53 if self.min > new or new > self.max:
54 self.value = min(max(new, self.min), self.max)
54 self.value = min(max(new, self.min), self.max)
55
55
56 def _handle_max_changed(self, name, old, new):
56 def _handle_max_changed(self, name, old, new):
57 """Make sure the min is always <= the max."""
57 """Make sure the min is always <= the max."""
58 if new < self.min:
58 if new < self.min:
59 raise ValueError("setting max < min")
59 raise ValueError("setting max < min")
60 if new < self.value:
60 if new < self.value:
61 self.value = new
61 self.value = new
62
62
63 def _handle_min_changed(self, name, old, new):
63 def _handle_min_changed(self, name, old, new):
64 """Make sure the max is always >= the min."""
64 """Make sure the max is always >= the min."""
65 if new > self.max:
65 if new > self.max:
66 raise ValueError("setting min > max")
66 raise ValueError("setting min > max")
67 if new > self.value:
67 if new > self.value:
68 self.value = new
68 self.value = new
69
69
70 @register('IPython.IntText')
70 @register('IPython.IntText')
71 class IntText(_Int):
71 class IntText(_Int):
72 """Textbox widget that represents a int."""
72 """Textbox widget that represents a int."""
73 _view_name = Unicode('IntTextView', sync=True)
73 _view_name = Unicode('IntTextView', sync=True)
74
74
75
75
76 @register('IPython.BoundedIntText')
76 @register('IPython.BoundedIntText')
77 class BoundedIntText(_BoundedInt):
77 class BoundedIntText(_BoundedInt):
78 """Textbox widget that represents a int bounded by a minimum and maximum value."""
78 """Textbox widget that represents a int bounded by a minimum and maximum value."""
79 _view_name = Unicode('IntTextView', sync=True)
79 _view_name = Unicode('IntTextView', sync=True)
80
80
81
81
82 @register('IPython.IntSlider')
82 @register('IPython.IntSlider')
83 class IntSlider(_BoundedInt):
83 class IntSlider(_BoundedInt):
84 """Slider widget that represents a int bounded by a minimum and maximum value."""
84 """Slider widget that represents a int bounded by a minimum and maximum value."""
85 _view_name = Unicode('IntSliderView', sync=True)
85 _view_name = Unicode('IntSliderView', sync=True)
86 orientation = CaselessStrEnum(values=['horizontal', 'vertical'],
86 orientation = CaselessStrEnum(values=['horizontal', 'vertical'],
87 default_value='horizontal', allow_none=False,
87 default_value='horizontal', help="Vertical or horizontal.", sync=True)
88 help="Vertical or horizontal.", sync=True)
89 _range = Bool(False, help="Display a range selector", sync=True)
88 _range = Bool(False, help="Display a range selector", sync=True)
90 readout = Bool(True, help="Display the current value of the slider next to it.", sync=True)
89 readout = Bool(True, help="Display the current value of the slider next to it.", sync=True)
91 slider_color = Unicode(sync=True)
90 slider_color = Unicode(sync=True)
92
91
93
92
94 @register('IPython.IntProgress')
93 @register('IPython.IntProgress')
95 class IntProgress(_BoundedInt):
94 class IntProgress(_BoundedInt):
96 """Progress bar that represents a int bounded by a minimum and maximum value."""
95 """Progress bar that represents a int bounded by a minimum and maximum value."""
97 _view_name = Unicode('ProgressView', sync=True)
96 _view_name = Unicode('ProgressView', sync=True)
98
97
99 bar_style = CaselessStrEnum(
98 bar_style = CaselessStrEnum(
100 values=['success', 'info', 'warning', 'danger', ''],
99 values=['success', 'info', 'warning', 'danger', ''],
101 default_value='', allow_none=True, sync=True, help="""Use a
100 default_value='', allow_none=True, sync=True, help="""Use a
102 predefined styling for the progess bar.""")
101 predefined styling for the progess bar.""")
103
102
104 class _IntRange(_Int):
103 class _IntRange(_Int):
105 value = Tuple(CInt, CInt, default_value=(0, 1), help="Tuple of (lower, upper) bounds", sync=True)
104 value = Tuple(CInt, CInt, default_value=(0, 1), help="Tuple of (lower, upper) bounds", sync=True)
106 lower = CInt(0, help="Lower bound", sync=False)
105 lower = CInt(0, help="Lower bound", sync=False)
107 upper = CInt(1, help="Upper bound", sync=False)
106 upper = CInt(1, help="Upper bound", sync=False)
108
107
109 def __init__(self, *pargs, **kwargs):
108 def __init__(self, *pargs, **kwargs):
110 value_given = 'value' in kwargs
109 value_given = 'value' in kwargs
111 lower_given = 'lower' in kwargs
110 lower_given = 'lower' in kwargs
112 upper_given = 'upper' in kwargs
111 upper_given = 'upper' in kwargs
113 if value_given and (lower_given or upper_given):
112 if value_given and (lower_given or upper_given):
114 raise ValueError("Cannot specify both 'value' and 'lower'/'upper' for range widget")
113 raise ValueError("Cannot specify both 'value' and 'lower'/'upper' for range widget")
115 if lower_given != upper_given:
114 if lower_given != upper_given:
116 raise ValueError("Must specify both 'lower' and 'upper' for range widget")
115 raise ValueError("Must specify both 'lower' and 'upper' for range widget")
117
116
118 super(_IntRange, self).__init__(*pargs, **kwargs)
117 super(_IntRange, self).__init__(*pargs, **kwargs)
119
118
120 # ensure the traits match, preferring whichever (if any) was given in kwargs
119 # ensure the traits match, preferring whichever (if any) was given in kwargs
121 if value_given:
120 if value_given:
122 self.lower, self.upper = self.value
121 self.lower, self.upper = self.value
123 else:
122 else:
124 self.value = (self.lower, self.upper)
123 self.value = (self.lower, self.upper)
125
124
126 self.on_trait_change(self._validate, ['value', 'upper', 'lower'])
125 self.on_trait_change(self._validate, ['value', 'upper', 'lower'])
127
126
128 def _validate(self, name, old, new):
127 def _validate(self, name, old, new):
129 if name == 'value':
128 if name == 'value':
130 self.lower, self.upper = min(new), max(new)
129 self.lower, self.upper = min(new), max(new)
131 elif name == 'lower':
130 elif name == 'lower':
132 self.value = (new, self.value[1])
131 self.value = (new, self.value[1])
133 elif name == 'upper':
132 elif name == 'upper':
134 self.value = (self.value[0], new)
133 self.value = (self.value[0], new)
135
134
136 class _BoundedIntRange(_IntRange):
135 class _BoundedIntRange(_IntRange):
137 step = CInt(1, help="Minimum step that the value can take (ignored by some views)", sync=True)
136 step = CInt(1, help="Minimum step that the value can take (ignored by some views)", sync=True)
138 max = CInt(100, help="Max value", sync=True)
137 max = CInt(100, help="Max value", sync=True)
139 min = CInt(0, help="Min value", sync=True)
138 min = CInt(0, help="Min value", sync=True)
140
139
141 def __init__(self, *pargs, **kwargs):
140 def __init__(self, *pargs, **kwargs):
142 any_value_given = 'value' in kwargs or 'upper' in kwargs or 'lower' in kwargs
141 any_value_given = 'value' in kwargs or 'upper' in kwargs or 'lower' in kwargs
143 _IntRange.__init__(self, *pargs, **kwargs)
142 _IntRange.__init__(self, *pargs, **kwargs)
144
143
145 # ensure a minimal amount of sanity
144 # ensure a minimal amount of sanity
146 if self.min > self.max:
145 if self.min > self.max:
147 raise ValueError("min must be <= max")
146 raise ValueError("min must be <= max")
148
147
149 if any_value_given:
148 if any_value_given:
150 # if a value was given, clamp it within (min, max)
149 # if a value was given, clamp it within (min, max)
151 self._validate("value", None, self.value)
150 self._validate("value", None, self.value)
152 else:
151 else:
153 # otherwise, set it to 25-75% to avoid the handles overlapping
152 # otherwise, set it to 25-75% to avoid the handles overlapping
154 self.value = (0.75*self.min + 0.25*self.max,
153 self.value = (0.75*self.min + 0.25*self.max,
155 0.25*self.min + 0.75*self.max)
154 0.25*self.min + 0.75*self.max)
156 # callback already set for 'value', 'lower', 'upper'
155 # callback already set for 'value', 'lower', 'upper'
157 self.on_trait_change(self._validate, ['min', 'max'])
156 self.on_trait_change(self._validate, ['min', 'max'])
158
157
159 def _validate(self, name, old, new):
158 def _validate(self, name, old, new):
160 if name == "min":
159 if name == "min":
161 if new > self.max:
160 if new > self.max:
162 raise ValueError("setting min > max")
161 raise ValueError("setting min > max")
163 elif name == "max":
162 elif name == "max":
164 if new < self.min:
163 if new < self.min:
165 raise ValueError("setting max < min")
164 raise ValueError("setting max < min")
166
165
167 low, high = self.value
166 low, high = self.value
168 if name == "value":
167 if name == "value":
169 low, high = min(new), max(new)
168 low, high = min(new), max(new)
170 elif name == "upper":
169 elif name == "upper":
171 if new < self.lower:
170 if new < self.lower:
172 raise ValueError("setting upper < lower")
171 raise ValueError("setting upper < lower")
173 high = new
172 high = new
174 elif name == "lower":
173 elif name == "lower":
175 if new > self.upper:
174 if new > self.upper:
176 raise ValueError("setting lower > upper")
175 raise ValueError("setting lower > upper")
177 low = new
176 low = new
178
177
179 low = max(self.min, min(low, self.max))
178 low = max(self.min, min(low, self.max))
180 high = min(self.max, max(high, self.min))
179 high = min(self.max, max(high, self.min))
181
180
182 # determine the order in which we should update the
181 # determine the order in which we should update the
183 # lower, upper traits to avoid a temporary inverted overlap
182 # lower, upper traits to avoid a temporary inverted overlap
184 lower_first = high < self.lower
183 lower_first = high < self.lower
185
184
186 self.value = (low, high)
185 self.value = (low, high)
187 if lower_first:
186 if lower_first:
188 self.lower = low
187 self.lower = low
189 self.upper = high
188 self.upper = high
190 else:
189 else:
191 self.upper = high
190 self.upper = high
192 self.lower = low
191 self.lower = low
193
192
194 @register('IPython.IntRangeSlider')
193 @register('IPython.IntRangeSlider')
195 class IntRangeSlider(_BoundedIntRange):
194 class IntRangeSlider(_BoundedIntRange):
196 """Slider widget that represents a pair of ints between a minimum and maximum value."""
195 """Slider widget that represents a pair of ints between a minimum and maximum value."""
197 _view_name = Unicode('IntSliderView', sync=True)
196 _view_name = Unicode('IntSliderView', sync=True)
198 orientation = CaselessStrEnum(values=['horizontal', 'vertical'],
197 orientation = CaselessStrEnum(values=['horizontal', 'vertical'],
199 default_value='horizontal', allow_none=False,
198 default_value='horizontal', help="Vertical or horizontal.", sync=True)
200 help="Vertical or horizontal.", sync=True)
201 _range = Bool(True, help="Display a range selector", sync=True)
199 _range = Bool(True, help="Display a range selector", sync=True)
202 readout = Bool(True, help="Display the current value of the slider next to it.", sync=True)
200 readout = Bool(True, help="Display the current value of the slider next to it.", sync=True)
203 slider_color = Unicode(sync=True)
201 slider_color = Unicode(sync=True)
204
202
205 # Remove in IPython 4.0
203 # Remove in IPython 4.0
206 IntTextWidget = DeprecatedClass(IntText, 'IntTextWidget')
204 IntTextWidget = DeprecatedClass(IntText, 'IntTextWidget')
207 BoundedIntTextWidget = DeprecatedClass(BoundedIntText, 'BoundedIntTextWidget')
205 BoundedIntTextWidget = DeprecatedClass(BoundedIntText, 'BoundedIntTextWidget')
208 IntSliderWidget = DeprecatedClass(IntSlider, 'IntSliderWidget')
206 IntSliderWidget = DeprecatedClass(IntSlider, 'IntSliderWidget')
209 IntProgressWidget = DeprecatedClass(IntProgress, 'IntProgressWidget')
207 IntProgressWidget = DeprecatedClass(IntProgress, 'IntProgressWidget')
@@ -1,496 +1,496 b''
1 """Test suite for our zeromq-based message specification."""
1 """Test suite for our zeromq-based message specification."""
2
2
3 # Copyright (c) IPython Development Team.
3 # Copyright (c) IPython Development Team.
4 # Distributed under the terms of the Modified BSD License.
4 # Distributed under the terms of the Modified BSD License.
5
5
6 import re
6 import re
7 import sys
7 import sys
8 from distutils.version import LooseVersion as V
8 from distutils.version import LooseVersion as V
9 try:
9 try:
10 from queue import Empty # Py 3
10 from queue import Empty # Py 3
11 except ImportError:
11 except ImportError:
12 from Queue import Empty # Py 2
12 from Queue import Empty # Py 2
13
13
14 import nose.tools as nt
14 import nose.tools as nt
15
15
16 from IPython.utils.traitlets import (
16 from IPython.utils.traitlets import (
17 HasTraits, TraitError, Bool, Unicode, Dict, Integer, List, Enum,
17 HasTraits, TraitError, Bool, Unicode, Dict, Integer, List, Enum,
18 )
18 )
19 from IPython.utils.py3compat import string_types, iteritems
19 from IPython.utils.py3compat import string_types, iteritems
20
20
21 from .utils import TIMEOUT, start_global_kernel, flush_channels, execute
21 from .utils import TIMEOUT, start_global_kernel, flush_channels, execute
22
22
23 #-----------------------------------------------------------------------------
23 #-----------------------------------------------------------------------------
24 # Globals
24 # Globals
25 #-----------------------------------------------------------------------------
25 #-----------------------------------------------------------------------------
26 KC = None
26 KC = None
27
27
28 def setup():
28 def setup():
29 global KC
29 global KC
30 KC = start_global_kernel()
30 KC = start_global_kernel()
31
31
32 #-----------------------------------------------------------------------------
32 #-----------------------------------------------------------------------------
33 # Message Spec References
33 # Message Spec References
34 #-----------------------------------------------------------------------------
34 #-----------------------------------------------------------------------------
35
35
36 class Reference(HasTraits):
36 class Reference(HasTraits):
37
37
38 """
38 """
39 Base class for message spec specification testing.
39 Base class for message spec specification testing.
40
40
41 This class is the core of the message specification test. The
41 This class is the core of the message specification test. The
42 idea is that child classes implement trait attributes for each
42 idea is that child classes implement trait attributes for each
43 message keys, so that message keys can be tested against these
43 message keys, so that message keys can be tested against these
44 traits using :meth:`check` method.
44 traits using :meth:`check` method.
45
45
46 """
46 """
47
47
48 def check(self, d):
48 def check(self, d):
49 """validate a dict against our traits"""
49 """validate a dict against our traits"""
50 for key in self.trait_names():
50 for key in self.trait_names():
51 nt.assert_in(key, d)
51 nt.assert_in(key, d)
52 # FIXME: always allow None, probably not a good idea
52 # FIXME: always allow None, probably not a good idea
53 if d[key] is None:
53 if d[key] is None:
54 continue
54 continue
55 try:
55 try:
56 setattr(self, key, d[key])
56 setattr(self, key, d[key])
57 except TraitError as e:
57 except TraitError as e:
58 assert False, str(e)
58 assert False, str(e)
59
59
60
60
61 class Version(Unicode):
61 class Version(Unicode):
62 def __init__(self, *args, **kwargs):
62 def __init__(self, *args, **kwargs):
63 self.min = kwargs.pop('min', None)
63 self.min = kwargs.pop('min', None)
64 self.max = kwargs.pop('max', None)
64 self.max = kwargs.pop('max', None)
65 kwargs['default_value'] = self.min
65 kwargs['default_value'] = self.min
66 super(Version, self).__init__(*args, **kwargs)
66 super(Version, self).__init__(*args, **kwargs)
67
67
68 def validate(self, obj, value):
68 def validate(self, obj, value):
69 if self.min and V(value) < V(self.min):
69 if self.min and V(value) < V(self.min):
70 raise TraitError("bad version: %s < %s" % (value, self.min))
70 raise TraitError("bad version: %s < %s" % (value, self.min))
71 if self.max and (V(value) > V(self.max)):
71 if self.max and (V(value) > V(self.max)):
72 raise TraitError("bad version: %s > %s" % (value, self.max))
72 raise TraitError("bad version: %s > %s" % (value, self.max))
73
73
74
74
75 class RMessage(Reference):
75 class RMessage(Reference):
76 msg_id = Unicode()
76 msg_id = Unicode()
77 msg_type = Unicode()
77 msg_type = Unicode()
78 header = Dict()
78 header = Dict()
79 parent_header = Dict()
79 parent_header = Dict()
80 content = Dict()
80 content = Dict()
81
81
82 def check(self, d):
82 def check(self, d):
83 super(RMessage, self).check(d)
83 super(RMessage, self).check(d)
84 RHeader().check(self.header)
84 RHeader().check(self.header)
85 if self.parent_header:
85 if self.parent_header:
86 RHeader().check(self.parent_header)
86 RHeader().check(self.parent_header)
87
87
88 class RHeader(Reference):
88 class RHeader(Reference):
89 msg_id = Unicode()
89 msg_id = Unicode()
90 msg_type = Unicode()
90 msg_type = Unicode()
91 session = Unicode()
91 session = Unicode()
92 username = Unicode()
92 username = Unicode()
93 version = Version(min='5.0')
93 version = Version(min='5.0')
94
94
95 mime_pat = re.compile(r'^[\w\-\+\.]+/[\w\-\+\.]+$')
95 mime_pat = re.compile(r'^[\w\-\+\.]+/[\w\-\+\.]+$')
96
96
97 class MimeBundle(Reference):
97 class MimeBundle(Reference):
98 metadata = Dict()
98 metadata = Dict()
99 data = Dict()
99 data = Dict()
100 def _data_changed(self, name, old, new):
100 def _data_changed(self, name, old, new):
101 for k,v in iteritems(new):
101 for k,v in iteritems(new):
102 assert mime_pat.match(k)
102 assert mime_pat.match(k)
103 nt.assert_is_instance(v, string_types)
103 nt.assert_is_instance(v, string_types)
104
104
105 # shell replies
105 # shell replies
106
106
107 class ExecuteReply(Reference):
107 class ExecuteReply(Reference):
108 execution_count = Integer()
108 execution_count = Integer()
109 status = Enum((u'ok', u'error'))
109 status = Enum((u'ok', u'error'), default_value=u'ok')
110
110
111 def check(self, d):
111 def check(self, d):
112 Reference.check(self, d)
112 Reference.check(self, d)
113 if d['status'] == 'ok':
113 if d['status'] == 'ok':
114 ExecuteReplyOkay().check(d)
114 ExecuteReplyOkay().check(d)
115 elif d['status'] == 'error':
115 elif d['status'] == 'error':
116 ExecuteReplyError().check(d)
116 ExecuteReplyError().check(d)
117
117
118
118
119 class ExecuteReplyOkay(Reference):
119 class ExecuteReplyOkay(Reference):
120 payload = List(Dict)
120 payload = List(Dict)
121 user_expressions = Dict()
121 user_expressions = Dict()
122
122
123
123
124 class ExecuteReplyError(Reference):
124 class ExecuteReplyError(Reference):
125 ename = Unicode()
125 ename = Unicode()
126 evalue = Unicode()
126 evalue = Unicode()
127 traceback = List(Unicode)
127 traceback = List(Unicode)
128
128
129
129
130 class InspectReply(MimeBundle):
130 class InspectReply(MimeBundle):
131 found = Bool()
131 found = Bool()
132
132
133
133
134 class ArgSpec(Reference):
134 class ArgSpec(Reference):
135 args = List(Unicode)
135 args = List(Unicode)
136 varargs = Unicode()
136 varargs = Unicode()
137 varkw = Unicode()
137 varkw = Unicode()
138 defaults = List()
138 defaults = List()
139
139
140
140
141 class Status(Reference):
141 class Status(Reference):
142 execution_state = Enum((u'busy', u'idle', u'starting'))
142 execution_state = Enum((u'busy', u'idle', u'starting'), default_value=u'busy')
143
143
144
144
145 class CompleteReply(Reference):
145 class CompleteReply(Reference):
146 matches = List(Unicode)
146 matches = List(Unicode)
147 cursor_start = Integer()
147 cursor_start = Integer()
148 cursor_end = Integer()
148 cursor_end = Integer()
149 status = Unicode()
149 status = Unicode()
150
150
151 class LanguageInfo(Reference):
151 class LanguageInfo(Reference):
152 name = Unicode('python')
152 name = Unicode('python')
153 version = Unicode(sys.version.split()[0])
153 version = Unicode(sys.version.split()[0])
154
154
155 class KernelInfoReply(Reference):
155 class KernelInfoReply(Reference):
156 protocol_version = Version(min='5.0')
156 protocol_version = Version(min='5.0')
157 implementation = Unicode('ipython')
157 implementation = Unicode('ipython')
158 implementation_version = Version(min='2.1')
158 implementation_version = Version(min='2.1')
159 language_info = Dict()
159 language_info = Dict()
160 banner = Unicode()
160 banner = Unicode()
161
161
162 def check(self, d):
162 def check(self, d):
163 Reference.check(self, d)
163 Reference.check(self, d)
164 LanguageInfo().check(d['language_info'])
164 LanguageInfo().check(d['language_info'])
165
165
166
166
167 class IsCompleteReply(Reference):
167 class IsCompleteReply(Reference):
168 status = Enum((u'complete', u'incomplete', u'invalid', u'unknown'))
168 status = Enum((u'complete', u'incomplete', u'invalid', u'unknown'), default_value=u'complete')
169
169
170 def check(self, d):
170 def check(self, d):
171 Reference.check(self, d)
171 Reference.check(self, d)
172 if d['status'] == 'incomplete':
172 if d['status'] == 'incomplete':
173 IsCompleteReplyIncomplete().check(d)
173 IsCompleteReplyIncomplete().check(d)
174
174
175 class IsCompleteReplyIncomplete(Reference):
175 class IsCompleteReplyIncomplete(Reference):
176 indent = Unicode()
176 indent = Unicode()
177
177
178
178
179 # IOPub messages
179 # IOPub messages
180
180
181 class ExecuteInput(Reference):
181 class ExecuteInput(Reference):
182 code = Unicode()
182 code = Unicode()
183 execution_count = Integer()
183 execution_count = Integer()
184
184
185
185
186 Error = ExecuteReplyError
186 Error = ExecuteReplyError
187
187
188
188
189 class Stream(Reference):
189 class Stream(Reference):
190 name = Enum((u'stdout', u'stderr'))
190 name = Enum((u'stdout', u'stderr'), default_value=u'stdout')
191 text = Unicode()
191 text = Unicode()
192
192
193
193
194 class DisplayData(MimeBundle):
194 class DisplayData(MimeBundle):
195 pass
195 pass
196
196
197
197
198 class ExecuteResult(MimeBundle):
198 class ExecuteResult(MimeBundle):
199 execution_count = Integer()
199 execution_count = Integer()
200
200
201 class HistoryReply(Reference):
201 class HistoryReply(Reference):
202 history = List(List())
202 history = List(List())
203
203
204
204
205 references = {
205 references = {
206 'execute_reply' : ExecuteReply(),
206 'execute_reply' : ExecuteReply(),
207 'inspect_reply' : InspectReply(),
207 'inspect_reply' : InspectReply(),
208 'status' : Status(),
208 'status' : Status(),
209 'complete_reply' : CompleteReply(),
209 'complete_reply' : CompleteReply(),
210 'kernel_info_reply': KernelInfoReply(),
210 'kernel_info_reply': KernelInfoReply(),
211 'is_complete_reply': IsCompleteReply(),
211 'is_complete_reply': IsCompleteReply(),
212 'execute_input' : ExecuteInput(),
212 'execute_input' : ExecuteInput(),
213 'execute_result' : ExecuteResult(),
213 'execute_result' : ExecuteResult(),
214 'history_reply' : HistoryReply(),
214 'history_reply' : HistoryReply(),
215 'error' : Error(),
215 'error' : Error(),
216 'stream' : Stream(),
216 'stream' : Stream(),
217 'display_data' : DisplayData(),
217 'display_data' : DisplayData(),
218 'header' : RHeader(),
218 'header' : RHeader(),
219 }
219 }
220 """
220 """
221 Specifications of `content` part of the reply messages.
221 Specifications of `content` part of the reply messages.
222 """
222 """
223
223
224
224
225 def validate_message(msg, msg_type=None, parent=None):
225 def validate_message(msg, msg_type=None, parent=None):
226 """validate a message
226 """validate a message
227
227
228 This is a generator, and must be iterated through to actually
228 This is a generator, and must be iterated through to actually
229 trigger each test.
229 trigger each test.
230
230
231 If msg_type and/or parent are given, the msg_type and/or parent msg_id
231 If msg_type and/or parent are given, the msg_type and/or parent msg_id
232 are compared with the given values.
232 are compared with the given values.
233 """
233 """
234 RMessage().check(msg)
234 RMessage().check(msg)
235 if msg_type:
235 if msg_type:
236 nt.assert_equal(msg['msg_type'], msg_type)
236 nt.assert_equal(msg['msg_type'], msg_type)
237 if parent:
237 if parent:
238 nt.assert_equal(msg['parent_header']['msg_id'], parent)
238 nt.assert_equal(msg['parent_header']['msg_id'], parent)
239 content = msg['content']
239 content = msg['content']
240 ref = references[msg['msg_type']]
240 ref = references[msg['msg_type']]
241 ref.check(content)
241 ref.check(content)
242
242
243
243
244 #-----------------------------------------------------------------------------
244 #-----------------------------------------------------------------------------
245 # Tests
245 # Tests
246 #-----------------------------------------------------------------------------
246 #-----------------------------------------------------------------------------
247
247
248 # Shell channel
248 # Shell channel
249
249
250 def test_execute():
250 def test_execute():
251 flush_channels()
251 flush_channels()
252
252
253 msg_id = KC.execute(code='x=1')
253 msg_id = KC.execute(code='x=1')
254 reply = KC.get_shell_msg(timeout=TIMEOUT)
254 reply = KC.get_shell_msg(timeout=TIMEOUT)
255 validate_message(reply, 'execute_reply', msg_id)
255 validate_message(reply, 'execute_reply', msg_id)
256
256
257
257
258 def test_execute_silent():
258 def test_execute_silent():
259 flush_channels()
259 flush_channels()
260 msg_id, reply = execute(code='x=1', silent=True)
260 msg_id, reply = execute(code='x=1', silent=True)
261
261
262 # flush status=idle
262 # flush status=idle
263 status = KC.iopub_channel.get_msg(timeout=TIMEOUT)
263 status = KC.iopub_channel.get_msg(timeout=TIMEOUT)
264 validate_message(status, 'status', msg_id)
264 validate_message(status, 'status', msg_id)
265 nt.assert_equal(status['content']['execution_state'], 'idle')
265 nt.assert_equal(status['content']['execution_state'], 'idle')
266
266
267 nt.assert_raises(Empty, KC.iopub_channel.get_msg, timeout=0.1)
267 nt.assert_raises(Empty, KC.iopub_channel.get_msg, timeout=0.1)
268 count = reply['execution_count']
268 count = reply['execution_count']
269
269
270 msg_id, reply = execute(code='x=2', silent=True)
270 msg_id, reply = execute(code='x=2', silent=True)
271
271
272 # flush status=idle
272 # flush status=idle
273 status = KC.iopub_channel.get_msg(timeout=TIMEOUT)
273 status = KC.iopub_channel.get_msg(timeout=TIMEOUT)
274 validate_message(status, 'status', msg_id)
274 validate_message(status, 'status', msg_id)
275 nt.assert_equal(status['content']['execution_state'], 'idle')
275 nt.assert_equal(status['content']['execution_state'], 'idle')
276
276
277 nt.assert_raises(Empty, KC.iopub_channel.get_msg, timeout=0.1)
277 nt.assert_raises(Empty, KC.iopub_channel.get_msg, timeout=0.1)
278 count_2 = reply['execution_count']
278 count_2 = reply['execution_count']
279 nt.assert_equal(count_2, count)
279 nt.assert_equal(count_2, count)
280
280
281
281
282 def test_execute_error():
282 def test_execute_error():
283 flush_channels()
283 flush_channels()
284
284
285 msg_id, reply = execute(code='1/0')
285 msg_id, reply = execute(code='1/0')
286 nt.assert_equal(reply['status'], 'error')
286 nt.assert_equal(reply['status'], 'error')
287 nt.assert_equal(reply['ename'], 'ZeroDivisionError')
287 nt.assert_equal(reply['ename'], 'ZeroDivisionError')
288
288
289 error = KC.iopub_channel.get_msg(timeout=TIMEOUT)
289 error = KC.iopub_channel.get_msg(timeout=TIMEOUT)
290 validate_message(error, 'error', msg_id)
290 validate_message(error, 'error', msg_id)
291
291
292
292
293 def test_execute_inc():
293 def test_execute_inc():
294 """execute request should increment execution_count"""
294 """execute request should increment execution_count"""
295 flush_channels()
295 flush_channels()
296
296
297 msg_id, reply = execute(code='x=1')
297 msg_id, reply = execute(code='x=1')
298 count = reply['execution_count']
298 count = reply['execution_count']
299
299
300 flush_channels()
300 flush_channels()
301
301
302 msg_id, reply = execute(code='x=2')
302 msg_id, reply = execute(code='x=2')
303 count_2 = reply['execution_count']
303 count_2 = reply['execution_count']
304 nt.assert_equal(count_2, count+1)
304 nt.assert_equal(count_2, count+1)
305
305
306 def test_execute_stop_on_error():
306 def test_execute_stop_on_error():
307 """execute request should not abort execution queue with stop_on_error False"""
307 """execute request should not abort execution queue with stop_on_error False"""
308 flush_channels()
308 flush_channels()
309
309
310 fail = '\n'.join([
310 fail = '\n'.join([
311 # sleep to ensure subsequent message is waiting in the queue to be aborted
311 # sleep to ensure subsequent message is waiting in the queue to be aborted
312 'import time',
312 'import time',
313 'time.sleep(0.5)',
313 'time.sleep(0.5)',
314 'raise ValueError',
314 'raise ValueError',
315 ])
315 ])
316 KC.execute(code=fail)
316 KC.execute(code=fail)
317 msg_id = KC.execute(code='print("Hello")')
317 msg_id = KC.execute(code='print("Hello")')
318 KC.get_shell_msg(timeout=TIMEOUT)
318 KC.get_shell_msg(timeout=TIMEOUT)
319 reply = KC.get_shell_msg(timeout=TIMEOUT)
319 reply = KC.get_shell_msg(timeout=TIMEOUT)
320 nt.assert_equal(reply['content']['status'], 'aborted')
320 nt.assert_equal(reply['content']['status'], 'aborted')
321
321
322 flush_channels()
322 flush_channels()
323
323
324 KC.execute(code=fail, stop_on_error=False)
324 KC.execute(code=fail, stop_on_error=False)
325 msg_id = KC.execute(code='print("Hello")')
325 msg_id = KC.execute(code='print("Hello")')
326 KC.get_shell_msg(timeout=TIMEOUT)
326 KC.get_shell_msg(timeout=TIMEOUT)
327 reply = KC.get_shell_msg(timeout=TIMEOUT)
327 reply = KC.get_shell_msg(timeout=TIMEOUT)
328 nt.assert_equal(reply['content']['status'], 'ok')
328 nt.assert_equal(reply['content']['status'], 'ok')
329
329
330
330
331 def test_user_expressions():
331 def test_user_expressions():
332 flush_channels()
332 flush_channels()
333
333
334 msg_id, reply = execute(code='x=1', user_expressions=dict(foo='x+1'))
334 msg_id, reply = execute(code='x=1', user_expressions=dict(foo='x+1'))
335 user_expressions = reply['user_expressions']
335 user_expressions = reply['user_expressions']
336 nt.assert_equal(user_expressions, {u'foo': {
336 nt.assert_equal(user_expressions, {u'foo': {
337 u'status': u'ok',
337 u'status': u'ok',
338 u'data': {u'text/plain': u'2'},
338 u'data': {u'text/plain': u'2'},
339 u'metadata': {},
339 u'metadata': {},
340 }})
340 }})
341
341
342
342
343 def test_user_expressions_fail():
343 def test_user_expressions_fail():
344 flush_channels()
344 flush_channels()
345
345
346 msg_id, reply = execute(code='x=0', user_expressions=dict(foo='nosuchname'))
346 msg_id, reply = execute(code='x=0', user_expressions=dict(foo='nosuchname'))
347 user_expressions = reply['user_expressions']
347 user_expressions = reply['user_expressions']
348 foo = user_expressions['foo']
348 foo = user_expressions['foo']
349 nt.assert_equal(foo['status'], 'error')
349 nt.assert_equal(foo['status'], 'error')
350 nt.assert_equal(foo['ename'], 'NameError')
350 nt.assert_equal(foo['ename'], 'NameError')
351
351
352
352
353 def test_oinfo():
353 def test_oinfo():
354 flush_channels()
354 flush_channels()
355
355
356 msg_id = KC.inspect('a')
356 msg_id = KC.inspect('a')
357 reply = KC.get_shell_msg(timeout=TIMEOUT)
357 reply = KC.get_shell_msg(timeout=TIMEOUT)
358 validate_message(reply, 'inspect_reply', msg_id)
358 validate_message(reply, 'inspect_reply', msg_id)
359
359
360
360
361 def test_oinfo_found():
361 def test_oinfo_found():
362 flush_channels()
362 flush_channels()
363
363
364 msg_id, reply = execute(code='a=5')
364 msg_id, reply = execute(code='a=5')
365
365
366 msg_id = KC.inspect('a')
366 msg_id = KC.inspect('a')
367 reply = KC.get_shell_msg(timeout=TIMEOUT)
367 reply = KC.get_shell_msg(timeout=TIMEOUT)
368 validate_message(reply, 'inspect_reply', msg_id)
368 validate_message(reply, 'inspect_reply', msg_id)
369 content = reply['content']
369 content = reply['content']
370 assert content['found']
370 assert content['found']
371 text = content['data']['text/plain']
371 text = content['data']['text/plain']
372 nt.assert_in('Type:', text)
372 nt.assert_in('Type:', text)
373 nt.assert_in('Docstring:', text)
373 nt.assert_in('Docstring:', text)
374
374
375
375
376 def test_oinfo_detail():
376 def test_oinfo_detail():
377 flush_channels()
377 flush_channels()
378
378
379 msg_id, reply = execute(code='ip=get_ipython()')
379 msg_id, reply = execute(code='ip=get_ipython()')
380
380
381 msg_id = KC.inspect('ip.object_inspect', cursor_pos=10, detail_level=1)
381 msg_id = KC.inspect('ip.object_inspect', cursor_pos=10, detail_level=1)
382 reply = KC.get_shell_msg(timeout=TIMEOUT)
382 reply = KC.get_shell_msg(timeout=TIMEOUT)
383 validate_message(reply, 'inspect_reply', msg_id)
383 validate_message(reply, 'inspect_reply', msg_id)
384 content = reply['content']
384 content = reply['content']
385 assert content['found']
385 assert content['found']
386 text = content['data']['text/plain']
386 text = content['data']['text/plain']
387 nt.assert_in('Signature:', text)
387 nt.assert_in('Signature:', text)
388 nt.assert_in('Source:', text)
388 nt.assert_in('Source:', text)
389
389
390
390
391 def test_oinfo_not_found():
391 def test_oinfo_not_found():
392 flush_channels()
392 flush_channels()
393
393
394 msg_id = KC.inspect('dne')
394 msg_id = KC.inspect('dne')
395 reply = KC.get_shell_msg(timeout=TIMEOUT)
395 reply = KC.get_shell_msg(timeout=TIMEOUT)
396 validate_message(reply, 'inspect_reply', msg_id)
396 validate_message(reply, 'inspect_reply', msg_id)
397 content = reply['content']
397 content = reply['content']
398 nt.assert_false(content['found'])
398 nt.assert_false(content['found'])
399
399
400
400
401 def test_complete():
401 def test_complete():
402 flush_channels()
402 flush_channels()
403
403
404 msg_id, reply = execute(code="alpha = albert = 5")
404 msg_id, reply = execute(code="alpha = albert = 5")
405
405
406 msg_id = KC.complete('al', 2)
406 msg_id = KC.complete('al', 2)
407 reply = KC.get_shell_msg(timeout=TIMEOUT)
407 reply = KC.get_shell_msg(timeout=TIMEOUT)
408 validate_message(reply, 'complete_reply', msg_id)
408 validate_message(reply, 'complete_reply', msg_id)
409 matches = reply['content']['matches']
409 matches = reply['content']['matches']
410 for name in ('alpha', 'albert'):
410 for name in ('alpha', 'albert'):
411 nt.assert_in(name, matches)
411 nt.assert_in(name, matches)
412
412
413
413
414 def test_kernel_info_request():
414 def test_kernel_info_request():
415 flush_channels()
415 flush_channels()
416
416
417 msg_id = KC.kernel_info()
417 msg_id = KC.kernel_info()
418 reply = KC.get_shell_msg(timeout=TIMEOUT)
418 reply = KC.get_shell_msg(timeout=TIMEOUT)
419 validate_message(reply, 'kernel_info_reply', msg_id)
419 validate_message(reply, 'kernel_info_reply', msg_id)
420
420
421
421
422 def test_single_payload():
422 def test_single_payload():
423 flush_channels()
423 flush_channels()
424 msg_id, reply = execute(code="for i in range(3):\n"+
424 msg_id, reply = execute(code="for i in range(3):\n"+
425 " x=range?\n")
425 " x=range?\n")
426 payload = reply['payload']
426 payload = reply['payload']
427 next_input_pls = [pl for pl in payload if pl["source"] == "set_next_input"]
427 next_input_pls = [pl for pl in payload if pl["source"] == "set_next_input"]
428 nt.assert_equal(len(next_input_pls), 1)
428 nt.assert_equal(len(next_input_pls), 1)
429
429
430 def test_is_complete():
430 def test_is_complete():
431 flush_channels()
431 flush_channels()
432
432
433 msg_id = KC.is_complete("a = 1")
433 msg_id = KC.is_complete("a = 1")
434 reply = KC.get_shell_msg(timeout=TIMEOUT)
434 reply = KC.get_shell_msg(timeout=TIMEOUT)
435 validate_message(reply, 'is_complete_reply', msg_id)
435 validate_message(reply, 'is_complete_reply', msg_id)
436
436
437 def test_history_range():
437 def test_history_range():
438 flush_channels()
438 flush_channels()
439
439
440 msg_id_exec = KC.execute(code='x=1', store_history = True)
440 msg_id_exec = KC.execute(code='x=1', store_history = True)
441 reply_exec = KC.get_shell_msg(timeout=TIMEOUT)
441 reply_exec = KC.get_shell_msg(timeout=TIMEOUT)
442
442
443 msg_id = KC.history(hist_access_type = 'range', raw = True, output = True, start = 1, stop = 2, session = 0)
443 msg_id = KC.history(hist_access_type = 'range', raw = True, output = True, start = 1, stop = 2, session = 0)
444 reply = KC.get_shell_msg(timeout=TIMEOUT)
444 reply = KC.get_shell_msg(timeout=TIMEOUT)
445 validate_message(reply, 'history_reply', msg_id)
445 validate_message(reply, 'history_reply', msg_id)
446 content = reply['content']
446 content = reply['content']
447 nt.assert_equal(len(content['history']), 1)
447 nt.assert_equal(len(content['history']), 1)
448
448
449 def test_history_tail():
449 def test_history_tail():
450 flush_channels()
450 flush_channels()
451
451
452 msg_id_exec = KC.execute(code='x=1', store_history = True)
452 msg_id_exec = KC.execute(code='x=1', store_history = True)
453 reply_exec = KC.get_shell_msg(timeout=TIMEOUT)
453 reply_exec = KC.get_shell_msg(timeout=TIMEOUT)
454
454
455 msg_id = KC.history(hist_access_type = 'tail', raw = True, output = True, n = 1, session = 0)
455 msg_id = KC.history(hist_access_type = 'tail', raw = True, output = True, n = 1, session = 0)
456 reply = KC.get_shell_msg(timeout=TIMEOUT)
456 reply = KC.get_shell_msg(timeout=TIMEOUT)
457 validate_message(reply, 'history_reply', msg_id)
457 validate_message(reply, 'history_reply', msg_id)
458 content = reply['content']
458 content = reply['content']
459 nt.assert_equal(len(content['history']), 1)
459 nt.assert_equal(len(content['history']), 1)
460
460
461 def test_history_search():
461 def test_history_search():
462 flush_channels()
462 flush_channels()
463
463
464 msg_id_exec = KC.execute(code='x=1', store_history = True)
464 msg_id_exec = KC.execute(code='x=1', store_history = True)
465 reply_exec = KC.get_shell_msg(timeout=TIMEOUT)
465 reply_exec = KC.get_shell_msg(timeout=TIMEOUT)
466
466
467 msg_id = KC.history(hist_access_type = 'search', raw = True, output = True, n = 1, pattern = '*', session = 0)
467 msg_id = KC.history(hist_access_type = 'search', raw = True, output = True, n = 1, pattern = '*', session = 0)
468 reply = KC.get_shell_msg(timeout=TIMEOUT)
468 reply = KC.get_shell_msg(timeout=TIMEOUT)
469 validate_message(reply, 'history_reply', msg_id)
469 validate_message(reply, 'history_reply', msg_id)
470 content = reply['content']
470 content = reply['content']
471 nt.assert_equal(len(content['history']), 1)
471 nt.assert_equal(len(content['history']), 1)
472
472
473 # IOPub channel
473 # IOPub channel
474
474
475
475
476 def test_stream():
476 def test_stream():
477 flush_channels()
477 flush_channels()
478
478
479 msg_id, reply = execute("print('hi')")
479 msg_id, reply = execute("print('hi')")
480
480
481 stdout = KC.iopub_channel.get_msg(timeout=TIMEOUT)
481 stdout = KC.iopub_channel.get_msg(timeout=TIMEOUT)
482 validate_message(stdout, 'stream', msg_id)
482 validate_message(stdout, 'stream', msg_id)
483 content = stdout['content']
483 content = stdout['content']
484 nt.assert_equal(content['text'], u'hi\n')
484 nt.assert_equal(content['text'], u'hi\n')
485
485
486
486
487 def test_display_data():
487 def test_display_data():
488 flush_channels()
488 flush_channels()
489
489
490 msg_id, reply = execute("from IPython.core.display import display; display(1)")
490 msg_id, reply = execute("from IPython.core.display import display; display(1)")
491
491
492 display = KC.iopub_channel.get_msg(timeout=TIMEOUT)
492 display = KC.iopub_channel.get_msg(timeout=TIMEOUT)
493 validate_message(display, 'display_data', parent=msg_id)
493 validate_message(display, 'display_data', parent=msg_id)
494 data = display['content']['data']
494 data = display['content']['data']
495 nt.assert_equal(data['text/plain'], u'1')
495 nt.assert_equal(data['text/plain'], u'1')
496
496
@@ -1,849 +1,849 b''
1 """The Python scheduler for rich scheduling.
1 """The Python scheduler for rich scheduling.
2
2
3 The Pure ZMQ scheduler does not allow routing schemes other than LRU,
3 The Pure ZMQ scheduler does not allow routing schemes other than LRU,
4 nor does it check msg_id DAG dependencies. For those, a slightly slower
4 nor does it check msg_id DAG dependencies. For those, a slightly slower
5 Python Scheduler exists.
5 Python Scheduler exists.
6 """
6 """
7
7
8 # Copyright (c) IPython Development Team.
8 # Copyright (c) IPython Development Team.
9 # Distributed under the terms of the Modified BSD License.
9 # Distributed under the terms of the Modified BSD License.
10
10
11 import logging
11 import logging
12 import sys
12 import sys
13 import time
13 import time
14
14
15 from collections import deque
15 from collections import deque
16 from datetime import datetime
16 from datetime import datetime
17 from random import randint, random
17 from random import randint, random
18 from types import FunctionType
18 from types import FunctionType
19
19
20 try:
20 try:
21 import numpy
21 import numpy
22 except ImportError:
22 except ImportError:
23 numpy = None
23 numpy = None
24
24
25 import zmq
25 import zmq
26 from zmq.eventloop import ioloop, zmqstream
26 from zmq.eventloop import ioloop, zmqstream
27
27
28 # local imports
28 # local imports
29 from IPython.external.decorator import decorator
29 from IPython.external.decorator import decorator
30 from IPython.config.application import Application
30 from IPython.config.application import Application
31 from IPython.config.loader import Config
31 from IPython.config.loader import Config
32 from IPython.utils.traitlets import Instance, Dict, List, Set, Integer, Enum, CBytes
32 from IPython.utils.traitlets import Instance, Dict, List, Set, Integer, Enum, CBytes
33 from IPython.utils.py3compat import cast_bytes
33 from IPython.utils.py3compat import cast_bytes
34
34
35 from IPython.parallel import error, util
35 from IPython.parallel import error, util
36 from IPython.parallel.factory import SessionFactory
36 from IPython.parallel.factory import SessionFactory
37 from IPython.parallel.util import connect_logger, local_logger
37 from IPython.parallel.util import connect_logger, local_logger
38
38
39 from .dependency import Dependency
39 from .dependency import Dependency
40
40
41 @decorator
41 @decorator
42 def logged(f,self,*args,**kwargs):
42 def logged(f,self,*args,**kwargs):
43 # print ("#--------------------")
43 # print ("#--------------------")
44 self.log.debug("scheduler::%s(*%s,**%s)", f.__name__, args, kwargs)
44 self.log.debug("scheduler::%s(*%s,**%s)", f.__name__, args, kwargs)
45 # print ("#--")
45 # print ("#--")
46 return f(self,*args, **kwargs)
46 return f(self,*args, **kwargs)
47
47
48 #----------------------------------------------------------------------
48 #----------------------------------------------------------------------
49 # Chooser functions
49 # Chooser functions
50 #----------------------------------------------------------------------
50 #----------------------------------------------------------------------
51
51
52 def plainrandom(loads):
52 def plainrandom(loads):
53 """Plain random pick."""
53 """Plain random pick."""
54 n = len(loads)
54 n = len(loads)
55 return randint(0,n-1)
55 return randint(0,n-1)
56
56
57 def lru(loads):
57 def lru(loads):
58 """Always pick the front of the line.
58 """Always pick the front of the line.
59
59
60 The content of `loads` is ignored.
60 The content of `loads` is ignored.
61
61
62 Assumes LRU ordering of loads, with oldest first.
62 Assumes LRU ordering of loads, with oldest first.
63 """
63 """
64 return 0
64 return 0
65
65
66 def twobin(loads):
66 def twobin(loads):
67 """Pick two at random, use the LRU of the two.
67 """Pick two at random, use the LRU of the two.
68
68
69 The content of loads is ignored.
69 The content of loads is ignored.
70
70
71 Assumes LRU ordering of loads, with oldest first.
71 Assumes LRU ordering of loads, with oldest first.
72 """
72 """
73 n = len(loads)
73 n = len(loads)
74 a = randint(0,n-1)
74 a = randint(0,n-1)
75 b = randint(0,n-1)
75 b = randint(0,n-1)
76 return min(a,b)
76 return min(a,b)
77
77
78 def weighted(loads):
78 def weighted(loads):
79 """Pick two at random using inverse load as weight.
79 """Pick two at random using inverse load as weight.
80
80
81 Return the less loaded of the two.
81 Return the less loaded of the two.
82 """
82 """
83 # weight 0 a million times more than 1:
83 # weight 0 a million times more than 1:
84 weights = 1./(1e-6+numpy.array(loads))
84 weights = 1./(1e-6+numpy.array(loads))
85 sums = weights.cumsum()
85 sums = weights.cumsum()
86 t = sums[-1]
86 t = sums[-1]
87 x = random()*t
87 x = random()*t
88 y = random()*t
88 y = random()*t
89 idx = 0
89 idx = 0
90 idy = 0
90 idy = 0
91 while sums[idx] < x:
91 while sums[idx] < x:
92 idx += 1
92 idx += 1
93 while sums[idy] < y:
93 while sums[idy] < y:
94 idy += 1
94 idy += 1
95 if weights[idy] > weights[idx]:
95 if weights[idy] > weights[idx]:
96 return idy
96 return idy
97 else:
97 else:
98 return idx
98 return idx
99
99
100 def leastload(loads):
100 def leastload(loads):
101 """Always choose the lowest load.
101 """Always choose the lowest load.
102
102
103 If the lowest load occurs more than once, the first
103 If the lowest load occurs more than once, the first
104 occurance will be used. If loads has LRU ordering, this means
104 occurance will be used. If loads has LRU ordering, this means
105 the LRU of those with the lowest load is chosen.
105 the LRU of those with the lowest load is chosen.
106 """
106 """
107 return loads.index(min(loads))
107 return loads.index(min(loads))
108
108
109 #---------------------------------------------------------------------
109 #---------------------------------------------------------------------
110 # Classes
110 # Classes
111 #---------------------------------------------------------------------
111 #---------------------------------------------------------------------
112
112
113
113
114 # store empty default dependency:
114 # store empty default dependency:
115 MET = Dependency([])
115 MET = Dependency([])
116
116
117
117
118 class Job(object):
118 class Job(object):
119 """Simple container for a job"""
119 """Simple container for a job"""
120 def __init__(self, msg_id, raw_msg, idents, msg, header, metadata,
120 def __init__(self, msg_id, raw_msg, idents, msg, header, metadata,
121 targets, after, follow, timeout):
121 targets, after, follow, timeout):
122 self.msg_id = msg_id
122 self.msg_id = msg_id
123 self.raw_msg = raw_msg
123 self.raw_msg = raw_msg
124 self.idents = idents
124 self.idents = idents
125 self.msg = msg
125 self.msg = msg
126 self.header = header
126 self.header = header
127 self.metadata = metadata
127 self.metadata = metadata
128 self.targets = targets
128 self.targets = targets
129 self.after = after
129 self.after = after
130 self.follow = follow
130 self.follow = follow
131 self.timeout = timeout
131 self.timeout = timeout
132
132
133 self.removed = False # used for lazy-delete from sorted queue
133 self.removed = False # used for lazy-delete from sorted queue
134 self.timestamp = time.time()
134 self.timestamp = time.time()
135 self.timeout_id = 0
135 self.timeout_id = 0
136 self.blacklist = set()
136 self.blacklist = set()
137
137
138 def __lt__(self, other):
138 def __lt__(self, other):
139 return self.timestamp < other.timestamp
139 return self.timestamp < other.timestamp
140
140
141 def __cmp__(self, other):
141 def __cmp__(self, other):
142 return cmp(self.timestamp, other.timestamp)
142 return cmp(self.timestamp, other.timestamp)
143
143
144 @property
144 @property
145 def dependents(self):
145 def dependents(self):
146 return self.follow.union(self.after)
146 return self.follow.union(self.after)
147
147
148
148
149 class TaskScheduler(SessionFactory):
149 class TaskScheduler(SessionFactory):
150 """Python TaskScheduler object.
150 """Python TaskScheduler object.
151
151
152 This is the simplest object that supports msg_id based
152 This is the simplest object that supports msg_id based
153 DAG dependencies. *Only* task msg_ids are checked, not
153 DAG dependencies. *Only* task msg_ids are checked, not
154 msg_ids of jobs submitted via the MUX queue.
154 msg_ids of jobs submitted via the MUX queue.
155
155
156 """
156 """
157
157
158 hwm = Integer(1, config=True,
158 hwm = Integer(1, config=True,
159 help="""specify the High Water Mark (HWM) for the downstream
159 help="""specify the High Water Mark (HWM) for the downstream
160 socket in the Task scheduler. This is the maximum number
160 socket in the Task scheduler. This is the maximum number
161 of allowed outstanding tasks on each engine.
161 of allowed outstanding tasks on each engine.
162
162
163 The default (1) means that only one task can be outstanding on each
163 The default (1) means that only one task can be outstanding on each
164 engine. Setting TaskScheduler.hwm=0 means there is no limit, and the
164 engine. Setting TaskScheduler.hwm=0 means there is no limit, and the
165 engines continue to be assigned tasks while they are working,
165 engines continue to be assigned tasks while they are working,
166 effectively hiding network latency behind computation, but can result
166 effectively hiding network latency behind computation, but can result
167 in an imbalance of work when submitting many heterogenous tasks all at
167 in an imbalance of work when submitting many heterogenous tasks all at
168 once. Any positive value greater than one is a compromise between the
168 once. Any positive value greater than one is a compromise between the
169 two.
169 two.
170
170
171 """
171 """
172 )
172 )
173 scheme_name = Enum(('leastload', 'pure', 'lru', 'plainrandom', 'weighted', 'twobin'),
173 scheme_name = Enum(('leastload', 'pure', 'lru', 'plainrandom', 'weighted', 'twobin'),
174 'leastload', config=True, allow_none=False,
174 'leastload', config=True,
175 help="""select the task scheduler scheme [default: Python LRU]
175 help="""select the task scheduler scheme [default: Python LRU]
176 Options are: 'pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'"""
176 Options are: 'pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'"""
177 )
177 )
178 def _scheme_name_changed(self, old, new):
178 def _scheme_name_changed(self, old, new):
179 self.log.debug("Using scheme %r"%new)
179 self.log.debug("Using scheme %r"%new)
180 self.scheme = globals()[new]
180 self.scheme = globals()[new]
181
181
182 # input arguments:
182 # input arguments:
183 scheme = Instance(FunctionType) # function for determining the destination
183 scheme = Instance(FunctionType) # function for determining the destination
184 def _scheme_default(self):
184 def _scheme_default(self):
185 return leastload
185 return leastload
186 client_stream = Instance(zmqstream.ZMQStream) # client-facing stream
186 client_stream = Instance(zmqstream.ZMQStream) # client-facing stream
187 engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream
187 engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream
188 notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream
188 notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream
189 mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream
189 mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream
190 query_stream = Instance(zmqstream.ZMQStream) # hub-facing DEALER stream
190 query_stream = Instance(zmqstream.ZMQStream) # hub-facing DEALER stream
191
191
192 # internals:
192 # internals:
193 queue = Instance(deque) # sorted list of Jobs
193 queue = Instance(deque) # sorted list of Jobs
194 def _queue_default(self):
194 def _queue_default(self):
195 return deque()
195 return deque()
196 queue_map = Dict() # dict by msg_id of Jobs (for O(1) access to the Queue)
196 queue_map = Dict() # dict by msg_id of Jobs (for O(1) access to the Queue)
197 graph = Dict() # dict by msg_id of [ msg_ids that depend on key ]
197 graph = Dict() # dict by msg_id of [ msg_ids that depend on key ]
198 retries = Dict() # dict by msg_id of retries remaining (non-neg ints)
198 retries = Dict() # dict by msg_id of retries remaining (non-neg ints)
199 # waiting = List() # list of msg_ids ready to run, but haven't due to HWM
199 # waiting = List() # list of msg_ids ready to run, but haven't due to HWM
200 pending = Dict() # dict by engine_uuid of submitted tasks
200 pending = Dict() # dict by engine_uuid of submitted tasks
201 completed = Dict() # dict by engine_uuid of completed tasks
201 completed = Dict() # dict by engine_uuid of completed tasks
202 failed = Dict() # dict by engine_uuid of failed tasks
202 failed = Dict() # dict by engine_uuid of failed tasks
203 destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
203 destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
204 clients = Dict() # dict by msg_id for who submitted the task
204 clients = Dict() # dict by msg_id for who submitted the task
205 targets = List() # list of target IDENTs
205 targets = List() # list of target IDENTs
206 loads = List() # list of engine loads
206 loads = List() # list of engine loads
207 # full = Set() # set of IDENTs that have HWM outstanding tasks
207 # full = Set() # set of IDENTs that have HWM outstanding tasks
208 all_completed = Set() # set of all completed tasks
208 all_completed = Set() # set of all completed tasks
209 all_failed = Set() # set of all failed tasks
209 all_failed = Set() # set of all failed tasks
210 all_done = Set() # set of all finished tasks=union(completed,failed)
210 all_done = Set() # set of all finished tasks=union(completed,failed)
211 all_ids = Set() # set of all submitted task IDs
211 all_ids = Set() # set of all submitted task IDs
212
212
213 ident = CBytes() # ZMQ identity. This should just be self.session.session
213 ident = CBytes() # ZMQ identity. This should just be self.session.session
214 # but ensure Bytes
214 # but ensure Bytes
215 def _ident_default(self):
215 def _ident_default(self):
216 return self.session.bsession
216 return self.session.bsession
217
217
218 def start(self):
218 def start(self):
219 self.query_stream.on_recv(self.dispatch_query_reply)
219 self.query_stream.on_recv(self.dispatch_query_reply)
220 self.session.send(self.query_stream, "connection_request", {})
220 self.session.send(self.query_stream, "connection_request", {})
221
221
222 self.engine_stream.on_recv(self.dispatch_result, copy=False)
222 self.engine_stream.on_recv(self.dispatch_result, copy=False)
223 self.client_stream.on_recv(self.dispatch_submission, copy=False)
223 self.client_stream.on_recv(self.dispatch_submission, copy=False)
224
224
225 self._notification_handlers = dict(
225 self._notification_handlers = dict(
226 registration_notification = self._register_engine,
226 registration_notification = self._register_engine,
227 unregistration_notification = self._unregister_engine
227 unregistration_notification = self._unregister_engine
228 )
228 )
229 self.notifier_stream.on_recv(self.dispatch_notification)
229 self.notifier_stream.on_recv(self.dispatch_notification)
230 self.log.info("Scheduler started [%s]" % self.scheme_name)
230 self.log.info("Scheduler started [%s]" % self.scheme_name)
231
231
232 def resume_receiving(self):
232 def resume_receiving(self):
233 """Resume accepting jobs."""
233 """Resume accepting jobs."""
234 self.client_stream.on_recv(self.dispatch_submission, copy=False)
234 self.client_stream.on_recv(self.dispatch_submission, copy=False)
235
235
236 def stop_receiving(self):
236 def stop_receiving(self):
237 """Stop accepting jobs while there are no engines.
237 """Stop accepting jobs while there are no engines.
238 Leave them in the ZMQ queue."""
238 Leave them in the ZMQ queue."""
239 self.client_stream.on_recv(None)
239 self.client_stream.on_recv(None)
240
240
241 #-----------------------------------------------------------------------
241 #-----------------------------------------------------------------------
242 # [Un]Registration Handling
242 # [Un]Registration Handling
243 #-----------------------------------------------------------------------
243 #-----------------------------------------------------------------------
244
244
245
245
246 def dispatch_query_reply(self, msg):
246 def dispatch_query_reply(self, msg):
247 """handle reply to our initial connection request"""
247 """handle reply to our initial connection request"""
248 try:
248 try:
249 idents,msg = self.session.feed_identities(msg)
249 idents,msg = self.session.feed_identities(msg)
250 except ValueError:
250 except ValueError:
251 self.log.warn("task::Invalid Message: %r",msg)
251 self.log.warn("task::Invalid Message: %r",msg)
252 return
252 return
253 try:
253 try:
254 msg = self.session.deserialize(msg)
254 msg = self.session.deserialize(msg)
255 except ValueError:
255 except ValueError:
256 self.log.warn("task::Unauthorized message from: %r"%idents)
256 self.log.warn("task::Unauthorized message from: %r"%idents)
257 return
257 return
258
258
259 content = msg['content']
259 content = msg['content']
260 for uuid in content.get('engines', {}).values():
260 for uuid in content.get('engines', {}).values():
261 self._register_engine(cast_bytes(uuid))
261 self._register_engine(cast_bytes(uuid))
262
262
263
263
264 @util.log_errors
264 @util.log_errors
265 def dispatch_notification(self, msg):
265 def dispatch_notification(self, msg):
266 """dispatch register/unregister events."""
266 """dispatch register/unregister events."""
267 try:
267 try:
268 idents,msg = self.session.feed_identities(msg)
268 idents,msg = self.session.feed_identities(msg)
269 except ValueError:
269 except ValueError:
270 self.log.warn("task::Invalid Message: %r",msg)
270 self.log.warn("task::Invalid Message: %r",msg)
271 return
271 return
272 try:
272 try:
273 msg = self.session.deserialize(msg)
273 msg = self.session.deserialize(msg)
274 except ValueError:
274 except ValueError:
275 self.log.warn("task::Unauthorized message from: %r"%idents)
275 self.log.warn("task::Unauthorized message from: %r"%idents)
276 return
276 return
277
277
278 msg_type = msg['header']['msg_type']
278 msg_type = msg['header']['msg_type']
279
279
280 handler = self._notification_handlers.get(msg_type, None)
280 handler = self._notification_handlers.get(msg_type, None)
281 if handler is None:
281 if handler is None:
282 self.log.error("Unhandled message type: %r"%msg_type)
282 self.log.error("Unhandled message type: %r"%msg_type)
283 else:
283 else:
284 try:
284 try:
285 handler(cast_bytes(msg['content']['uuid']))
285 handler(cast_bytes(msg['content']['uuid']))
286 except Exception:
286 except Exception:
287 self.log.error("task::Invalid notification msg: %r", msg, exc_info=True)
287 self.log.error("task::Invalid notification msg: %r", msg, exc_info=True)
288
288
289 def _register_engine(self, uid):
289 def _register_engine(self, uid):
290 """New engine with ident `uid` became available."""
290 """New engine with ident `uid` became available."""
291 # head of the line:
291 # head of the line:
292 self.targets.insert(0,uid)
292 self.targets.insert(0,uid)
293 self.loads.insert(0,0)
293 self.loads.insert(0,0)
294
294
295 # initialize sets
295 # initialize sets
296 self.completed[uid] = set()
296 self.completed[uid] = set()
297 self.failed[uid] = set()
297 self.failed[uid] = set()
298 self.pending[uid] = {}
298 self.pending[uid] = {}
299
299
300 # rescan the graph:
300 # rescan the graph:
301 self.update_graph(None)
301 self.update_graph(None)
302
302
303 def _unregister_engine(self, uid):
303 def _unregister_engine(self, uid):
304 """Existing engine with ident `uid` became unavailable."""
304 """Existing engine with ident `uid` became unavailable."""
305 if len(self.targets) == 1:
305 if len(self.targets) == 1:
306 # this was our only engine
306 # this was our only engine
307 pass
307 pass
308
308
309 # handle any potentially finished tasks:
309 # handle any potentially finished tasks:
310 self.engine_stream.flush()
310 self.engine_stream.flush()
311
311
312 # don't pop destinations, because they might be used later
312 # don't pop destinations, because they might be used later
313 # map(self.destinations.pop, self.completed.pop(uid))
313 # map(self.destinations.pop, self.completed.pop(uid))
314 # map(self.destinations.pop, self.failed.pop(uid))
314 # map(self.destinations.pop, self.failed.pop(uid))
315
315
316 # prevent this engine from receiving work
316 # prevent this engine from receiving work
317 idx = self.targets.index(uid)
317 idx = self.targets.index(uid)
318 self.targets.pop(idx)
318 self.targets.pop(idx)
319 self.loads.pop(idx)
319 self.loads.pop(idx)
320
320
321 # wait 5 seconds before cleaning up pending jobs, since the results might
321 # wait 5 seconds before cleaning up pending jobs, since the results might
322 # still be incoming
322 # still be incoming
323 if self.pending[uid]:
323 if self.pending[uid]:
324 self.loop.add_timeout(self.loop.time() + 5,
324 self.loop.add_timeout(self.loop.time() + 5,
325 lambda : self.handle_stranded_tasks(uid),
325 lambda : self.handle_stranded_tasks(uid),
326 )
326 )
327 else:
327 else:
328 self.completed.pop(uid)
328 self.completed.pop(uid)
329 self.failed.pop(uid)
329 self.failed.pop(uid)
330
330
331
331
332 def handle_stranded_tasks(self, engine):
332 def handle_stranded_tasks(self, engine):
333 """Deal with jobs resident in an engine that died."""
333 """Deal with jobs resident in an engine that died."""
334 lost = self.pending[engine]
334 lost = self.pending[engine]
335 for msg_id in lost.keys():
335 for msg_id in lost.keys():
336 if msg_id not in self.pending[engine]:
336 if msg_id not in self.pending[engine]:
337 # prevent double-handling of messages
337 # prevent double-handling of messages
338 continue
338 continue
339
339
340 raw_msg = lost[msg_id].raw_msg
340 raw_msg = lost[msg_id].raw_msg
341 idents,msg = self.session.feed_identities(raw_msg, copy=False)
341 idents,msg = self.session.feed_identities(raw_msg, copy=False)
342 parent = self.session.unpack(msg[1].bytes)
342 parent = self.session.unpack(msg[1].bytes)
343 idents = [engine, idents[0]]
343 idents = [engine, idents[0]]
344
344
345 # build fake error reply
345 # build fake error reply
346 try:
346 try:
347 raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id))
347 raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id))
348 except:
348 except:
349 content = error.wrap_exception()
349 content = error.wrap_exception()
350 # build fake metadata
350 # build fake metadata
351 md = dict(
351 md = dict(
352 status=u'error',
352 status=u'error',
353 engine=engine.decode('ascii'),
353 engine=engine.decode('ascii'),
354 date=datetime.now(),
354 date=datetime.now(),
355 )
355 )
356 msg = self.session.msg('apply_reply', content, parent=parent, metadata=md)
356 msg = self.session.msg('apply_reply', content, parent=parent, metadata=md)
357 raw_reply = list(map(zmq.Message, self.session.serialize(msg, ident=idents)))
357 raw_reply = list(map(zmq.Message, self.session.serialize(msg, ident=idents)))
358 # and dispatch it
358 # and dispatch it
359 self.dispatch_result(raw_reply)
359 self.dispatch_result(raw_reply)
360
360
361 # finally scrub completed/failed lists
361 # finally scrub completed/failed lists
362 self.completed.pop(engine)
362 self.completed.pop(engine)
363 self.failed.pop(engine)
363 self.failed.pop(engine)
364
364
365
365
366 #-----------------------------------------------------------------------
366 #-----------------------------------------------------------------------
367 # Job Submission
367 # Job Submission
368 #-----------------------------------------------------------------------
368 #-----------------------------------------------------------------------
369
369
370
370
371 @util.log_errors
371 @util.log_errors
372 def dispatch_submission(self, raw_msg):
372 def dispatch_submission(self, raw_msg):
373 """Dispatch job submission to appropriate handlers."""
373 """Dispatch job submission to appropriate handlers."""
374 # ensure targets up to date:
374 # ensure targets up to date:
375 self.notifier_stream.flush()
375 self.notifier_stream.flush()
376 try:
376 try:
377 idents, msg = self.session.feed_identities(raw_msg, copy=False)
377 idents, msg = self.session.feed_identities(raw_msg, copy=False)
378 msg = self.session.deserialize(msg, content=False, copy=False)
378 msg = self.session.deserialize(msg, content=False, copy=False)
379 except Exception:
379 except Exception:
380 self.log.error("task::Invaid task msg: %r"%raw_msg, exc_info=True)
380 self.log.error("task::Invaid task msg: %r"%raw_msg, exc_info=True)
381 return
381 return
382
382
383
383
384 # send to monitor
384 # send to monitor
385 self.mon_stream.send_multipart([b'intask']+raw_msg, copy=False)
385 self.mon_stream.send_multipart([b'intask']+raw_msg, copy=False)
386
386
387 header = msg['header']
387 header = msg['header']
388 md = msg['metadata']
388 md = msg['metadata']
389 msg_id = header['msg_id']
389 msg_id = header['msg_id']
390 self.all_ids.add(msg_id)
390 self.all_ids.add(msg_id)
391
391
392 # get targets as a set of bytes objects
392 # get targets as a set of bytes objects
393 # from a list of unicode objects
393 # from a list of unicode objects
394 targets = md.get('targets', [])
394 targets = md.get('targets', [])
395 targets = set(map(cast_bytes, targets))
395 targets = set(map(cast_bytes, targets))
396
396
397 retries = md.get('retries', 0)
397 retries = md.get('retries', 0)
398 self.retries[msg_id] = retries
398 self.retries[msg_id] = retries
399
399
400 # time dependencies
400 # time dependencies
401 after = md.get('after', None)
401 after = md.get('after', None)
402 if after:
402 if after:
403 after = Dependency(after)
403 after = Dependency(after)
404 if after.all:
404 if after.all:
405 if after.success:
405 if after.success:
406 after = Dependency(after.difference(self.all_completed),
406 after = Dependency(after.difference(self.all_completed),
407 success=after.success,
407 success=after.success,
408 failure=after.failure,
408 failure=after.failure,
409 all=after.all,
409 all=after.all,
410 )
410 )
411 if after.failure:
411 if after.failure:
412 after = Dependency(after.difference(self.all_failed),
412 after = Dependency(after.difference(self.all_failed),
413 success=after.success,
413 success=after.success,
414 failure=after.failure,
414 failure=after.failure,
415 all=after.all,
415 all=after.all,
416 )
416 )
417 if after.check(self.all_completed, self.all_failed):
417 if after.check(self.all_completed, self.all_failed):
418 # recast as empty set, if `after` already met,
418 # recast as empty set, if `after` already met,
419 # to prevent unnecessary set comparisons
419 # to prevent unnecessary set comparisons
420 after = MET
420 after = MET
421 else:
421 else:
422 after = MET
422 after = MET
423
423
424 # location dependencies
424 # location dependencies
425 follow = Dependency(md.get('follow', []))
425 follow = Dependency(md.get('follow', []))
426
426
427 timeout = md.get('timeout', None)
427 timeout = md.get('timeout', None)
428 if timeout:
428 if timeout:
429 timeout = float(timeout)
429 timeout = float(timeout)
430
430
431 job = Job(msg_id=msg_id, raw_msg=raw_msg, idents=idents, msg=msg,
431 job = Job(msg_id=msg_id, raw_msg=raw_msg, idents=idents, msg=msg,
432 header=header, targets=targets, after=after, follow=follow,
432 header=header, targets=targets, after=after, follow=follow,
433 timeout=timeout, metadata=md,
433 timeout=timeout, metadata=md,
434 )
434 )
435 # validate and reduce dependencies:
435 # validate and reduce dependencies:
436 for dep in after,follow:
436 for dep in after,follow:
437 if not dep: # empty dependency
437 if not dep: # empty dependency
438 continue
438 continue
439 # check valid:
439 # check valid:
440 if msg_id in dep or dep.difference(self.all_ids):
440 if msg_id in dep or dep.difference(self.all_ids):
441 self.queue_map[msg_id] = job
441 self.queue_map[msg_id] = job
442 return self.fail_unreachable(msg_id, error.InvalidDependency)
442 return self.fail_unreachable(msg_id, error.InvalidDependency)
443 # check if unreachable:
443 # check if unreachable:
444 if dep.unreachable(self.all_completed, self.all_failed):
444 if dep.unreachable(self.all_completed, self.all_failed):
445 self.queue_map[msg_id] = job
445 self.queue_map[msg_id] = job
446 return self.fail_unreachable(msg_id)
446 return self.fail_unreachable(msg_id)
447
447
448 if after.check(self.all_completed, self.all_failed):
448 if after.check(self.all_completed, self.all_failed):
449 # time deps already met, try to run
449 # time deps already met, try to run
450 if not self.maybe_run(job):
450 if not self.maybe_run(job):
451 # can't run yet
451 # can't run yet
452 if msg_id not in self.all_failed:
452 if msg_id not in self.all_failed:
453 # could have failed as unreachable
453 # could have failed as unreachable
454 self.save_unmet(job)
454 self.save_unmet(job)
455 else:
455 else:
456 self.save_unmet(job)
456 self.save_unmet(job)
457
457
458 def job_timeout(self, job, timeout_id):
458 def job_timeout(self, job, timeout_id):
459 """callback for a job's timeout.
459 """callback for a job's timeout.
460
460
461 The job may or may not have been run at this point.
461 The job may or may not have been run at this point.
462 """
462 """
463 if job.timeout_id != timeout_id:
463 if job.timeout_id != timeout_id:
464 # not the most recent call
464 # not the most recent call
465 return
465 return
466 now = time.time()
466 now = time.time()
467 if job.timeout >= (now + 1):
467 if job.timeout >= (now + 1):
468 self.log.warn("task %s timeout fired prematurely: %s > %s",
468 self.log.warn("task %s timeout fired prematurely: %s > %s",
469 job.msg_id, job.timeout, now
469 job.msg_id, job.timeout, now
470 )
470 )
471 if job.msg_id in self.queue_map:
471 if job.msg_id in self.queue_map:
472 # still waiting, but ran out of time
472 # still waiting, but ran out of time
473 self.log.info("task %r timed out", job.msg_id)
473 self.log.info("task %r timed out", job.msg_id)
474 self.fail_unreachable(job.msg_id, error.TaskTimeout)
474 self.fail_unreachable(job.msg_id, error.TaskTimeout)
475
475
476 def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
476 def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
477 """a task has become unreachable, send a reply with an ImpossibleDependency
477 """a task has become unreachable, send a reply with an ImpossibleDependency
478 error."""
478 error."""
479 if msg_id not in self.queue_map:
479 if msg_id not in self.queue_map:
480 self.log.error("task %r already failed!", msg_id)
480 self.log.error("task %r already failed!", msg_id)
481 return
481 return
482 job = self.queue_map.pop(msg_id)
482 job = self.queue_map.pop(msg_id)
483 # lazy-delete from the queue
483 # lazy-delete from the queue
484 job.removed = True
484 job.removed = True
485 for mid in job.dependents:
485 for mid in job.dependents:
486 if mid in self.graph:
486 if mid in self.graph:
487 self.graph[mid].remove(msg_id)
487 self.graph[mid].remove(msg_id)
488
488
489 try:
489 try:
490 raise why()
490 raise why()
491 except:
491 except:
492 content = error.wrap_exception()
492 content = error.wrap_exception()
493 self.log.debug("task %r failing as unreachable with: %s", msg_id, content['ename'])
493 self.log.debug("task %r failing as unreachable with: %s", msg_id, content['ename'])
494
494
495 self.all_done.add(msg_id)
495 self.all_done.add(msg_id)
496 self.all_failed.add(msg_id)
496 self.all_failed.add(msg_id)
497
497
498 msg = self.session.send(self.client_stream, 'apply_reply', content,
498 msg = self.session.send(self.client_stream, 'apply_reply', content,
499 parent=job.header, ident=job.idents)
499 parent=job.header, ident=job.idents)
500 self.session.send(self.mon_stream, msg, ident=[b'outtask']+job.idents)
500 self.session.send(self.mon_stream, msg, ident=[b'outtask']+job.idents)
501
501
502 self.update_graph(msg_id, success=False)
502 self.update_graph(msg_id, success=False)
503
503
504 def available_engines(self):
504 def available_engines(self):
505 """return a list of available engine indices based on HWM"""
505 """return a list of available engine indices based on HWM"""
506 if not self.hwm:
506 if not self.hwm:
507 return list(range(len(self.targets)))
507 return list(range(len(self.targets)))
508 available = []
508 available = []
509 for idx in range(len(self.targets)):
509 for idx in range(len(self.targets)):
510 if self.loads[idx] < self.hwm:
510 if self.loads[idx] < self.hwm:
511 available.append(idx)
511 available.append(idx)
512 return available
512 return available
513
513
514 def maybe_run(self, job):
514 def maybe_run(self, job):
515 """check location dependencies, and run if they are met."""
515 """check location dependencies, and run if they are met."""
516 msg_id = job.msg_id
516 msg_id = job.msg_id
517 self.log.debug("Attempting to assign task %s", msg_id)
517 self.log.debug("Attempting to assign task %s", msg_id)
518 available = self.available_engines()
518 available = self.available_engines()
519 if not available:
519 if not available:
520 # no engines, definitely can't run
520 # no engines, definitely can't run
521 return False
521 return False
522
522
523 if job.follow or job.targets or job.blacklist or self.hwm:
523 if job.follow or job.targets or job.blacklist or self.hwm:
524 # we need a can_run filter
524 # we need a can_run filter
525 def can_run(idx):
525 def can_run(idx):
526 # check hwm
526 # check hwm
527 if self.hwm and self.loads[idx] == self.hwm:
527 if self.hwm and self.loads[idx] == self.hwm:
528 return False
528 return False
529 target = self.targets[idx]
529 target = self.targets[idx]
530 # check blacklist
530 # check blacklist
531 if target in job.blacklist:
531 if target in job.blacklist:
532 return False
532 return False
533 # check targets
533 # check targets
534 if job.targets and target not in job.targets:
534 if job.targets and target not in job.targets:
535 return False
535 return False
536 # check follow
536 # check follow
537 return job.follow.check(self.completed[target], self.failed[target])
537 return job.follow.check(self.completed[target], self.failed[target])
538
538
539 indices = list(filter(can_run, available))
539 indices = list(filter(can_run, available))
540
540
541 if not indices:
541 if not indices:
542 # couldn't run
542 # couldn't run
543 if job.follow.all:
543 if job.follow.all:
544 # check follow for impossibility
544 # check follow for impossibility
545 dests = set()
545 dests = set()
546 relevant = set()
546 relevant = set()
547 if job.follow.success:
547 if job.follow.success:
548 relevant = self.all_completed
548 relevant = self.all_completed
549 if job.follow.failure:
549 if job.follow.failure:
550 relevant = relevant.union(self.all_failed)
550 relevant = relevant.union(self.all_failed)
551 for m in job.follow.intersection(relevant):
551 for m in job.follow.intersection(relevant):
552 dests.add(self.destinations[m])
552 dests.add(self.destinations[m])
553 if len(dests) > 1:
553 if len(dests) > 1:
554 self.queue_map[msg_id] = job
554 self.queue_map[msg_id] = job
555 self.fail_unreachable(msg_id)
555 self.fail_unreachable(msg_id)
556 return False
556 return False
557 if job.targets:
557 if job.targets:
558 # check blacklist+targets for impossibility
558 # check blacklist+targets for impossibility
559 job.targets.difference_update(job.blacklist)
559 job.targets.difference_update(job.blacklist)
560 if not job.targets or not job.targets.intersection(self.targets):
560 if not job.targets or not job.targets.intersection(self.targets):
561 self.queue_map[msg_id] = job
561 self.queue_map[msg_id] = job
562 self.fail_unreachable(msg_id)
562 self.fail_unreachable(msg_id)
563 return False
563 return False
564 return False
564 return False
565 else:
565 else:
566 indices = None
566 indices = None
567
567
568 self.submit_task(job, indices)
568 self.submit_task(job, indices)
569 return True
569 return True
570
570
571 def save_unmet(self, job):
571 def save_unmet(self, job):
572 """Save a message for later submission when its dependencies are met."""
572 """Save a message for later submission when its dependencies are met."""
573 msg_id = job.msg_id
573 msg_id = job.msg_id
574 self.log.debug("Adding task %s to the queue", msg_id)
574 self.log.debug("Adding task %s to the queue", msg_id)
575 self.queue_map[msg_id] = job
575 self.queue_map[msg_id] = job
576 self.queue.append(job)
576 self.queue.append(job)
577 # track the ids in follow or after, but not those already finished
577 # track the ids in follow or after, but not those already finished
578 for dep_id in job.after.union(job.follow).difference(self.all_done):
578 for dep_id in job.after.union(job.follow).difference(self.all_done):
579 if dep_id not in self.graph:
579 if dep_id not in self.graph:
580 self.graph[dep_id] = set()
580 self.graph[dep_id] = set()
581 self.graph[dep_id].add(msg_id)
581 self.graph[dep_id].add(msg_id)
582
582
583 # schedule timeout callback
583 # schedule timeout callback
584 if job.timeout:
584 if job.timeout:
585 timeout_id = job.timeout_id = job.timeout_id + 1
585 timeout_id = job.timeout_id = job.timeout_id + 1
586 self.loop.add_timeout(time.time() + job.timeout,
586 self.loop.add_timeout(time.time() + job.timeout,
587 lambda : self.job_timeout(job, timeout_id)
587 lambda : self.job_timeout(job, timeout_id)
588 )
588 )
589
589
590
590
591 def submit_task(self, job, indices=None):
591 def submit_task(self, job, indices=None):
592 """Submit a task to any of a subset of our targets."""
592 """Submit a task to any of a subset of our targets."""
593 if indices:
593 if indices:
594 loads = [self.loads[i] for i in indices]
594 loads = [self.loads[i] for i in indices]
595 else:
595 else:
596 loads = self.loads
596 loads = self.loads
597 idx = self.scheme(loads)
597 idx = self.scheme(loads)
598 if indices:
598 if indices:
599 idx = indices[idx]
599 idx = indices[idx]
600 target = self.targets[idx]
600 target = self.targets[idx]
601 # print (target, map(str, msg[:3]))
601 # print (target, map(str, msg[:3]))
602 # send job to the engine
602 # send job to the engine
603 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
603 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
604 self.engine_stream.send_multipart(job.raw_msg, copy=False)
604 self.engine_stream.send_multipart(job.raw_msg, copy=False)
605 # update load
605 # update load
606 self.add_job(idx)
606 self.add_job(idx)
607 self.pending[target][job.msg_id] = job
607 self.pending[target][job.msg_id] = job
608 # notify Hub
608 # notify Hub
609 content = dict(msg_id=job.msg_id, engine_id=target.decode('ascii'))
609 content = dict(msg_id=job.msg_id, engine_id=target.decode('ascii'))
610 self.session.send(self.mon_stream, 'task_destination', content=content,
610 self.session.send(self.mon_stream, 'task_destination', content=content,
611 ident=[b'tracktask',self.ident])
611 ident=[b'tracktask',self.ident])
612
612
613
613
614 #-----------------------------------------------------------------------
614 #-----------------------------------------------------------------------
615 # Result Handling
615 # Result Handling
616 #-----------------------------------------------------------------------
616 #-----------------------------------------------------------------------
617
617
618
618
619 @util.log_errors
619 @util.log_errors
620 def dispatch_result(self, raw_msg):
620 def dispatch_result(self, raw_msg):
621 """dispatch method for result replies"""
621 """dispatch method for result replies"""
622 try:
622 try:
623 idents,msg = self.session.feed_identities(raw_msg, copy=False)
623 idents,msg = self.session.feed_identities(raw_msg, copy=False)
624 msg = self.session.deserialize(msg, content=False, copy=False)
624 msg = self.session.deserialize(msg, content=False, copy=False)
625 engine = idents[0]
625 engine = idents[0]
626 try:
626 try:
627 idx = self.targets.index(engine)
627 idx = self.targets.index(engine)
628 except ValueError:
628 except ValueError:
629 pass # skip load-update for dead engines
629 pass # skip load-update for dead engines
630 else:
630 else:
631 self.finish_job(idx)
631 self.finish_job(idx)
632 except Exception:
632 except Exception:
633 self.log.error("task::Invalid result: %r", raw_msg, exc_info=True)
633 self.log.error("task::Invalid result: %r", raw_msg, exc_info=True)
634 return
634 return
635
635
636 md = msg['metadata']
636 md = msg['metadata']
637 parent = msg['parent_header']
637 parent = msg['parent_header']
638 if md.get('dependencies_met', True):
638 if md.get('dependencies_met', True):
639 success = (md['status'] == 'ok')
639 success = (md['status'] == 'ok')
640 msg_id = parent['msg_id']
640 msg_id = parent['msg_id']
641 retries = self.retries[msg_id]
641 retries = self.retries[msg_id]
642 if not success and retries > 0:
642 if not success and retries > 0:
643 # failed
643 # failed
644 self.retries[msg_id] = retries - 1
644 self.retries[msg_id] = retries - 1
645 self.handle_unmet_dependency(idents, parent)
645 self.handle_unmet_dependency(idents, parent)
646 else:
646 else:
647 del self.retries[msg_id]
647 del self.retries[msg_id]
648 # relay to client and update graph
648 # relay to client and update graph
649 self.handle_result(idents, parent, raw_msg, success)
649 self.handle_result(idents, parent, raw_msg, success)
650 # send to Hub monitor
650 # send to Hub monitor
651 self.mon_stream.send_multipart([b'outtask']+raw_msg, copy=False)
651 self.mon_stream.send_multipart([b'outtask']+raw_msg, copy=False)
652 else:
652 else:
653 self.handle_unmet_dependency(idents, parent)
653 self.handle_unmet_dependency(idents, parent)
654
654
655 def handle_result(self, idents, parent, raw_msg, success=True):
655 def handle_result(self, idents, parent, raw_msg, success=True):
656 """handle a real task result, either success or failure"""
656 """handle a real task result, either success or failure"""
657 # first, relay result to client
657 # first, relay result to client
658 engine = idents[0]
658 engine = idents[0]
659 client = idents[1]
659 client = idents[1]
660 # swap_ids for ROUTER-ROUTER mirror
660 # swap_ids for ROUTER-ROUTER mirror
661 raw_msg[:2] = [client,engine]
661 raw_msg[:2] = [client,engine]
662 # print (map(str, raw_msg[:4]))
662 # print (map(str, raw_msg[:4]))
663 self.client_stream.send_multipart(raw_msg, copy=False)
663 self.client_stream.send_multipart(raw_msg, copy=False)
664 # now, update our data structures
664 # now, update our data structures
665 msg_id = parent['msg_id']
665 msg_id = parent['msg_id']
666 self.pending[engine].pop(msg_id)
666 self.pending[engine].pop(msg_id)
667 if success:
667 if success:
668 self.completed[engine].add(msg_id)
668 self.completed[engine].add(msg_id)
669 self.all_completed.add(msg_id)
669 self.all_completed.add(msg_id)
670 else:
670 else:
671 self.failed[engine].add(msg_id)
671 self.failed[engine].add(msg_id)
672 self.all_failed.add(msg_id)
672 self.all_failed.add(msg_id)
673 self.all_done.add(msg_id)
673 self.all_done.add(msg_id)
674 self.destinations[msg_id] = engine
674 self.destinations[msg_id] = engine
675
675
676 self.update_graph(msg_id, success)
676 self.update_graph(msg_id, success)
677
677
678 def handle_unmet_dependency(self, idents, parent):
678 def handle_unmet_dependency(self, idents, parent):
679 """handle an unmet dependency"""
679 """handle an unmet dependency"""
680 engine = idents[0]
680 engine = idents[0]
681 msg_id = parent['msg_id']
681 msg_id = parent['msg_id']
682
682
683 job = self.pending[engine].pop(msg_id)
683 job = self.pending[engine].pop(msg_id)
684 job.blacklist.add(engine)
684 job.blacklist.add(engine)
685
685
686 if job.blacklist == job.targets:
686 if job.blacklist == job.targets:
687 self.queue_map[msg_id] = job
687 self.queue_map[msg_id] = job
688 self.fail_unreachable(msg_id)
688 self.fail_unreachable(msg_id)
689 elif not self.maybe_run(job):
689 elif not self.maybe_run(job):
690 # resubmit failed
690 # resubmit failed
691 if msg_id not in self.all_failed:
691 if msg_id not in self.all_failed:
692 # put it back in our dependency tree
692 # put it back in our dependency tree
693 self.save_unmet(job)
693 self.save_unmet(job)
694
694
695 if self.hwm:
695 if self.hwm:
696 try:
696 try:
697 idx = self.targets.index(engine)
697 idx = self.targets.index(engine)
698 except ValueError:
698 except ValueError:
699 pass # skip load-update for dead engines
699 pass # skip load-update for dead engines
700 else:
700 else:
701 if self.loads[idx] == self.hwm-1:
701 if self.loads[idx] == self.hwm-1:
702 self.update_graph(None)
702 self.update_graph(None)
703
703
704 def update_graph(self, dep_id=None, success=True):
704 def update_graph(self, dep_id=None, success=True):
705 """dep_id just finished. Update our dependency
705 """dep_id just finished. Update our dependency
706 graph and submit any jobs that just became runnable.
706 graph and submit any jobs that just became runnable.
707
707
708 Called with dep_id=None to update entire graph for hwm, but without finishing a task.
708 Called with dep_id=None to update entire graph for hwm, but without finishing a task.
709 """
709 """
710 # print ("\n\n***********")
710 # print ("\n\n***********")
711 # pprint (dep_id)
711 # pprint (dep_id)
712 # pprint (self.graph)
712 # pprint (self.graph)
713 # pprint (self.queue_map)
713 # pprint (self.queue_map)
714 # pprint (self.all_completed)
714 # pprint (self.all_completed)
715 # pprint (self.all_failed)
715 # pprint (self.all_failed)
716 # print ("\n\n***********\n\n")
716 # print ("\n\n***********\n\n")
717 # update any jobs that depended on the dependency
717 # update any jobs that depended on the dependency
718 msg_ids = self.graph.pop(dep_id, [])
718 msg_ids = self.graph.pop(dep_id, [])
719
719
720 # recheck *all* jobs if
720 # recheck *all* jobs if
721 # a) we have HWM and an engine just become no longer full
721 # a) we have HWM and an engine just become no longer full
722 # or b) dep_id was given as None
722 # or b) dep_id was given as None
723
723
724 if dep_id is None or self.hwm and any( [ load==self.hwm-1 for load in self.loads ]):
724 if dep_id is None or self.hwm and any( [ load==self.hwm-1 for load in self.loads ]):
725 jobs = self.queue
725 jobs = self.queue
726 using_queue = True
726 using_queue = True
727 else:
727 else:
728 using_queue = False
728 using_queue = False
729 jobs = deque(sorted( self.queue_map[msg_id] for msg_id in msg_ids ))
729 jobs = deque(sorted( self.queue_map[msg_id] for msg_id in msg_ids ))
730
730
731 to_restore = []
731 to_restore = []
732 while jobs:
732 while jobs:
733 job = jobs.popleft()
733 job = jobs.popleft()
734 if job.removed:
734 if job.removed:
735 continue
735 continue
736 msg_id = job.msg_id
736 msg_id = job.msg_id
737
737
738 put_it_back = True
738 put_it_back = True
739
739
740 if job.after.unreachable(self.all_completed, self.all_failed)\
740 if job.after.unreachable(self.all_completed, self.all_failed)\
741 or job.follow.unreachable(self.all_completed, self.all_failed):
741 or job.follow.unreachable(self.all_completed, self.all_failed):
742 self.fail_unreachable(msg_id)
742 self.fail_unreachable(msg_id)
743 put_it_back = False
743 put_it_back = False
744
744
745 elif job.after.check(self.all_completed, self.all_failed): # time deps met, maybe run
745 elif job.after.check(self.all_completed, self.all_failed): # time deps met, maybe run
746 if self.maybe_run(job):
746 if self.maybe_run(job):
747 put_it_back = False
747 put_it_back = False
748 self.queue_map.pop(msg_id)
748 self.queue_map.pop(msg_id)
749 for mid in job.dependents:
749 for mid in job.dependents:
750 if mid in self.graph:
750 if mid in self.graph:
751 self.graph[mid].remove(msg_id)
751 self.graph[mid].remove(msg_id)
752
752
753 # abort the loop if we just filled up all of our engines.
753 # abort the loop if we just filled up all of our engines.
754 # avoids an O(N) operation in situation of full queue,
754 # avoids an O(N) operation in situation of full queue,
755 # where graph update is triggered as soon as an engine becomes
755 # where graph update is triggered as soon as an engine becomes
756 # non-full, and all tasks after the first are checked,
756 # non-full, and all tasks after the first are checked,
757 # even though they can't run.
757 # even though they can't run.
758 if not self.available_engines():
758 if not self.available_engines():
759 break
759 break
760
760
761 if using_queue and put_it_back:
761 if using_queue and put_it_back:
762 # popped a job from the queue but it neither ran nor failed,
762 # popped a job from the queue but it neither ran nor failed,
763 # so we need to put it back when we are done
763 # so we need to put it back when we are done
764 # make sure to_restore preserves the same ordering
764 # make sure to_restore preserves the same ordering
765 to_restore.append(job)
765 to_restore.append(job)
766
766
767 # put back any tasks we popped but didn't run
767 # put back any tasks we popped but didn't run
768 if using_queue:
768 if using_queue:
769 self.queue.extendleft(to_restore)
769 self.queue.extendleft(to_restore)
770
770
771 #----------------------------------------------------------------------
771 #----------------------------------------------------------------------
772 # methods to be overridden by subclasses
772 # methods to be overridden by subclasses
773 #----------------------------------------------------------------------
773 #----------------------------------------------------------------------
774
774
775 def add_job(self, idx):
775 def add_job(self, idx):
776 """Called after self.targets[idx] just got the job with header.
776 """Called after self.targets[idx] just got the job with header.
777 Override with subclasses. The default ordering is simple LRU.
777 Override with subclasses. The default ordering is simple LRU.
778 The default loads are the number of outstanding jobs."""
778 The default loads are the number of outstanding jobs."""
779 self.loads[idx] += 1
779 self.loads[idx] += 1
780 for lis in (self.targets, self.loads):
780 for lis in (self.targets, self.loads):
781 lis.append(lis.pop(idx))
781 lis.append(lis.pop(idx))
782
782
783
783
784 def finish_job(self, idx):
784 def finish_job(self, idx):
785 """Called after self.targets[idx] just finished a job.
785 """Called after self.targets[idx] just finished a job.
786 Override with subclasses."""
786 Override with subclasses."""
787 self.loads[idx] -= 1
787 self.loads[idx] -= 1
788
788
789
789
790
790
791 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, reg_addr, config=None,
791 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, reg_addr, config=None,
792 logname='root', log_url=None, loglevel=logging.DEBUG,
792 logname='root', log_url=None, loglevel=logging.DEBUG,
793 identity=b'task', in_thread=False):
793 identity=b'task', in_thread=False):
794
794
795 ZMQStream = zmqstream.ZMQStream
795 ZMQStream = zmqstream.ZMQStream
796
796
797 if config:
797 if config:
798 # unwrap dict back into Config
798 # unwrap dict back into Config
799 config = Config(config)
799 config = Config(config)
800
800
801 if in_thread:
801 if in_thread:
802 # use instance() to get the same Context/Loop as our parent
802 # use instance() to get the same Context/Loop as our parent
803 ctx = zmq.Context.instance()
803 ctx = zmq.Context.instance()
804 loop = ioloop.IOLoop.instance()
804 loop = ioloop.IOLoop.instance()
805 else:
805 else:
806 # in a process, don't use instance()
806 # in a process, don't use instance()
807 # for safety with multiprocessing
807 # for safety with multiprocessing
808 ctx = zmq.Context()
808 ctx = zmq.Context()
809 loop = ioloop.IOLoop()
809 loop = ioloop.IOLoop()
810 ins = ZMQStream(ctx.socket(zmq.ROUTER),loop)
810 ins = ZMQStream(ctx.socket(zmq.ROUTER),loop)
811 util.set_hwm(ins, 0)
811 util.set_hwm(ins, 0)
812 ins.setsockopt(zmq.IDENTITY, identity + b'_in')
812 ins.setsockopt(zmq.IDENTITY, identity + b'_in')
813 ins.bind(in_addr)
813 ins.bind(in_addr)
814
814
815 outs = ZMQStream(ctx.socket(zmq.ROUTER),loop)
815 outs = ZMQStream(ctx.socket(zmq.ROUTER),loop)
816 util.set_hwm(outs, 0)
816 util.set_hwm(outs, 0)
817 outs.setsockopt(zmq.IDENTITY, identity + b'_out')
817 outs.setsockopt(zmq.IDENTITY, identity + b'_out')
818 outs.bind(out_addr)
818 outs.bind(out_addr)
819 mons = zmqstream.ZMQStream(ctx.socket(zmq.PUB),loop)
819 mons = zmqstream.ZMQStream(ctx.socket(zmq.PUB),loop)
820 util.set_hwm(mons, 0)
820 util.set_hwm(mons, 0)
821 mons.connect(mon_addr)
821 mons.connect(mon_addr)
822 nots = zmqstream.ZMQStream(ctx.socket(zmq.SUB),loop)
822 nots = zmqstream.ZMQStream(ctx.socket(zmq.SUB),loop)
823 nots.setsockopt(zmq.SUBSCRIBE, b'')
823 nots.setsockopt(zmq.SUBSCRIBE, b'')
824 nots.connect(not_addr)
824 nots.connect(not_addr)
825
825
826 querys = ZMQStream(ctx.socket(zmq.DEALER),loop)
826 querys = ZMQStream(ctx.socket(zmq.DEALER),loop)
827 querys.connect(reg_addr)
827 querys.connect(reg_addr)
828
828
829 # setup logging.
829 # setup logging.
830 if in_thread:
830 if in_thread:
831 log = Application.instance().log
831 log = Application.instance().log
832 else:
832 else:
833 if log_url:
833 if log_url:
834 log = connect_logger(logname, ctx, log_url, root="scheduler", loglevel=loglevel)
834 log = connect_logger(logname, ctx, log_url, root="scheduler", loglevel=loglevel)
835 else:
835 else:
836 log = local_logger(logname, loglevel)
836 log = local_logger(logname, loglevel)
837
837
838 scheduler = TaskScheduler(client_stream=ins, engine_stream=outs,
838 scheduler = TaskScheduler(client_stream=ins, engine_stream=outs,
839 mon_stream=mons, notifier_stream=nots,
839 mon_stream=mons, notifier_stream=nots,
840 query_stream=querys,
840 query_stream=querys,
841 loop=loop, log=log,
841 loop=loop, log=log,
842 config=config)
842 config=config)
843 scheduler.start()
843 scheduler.start()
844 if not in_thread:
844 if not in_thread:
845 try:
845 try:
846 loop.start()
846 loop.start()
847 except KeyboardInterrupt:
847 except KeyboardInterrupt:
848 scheduler.log.critical("Interrupted, exiting...")
848 scheduler.log.critical("Interrupted, exiting...")
849
849
@@ -1,578 +1,578 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """terminal client to the IPython kernel"""
2 """terminal client to the IPython kernel"""
3
3
4 # Copyright (c) IPython Development Team.
4 # Copyright (c) IPython Development Team.
5 # Distributed under the terms of the Modified BSD License.
5 # Distributed under the terms of the Modified BSD License.
6
6
7 from __future__ import print_function
7 from __future__ import print_function
8
8
9 import base64
9 import base64
10 import bdb
10 import bdb
11 import signal
11 import signal
12 import os
12 import os
13 import sys
13 import sys
14 import time
14 import time
15 import subprocess
15 import subprocess
16 from getpass import getpass
16 from getpass import getpass
17 from io import BytesIO
17 from io import BytesIO
18
18
19 try:
19 try:
20 from queue import Empty # Py 3
20 from queue import Empty # Py 3
21 except ImportError:
21 except ImportError:
22 from Queue import Empty # Py 2
22 from Queue import Empty # Py 2
23
23
24 from IPython.core import page
24 from IPython.core import page
25 from IPython.core import release
25 from IPython.core import release
26 from IPython.terminal.console.zmqhistory import ZMQHistoryManager
26 from IPython.terminal.console.zmqhistory import ZMQHistoryManager
27 from IPython.utils.warn import warn, error
27 from IPython.utils.warn import warn, error
28 from IPython.utils import io
28 from IPython.utils import io
29 from IPython.utils.py3compat import string_types, input
29 from IPython.utils.py3compat import string_types, input
30 from IPython.utils.traitlets import List, Enum, Any, Instance, Unicode, Float, Bool
30 from IPython.utils.traitlets import List, Enum, Any, Instance, Unicode, Float, Bool
31 from IPython.utils.tempdir import NamedFileInTemporaryDirectory
31 from IPython.utils.tempdir import NamedFileInTemporaryDirectory
32
32
33 from IPython.terminal.interactiveshell import TerminalInteractiveShell
33 from IPython.terminal.interactiveshell import TerminalInteractiveShell
34 from IPython.terminal.console.completer import ZMQCompleter
34 from IPython.terminal.console.completer import ZMQCompleter
35
35
36 class ZMQTerminalInteractiveShell(TerminalInteractiveShell):
36 class ZMQTerminalInteractiveShell(TerminalInteractiveShell):
37 """A subclass of TerminalInteractiveShell that uses the 0MQ kernel"""
37 """A subclass of TerminalInteractiveShell that uses the 0MQ kernel"""
38 _executing = False
38 _executing = False
39 _execution_state = Unicode('')
39 _execution_state = Unicode('')
40 _pending_clearoutput = False
40 _pending_clearoutput = False
41 kernel_banner = Unicode('')
41 kernel_banner = Unicode('')
42 kernel_timeout = Float(60, config=True,
42 kernel_timeout = Float(60, config=True,
43 help="""Timeout for giving up on a kernel (in seconds).
43 help="""Timeout for giving up on a kernel (in seconds).
44
44
45 On first connect and restart, the console tests whether the
45 On first connect and restart, the console tests whether the
46 kernel is running and responsive by sending kernel_info_requests.
46 kernel is running and responsive by sending kernel_info_requests.
47 This sets the timeout in seconds for how long the kernel can take
47 This sets the timeout in seconds for how long the kernel can take
48 before being presumed dead.
48 before being presumed dead.
49 """
49 """
50 )
50 )
51
51
52 image_handler = Enum(('PIL', 'stream', 'tempfile', 'callable'),
52 image_handler = Enum(('PIL', 'stream', 'tempfile', 'callable'),
53 config=True, help=
53 config=True, allow_none=True, help=
54 """
54 """
55 Handler for image type output. This is useful, for example,
55 Handler for image type output. This is useful, for example,
56 when connecting to the kernel in which pylab inline backend is
56 when connecting to the kernel in which pylab inline backend is
57 activated. There are four handlers defined. 'PIL': Use
57 activated. There are four handlers defined. 'PIL': Use
58 Python Imaging Library to popup image; 'stream': Use an
58 Python Imaging Library to popup image; 'stream': Use an
59 external program to show the image. Image will be fed into
59 external program to show the image. Image will be fed into
60 the STDIN of the program. You will need to configure
60 the STDIN of the program. You will need to configure
61 `stream_image_handler`; 'tempfile': Use an external program to
61 `stream_image_handler`; 'tempfile': Use an external program to
62 show the image. Image will be saved in a temporally file and
62 show the image. Image will be saved in a temporally file and
63 the program is called with the temporally file. You will need
63 the program is called with the temporally file. You will need
64 to configure `tempfile_image_handler`; 'callable': You can set
64 to configure `tempfile_image_handler`; 'callable': You can set
65 any Python callable which is called with the image data. You
65 any Python callable which is called with the image data. You
66 will need to configure `callable_image_handler`.
66 will need to configure `callable_image_handler`.
67 """
67 """
68 )
68 )
69
69
70 stream_image_handler = List(config=True, help=
70 stream_image_handler = List(config=True, help=
71 """
71 """
72 Command to invoke an image viewer program when you are using
72 Command to invoke an image viewer program when you are using
73 'stream' image handler. This option is a list of string where
73 'stream' image handler. This option is a list of string where
74 the first element is the command itself and reminders are the
74 the first element is the command itself and reminders are the
75 options for the command. Raw image data is given as STDIN to
75 options for the command. Raw image data is given as STDIN to
76 the program.
76 the program.
77 """
77 """
78 )
78 )
79
79
80 tempfile_image_handler = List(config=True, help=
80 tempfile_image_handler = List(config=True, help=
81 """
81 """
82 Command to invoke an image viewer program when you are using
82 Command to invoke an image viewer program when you are using
83 'tempfile' image handler. This option is a list of string
83 'tempfile' image handler. This option is a list of string
84 where the first element is the command itself and reminders
84 where the first element is the command itself and reminders
85 are the options for the command. You can use {file} and
85 are the options for the command. You can use {file} and
86 {format} in the string to represent the location of the
86 {format} in the string to represent the location of the
87 generated image file and image format.
87 generated image file and image format.
88 """
88 """
89 )
89 )
90
90
91 callable_image_handler = Any(config=True, help=
91 callable_image_handler = Any(config=True, help=
92 """
92 """
93 Callable object called via 'callable' image handler with one
93 Callable object called via 'callable' image handler with one
94 argument, `data`, which is `msg["content"]["data"]` where
94 argument, `data`, which is `msg["content"]["data"]` where
95 `msg` is the message from iopub channel. For exmaple, you can
95 `msg` is the message from iopub channel. For exmaple, you can
96 find base64 encoded PNG data as `data['image/png']`.
96 find base64 encoded PNG data as `data['image/png']`.
97 """
97 """
98 )
98 )
99
99
100 mime_preference = List(
100 mime_preference = List(
101 default_value=['image/png', 'image/jpeg', 'image/svg+xml'],
101 default_value=['image/png', 'image/jpeg', 'image/svg+xml'],
102 config=True, allow_none=False, help=
102 config=True, help=
103 """
103 """
104 Preferred object representation MIME type in order. First
104 Preferred object representation MIME type in order. First
105 matched MIME type will be used.
105 matched MIME type will be used.
106 """
106 """
107 )
107 )
108
108
109 manager = Instance('IPython.kernel.KernelManager')
109 manager = Instance('IPython.kernel.KernelManager')
110 client = Instance('IPython.kernel.KernelClient')
110 client = Instance('IPython.kernel.KernelClient')
111 def _client_changed(self, name, old, new):
111 def _client_changed(self, name, old, new):
112 self.session_id = new.session.session
112 self.session_id = new.session.session
113 session_id = Unicode()
113 session_id = Unicode()
114
114
115 def init_completer(self):
115 def init_completer(self):
116 """Initialize the completion machinery.
116 """Initialize the completion machinery.
117
117
118 This creates completion machinery that can be used by client code,
118 This creates completion machinery that can be used by client code,
119 either interactively in-process (typically triggered by the readline
119 either interactively in-process (typically triggered by the readline
120 library), programmatically (such as in test suites) or out-of-process
120 library), programmatically (such as in test suites) or out-of-process
121 (typically over the network by remote frontends).
121 (typically over the network by remote frontends).
122 """
122 """
123 from IPython.core.completerlib import (module_completer,
123 from IPython.core.completerlib import (module_completer,
124 magic_run_completer, cd_completer)
124 magic_run_completer, cd_completer)
125
125
126 self.Completer = ZMQCompleter(self, self.client, config=self.config)
126 self.Completer = ZMQCompleter(self, self.client, config=self.config)
127
127
128
128
129 self.set_hook('complete_command', module_completer, str_key = 'import')
129 self.set_hook('complete_command', module_completer, str_key = 'import')
130 self.set_hook('complete_command', module_completer, str_key = 'from')
130 self.set_hook('complete_command', module_completer, str_key = 'from')
131 self.set_hook('complete_command', magic_run_completer, str_key = '%run')
131 self.set_hook('complete_command', magic_run_completer, str_key = '%run')
132 self.set_hook('complete_command', cd_completer, str_key = '%cd')
132 self.set_hook('complete_command', cd_completer, str_key = '%cd')
133
133
134 # Only configure readline if we truly are using readline. IPython can
134 # Only configure readline if we truly are using readline. IPython can
135 # do tab-completion over the network, in GUIs, etc, where readline
135 # do tab-completion over the network, in GUIs, etc, where readline
136 # itself may be absent
136 # itself may be absent
137 if self.has_readline:
137 if self.has_readline:
138 self.set_readline_completer()
138 self.set_readline_completer()
139
139
140 def run_cell(self, cell, store_history=True):
140 def run_cell(self, cell, store_history=True):
141 """Run a complete IPython cell.
141 """Run a complete IPython cell.
142
142
143 Parameters
143 Parameters
144 ----------
144 ----------
145 cell : str
145 cell : str
146 The code (including IPython code such as %magic functions) to run.
146 The code (including IPython code such as %magic functions) to run.
147 store_history : bool
147 store_history : bool
148 If True, the raw and translated cell will be stored in IPython's
148 If True, the raw and translated cell will be stored in IPython's
149 history. For user code calling back into IPython's machinery, this
149 history. For user code calling back into IPython's machinery, this
150 should be set to False.
150 should be set to False.
151 """
151 """
152 if (not cell) or cell.isspace():
152 if (not cell) or cell.isspace():
153 # pressing enter flushes any pending display
153 # pressing enter flushes any pending display
154 self.handle_iopub()
154 self.handle_iopub()
155 return
155 return
156
156
157 # flush stale replies, which could have been ignored, due to missed heartbeats
157 # flush stale replies, which could have been ignored, due to missed heartbeats
158 while self.client.shell_channel.msg_ready():
158 while self.client.shell_channel.msg_ready():
159 self.client.shell_channel.get_msg()
159 self.client.shell_channel.get_msg()
160 # execute takes 'hidden', which is the inverse of store_hist
160 # execute takes 'hidden', which is the inverse of store_hist
161 msg_id = self.client.execute(cell, not store_history)
161 msg_id = self.client.execute(cell, not store_history)
162
162
163 # first thing is wait for any side effects (output, stdin, etc.)
163 # first thing is wait for any side effects (output, stdin, etc.)
164 self._executing = True
164 self._executing = True
165 self._execution_state = "busy"
165 self._execution_state = "busy"
166 while self._execution_state != 'idle' and self.client.is_alive():
166 while self._execution_state != 'idle' and self.client.is_alive():
167 try:
167 try:
168 self.handle_input_request(msg_id, timeout=0.05)
168 self.handle_input_request(msg_id, timeout=0.05)
169 except Empty:
169 except Empty:
170 # display intermediate print statements, etc.
170 # display intermediate print statements, etc.
171 self.handle_iopub(msg_id)
171 self.handle_iopub(msg_id)
172
172
173 # after all of that is done, wait for the execute reply
173 # after all of that is done, wait for the execute reply
174 while self.client.is_alive():
174 while self.client.is_alive():
175 try:
175 try:
176 self.handle_execute_reply(msg_id, timeout=0.05)
176 self.handle_execute_reply(msg_id, timeout=0.05)
177 except Empty:
177 except Empty:
178 pass
178 pass
179 else:
179 else:
180 break
180 break
181 self._executing = False
181 self._executing = False
182
182
183 #-----------------
183 #-----------------
184 # message handlers
184 # message handlers
185 #-----------------
185 #-----------------
186
186
187 def handle_execute_reply(self, msg_id, timeout=None):
187 def handle_execute_reply(self, msg_id, timeout=None):
188 msg = self.client.shell_channel.get_msg(block=False, timeout=timeout)
188 msg = self.client.shell_channel.get_msg(block=False, timeout=timeout)
189 if msg["parent_header"].get("msg_id", None) == msg_id:
189 if msg["parent_header"].get("msg_id", None) == msg_id:
190
190
191 self.handle_iopub(msg_id)
191 self.handle_iopub(msg_id)
192
192
193 content = msg["content"]
193 content = msg["content"]
194 status = content['status']
194 status = content['status']
195
195
196 if status == 'aborted':
196 if status == 'aborted':
197 self.write('Aborted\n')
197 self.write('Aborted\n')
198 return
198 return
199 elif status == 'ok':
199 elif status == 'ok':
200 # handle payloads
200 # handle payloads
201 for item in content["payload"]:
201 for item in content["payload"]:
202 source = item['source']
202 source = item['source']
203 if source == 'page':
203 if source == 'page':
204 page.page(item['data']['text/plain'])
204 page.page(item['data']['text/plain'])
205 elif source == 'set_next_input':
205 elif source == 'set_next_input':
206 self.set_next_input(item['text'])
206 self.set_next_input(item['text'])
207 elif source == 'ask_exit':
207 elif source == 'ask_exit':
208 self.ask_exit()
208 self.ask_exit()
209
209
210 elif status == 'error':
210 elif status == 'error':
211 for frame in content["traceback"]:
211 for frame in content["traceback"]:
212 print(frame, file=io.stderr)
212 print(frame, file=io.stderr)
213
213
214 self.execution_count = int(content["execution_count"] + 1)
214 self.execution_count = int(content["execution_count"] + 1)
215
215
216 include_other_output = Bool(False, config=True,
216 include_other_output = Bool(False, config=True,
217 help="""Whether to include output from clients
217 help="""Whether to include output from clients
218 other than this one sharing the same kernel.
218 other than this one sharing the same kernel.
219
219
220 Outputs are not displayed until enter is pressed.
220 Outputs are not displayed until enter is pressed.
221 """
221 """
222 )
222 )
223 other_output_prefix = Unicode("[remote] ", config=True,
223 other_output_prefix = Unicode("[remote] ", config=True,
224 help="""Prefix to add to outputs coming from clients other than this one.
224 help="""Prefix to add to outputs coming from clients other than this one.
225
225
226 Only relevant if include_other_output is True.
226 Only relevant if include_other_output is True.
227 """
227 """
228 )
228 )
229
229
230 def from_here(self, msg):
230 def from_here(self, msg):
231 """Return whether a message is from this session"""
231 """Return whether a message is from this session"""
232 return msg['parent_header'].get("session", self.session_id) == self.session_id
232 return msg['parent_header'].get("session", self.session_id) == self.session_id
233
233
234 def include_output(self, msg):
234 def include_output(self, msg):
235 """Return whether we should include a given output message"""
235 """Return whether we should include a given output message"""
236 from_here = self.from_here(msg)
236 from_here = self.from_here(msg)
237 if msg['msg_type'] == 'execute_input':
237 if msg['msg_type'] == 'execute_input':
238 # only echo inputs not from here
238 # only echo inputs not from here
239 return self.include_other_output and not from_here
239 return self.include_other_output and not from_here
240
240
241 if self.include_other_output:
241 if self.include_other_output:
242 return True
242 return True
243 else:
243 else:
244 return from_here
244 return from_here
245
245
246 def handle_iopub(self, msg_id=''):
246 def handle_iopub(self, msg_id=''):
247 """Process messages on the IOPub channel
247 """Process messages on the IOPub channel
248
248
249 This method consumes and processes messages on the IOPub channel,
249 This method consumes and processes messages on the IOPub channel,
250 such as stdout, stderr, execute_result and status.
250 such as stdout, stderr, execute_result and status.
251
251
252 It only displays output that is caused by this session.
252 It only displays output that is caused by this session.
253 """
253 """
254 while self.client.iopub_channel.msg_ready():
254 while self.client.iopub_channel.msg_ready():
255 sub_msg = self.client.iopub_channel.get_msg()
255 sub_msg = self.client.iopub_channel.get_msg()
256 msg_type = sub_msg['header']['msg_type']
256 msg_type = sub_msg['header']['msg_type']
257 parent = sub_msg["parent_header"]
257 parent = sub_msg["parent_header"]
258
258
259 if self.include_output(sub_msg):
259 if self.include_output(sub_msg):
260 if msg_type == 'status':
260 if msg_type == 'status':
261 self._execution_state = sub_msg["content"]["execution_state"]
261 self._execution_state = sub_msg["content"]["execution_state"]
262 elif msg_type == 'stream':
262 elif msg_type == 'stream':
263 if sub_msg["content"]["name"] == "stdout":
263 if sub_msg["content"]["name"] == "stdout":
264 if self._pending_clearoutput:
264 if self._pending_clearoutput:
265 print("\r", file=io.stdout, end="")
265 print("\r", file=io.stdout, end="")
266 self._pending_clearoutput = False
266 self._pending_clearoutput = False
267 print(sub_msg["content"]["text"], file=io.stdout, end="")
267 print(sub_msg["content"]["text"], file=io.stdout, end="")
268 io.stdout.flush()
268 io.stdout.flush()
269 elif sub_msg["content"]["name"] == "stderr":
269 elif sub_msg["content"]["name"] == "stderr":
270 if self._pending_clearoutput:
270 if self._pending_clearoutput:
271 print("\r", file=io.stderr, end="")
271 print("\r", file=io.stderr, end="")
272 self._pending_clearoutput = False
272 self._pending_clearoutput = False
273 print(sub_msg["content"]["text"], file=io.stderr, end="")
273 print(sub_msg["content"]["text"], file=io.stderr, end="")
274 io.stderr.flush()
274 io.stderr.flush()
275
275
276 elif msg_type == 'execute_result':
276 elif msg_type == 'execute_result':
277 if self._pending_clearoutput:
277 if self._pending_clearoutput:
278 print("\r", file=io.stdout, end="")
278 print("\r", file=io.stdout, end="")
279 self._pending_clearoutput = False
279 self._pending_clearoutput = False
280 self.execution_count = int(sub_msg["content"]["execution_count"])
280 self.execution_count = int(sub_msg["content"]["execution_count"])
281 if not self.from_here(sub_msg):
281 if not self.from_here(sub_msg):
282 sys.stdout.write(self.other_output_prefix)
282 sys.stdout.write(self.other_output_prefix)
283 format_dict = sub_msg["content"]["data"]
283 format_dict = sub_msg["content"]["data"]
284 self.handle_rich_data(format_dict)
284 self.handle_rich_data(format_dict)
285
285
286 # taken from DisplayHook.__call__:
286 # taken from DisplayHook.__call__:
287 hook = self.displayhook
287 hook = self.displayhook
288 hook.start_displayhook()
288 hook.start_displayhook()
289 hook.write_output_prompt()
289 hook.write_output_prompt()
290 hook.write_format_data(format_dict)
290 hook.write_format_data(format_dict)
291 hook.log_output(format_dict)
291 hook.log_output(format_dict)
292 hook.finish_displayhook()
292 hook.finish_displayhook()
293
293
294 elif msg_type == 'display_data':
294 elif msg_type == 'display_data':
295 data = sub_msg["content"]["data"]
295 data = sub_msg["content"]["data"]
296 handled = self.handle_rich_data(data)
296 handled = self.handle_rich_data(data)
297 if not handled:
297 if not handled:
298 if not self.from_here(sub_msg):
298 if not self.from_here(sub_msg):
299 sys.stdout.write(self.other_output_prefix)
299 sys.stdout.write(self.other_output_prefix)
300 # if it was an image, we handled it by now
300 # if it was an image, we handled it by now
301 if 'text/plain' in data:
301 if 'text/plain' in data:
302 print(data['text/plain'])
302 print(data['text/plain'])
303
303
304 elif msg_type == 'execute_input':
304 elif msg_type == 'execute_input':
305 content = sub_msg['content']
305 content = sub_msg['content']
306 self.execution_count = content['execution_count']
306 self.execution_count = content['execution_count']
307 if not self.from_here(sub_msg):
307 if not self.from_here(sub_msg):
308 sys.stdout.write(self.other_output_prefix)
308 sys.stdout.write(self.other_output_prefix)
309 sys.stdout.write(self.prompt_manager.render('in'))
309 sys.stdout.write(self.prompt_manager.render('in'))
310 sys.stdout.write(content['code'])
310 sys.stdout.write(content['code'])
311
311
312 elif msg_type == 'clear_output':
312 elif msg_type == 'clear_output':
313 if sub_msg["content"]["wait"]:
313 if sub_msg["content"]["wait"]:
314 self._pending_clearoutput = True
314 self._pending_clearoutput = True
315 else:
315 else:
316 print("\r", file=io.stdout, end="")
316 print("\r", file=io.stdout, end="")
317
317
318 _imagemime = {
318 _imagemime = {
319 'image/png': 'png',
319 'image/png': 'png',
320 'image/jpeg': 'jpeg',
320 'image/jpeg': 'jpeg',
321 'image/svg+xml': 'svg',
321 'image/svg+xml': 'svg',
322 }
322 }
323
323
324 def handle_rich_data(self, data):
324 def handle_rich_data(self, data):
325 for mime in self.mime_preference:
325 for mime in self.mime_preference:
326 if mime in data and mime in self._imagemime:
326 if mime in data and mime in self._imagemime:
327 self.handle_image(data, mime)
327 self.handle_image(data, mime)
328 return True
328 return True
329
329
330 def handle_image(self, data, mime):
330 def handle_image(self, data, mime):
331 handler = getattr(
331 handler = getattr(
332 self, 'handle_image_{0}'.format(self.image_handler), None)
332 self, 'handle_image_{0}'.format(self.image_handler), None)
333 if handler:
333 if handler:
334 handler(data, mime)
334 handler(data, mime)
335
335
336 def handle_image_PIL(self, data, mime):
336 def handle_image_PIL(self, data, mime):
337 if mime not in ('image/png', 'image/jpeg'):
337 if mime not in ('image/png', 'image/jpeg'):
338 return
338 return
339 import PIL.Image
339 import PIL.Image
340 raw = base64.decodestring(data[mime].encode('ascii'))
340 raw = base64.decodestring(data[mime].encode('ascii'))
341 img = PIL.Image.open(BytesIO(raw))
341 img = PIL.Image.open(BytesIO(raw))
342 img.show()
342 img.show()
343
343
344 def handle_image_stream(self, data, mime):
344 def handle_image_stream(self, data, mime):
345 raw = base64.decodestring(data[mime].encode('ascii'))
345 raw = base64.decodestring(data[mime].encode('ascii'))
346 imageformat = self._imagemime[mime]
346 imageformat = self._imagemime[mime]
347 fmt = dict(format=imageformat)
347 fmt = dict(format=imageformat)
348 args = [s.format(**fmt) for s in self.stream_image_handler]
348 args = [s.format(**fmt) for s in self.stream_image_handler]
349 with open(os.devnull, 'w') as devnull:
349 with open(os.devnull, 'w') as devnull:
350 proc = subprocess.Popen(
350 proc = subprocess.Popen(
351 args, stdin=subprocess.PIPE,
351 args, stdin=subprocess.PIPE,
352 stdout=devnull, stderr=devnull)
352 stdout=devnull, stderr=devnull)
353 proc.communicate(raw)
353 proc.communicate(raw)
354
354
355 def handle_image_tempfile(self, data, mime):
355 def handle_image_tempfile(self, data, mime):
356 raw = base64.decodestring(data[mime].encode('ascii'))
356 raw = base64.decodestring(data[mime].encode('ascii'))
357 imageformat = self._imagemime[mime]
357 imageformat = self._imagemime[mime]
358 filename = 'tmp.{0}'.format(imageformat)
358 filename = 'tmp.{0}'.format(imageformat)
359 with NamedFileInTemporaryDirectory(filename) as f, \
359 with NamedFileInTemporaryDirectory(filename) as f, \
360 open(os.devnull, 'w') as devnull:
360 open(os.devnull, 'w') as devnull:
361 f.write(raw)
361 f.write(raw)
362 f.flush()
362 f.flush()
363 fmt = dict(file=f.name, format=imageformat)
363 fmt = dict(file=f.name, format=imageformat)
364 args = [s.format(**fmt) for s in self.tempfile_image_handler]
364 args = [s.format(**fmt) for s in self.tempfile_image_handler]
365 subprocess.call(args, stdout=devnull, stderr=devnull)
365 subprocess.call(args, stdout=devnull, stderr=devnull)
366
366
367 def handle_image_callable(self, data, mime):
367 def handle_image_callable(self, data, mime):
368 self.callable_image_handler(data)
368 self.callable_image_handler(data)
369
369
370 def handle_input_request(self, msg_id, timeout=0.1):
370 def handle_input_request(self, msg_id, timeout=0.1):
371 """ Method to capture raw_input
371 """ Method to capture raw_input
372 """
372 """
373 req = self.client.stdin_channel.get_msg(timeout=timeout)
373 req = self.client.stdin_channel.get_msg(timeout=timeout)
374 # in case any iopub came while we were waiting:
374 # in case any iopub came while we were waiting:
375 self.handle_iopub(msg_id)
375 self.handle_iopub(msg_id)
376 if msg_id == req["parent_header"].get("msg_id"):
376 if msg_id == req["parent_header"].get("msg_id"):
377 # wrap SIGINT handler
377 # wrap SIGINT handler
378 real_handler = signal.getsignal(signal.SIGINT)
378 real_handler = signal.getsignal(signal.SIGINT)
379 def double_int(sig,frame):
379 def double_int(sig,frame):
380 # call real handler (forwards sigint to kernel),
380 # call real handler (forwards sigint to kernel),
381 # then raise local interrupt, stopping local raw_input
381 # then raise local interrupt, stopping local raw_input
382 real_handler(sig,frame)
382 real_handler(sig,frame)
383 raise KeyboardInterrupt
383 raise KeyboardInterrupt
384 signal.signal(signal.SIGINT, double_int)
384 signal.signal(signal.SIGINT, double_int)
385 content = req['content']
385 content = req['content']
386 read = getpass if content.get('password', False) else input
386 read = getpass if content.get('password', False) else input
387 try:
387 try:
388 raw_data = read(content["prompt"])
388 raw_data = read(content["prompt"])
389 except EOFError:
389 except EOFError:
390 # turn EOFError into EOF character
390 # turn EOFError into EOF character
391 raw_data = '\x04'
391 raw_data = '\x04'
392 except KeyboardInterrupt:
392 except KeyboardInterrupt:
393 sys.stdout.write('\n')
393 sys.stdout.write('\n')
394 return
394 return
395 finally:
395 finally:
396 # restore SIGINT handler
396 # restore SIGINT handler
397 signal.signal(signal.SIGINT, real_handler)
397 signal.signal(signal.SIGINT, real_handler)
398
398
399 # only send stdin reply if there *was not* another request
399 # only send stdin reply if there *was not* another request
400 # or execution finished while we were reading.
400 # or execution finished while we were reading.
401 if not (self.client.stdin_channel.msg_ready() or self.client.shell_channel.msg_ready()):
401 if not (self.client.stdin_channel.msg_ready() or self.client.shell_channel.msg_ready()):
402 self.client.input(raw_data)
402 self.client.input(raw_data)
403
403
404 def mainloop(self, display_banner=False):
404 def mainloop(self, display_banner=False):
405 while True:
405 while True:
406 try:
406 try:
407 self.interact(display_banner=display_banner)
407 self.interact(display_banner=display_banner)
408 #self.interact_with_readline()
408 #self.interact_with_readline()
409 # XXX for testing of a readline-decoupled repl loop, call
409 # XXX for testing of a readline-decoupled repl loop, call
410 # interact_with_readline above
410 # interact_with_readline above
411 break
411 break
412 except KeyboardInterrupt:
412 except KeyboardInterrupt:
413 # this should not be necessary, but KeyboardInterrupt
413 # this should not be necessary, but KeyboardInterrupt
414 # handling seems rather unpredictable...
414 # handling seems rather unpredictable...
415 self.write("\nKeyboardInterrupt in interact()\n")
415 self.write("\nKeyboardInterrupt in interact()\n")
416
416
417 self.client.shutdown()
417 self.client.shutdown()
418
418
419 def _banner1_default(self):
419 def _banner1_default(self):
420 return "IPython Console {version}\n".format(version=release.version)
420 return "IPython Console {version}\n".format(version=release.version)
421
421
422 def compute_banner(self):
422 def compute_banner(self):
423 super(ZMQTerminalInteractiveShell, self).compute_banner()
423 super(ZMQTerminalInteractiveShell, self).compute_banner()
424 if self.client and not self.kernel_banner:
424 if self.client and not self.kernel_banner:
425 msg_id = self.client.kernel_info()
425 msg_id = self.client.kernel_info()
426 while True:
426 while True:
427 try:
427 try:
428 reply = self.client.get_shell_msg(timeout=1)
428 reply = self.client.get_shell_msg(timeout=1)
429 except Empty:
429 except Empty:
430 break
430 break
431 else:
431 else:
432 if reply['parent_header'].get('msg_id') == msg_id:
432 if reply['parent_header'].get('msg_id') == msg_id:
433 self.kernel_banner = reply['content'].get('banner', '')
433 self.kernel_banner = reply['content'].get('banner', '')
434 break
434 break
435 self.banner += self.kernel_banner
435 self.banner += self.kernel_banner
436
436
437 def wait_for_kernel(self, timeout=None):
437 def wait_for_kernel(self, timeout=None):
438 """method to wait for a kernel to be ready"""
438 """method to wait for a kernel to be ready"""
439 tic = time.time()
439 tic = time.time()
440 self.client.hb_channel.unpause()
440 self.client.hb_channel.unpause()
441 while True:
441 while True:
442 msg_id = self.client.kernel_info()
442 msg_id = self.client.kernel_info()
443 reply = None
443 reply = None
444 while True:
444 while True:
445 try:
445 try:
446 reply = self.client.get_shell_msg(timeout=1)
446 reply = self.client.get_shell_msg(timeout=1)
447 except Empty:
447 except Empty:
448 break
448 break
449 else:
449 else:
450 if reply['parent_header'].get('msg_id') == msg_id:
450 if reply['parent_header'].get('msg_id') == msg_id:
451 return True
451 return True
452 if timeout is not None \
452 if timeout is not None \
453 and (time.time() - tic) > timeout \
453 and (time.time() - tic) > timeout \
454 and not self.client.hb_channel.is_beating():
454 and not self.client.hb_channel.is_beating():
455 # heart failed
455 # heart failed
456 return False
456 return False
457 return True
457 return True
458
458
459 def interact(self, display_banner=None):
459 def interact(self, display_banner=None):
460 """Closely emulate the interactive Python console."""
460 """Closely emulate the interactive Python console."""
461
461
462 # batch run -> do not interact
462 # batch run -> do not interact
463 if self.exit_now:
463 if self.exit_now:
464 return
464 return
465
465
466 if display_banner is None:
466 if display_banner is None:
467 display_banner = self.display_banner
467 display_banner = self.display_banner
468
468
469 if isinstance(display_banner, string_types):
469 if isinstance(display_banner, string_types):
470 self.show_banner(display_banner)
470 self.show_banner(display_banner)
471 elif display_banner:
471 elif display_banner:
472 self.show_banner()
472 self.show_banner()
473
473
474 more = False
474 more = False
475
475
476 # run a non-empty no-op, so that we don't get a prompt until
476 # run a non-empty no-op, so that we don't get a prompt until
477 # we know the kernel is ready. This keeps the connection
477 # we know the kernel is ready. This keeps the connection
478 # message above the first prompt.
478 # message above the first prompt.
479 if not self.wait_for_kernel(self.kernel_timeout):
479 if not self.wait_for_kernel(self.kernel_timeout):
480 error("Kernel did not respond\n")
480 error("Kernel did not respond\n")
481 return
481 return
482
482
483 if self.has_readline:
483 if self.has_readline:
484 self.readline_startup_hook(self.pre_readline)
484 self.readline_startup_hook(self.pre_readline)
485 hlen_b4_cell = self.readline.get_current_history_length()
485 hlen_b4_cell = self.readline.get_current_history_length()
486 else:
486 else:
487 hlen_b4_cell = 0
487 hlen_b4_cell = 0
488 # exit_now is set by a call to %Exit or %Quit, through the
488 # exit_now is set by a call to %Exit or %Quit, through the
489 # ask_exit callback.
489 # ask_exit callback.
490
490
491 while not self.exit_now:
491 while not self.exit_now:
492 if not self.client.is_alive():
492 if not self.client.is_alive():
493 # kernel died, prompt for action or exit
493 # kernel died, prompt for action or exit
494
494
495 action = "restart" if self.manager else "wait for restart"
495 action = "restart" if self.manager else "wait for restart"
496 ans = self.ask_yes_no("kernel died, %s ([y]/n)?" % action, default='y')
496 ans = self.ask_yes_no("kernel died, %s ([y]/n)?" % action, default='y')
497 if ans:
497 if ans:
498 if self.manager:
498 if self.manager:
499 self.manager.restart_kernel(True)
499 self.manager.restart_kernel(True)
500 self.wait_for_kernel(self.kernel_timeout)
500 self.wait_for_kernel(self.kernel_timeout)
501 else:
501 else:
502 self.exit_now = True
502 self.exit_now = True
503 continue
503 continue
504 try:
504 try:
505 # protect prompt block from KeyboardInterrupt
505 # protect prompt block from KeyboardInterrupt
506 # when sitting on ctrl-C
506 # when sitting on ctrl-C
507 self.hooks.pre_prompt_hook()
507 self.hooks.pre_prompt_hook()
508 if more:
508 if more:
509 try:
509 try:
510 prompt = self.prompt_manager.render('in2')
510 prompt = self.prompt_manager.render('in2')
511 except Exception:
511 except Exception:
512 self.showtraceback()
512 self.showtraceback()
513 if self.autoindent:
513 if self.autoindent:
514 self.rl_do_indent = True
514 self.rl_do_indent = True
515
515
516 else:
516 else:
517 try:
517 try:
518 prompt = self.separate_in + self.prompt_manager.render('in')
518 prompt = self.separate_in + self.prompt_manager.render('in')
519 except Exception:
519 except Exception:
520 self.showtraceback()
520 self.showtraceback()
521
521
522 line = self.raw_input(prompt)
522 line = self.raw_input(prompt)
523 if self.exit_now:
523 if self.exit_now:
524 # quick exit on sys.std[in|out] close
524 # quick exit on sys.std[in|out] close
525 break
525 break
526 if self.autoindent:
526 if self.autoindent:
527 self.rl_do_indent = False
527 self.rl_do_indent = False
528
528
529 except KeyboardInterrupt:
529 except KeyboardInterrupt:
530 #double-guard against keyboardinterrupts during kbdint handling
530 #double-guard against keyboardinterrupts during kbdint handling
531 try:
531 try:
532 self.write('\n' + self.get_exception_only())
532 self.write('\n' + self.get_exception_only())
533 source_raw = self.input_splitter.raw_reset()
533 source_raw = self.input_splitter.raw_reset()
534 hlen_b4_cell = self._replace_rlhist_multiline(source_raw, hlen_b4_cell)
534 hlen_b4_cell = self._replace_rlhist_multiline(source_raw, hlen_b4_cell)
535 more = False
535 more = False
536 except KeyboardInterrupt:
536 except KeyboardInterrupt:
537 pass
537 pass
538 except EOFError:
538 except EOFError:
539 if self.autoindent:
539 if self.autoindent:
540 self.rl_do_indent = False
540 self.rl_do_indent = False
541 if self.has_readline:
541 if self.has_readline:
542 self.readline_startup_hook(None)
542 self.readline_startup_hook(None)
543 self.write('\n')
543 self.write('\n')
544 self.exit()
544 self.exit()
545 except bdb.BdbQuit:
545 except bdb.BdbQuit:
546 warn('The Python debugger has exited with a BdbQuit exception.\n'
546 warn('The Python debugger has exited with a BdbQuit exception.\n'
547 'Because of how pdb handles the stack, it is impossible\n'
547 'Because of how pdb handles the stack, it is impossible\n'
548 'for IPython to properly format this particular exception.\n'
548 'for IPython to properly format this particular exception.\n'
549 'IPython will resume normal operation.')
549 'IPython will resume normal operation.')
550 except:
550 except:
551 # exceptions here are VERY RARE, but they can be triggered
551 # exceptions here are VERY RARE, but they can be triggered
552 # asynchronously by signal handlers, for example.
552 # asynchronously by signal handlers, for example.
553 self.showtraceback()
553 self.showtraceback()
554 else:
554 else:
555 try:
555 try:
556 self.input_splitter.push(line)
556 self.input_splitter.push(line)
557 more = self.input_splitter.push_accepts_more()
557 more = self.input_splitter.push_accepts_more()
558 except SyntaxError:
558 except SyntaxError:
559 # Run the code directly - run_cell takes care of displaying
559 # Run the code directly - run_cell takes care of displaying
560 # the exception.
560 # the exception.
561 more = False
561 more = False
562 if (self.SyntaxTB.last_syntax_error and
562 if (self.SyntaxTB.last_syntax_error and
563 self.autoedit_syntax):
563 self.autoedit_syntax):
564 self.edit_syntax_error()
564 self.edit_syntax_error()
565 if not more:
565 if not more:
566 source_raw = self.input_splitter.raw_reset()
566 source_raw = self.input_splitter.raw_reset()
567 hlen_b4_cell = self._replace_rlhist_multiline(source_raw, hlen_b4_cell)
567 hlen_b4_cell = self._replace_rlhist_multiline(source_raw, hlen_b4_cell)
568 self.run_cell(source_raw)
568 self.run_cell(source_raw)
569
569
570
570
571 # Turn off the exit flag, so the mainloop can be restarted if desired
571 # Turn off the exit flag, so the mainloop can be restarted if desired
572 self.exit_now = False
572 self.exit_now = False
573
573
574 def init_history(self):
574 def init_history(self):
575 """Sets up the command history. """
575 """Sets up the command history. """
576 self.history_manager = ZMQHistoryManager(client=self.client)
576 self.history_manager = ZMQHistoryManager(client=self.client)
577 self.configurables.append(self.history_manager)
577 self.configurables.append(self.history_manager)
578
578
@@ -1,1468 +1,1468 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """Tests for IPython.utils.traitlets."""
2 """Tests for IPython.utils.traitlets."""
3
3
4 # Copyright (c) IPython Development Team.
4 # Copyright (c) IPython Development Team.
5 # Distributed under the terms of the Modified BSD License.
5 # Distributed under the terms of the Modified BSD License.
6 #
6 #
7 # Adapted from enthought.traits, Copyright (c) Enthought, Inc.,
7 # Adapted from enthought.traits, Copyright (c) Enthought, Inc.,
8 # also under the terms of the Modified BSD License.
8 # also under the terms of the Modified BSD License.
9
9
10 import pickle
10 import pickle
11 import re
11 import re
12 import sys
12 import sys
13 from unittest import TestCase
13 from unittest import TestCase
14
14
15 import nose.tools as nt
15 import nose.tools as nt
16 from nose import SkipTest
16 from nose import SkipTest
17
17
18 from IPython.utils.traitlets import (
18 from IPython.utils.traitlets import (
19 HasTraits, MetaHasTraits, TraitType, Any, Bool, CBytes, Dict, Enum,
19 HasTraits, MetaHasTraits, TraitType, Any, Bool, CBytes, Dict, Enum,
20 Int, Long, Integer, Float, Complex, Bytes, Unicode, TraitError,
20 Int, Long, Integer, Float, Complex, Bytes, Unicode, TraitError,
21 Union, Undefined, Type, This, Instance, TCPAddress, List, Tuple,
21 Union, Undefined, Type, This, Instance, TCPAddress, List, Tuple,
22 ObjectName, DottedObjectName, CRegExp, link, directional_link,
22 ObjectName, DottedObjectName, CRegExp, link, directional_link,
23 EventfulList, EventfulDict, ForwardDeclaredType, ForwardDeclaredInstance,
23 EventfulList, EventfulDict, ForwardDeclaredType, ForwardDeclaredInstance,
24 )
24 )
25 from IPython.utils import py3compat
25 from IPython.utils import py3compat
26 from IPython.testing.decorators import skipif
26 from IPython.testing.decorators import skipif
27
27
28 #-----------------------------------------------------------------------------
28 #-----------------------------------------------------------------------------
29 # Helper classes for testing
29 # Helper classes for testing
30 #-----------------------------------------------------------------------------
30 #-----------------------------------------------------------------------------
31
31
32
32
33 class HasTraitsStub(HasTraits):
33 class HasTraitsStub(HasTraits):
34
34
35 def _notify_trait(self, name, old, new):
35 def _notify_trait(self, name, old, new):
36 self._notify_name = name
36 self._notify_name = name
37 self._notify_old = old
37 self._notify_old = old
38 self._notify_new = new
38 self._notify_new = new
39
39
40
40
41 #-----------------------------------------------------------------------------
41 #-----------------------------------------------------------------------------
42 # Test classes
42 # Test classes
43 #-----------------------------------------------------------------------------
43 #-----------------------------------------------------------------------------
44
44
45
45
46 class TestTraitType(TestCase):
46 class TestTraitType(TestCase):
47
47
48 def test_get_undefined(self):
48 def test_get_undefined(self):
49 class A(HasTraits):
49 class A(HasTraits):
50 a = TraitType
50 a = TraitType
51 a = A()
51 a = A()
52 self.assertEqual(a.a, Undefined)
52 self.assertEqual(a.a, Undefined)
53
53
54 def test_set(self):
54 def test_set(self):
55 class A(HasTraitsStub):
55 class A(HasTraitsStub):
56 a = TraitType
56 a = TraitType
57
57
58 a = A()
58 a = A()
59 a.a = 10
59 a.a = 10
60 self.assertEqual(a.a, 10)
60 self.assertEqual(a.a, 10)
61 self.assertEqual(a._notify_name, 'a')
61 self.assertEqual(a._notify_name, 'a')
62 self.assertEqual(a._notify_old, Undefined)
62 self.assertEqual(a._notify_old, Undefined)
63 self.assertEqual(a._notify_new, 10)
63 self.assertEqual(a._notify_new, 10)
64
64
65 def test_validate(self):
65 def test_validate(self):
66 class MyTT(TraitType):
66 class MyTT(TraitType):
67 def validate(self, inst, value):
67 def validate(self, inst, value):
68 return -1
68 return -1
69 class A(HasTraitsStub):
69 class A(HasTraitsStub):
70 tt = MyTT
70 tt = MyTT
71
71
72 a = A()
72 a = A()
73 a.tt = 10
73 a.tt = 10
74 self.assertEqual(a.tt, -1)
74 self.assertEqual(a.tt, -1)
75
75
76 def test_default_validate(self):
76 def test_default_validate(self):
77 class MyIntTT(TraitType):
77 class MyIntTT(TraitType):
78 def validate(self, obj, value):
78 def validate(self, obj, value):
79 if isinstance(value, int):
79 if isinstance(value, int):
80 return value
80 return value
81 self.error(obj, value)
81 self.error(obj, value)
82 class A(HasTraits):
82 class A(HasTraits):
83 tt = MyIntTT(10)
83 tt = MyIntTT(10)
84 a = A()
84 a = A()
85 self.assertEqual(a.tt, 10)
85 self.assertEqual(a.tt, 10)
86
86
87 # Defaults are validated when the HasTraits is instantiated
87 # Defaults are validated when the HasTraits is instantiated
88 class B(HasTraits):
88 class B(HasTraits):
89 tt = MyIntTT('bad default')
89 tt = MyIntTT('bad default')
90 self.assertRaises(TraitError, B)
90 self.assertRaises(TraitError, B)
91
91
92 def test_info(self):
92 def test_info(self):
93 class A(HasTraits):
93 class A(HasTraits):
94 tt = TraitType
94 tt = TraitType
95 a = A()
95 a = A()
96 self.assertEqual(A.tt.info(), 'any value')
96 self.assertEqual(A.tt.info(), 'any value')
97
97
98 def test_error(self):
98 def test_error(self):
99 class A(HasTraits):
99 class A(HasTraits):
100 tt = TraitType
100 tt = TraitType
101 a = A()
101 a = A()
102 self.assertRaises(TraitError, A.tt.error, a, 10)
102 self.assertRaises(TraitError, A.tt.error, a, 10)
103
103
104 def test_dynamic_initializer(self):
104 def test_dynamic_initializer(self):
105 class A(HasTraits):
105 class A(HasTraits):
106 x = Int(10)
106 x = Int(10)
107 def _x_default(self):
107 def _x_default(self):
108 return 11
108 return 11
109 class B(A):
109 class B(A):
110 x = Int(20)
110 x = Int(20)
111 class C(A):
111 class C(A):
112 def _x_default(self):
112 def _x_default(self):
113 return 21
113 return 21
114
114
115 a = A()
115 a = A()
116 self.assertEqual(a._trait_values, {})
116 self.assertEqual(a._trait_values, {})
117 self.assertEqual(list(a._trait_dyn_inits.keys()), ['x'])
117 self.assertEqual(list(a._trait_dyn_inits.keys()), ['x'])
118 self.assertEqual(a.x, 11)
118 self.assertEqual(a.x, 11)
119 self.assertEqual(a._trait_values, {'x': 11})
119 self.assertEqual(a._trait_values, {'x': 11})
120 b = B()
120 b = B()
121 self.assertEqual(b._trait_values, {'x': 20})
121 self.assertEqual(b._trait_values, {'x': 20})
122 self.assertEqual(list(a._trait_dyn_inits.keys()), ['x'])
122 self.assertEqual(list(a._trait_dyn_inits.keys()), ['x'])
123 self.assertEqual(b.x, 20)
123 self.assertEqual(b.x, 20)
124 c = C()
124 c = C()
125 self.assertEqual(c._trait_values, {})
125 self.assertEqual(c._trait_values, {})
126 self.assertEqual(list(a._trait_dyn_inits.keys()), ['x'])
126 self.assertEqual(list(a._trait_dyn_inits.keys()), ['x'])
127 self.assertEqual(c.x, 21)
127 self.assertEqual(c.x, 21)
128 self.assertEqual(c._trait_values, {'x': 21})
128 self.assertEqual(c._trait_values, {'x': 21})
129 # Ensure that the base class remains unmolested when the _default
129 # Ensure that the base class remains unmolested when the _default
130 # initializer gets overridden in a subclass.
130 # initializer gets overridden in a subclass.
131 a = A()
131 a = A()
132 c = C()
132 c = C()
133 self.assertEqual(a._trait_values, {})
133 self.assertEqual(a._trait_values, {})
134 self.assertEqual(list(a._trait_dyn_inits.keys()), ['x'])
134 self.assertEqual(list(a._trait_dyn_inits.keys()), ['x'])
135 self.assertEqual(a.x, 11)
135 self.assertEqual(a.x, 11)
136 self.assertEqual(a._trait_values, {'x': 11})
136 self.assertEqual(a._trait_values, {'x': 11})
137
137
138
138
139
139
140 class TestHasTraitsMeta(TestCase):
140 class TestHasTraitsMeta(TestCase):
141
141
142 def test_metaclass(self):
142 def test_metaclass(self):
143 self.assertEqual(type(HasTraits), MetaHasTraits)
143 self.assertEqual(type(HasTraits), MetaHasTraits)
144
144
145 class A(HasTraits):
145 class A(HasTraits):
146 a = Int
146 a = Int
147
147
148 a = A()
148 a = A()
149 self.assertEqual(type(a.__class__), MetaHasTraits)
149 self.assertEqual(type(a.__class__), MetaHasTraits)
150 self.assertEqual(a.a,0)
150 self.assertEqual(a.a,0)
151 a.a = 10
151 a.a = 10
152 self.assertEqual(a.a,10)
152 self.assertEqual(a.a,10)
153
153
154 class B(HasTraits):
154 class B(HasTraits):
155 b = Int()
155 b = Int()
156
156
157 b = B()
157 b = B()
158 self.assertEqual(b.b,0)
158 self.assertEqual(b.b,0)
159 b.b = 10
159 b.b = 10
160 self.assertEqual(b.b,10)
160 self.assertEqual(b.b,10)
161
161
162 class C(HasTraits):
162 class C(HasTraits):
163 c = Int(30)
163 c = Int(30)
164
164
165 c = C()
165 c = C()
166 self.assertEqual(c.c,30)
166 self.assertEqual(c.c,30)
167 c.c = 10
167 c.c = 10
168 self.assertEqual(c.c,10)
168 self.assertEqual(c.c,10)
169
169
170 def test_this_class(self):
170 def test_this_class(self):
171 class A(HasTraits):
171 class A(HasTraits):
172 t = This()
172 t = This()
173 tt = This()
173 tt = This()
174 class B(A):
174 class B(A):
175 tt = This()
175 tt = This()
176 ttt = This()
176 ttt = This()
177 self.assertEqual(A.t.this_class, A)
177 self.assertEqual(A.t.this_class, A)
178 self.assertEqual(B.t.this_class, A)
178 self.assertEqual(B.t.this_class, A)
179 self.assertEqual(B.tt.this_class, B)
179 self.assertEqual(B.tt.this_class, B)
180 self.assertEqual(B.ttt.this_class, B)
180 self.assertEqual(B.ttt.this_class, B)
181
181
182 class TestHasTraitsNotify(TestCase):
182 class TestHasTraitsNotify(TestCase):
183
183
184 def setUp(self):
184 def setUp(self):
185 self._notify1 = []
185 self._notify1 = []
186 self._notify2 = []
186 self._notify2 = []
187
187
188 def notify1(self, name, old, new):
188 def notify1(self, name, old, new):
189 self._notify1.append((name, old, new))
189 self._notify1.append((name, old, new))
190
190
191 def notify2(self, name, old, new):
191 def notify2(self, name, old, new):
192 self._notify2.append((name, old, new))
192 self._notify2.append((name, old, new))
193
193
194 def test_notify_all(self):
194 def test_notify_all(self):
195
195
196 class A(HasTraits):
196 class A(HasTraits):
197 a = Int
197 a = Int
198 b = Float
198 b = Float
199
199
200 a = A()
200 a = A()
201 a.on_trait_change(self.notify1)
201 a.on_trait_change(self.notify1)
202 a.a = 0
202 a.a = 0
203 self.assertEqual(len(self._notify1),0)
203 self.assertEqual(len(self._notify1),0)
204 a.b = 0.0
204 a.b = 0.0
205 self.assertEqual(len(self._notify1),0)
205 self.assertEqual(len(self._notify1),0)
206 a.a = 10
206 a.a = 10
207 self.assertTrue(('a',0,10) in self._notify1)
207 self.assertTrue(('a',0,10) in self._notify1)
208 a.b = 10.0
208 a.b = 10.0
209 self.assertTrue(('b',0.0,10.0) in self._notify1)
209 self.assertTrue(('b',0.0,10.0) in self._notify1)
210 self.assertRaises(TraitError,setattr,a,'a','bad string')
210 self.assertRaises(TraitError,setattr,a,'a','bad string')
211 self.assertRaises(TraitError,setattr,a,'b','bad string')
211 self.assertRaises(TraitError,setattr,a,'b','bad string')
212 self._notify1 = []
212 self._notify1 = []
213 a.on_trait_change(self.notify1,remove=True)
213 a.on_trait_change(self.notify1,remove=True)
214 a.a = 20
214 a.a = 20
215 a.b = 20.0
215 a.b = 20.0
216 self.assertEqual(len(self._notify1),0)
216 self.assertEqual(len(self._notify1),0)
217
217
218 def test_notify_one(self):
218 def test_notify_one(self):
219
219
220 class A(HasTraits):
220 class A(HasTraits):
221 a = Int
221 a = Int
222 b = Float
222 b = Float
223
223
224 a = A()
224 a = A()
225 a.on_trait_change(self.notify1, 'a')
225 a.on_trait_change(self.notify1, 'a')
226 a.a = 0
226 a.a = 0
227 self.assertEqual(len(self._notify1),0)
227 self.assertEqual(len(self._notify1),0)
228 a.a = 10
228 a.a = 10
229 self.assertTrue(('a',0,10) in self._notify1)
229 self.assertTrue(('a',0,10) in self._notify1)
230 self.assertRaises(TraitError,setattr,a,'a','bad string')
230 self.assertRaises(TraitError,setattr,a,'a','bad string')
231
231
232 def test_subclass(self):
232 def test_subclass(self):
233
233
234 class A(HasTraits):
234 class A(HasTraits):
235 a = Int
235 a = Int
236
236
237 class B(A):
237 class B(A):
238 b = Float
238 b = Float
239
239
240 b = B()
240 b = B()
241 self.assertEqual(b.a,0)
241 self.assertEqual(b.a,0)
242 self.assertEqual(b.b,0.0)
242 self.assertEqual(b.b,0.0)
243 b.a = 100
243 b.a = 100
244 b.b = 100.0
244 b.b = 100.0
245 self.assertEqual(b.a,100)
245 self.assertEqual(b.a,100)
246 self.assertEqual(b.b,100.0)
246 self.assertEqual(b.b,100.0)
247
247
248 def test_notify_subclass(self):
248 def test_notify_subclass(self):
249
249
250 class A(HasTraits):
250 class A(HasTraits):
251 a = Int
251 a = Int
252
252
253 class B(A):
253 class B(A):
254 b = Float
254 b = Float
255
255
256 b = B()
256 b = B()
257 b.on_trait_change(self.notify1, 'a')
257 b.on_trait_change(self.notify1, 'a')
258 b.on_trait_change(self.notify2, 'b')
258 b.on_trait_change(self.notify2, 'b')
259 b.a = 0
259 b.a = 0
260 b.b = 0.0
260 b.b = 0.0
261 self.assertEqual(len(self._notify1),0)
261 self.assertEqual(len(self._notify1),0)
262 self.assertEqual(len(self._notify2),0)
262 self.assertEqual(len(self._notify2),0)
263 b.a = 10
263 b.a = 10
264 b.b = 10.0
264 b.b = 10.0
265 self.assertTrue(('a',0,10) in self._notify1)
265 self.assertTrue(('a',0,10) in self._notify1)
266 self.assertTrue(('b',0.0,10.0) in self._notify2)
266 self.assertTrue(('b',0.0,10.0) in self._notify2)
267
267
268 def test_static_notify(self):
268 def test_static_notify(self):
269
269
270 class A(HasTraits):
270 class A(HasTraits):
271 a = Int
271 a = Int
272 _notify1 = []
272 _notify1 = []
273 def _a_changed(self, name, old, new):
273 def _a_changed(self, name, old, new):
274 self._notify1.append((name, old, new))
274 self._notify1.append((name, old, new))
275
275
276 a = A()
276 a = A()
277 a.a = 0
277 a.a = 0
278 # This is broken!!!
278 # This is broken!!!
279 self.assertEqual(len(a._notify1),0)
279 self.assertEqual(len(a._notify1),0)
280 a.a = 10
280 a.a = 10
281 self.assertTrue(('a',0,10) in a._notify1)
281 self.assertTrue(('a',0,10) in a._notify1)
282
282
283 class B(A):
283 class B(A):
284 b = Float
284 b = Float
285 _notify2 = []
285 _notify2 = []
286 def _b_changed(self, name, old, new):
286 def _b_changed(self, name, old, new):
287 self._notify2.append((name, old, new))
287 self._notify2.append((name, old, new))
288
288
289 b = B()
289 b = B()
290 b.a = 10
290 b.a = 10
291 b.b = 10.0
291 b.b = 10.0
292 self.assertTrue(('a',0,10) in b._notify1)
292 self.assertTrue(('a',0,10) in b._notify1)
293 self.assertTrue(('b',0.0,10.0) in b._notify2)
293 self.assertTrue(('b',0.0,10.0) in b._notify2)
294
294
295 def test_notify_args(self):
295 def test_notify_args(self):
296
296
297 def callback0():
297 def callback0():
298 self.cb = ()
298 self.cb = ()
299 def callback1(name):
299 def callback1(name):
300 self.cb = (name,)
300 self.cb = (name,)
301 def callback2(name, new):
301 def callback2(name, new):
302 self.cb = (name, new)
302 self.cb = (name, new)
303 def callback3(name, old, new):
303 def callback3(name, old, new):
304 self.cb = (name, old, new)
304 self.cb = (name, old, new)
305
305
306 class A(HasTraits):
306 class A(HasTraits):
307 a = Int
307 a = Int
308
308
309 a = A()
309 a = A()
310 a.on_trait_change(callback0, 'a')
310 a.on_trait_change(callback0, 'a')
311 a.a = 10
311 a.a = 10
312 self.assertEqual(self.cb,())
312 self.assertEqual(self.cb,())
313 a.on_trait_change(callback0, 'a', remove=True)
313 a.on_trait_change(callback0, 'a', remove=True)
314
314
315 a.on_trait_change(callback1, 'a')
315 a.on_trait_change(callback1, 'a')
316 a.a = 100
316 a.a = 100
317 self.assertEqual(self.cb,('a',))
317 self.assertEqual(self.cb,('a',))
318 a.on_trait_change(callback1, 'a', remove=True)
318 a.on_trait_change(callback1, 'a', remove=True)
319
319
320 a.on_trait_change(callback2, 'a')
320 a.on_trait_change(callback2, 'a')
321 a.a = 1000
321 a.a = 1000
322 self.assertEqual(self.cb,('a',1000))
322 self.assertEqual(self.cb,('a',1000))
323 a.on_trait_change(callback2, 'a', remove=True)
323 a.on_trait_change(callback2, 'a', remove=True)
324
324
325 a.on_trait_change(callback3, 'a')
325 a.on_trait_change(callback3, 'a')
326 a.a = 10000
326 a.a = 10000
327 self.assertEqual(self.cb,('a',1000,10000))
327 self.assertEqual(self.cb,('a',1000,10000))
328 a.on_trait_change(callback3, 'a', remove=True)
328 a.on_trait_change(callback3, 'a', remove=True)
329
329
330 self.assertEqual(len(a._trait_notifiers['a']),0)
330 self.assertEqual(len(a._trait_notifiers['a']),0)
331
331
332 def test_notify_only_once(self):
332 def test_notify_only_once(self):
333
333
334 class A(HasTraits):
334 class A(HasTraits):
335 listen_to = ['a']
335 listen_to = ['a']
336
336
337 a = Int(0)
337 a = Int(0)
338 b = 0
338 b = 0
339
339
340 def __init__(self, **kwargs):
340 def __init__(self, **kwargs):
341 super(A, self).__init__(**kwargs)
341 super(A, self).__init__(**kwargs)
342 self.on_trait_change(self.listener1, ['a'])
342 self.on_trait_change(self.listener1, ['a'])
343
343
344 def listener1(self, name, old, new):
344 def listener1(self, name, old, new):
345 self.b += 1
345 self.b += 1
346
346
347 class B(A):
347 class B(A):
348
348
349 c = 0
349 c = 0
350 d = 0
350 d = 0
351
351
352 def __init__(self, **kwargs):
352 def __init__(self, **kwargs):
353 super(B, self).__init__(**kwargs)
353 super(B, self).__init__(**kwargs)
354 self.on_trait_change(self.listener2)
354 self.on_trait_change(self.listener2)
355
355
356 def listener2(self, name, old, new):
356 def listener2(self, name, old, new):
357 self.c += 1
357 self.c += 1
358
358
359 def _a_changed(self, name, old, new):
359 def _a_changed(self, name, old, new):
360 self.d += 1
360 self.d += 1
361
361
362 b = B()
362 b = B()
363 b.a += 1
363 b.a += 1
364 self.assertEqual(b.b, b.c)
364 self.assertEqual(b.b, b.c)
365 self.assertEqual(b.b, b.d)
365 self.assertEqual(b.b, b.d)
366 b.a += 1
366 b.a += 1
367 self.assertEqual(b.b, b.c)
367 self.assertEqual(b.b, b.c)
368 self.assertEqual(b.b, b.d)
368 self.assertEqual(b.b, b.d)
369
369
370
370
371 class TestHasTraits(TestCase):
371 class TestHasTraits(TestCase):
372
372
373 def test_trait_names(self):
373 def test_trait_names(self):
374 class A(HasTraits):
374 class A(HasTraits):
375 i = Int
375 i = Int
376 f = Float
376 f = Float
377 a = A()
377 a = A()
378 self.assertEqual(sorted(a.trait_names()),['f','i'])
378 self.assertEqual(sorted(a.trait_names()),['f','i'])
379 self.assertEqual(sorted(A.class_trait_names()),['f','i'])
379 self.assertEqual(sorted(A.class_trait_names()),['f','i'])
380
380
381 def test_trait_metadata(self):
381 def test_trait_metadata(self):
382 class A(HasTraits):
382 class A(HasTraits):
383 i = Int(config_key='MY_VALUE')
383 i = Int(config_key='MY_VALUE')
384 a = A()
384 a = A()
385 self.assertEqual(a.trait_metadata('i','config_key'), 'MY_VALUE')
385 self.assertEqual(a.trait_metadata('i','config_key'), 'MY_VALUE')
386
386
387 def test_trait_metadata_default(self):
387 def test_trait_metadata_default(self):
388 class A(HasTraits):
388 class A(HasTraits):
389 i = Int()
389 i = Int()
390 a = A()
390 a = A()
391 self.assertEqual(a.trait_metadata('i', 'config_key'), None)
391 self.assertEqual(a.trait_metadata('i', 'config_key'), None)
392 self.assertEqual(a.trait_metadata('i', 'config_key', 'default'), 'default')
392 self.assertEqual(a.trait_metadata('i', 'config_key', 'default'), 'default')
393
393
394 def test_traits(self):
394 def test_traits(self):
395 class A(HasTraits):
395 class A(HasTraits):
396 i = Int
396 i = Int
397 f = Float
397 f = Float
398 a = A()
398 a = A()
399 self.assertEqual(a.traits(), dict(i=A.i, f=A.f))
399 self.assertEqual(a.traits(), dict(i=A.i, f=A.f))
400 self.assertEqual(A.class_traits(), dict(i=A.i, f=A.f))
400 self.assertEqual(A.class_traits(), dict(i=A.i, f=A.f))
401
401
402 def test_traits_metadata(self):
402 def test_traits_metadata(self):
403 class A(HasTraits):
403 class A(HasTraits):
404 i = Int(config_key='VALUE1', other_thing='VALUE2')
404 i = Int(config_key='VALUE1', other_thing='VALUE2')
405 f = Float(config_key='VALUE3', other_thing='VALUE2')
405 f = Float(config_key='VALUE3', other_thing='VALUE2')
406 j = Int(0)
406 j = Int(0)
407 a = A()
407 a = A()
408 self.assertEqual(a.traits(), dict(i=A.i, f=A.f, j=A.j))
408 self.assertEqual(a.traits(), dict(i=A.i, f=A.f, j=A.j))
409 traits = a.traits(config_key='VALUE1', other_thing='VALUE2')
409 traits = a.traits(config_key='VALUE1', other_thing='VALUE2')
410 self.assertEqual(traits, dict(i=A.i))
410 self.assertEqual(traits, dict(i=A.i))
411
411
412 # This passes, but it shouldn't because I am replicating a bug in
412 # This passes, but it shouldn't because I am replicating a bug in
413 # traits.
413 # traits.
414 traits = a.traits(config_key=lambda v: True)
414 traits = a.traits(config_key=lambda v: True)
415 self.assertEqual(traits, dict(i=A.i, f=A.f, j=A.j))
415 self.assertEqual(traits, dict(i=A.i, f=A.f, j=A.j))
416
416
417 def test_init(self):
417 def test_init(self):
418 class A(HasTraits):
418 class A(HasTraits):
419 i = Int()
419 i = Int()
420 x = Float()
420 x = Float()
421 a = A(i=1, x=10.0)
421 a = A(i=1, x=10.0)
422 self.assertEqual(a.i, 1)
422 self.assertEqual(a.i, 1)
423 self.assertEqual(a.x, 10.0)
423 self.assertEqual(a.x, 10.0)
424
424
425 def test_positional_args(self):
425 def test_positional_args(self):
426 class A(HasTraits):
426 class A(HasTraits):
427 i = Int(0)
427 i = Int(0)
428 def __init__(self, i):
428 def __init__(self, i):
429 super(A, self).__init__()
429 super(A, self).__init__()
430 self.i = i
430 self.i = i
431
431
432 a = A(5)
432 a = A(5)
433 self.assertEqual(a.i, 5)
433 self.assertEqual(a.i, 5)
434 # should raise TypeError if no positional arg given
434 # should raise TypeError if no positional arg given
435 self.assertRaises(TypeError, A)
435 self.assertRaises(TypeError, A)
436
436
437 #-----------------------------------------------------------------------------
437 #-----------------------------------------------------------------------------
438 # Tests for specific trait types
438 # Tests for specific trait types
439 #-----------------------------------------------------------------------------
439 #-----------------------------------------------------------------------------
440
440
441
441
442 class TestType(TestCase):
442 class TestType(TestCase):
443
443
444 def test_default(self):
444 def test_default(self):
445
445
446 class B(object): pass
446 class B(object): pass
447 class A(HasTraits):
447 class A(HasTraits):
448 klass = Type
448 klass = Type
449
449
450 a = A()
450 a = A()
451 self.assertEqual(a.klass, None)
451 self.assertEqual(a.klass, None)
452
452
453 a.klass = B
453 a.klass = B
454 self.assertEqual(a.klass, B)
454 self.assertEqual(a.klass, B)
455 self.assertRaises(TraitError, setattr, a, 'klass', 10)
455 self.assertRaises(TraitError, setattr, a, 'klass', 10)
456
456
457 def test_value(self):
457 def test_value(self):
458
458
459 class B(object): pass
459 class B(object): pass
460 class C(object): pass
460 class C(object): pass
461 class A(HasTraits):
461 class A(HasTraits):
462 klass = Type(B)
462 klass = Type(B)
463
463
464 a = A()
464 a = A()
465 self.assertEqual(a.klass, B)
465 self.assertEqual(a.klass, B)
466 self.assertRaises(TraitError, setattr, a, 'klass', C)
466 self.assertRaises(TraitError, setattr, a, 'klass', C)
467 self.assertRaises(TraitError, setattr, a, 'klass', object)
467 self.assertRaises(TraitError, setattr, a, 'klass', object)
468 a.klass = B
468 a.klass = B
469
469
470 def test_allow_none(self):
470 def test_allow_none(self):
471
471
472 class B(object): pass
472 class B(object): pass
473 class C(B): pass
473 class C(B): pass
474 class A(HasTraits):
474 class A(HasTraits):
475 klass = Type(B, allow_none=False)
475 klass = Type(B, allow_none=False)
476
476
477 a = A()
477 a = A()
478 self.assertEqual(a.klass, B)
478 self.assertEqual(a.klass, B)
479 self.assertRaises(TraitError, setattr, a, 'klass', None)
479 self.assertRaises(TraitError, setattr, a, 'klass', None)
480 a.klass = C
480 a.klass = C
481 self.assertEqual(a.klass, C)
481 self.assertEqual(a.klass, C)
482
482
483 def test_validate_klass(self):
483 def test_validate_klass(self):
484
484
485 class A(HasTraits):
485 class A(HasTraits):
486 klass = Type('no strings allowed')
486 klass = Type('no strings allowed')
487
487
488 self.assertRaises(ImportError, A)
488 self.assertRaises(ImportError, A)
489
489
490 class A(HasTraits):
490 class A(HasTraits):
491 klass = Type('rub.adub.Duck')
491 klass = Type('rub.adub.Duck')
492
492
493 self.assertRaises(ImportError, A)
493 self.assertRaises(ImportError, A)
494
494
495 def test_validate_default(self):
495 def test_validate_default(self):
496
496
497 class B(object): pass
497 class B(object): pass
498 class A(HasTraits):
498 class A(HasTraits):
499 klass = Type('bad default', B)
499 klass = Type('bad default', B)
500
500
501 self.assertRaises(ImportError, A)
501 self.assertRaises(ImportError, A)
502
502
503 class C(HasTraits):
503 class C(HasTraits):
504 klass = Type(None, B, allow_none=False)
504 klass = Type(None, B, allow_none=False)
505
505
506 self.assertRaises(TraitError, C)
506 self.assertRaises(TraitError, C)
507
507
508 def test_str_klass(self):
508 def test_str_klass(self):
509
509
510 class A(HasTraits):
510 class A(HasTraits):
511 klass = Type('IPython.utils.ipstruct.Struct')
511 klass = Type('IPython.utils.ipstruct.Struct')
512
512
513 from IPython.utils.ipstruct import Struct
513 from IPython.utils.ipstruct import Struct
514 a = A()
514 a = A()
515 a.klass = Struct
515 a.klass = Struct
516 self.assertEqual(a.klass, Struct)
516 self.assertEqual(a.klass, Struct)
517
517
518 self.assertRaises(TraitError, setattr, a, 'klass', 10)
518 self.assertRaises(TraitError, setattr, a, 'klass', 10)
519
519
520 def test_set_str_klass(self):
520 def test_set_str_klass(self):
521
521
522 class A(HasTraits):
522 class A(HasTraits):
523 klass = Type()
523 klass = Type()
524
524
525 a = A(klass='IPython.utils.ipstruct.Struct')
525 a = A(klass='IPython.utils.ipstruct.Struct')
526 from IPython.utils.ipstruct import Struct
526 from IPython.utils.ipstruct import Struct
527 self.assertEqual(a.klass, Struct)
527 self.assertEqual(a.klass, Struct)
528
528
529 class TestInstance(TestCase):
529 class TestInstance(TestCase):
530
530
531 def test_basic(self):
531 def test_basic(self):
532 class Foo(object): pass
532 class Foo(object): pass
533 class Bar(Foo): pass
533 class Bar(Foo): pass
534 class Bah(object): pass
534 class Bah(object): pass
535
535
536 class A(HasTraits):
536 class A(HasTraits):
537 inst = Instance(Foo)
537 inst = Instance(Foo)
538
538
539 a = A()
539 a = A()
540 self.assertTrue(a.inst is None)
540 self.assertTrue(a.inst is None)
541 a.inst = Foo()
541 a.inst = Foo()
542 self.assertTrue(isinstance(a.inst, Foo))
542 self.assertTrue(isinstance(a.inst, Foo))
543 a.inst = Bar()
543 a.inst = Bar()
544 self.assertTrue(isinstance(a.inst, Foo))
544 self.assertTrue(isinstance(a.inst, Foo))
545 self.assertRaises(TraitError, setattr, a, 'inst', Foo)
545 self.assertRaises(TraitError, setattr, a, 'inst', Foo)
546 self.assertRaises(TraitError, setattr, a, 'inst', Bar)
546 self.assertRaises(TraitError, setattr, a, 'inst', Bar)
547 self.assertRaises(TraitError, setattr, a, 'inst', Bah())
547 self.assertRaises(TraitError, setattr, a, 'inst', Bah())
548
548
549 def test_default_klass(self):
549 def test_default_klass(self):
550 class Foo(object): pass
550 class Foo(object): pass
551 class Bar(Foo): pass
551 class Bar(Foo): pass
552 class Bah(object): pass
552 class Bah(object): pass
553
553
554 class FooInstance(Instance):
554 class FooInstance(Instance):
555 klass = Foo
555 klass = Foo
556
556
557 class A(HasTraits):
557 class A(HasTraits):
558 inst = FooInstance()
558 inst = FooInstance()
559
559
560 a = A()
560 a = A()
561 self.assertTrue(a.inst is None)
561 self.assertTrue(a.inst is None)
562 a.inst = Foo()
562 a.inst = Foo()
563 self.assertTrue(isinstance(a.inst, Foo))
563 self.assertTrue(isinstance(a.inst, Foo))
564 a.inst = Bar()
564 a.inst = Bar()
565 self.assertTrue(isinstance(a.inst, Foo))
565 self.assertTrue(isinstance(a.inst, Foo))
566 self.assertRaises(TraitError, setattr, a, 'inst', Foo)
566 self.assertRaises(TraitError, setattr, a, 'inst', Foo)
567 self.assertRaises(TraitError, setattr, a, 'inst', Bar)
567 self.assertRaises(TraitError, setattr, a, 'inst', Bar)
568 self.assertRaises(TraitError, setattr, a, 'inst', Bah())
568 self.assertRaises(TraitError, setattr, a, 'inst', Bah())
569
569
570 def test_unique_default_value(self):
570 def test_unique_default_value(self):
571 class Foo(object): pass
571 class Foo(object): pass
572 class A(HasTraits):
572 class A(HasTraits):
573 inst = Instance(Foo,(),{})
573 inst = Instance(Foo,(),{})
574
574
575 a = A()
575 a = A()
576 b = A()
576 b = A()
577 self.assertTrue(a.inst is not b.inst)
577 self.assertTrue(a.inst is not b.inst)
578
578
579 def test_args_kw(self):
579 def test_args_kw(self):
580 class Foo(object):
580 class Foo(object):
581 def __init__(self, c): self.c = c
581 def __init__(self, c): self.c = c
582 class Bar(object): pass
582 class Bar(object): pass
583 class Bah(object):
583 class Bah(object):
584 def __init__(self, c, d):
584 def __init__(self, c, d):
585 self.c = c; self.d = d
585 self.c = c; self.d = d
586
586
587 class A(HasTraits):
587 class A(HasTraits):
588 inst = Instance(Foo, (10,))
588 inst = Instance(Foo, (10,))
589 a = A()
589 a = A()
590 self.assertEqual(a.inst.c, 10)
590 self.assertEqual(a.inst.c, 10)
591
591
592 class B(HasTraits):
592 class B(HasTraits):
593 inst = Instance(Bah, args=(10,), kw=dict(d=20))
593 inst = Instance(Bah, args=(10,), kw=dict(d=20))
594 b = B()
594 b = B()
595 self.assertEqual(b.inst.c, 10)
595 self.assertEqual(b.inst.c, 10)
596 self.assertEqual(b.inst.d, 20)
596 self.assertEqual(b.inst.d, 20)
597
597
598 class C(HasTraits):
598 class C(HasTraits):
599 inst = Instance(Foo)
599 inst = Instance(Foo)
600 c = C()
600 c = C()
601 self.assertTrue(c.inst is None)
601 self.assertTrue(c.inst is None)
602
602
603 def test_bad_default(self):
603 def test_bad_default(self):
604 class Foo(object): pass
604 class Foo(object): pass
605
605
606 class A(HasTraits):
606 class A(HasTraits):
607 inst = Instance(Foo, allow_none=False)
607 inst = Instance(Foo, allow_none=False)
608
608
609 self.assertRaises(TraitError, A)
609 self.assertRaises(TraitError, A)
610
610
611 def test_instance(self):
611 def test_instance(self):
612 class Foo(object): pass
612 class Foo(object): pass
613
613
614 def inner():
614 def inner():
615 class A(HasTraits):
615 class A(HasTraits):
616 inst = Instance(Foo())
616 inst = Instance(Foo())
617
617
618 self.assertRaises(TraitError, inner)
618 self.assertRaises(TraitError, inner)
619
619
620
620
621 class TestThis(TestCase):
621 class TestThis(TestCase):
622
622
623 def test_this_class(self):
623 def test_this_class(self):
624 class Foo(HasTraits):
624 class Foo(HasTraits):
625 this = This
625 this = This
626
626
627 f = Foo()
627 f = Foo()
628 self.assertEqual(f.this, None)
628 self.assertEqual(f.this, None)
629 g = Foo()
629 g = Foo()
630 f.this = g
630 f.this = g
631 self.assertEqual(f.this, g)
631 self.assertEqual(f.this, g)
632 self.assertRaises(TraitError, setattr, f, 'this', 10)
632 self.assertRaises(TraitError, setattr, f, 'this', 10)
633
633
634 def test_this_inst(self):
634 def test_this_inst(self):
635 class Foo(HasTraits):
635 class Foo(HasTraits):
636 this = This()
636 this = This()
637
637
638 f = Foo()
638 f = Foo()
639 f.this = Foo()
639 f.this = Foo()
640 self.assertTrue(isinstance(f.this, Foo))
640 self.assertTrue(isinstance(f.this, Foo))
641
641
642 def test_subclass(self):
642 def test_subclass(self):
643 class Foo(HasTraits):
643 class Foo(HasTraits):
644 t = This()
644 t = This()
645 class Bar(Foo):
645 class Bar(Foo):
646 pass
646 pass
647 f = Foo()
647 f = Foo()
648 b = Bar()
648 b = Bar()
649 f.t = b
649 f.t = b
650 b.t = f
650 b.t = f
651 self.assertEqual(f.t, b)
651 self.assertEqual(f.t, b)
652 self.assertEqual(b.t, f)
652 self.assertEqual(b.t, f)
653
653
654 def test_subclass_override(self):
654 def test_subclass_override(self):
655 class Foo(HasTraits):
655 class Foo(HasTraits):
656 t = This()
656 t = This()
657 class Bar(Foo):
657 class Bar(Foo):
658 t = This()
658 t = This()
659 f = Foo()
659 f = Foo()
660 b = Bar()
660 b = Bar()
661 f.t = b
661 f.t = b
662 self.assertEqual(f.t, b)
662 self.assertEqual(f.t, b)
663 self.assertRaises(TraitError, setattr, b, 't', f)
663 self.assertRaises(TraitError, setattr, b, 't', f)
664
664
665 def test_this_in_container(self):
665 def test_this_in_container(self):
666
666
667 class Tree(HasTraits):
667 class Tree(HasTraits):
668 value = Unicode()
668 value = Unicode()
669 leaves = List(This())
669 leaves = List(This())
670
670
671 tree = Tree(
671 tree = Tree(
672 value='foo',
672 value='foo',
673 leaves=[Tree('bar'), Tree('buzz')]
673 leaves=[Tree('bar'), Tree('buzz')]
674 )
674 )
675
675
676 with self.assertRaises(TraitError):
676 with self.assertRaises(TraitError):
677 tree.leaves = [1, 2]
677 tree.leaves = [1, 2]
678
678
679 class TraitTestBase(TestCase):
679 class TraitTestBase(TestCase):
680 """A best testing class for basic trait types."""
680 """A best testing class for basic trait types."""
681
681
682 def assign(self, value):
682 def assign(self, value):
683 self.obj.value = value
683 self.obj.value = value
684
684
685 def coerce(self, value):
685 def coerce(self, value):
686 return value
686 return value
687
687
688 def test_good_values(self):
688 def test_good_values(self):
689 if hasattr(self, '_good_values'):
689 if hasattr(self, '_good_values'):
690 for value in self._good_values:
690 for value in self._good_values:
691 self.assign(value)
691 self.assign(value)
692 self.assertEqual(self.obj.value, self.coerce(value))
692 self.assertEqual(self.obj.value, self.coerce(value))
693
693
694 def test_bad_values(self):
694 def test_bad_values(self):
695 if hasattr(self, '_bad_values'):
695 if hasattr(self, '_bad_values'):
696 for value in self._bad_values:
696 for value in self._bad_values:
697 try:
697 try:
698 self.assertRaises(TraitError, self.assign, value)
698 self.assertRaises(TraitError, self.assign, value)
699 except AssertionError:
699 except AssertionError:
700 assert False, value
700 assert False, value
701
701
702 def test_default_value(self):
702 def test_default_value(self):
703 if hasattr(self, '_default_value'):
703 if hasattr(self, '_default_value'):
704 self.assertEqual(self._default_value, self.obj.value)
704 self.assertEqual(self._default_value, self.obj.value)
705
705
706 def test_allow_none(self):
706 def test_allow_none(self):
707 if (hasattr(self, '_bad_values') and hasattr(self, '_good_values') and
707 if (hasattr(self, '_bad_values') and hasattr(self, '_good_values') and
708 None in self._bad_values):
708 None in self._bad_values):
709 trait=self.obj.traits()['value']
709 trait=self.obj.traits()['value']
710 try:
710 try:
711 trait.allow_none = True
711 trait.allow_none = True
712 self._bad_values.remove(None)
712 self._bad_values.remove(None)
713 #skip coerce. Allow None casts None to None.
713 #skip coerce. Allow None casts None to None.
714 self.assign(None)
714 self.assign(None)
715 self.assertEqual(self.obj.value,None)
715 self.assertEqual(self.obj.value,None)
716 self.test_good_values()
716 self.test_good_values()
717 self.test_bad_values()
717 self.test_bad_values()
718 finally:
718 finally:
719 #tear down
719 #tear down
720 trait.allow_none = False
720 trait.allow_none = False
721 self._bad_values.append(None)
721 self._bad_values.append(None)
722
722
723 def tearDown(self):
723 def tearDown(self):
724 # restore default value after tests, if set
724 # restore default value after tests, if set
725 if hasattr(self, '_default_value'):
725 if hasattr(self, '_default_value'):
726 self.obj.value = self._default_value
726 self.obj.value = self._default_value
727
727
728
728
729 class AnyTrait(HasTraits):
729 class AnyTrait(HasTraits):
730
730
731 value = Any
731 value = Any
732
732
733 class AnyTraitTest(TraitTestBase):
733 class AnyTraitTest(TraitTestBase):
734
734
735 obj = AnyTrait()
735 obj = AnyTrait()
736
736
737 _default_value = None
737 _default_value = None
738 _good_values = [10.0, 'ten', u'ten', [10], {'ten': 10},(10,), None, 1j]
738 _good_values = [10.0, 'ten', u'ten', [10], {'ten': 10},(10,), None, 1j]
739 _bad_values = []
739 _bad_values = []
740
740
741 class UnionTrait(HasTraits):
741 class UnionTrait(HasTraits):
742
742
743 value = Union([Type(), Bool()])
743 value = Union([Type(), Bool()])
744
744
745 class UnionTraitTest(TraitTestBase):
745 class UnionTraitTest(TraitTestBase):
746
746
747 obj = UnionTrait(value='IPython.utils.ipstruct.Struct')
747 obj = UnionTrait(value='IPython.utils.ipstruct.Struct')
748 _good_values = [int, float, True]
748 _good_values = [int, float, True]
749 _bad_values = [[], (0,), 1j]
749 _bad_values = [[], (0,), 1j]
750
750
751 class OrTrait(HasTraits):
751 class OrTrait(HasTraits):
752
752
753 value = Bool() | Unicode()
753 value = Bool() | Unicode()
754
754
755 class OrTraitTest(TraitTestBase):
755 class OrTraitTest(TraitTestBase):
756
756
757 obj = OrTrait()
757 obj = OrTrait()
758 _good_values = [True, False, 'ten']
758 _good_values = [True, False, 'ten']
759 _bad_values = [[], (0,), 1j]
759 _bad_values = [[], (0,), 1j]
760
760
761 class IntTrait(HasTraits):
761 class IntTrait(HasTraits):
762
762
763 value = Int(99)
763 value = Int(99)
764
764
765 class TestInt(TraitTestBase):
765 class TestInt(TraitTestBase):
766
766
767 obj = IntTrait()
767 obj = IntTrait()
768 _default_value = 99
768 _default_value = 99
769 _good_values = [10, -10]
769 _good_values = [10, -10]
770 _bad_values = ['ten', u'ten', [10], {'ten': 10},(10,), None, 1j,
770 _bad_values = ['ten', u'ten', [10], {'ten': 10},(10,), None, 1j,
771 10.1, -10.1, '10L', '-10L', '10.1', '-10.1', u'10L',
771 10.1, -10.1, '10L', '-10L', '10.1', '-10.1', u'10L',
772 u'-10L', u'10.1', u'-10.1', '10', '-10', u'10', u'-10']
772 u'-10L', u'10.1', u'-10.1', '10', '-10', u'10', u'-10']
773 if not py3compat.PY3:
773 if not py3compat.PY3:
774 _bad_values.extend([long(10), long(-10), 10*sys.maxint, -10*sys.maxint])
774 _bad_values.extend([long(10), long(-10), 10*sys.maxint, -10*sys.maxint])
775
775
776
776
777 class LongTrait(HasTraits):
777 class LongTrait(HasTraits):
778
778
779 value = Long(99 if py3compat.PY3 else long(99))
779 value = Long(99 if py3compat.PY3 else long(99))
780
780
781 class TestLong(TraitTestBase):
781 class TestLong(TraitTestBase):
782
782
783 obj = LongTrait()
783 obj = LongTrait()
784
784
785 _default_value = 99 if py3compat.PY3 else long(99)
785 _default_value = 99 if py3compat.PY3 else long(99)
786 _good_values = [10, -10]
786 _good_values = [10, -10]
787 _bad_values = ['ten', u'ten', [10], {'ten': 10},(10,),
787 _bad_values = ['ten', u'ten', [10], {'ten': 10},(10,),
788 None, 1j, 10.1, -10.1, '10', '-10', '10L', '-10L', '10.1',
788 None, 1j, 10.1, -10.1, '10', '-10', '10L', '-10L', '10.1',
789 '-10.1', u'10', u'-10', u'10L', u'-10L', u'10.1',
789 '-10.1', u'10', u'-10', u'10L', u'-10L', u'10.1',
790 u'-10.1']
790 u'-10.1']
791 if not py3compat.PY3:
791 if not py3compat.PY3:
792 # maxint undefined on py3, because int == long
792 # maxint undefined on py3, because int == long
793 _good_values.extend([long(10), long(-10), 10*sys.maxint, -10*sys.maxint])
793 _good_values.extend([long(10), long(-10), 10*sys.maxint, -10*sys.maxint])
794 _bad_values.extend([[long(10)], (long(10),)])
794 _bad_values.extend([[long(10)], (long(10),)])
795
795
796 @skipif(py3compat.PY3, "not relevant on py3")
796 @skipif(py3compat.PY3, "not relevant on py3")
797 def test_cast_small(self):
797 def test_cast_small(self):
798 """Long casts ints to long"""
798 """Long casts ints to long"""
799 self.obj.value = 10
799 self.obj.value = 10
800 self.assertEqual(type(self.obj.value), long)
800 self.assertEqual(type(self.obj.value), long)
801
801
802
802
803 class IntegerTrait(HasTraits):
803 class IntegerTrait(HasTraits):
804 value = Integer(1)
804 value = Integer(1)
805
805
806 class TestInteger(TestLong):
806 class TestInteger(TestLong):
807 obj = IntegerTrait()
807 obj = IntegerTrait()
808 _default_value = 1
808 _default_value = 1
809
809
810 def coerce(self, n):
810 def coerce(self, n):
811 return int(n)
811 return int(n)
812
812
813 @skipif(py3compat.PY3, "not relevant on py3")
813 @skipif(py3compat.PY3, "not relevant on py3")
814 def test_cast_small(self):
814 def test_cast_small(self):
815 """Integer casts small longs to int"""
815 """Integer casts small longs to int"""
816 if py3compat.PY3:
816 if py3compat.PY3:
817 raise SkipTest("not relevant on py3")
817 raise SkipTest("not relevant on py3")
818
818
819 self.obj.value = long(100)
819 self.obj.value = long(100)
820 self.assertEqual(type(self.obj.value), int)
820 self.assertEqual(type(self.obj.value), int)
821
821
822
822
823 class FloatTrait(HasTraits):
823 class FloatTrait(HasTraits):
824
824
825 value = Float(99.0)
825 value = Float(99.0)
826
826
827 class TestFloat(TraitTestBase):
827 class TestFloat(TraitTestBase):
828
828
829 obj = FloatTrait()
829 obj = FloatTrait()
830
830
831 _default_value = 99.0
831 _default_value = 99.0
832 _good_values = [10, -10, 10.1, -10.1]
832 _good_values = [10, -10, 10.1, -10.1]
833 _bad_values = ['ten', u'ten', [10], {'ten': 10},(10,), None,
833 _bad_values = ['ten', u'ten', [10], {'ten': 10},(10,), None,
834 1j, '10', '-10', '10L', '-10L', '10.1', '-10.1', u'10',
834 1j, '10', '-10', '10L', '-10L', '10.1', '-10.1', u'10',
835 u'-10', u'10L', u'-10L', u'10.1', u'-10.1']
835 u'-10', u'10L', u'-10L', u'10.1', u'-10.1']
836 if not py3compat.PY3:
836 if not py3compat.PY3:
837 _bad_values.extend([long(10), long(-10)])
837 _bad_values.extend([long(10), long(-10)])
838
838
839
839
840 class ComplexTrait(HasTraits):
840 class ComplexTrait(HasTraits):
841
841
842 value = Complex(99.0-99.0j)
842 value = Complex(99.0-99.0j)
843
843
844 class TestComplex(TraitTestBase):
844 class TestComplex(TraitTestBase):
845
845
846 obj = ComplexTrait()
846 obj = ComplexTrait()
847
847
848 _default_value = 99.0-99.0j
848 _default_value = 99.0-99.0j
849 _good_values = [10, -10, 10.1, -10.1, 10j, 10+10j, 10-10j,
849 _good_values = [10, -10, 10.1, -10.1, 10j, 10+10j, 10-10j,
850 10.1j, 10.1+10.1j, 10.1-10.1j]
850 10.1j, 10.1+10.1j, 10.1-10.1j]
851 _bad_values = [u'10L', u'-10L', 'ten', [10], {'ten': 10},(10,), None]
851 _bad_values = [u'10L', u'-10L', 'ten', [10], {'ten': 10},(10,), None]
852 if not py3compat.PY3:
852 if not py3compat.PY3:
853 _bad_values.extend([long(10), long(-10)])
853 _bad_values.extend([long(10), long(-10)])
854
854
855
855
856 class BytesTrait(HasTraits):
856 class BytesTrait(HasTraits):
857
857
858 value = Bytes(b'string')
858 value = Bytes(b'string')
859
859
860 class TestBytes(TraitTestBase):
860 class TestBytes(TraitTestBase):
861
861
862 obj = BytesTrait()
862 obj = BytesTrait()
863
863
864 _default_value = b'string'
864 _default_value = b'string'
865 _good_values = [b'10', b'-10', b'10L',
865 _good_values = [b'10', b'-10', b'10L',
866 b'-10L', b'10.1', b'-10.1', b'string']
866 b'-10L', b'10.1', b'-10.1', b'string']
867 _bad_values = [10, -10, 10.1, -10.1, 1j, [10],
867 _bad_values = [10, -10, 10.1, -10.1, 1j, [10],
868 ['ten'],{'ten': 10},(10,), None, u'string']
868 ['ten'],{'ten': 10},(10,), None, u'string']
869 if not py3compat.PY3:
869 if not py3compat.PY3:
870 _bad_values.extend([long(10), long(-10)])
870 _bad_values.extend([long(10), long(-10)])
871
871
872
872
873 class UnicodeTrait(HasTraits):
873 class UnicodeTrait(HasTraits):
874
874
875 value = Unicode(u'unicode')
875 value = Unicode(u'unicode')
876
876
877 class TestUnicode(TraitTestBase):
877 class TestUnicode(TraitTestBase):
878
878
879 obj = UnicodeTrait()
879 obj = UnicodeTrait()
880
880
881 _default_value = u'unicode'
881 _default_value = u'unicode'
882 _good_values = ['10', '-10', '10L', '-10L', '10.1',
882 _good_values = ['10', '-10', '10L', '-10L', '10.1',
883 '-10.1', '', u'', 'string', u'string', u"€"]
883 '-10.1', '', u'', 'string', u'string', u"€"]
884 _bad_values = [10, -10, 10.1, -10.1, 1j,
884 _bad_values = [10, -10, 10.1, -10.1, 1j,
885 [10], ['ten'], [u'ten'], {'ten': 10},(10,), None]
885 [10], ['ten'], [u'ten'], {'ten': 10},(10,), None]
886 if not py3compat.PY3:
886 if not py3compat.PY3:
887 _bad_values.extend([long(10), long(-10)])
887 _bad_values.extend([long(10), long(-10)])
888
888
889
889
890 class ObjectNameTrait(HasTraits):
890 class ObjectNameTrait(HasTraits):
891 value = ObjectName("abc")
891 value = ObjectName("abc")
892
892
893 class TestObjectName(TraitTestBase):
893 class TestObjectName(TraitTestBase):
894 obj = ObjectNameTrait()
894 obj = ObjectNameTrait()
895
895
896 _default_value = "abc"
896 _default_value = "abc"
897 _good_values = ["a", "gh", "g9", "g_", "_G", u"a345_"]
897 _good_values = ["a", "gh", "g9", "g_", "_G", u"a345_"]
898 _bad_values = [1, "", u"€", "9g", "!", "#abc", "aj@", "a.b", "a()", "a[0]",
898 _bad_values = [1, "", u"€", "9g", "!", "#abc", "aj@", "a.b", "a()", "a[0]",
899 None, object(), object]
899 None, object(), object]
900 if sys.version_info[0] < 3:
900 if sys.version_info[0] < 3:
901 _bad_values.append(u"ΓΎ")
901 _bad_values.append(u"ΓΎ")
902 else:
902 else:
903 _good_values.append(u"ΓΎ") # ΓΎ=1 is valid in Python 3 (PEP 3131).
903 _good_values.append(u"ΓΎ") # ΓΎ=1 is valid in Python 3 (PEP 3131).
904
904
905
905
906 class DottedObjectNameTrait(HasTraits):
906 class DottedObjectNameTrait(HasTraits):
907 value = DottedObjectName("a.b")
907 value = DottedObjectName("a.b")
908
908
909 class TestDottedObjectName(TraitTestBase):
909 class TestDottedObjectName(TraitTestBase):
910 obj = DottedObjectNameTrait()
910 obj = DottedObjectNameTrait()
911
911
912 _default_value = "a.b"
912 _default_value = "a.b"
913 _good_values = ["A", "y.t", "y765.__repr__", "os.path.join", u"os.path.join"]
913 _good_values = ["A", "y.t", "y765.__repr__", "os.path.join", u"os.path.join"]
914 _bad_values = [1, u"abc.€", "_.@", ".", ".abc", "abc.", ".abc.", None]
914 _bad_values = [1, u"abc.€", "_.@", ".", ".abc", "abc.", ".abc.", None]
915 if sys.version_info[0] < 3:
915 if sys.version_info[0] < 3:
916 _bad_values.append(u"t.ΓΎ")
916 _bad_values.append(u"t.ΓΎ")
917 else:
917 else:
918 _good_values.append(u"t.ΓΎ")
918 _good_values.append(u"t.ΓΎ")
919
919
920
920
921 class TCPAddressTrait(HasTraits):
921 class TCPAddressTrait(HasTraits):
922
922
923 value = TCPAddress()
923 value = TCPAddress()
924
924
925 class TestTCPAddress(TraitTestBase):
925 class TestTCPAddress(TraitTestBase):
926
926
927 obj = TCPAddressTrait()
927 obj = TCPAddressTrait()
928
928
929 _default_value = ('127.0.0.1',0)
929 _default_value = ('127.0.0.1',0)
930 _good_values = [('localhost',0),('192.168.0.1',1000),('www.google.com',80)]
930 _good_values = [('localhost',0),('192.168.0.1',1000),('www.google.com',80)]
931 _bad_values = [(0,0),('localhost',10.0),('localhost',-1), None]
931 _bad_values = [(0,0),('localhost',10.0),('localhost',-1), None]
932
932
933 class ListTrait(HasTraits):
933 class ListTrait(HasTraits):
934
934
935 value = List(Int)
935 value = List(Int)
936
936
937 class TestList(TraitTestBase):
937 class TestList(TraitTestBase):
938
938
939 obj = ListTrait()
939 obj = ListTrait()
940
940
941 _default_value = []
941 _default_value = []
942 _good_values = [[], [1], list(range(10)), (1,2)]
942 _good_values = [[], [1], list(range(10)), (1,2)]
943 _bad_values = [10, [1,'a'], 'a']
943 _bad_values = [10, [1,'a'], 'a']
944
944
945 def coerce(self, value):
945 def coerce(self, value):
946 if value is not None:
946 if value is not None:
947 value = list(value)
947 value = list(value)
948 return value
948 return value
949
949
950 class Foo(object):
950 class Foo(object):
951 pass
951 pass
952
952
953 class NoneInstanceListTrait(HasTraits):
953 class NoneInstanceListTrait(HasTraits):
954
954
955 value = List(Instance(Foo, allow_none=False))
955 value = List(Instance(Foo, allow_none=False))
956
956
957 class TestNoneInstanceList(TraitTestBase):
957 class TestNoneInstanceList(TraitTestBase):
958
958
959 obj = NoneInstanceListTrait()
959 obj = NoneInstanceListTrait()
960
960
961 _default_value = []
961 _default_value = []
962 _good_values = [[Foo(), Foo()], []]
962 _good_values = [[Foo(), Foo()], []]
963 _bad_values = [[None], [Foo(), None]]
963 _bad_values = [[None], [Foo(), None]]
964
964
965
965
966 class InstanceListTrait(HasTraits):
966 class InstanceListTrait(HasTraits):
967
967
968 value = List(Instance(__name__+'.Foo'))
968 value = List(Instance(__name__+'.Foo'))
969
969
970 class TestInstanceList(TraitTestBase):
970 class TestInstanceList(TraitTestBase):
971
971
972 obj = InstanceListTrait()
972 obj = InstanceListTrait()
973
973
974 def test_klass(self):
974 def test_klass(self):
975 """Test that the instance klass is properly assigned."""
975 """Test that the instance klass is properly assigned."""
976 self.assertIs(self.obj.traits()['value']._trait.klass, Foo)
976 self.assertIs(self.obj.traits()['value']._trait.klass, Foo)
977
977
978 _default_value = []
978 _default_value = []
979 _good_values = [[Foo(), Foo(), None], None]
979 _good_values = [[Foo(), Foo(), None], []]
980 _bad_values = [['1', 2,], '1', [Foo]]
980 _bad_values = [['1', 2,], '1', [Foo], None]
981
981
982 class LenListTrait(HasTraits):
982 class LenListTrait(HasTraits):
983
983
984 value = List(Int, [0], minlen=1, maxlen=2)
984 value = List(Int, [0], minlen=1, maxlen=2)
985
985
986 class TestLenList(TraitTestBase):
986 class TestLenList(TraitTestBase):
987
987
988 obj = LenListTrait()
988 obj = LenListTrait()
989
989
990 _default_value = [0]
990 _default_value = [0]
991 _good_values = [[1], [1,2], (1,2)]
991 _good_values = [[1], [1,2], (1,2)]
992 _bad_values = [10, [1,'a'], 'a', [], list(range(3))]
992 _bad_values = [10, [1,'a'], 'a', [], list(range(3))]
993
993
994 def coerce(self, value):
994 def coerce(self, value):
995 if value is not None:
995 if value is not None:
996 value = list(value)
996 value = list(value)
997 return value
997 return value
998
998
999 class TupleTrait(HasTraits):
999 class TupleTrait(HasTraits):
1000
1000
1001 value = Tuple(Int(allow_none=True))
1001 value = Tuple(Int(allow_none=True))
1002
1002
1003 class TestTupleTrait(TraitTestBase):
1003 class TestTupleTrait(TraitTestBase):
1004
1004
1005 obj = TupleTrait()
1005 obj = TupleTrait()
1006
1006
1007 _default_value = None
1007 _default_value = None
1008 _good_values = [(1,), None, (0,), [1], (None,)]
1008 _good_values = [(1,), None, (0,), [1], (None,)]
1009 _bad_values = [10, (1,2), ('a'), ()]
1009 _bad_values = [10, (1,2), ('a'), ()]
1010
1010
1011 def coerce(self, value):
1011 def coerce(self, value):
1012 if value is not None:
1012 if value is not None:
1013 value = tuple(value)
1013 value = tuple(value)
1014 return value
1014 return value
1015
1015
1016 def test_invalid_args(self):
1016 def test_invalid_args(self):
1017 self.assertRaises(TypeError, Tuple, 5)
1017 self.assertRaises(TypeError, Tuple, 5)
1018 self.assertRaises(TypeError, Tuple, default_value='hello')
1018 self.assertRaises(TypeError, Tuple, default_value='hello')
1019 t = Tuple(Int, CBytes, default_value=(1,5))
1019 t = Tuple(Int, CBytes, default_value=(1,5))
1020
1020
1021 class LooseTupleTrait(HasTraits):
1021 class LooseTupleTrait(HasTraits):
1022
1022
1023 value = Tuple((1,2,3))
1023 value = Tuple((1,2,3))
1024
1024
1025 class TestLooseTupleTrait(TraitTestBase):
1025 class TestLooseTupleTrait(TraitTestBase):
1026
1026
1027 obj = LooseTupleTrait()
1027 obj = LooseTupleTrait()
1028
1028
1029 _default_value = (1,2,3)
1029 _default_value = (1,2,3)
1030 _good_values = [(1,), None, [1], (0,), tuple(range(5)), tuple('hello'), ('a',5), ()]
1030 _good_values = [(1,), None, [1], (0,), tuple(range(5)), tuple('hello'), ('a',5), ()]
1031 _bad_values = [10, 'hello', {}]
1031 _bad_values = [10, 'hello', {}]
1032
1032
1033 def coerce(self, value):
1033 def coerce(self, value):
1034 if value is not None:
1034 if value is not None:
1035 value = tuple(value)
1035 value = tuple(value)
1036 return value
1036 return value
1037
1037
1038 def test_invalid_args(self):
1038 def test_invalid_args(self):
1039 self.assertRaises(TypeError, Tuple, 5)
1039 self.assertRaises(TypeError, Tuple, 5)
1040 self.assertRaises(TypeError, Tuple, default_value='hello')
1040 self.assertRaises(TypeError, Tuple, default_value='hello')
1041 t = Tuple(Int, CBytes, default_value=(1,5))
1041 t = Tuple(Int, CBytes, default_value=(1,5))
1042
1042
1043
1043
1044 class MultiTupleTrait(HasTraits):
1044 class MultiTupleTrait(HasTraits):
1045
1045
1046 value = Tuple(Int, Bytes, default_value=[99,b'bottles'])
1046 value = Tuple(Int, Bytes, default_value=[99,b'bottles'])
1047
1047
1048 class TestMultiTuple(TraitTestBase):
1048 class TestMultiTuple(TraitTestBase):
1049
1049
1050 obj = MultiTupleTrait()
1050 obj = MultiTupleTrait()
1051
1051
1052 _default_value = (99,b'bottles')
1052 _default_value = (99,b'bottles')
1053 _good_values = [(1,b'a'), (2,b'b')]
1053 _good_values = [(1,b'a'), (2,b'b')]
1054 _bad_values = ((),10, b'a', (1,b'a',3), (b'a',1), (1, u'a'))
1054 _bad_values = ((),10, b'a', (1,b'a',3), (b'a',1), (1, u'a'))
1055
1055
1056 class CRegExpTrait(HasTraits):
1056 class CRegExpTrait(HasTraits):
1057
1057
1058 value = CRegExp(r'')
1058 value = CRegExp(r'')
1059
1059
1060 class TestCRegExp(TraitTestBase):
1060 class TestCRegExp(TraitTestBase):
1061
1061
1062 def coerce(self, value):
1062 def coerce(self, value):
1063 return re.compile(value)
1063 return re.compile(value)
1064
1064
1065 obj = CRegExpTrait()
1065 obj = CRegExpTrait()
1066
1066
1067 _default_value = re.compile(r'')
1067 _default_value = re.compile(r'')
1068 _good_values = [r'\d+', re.compile(r'\d+')]
1068 _good_values = [r'\d+', re.compile(r'\d+')]
1069 _bad_values = ['(', None, ()]
1069 _bad_values = ['(', None, ()]
1070
1070
1071 class DictTrait(HasTraits):
1071 class DictTrait(HasTraits):
1072 value = Dict()
1072 value = Dict()
1073
1073
1074 def test_dict_assignment():
1074 def test_dict_assignment():
1075 d = dict()
1075 d = dict()
1076 c = DictTrait()
1076 c = DictTrait()
1077 c.value = d
1077 c.value = d
1078 d['a'] = 5
1078 d['a'] = 5
1079 nt.assert_equal(d, c.value)
1079 nt.assert_equal(d, c.value)
1080 nt.assert_true(c.value is d)
1080 nt.assert_true(c.value is d)
1081
1081
1082 def test_dict_default_value():
1082 def test_dict_default_value():
1083 """Check that the `{}` default value of the Dict traitlet constructor is
1083 """Check that the `{}` default value of the Dict traitlet constructor is
1084 actually copied."""
1084 actually copied."""
1085
1085
1086 d1, d2 = Dict(), Dict()
1086 d1, d2 = Dict(), Dict()
1087 nt.assert_false(d1.get_default_value() is d2.get_default_value())
1087 nt.assert_false(d1.get_default_value() is d2.get_default_value())
1088
1088
1089
1089
1090 class TestValidationHook(TestCase):
1090 class TestValidationHook(TestCase):
1091
1091
1092 def test_parity_trait(self):
1092 def test_parity_trait(self):
1093 """Verify that the early validation hook is effective"""
1093 """Verify that the early validation hook is effective"""
1094
1094
1095 class Parity(HasTraits):
1095 class Parity(HasTraits):
1096
1096
1097 value = Int(0)
1097 value = Int(0)
1098 parity = Enum(['odd', 'even'], default_value='even', allow_none=False)
1098 parity = Enum(['odd', 'even'], default_value='even', allow_none=False)
1099
1099
1100 def _value_validate(self, value, trait):
1100 def _value_validate(self, value, trait):
1101 if self.parity == 'even' and value % 2:
1101 if self.parity == 'even' and value % 2:
1102 raise TraitError('Expected an even number')
1102 raise TraitError('Expected an even number')
1103 if self.parity == 'odd' and (value % 2 == 0):
1103 if self.parity == 'odd' and (value % 2 == 0):
1104 raise TraitError('Expected an odd number')
1104 raise TraitError('Expected an odd number')
1105 return value
1105 return value
1106
1106
1107 u = Parity()
1107 u = Parity()
1108 u.parity = 'odd'
1108 u.parity = 'odd'
1109 u.value = 1 # OK
1109 u.value = 1 # OK
1110 with self.assertRaises(TraitError):
1110 with self.assertRaises(TraitError):
1111 u.value = 2 # Trait Error
1111 u.value = 2 # Trait Error
1112
1112
1113 u.parity = 'even'
1113 u.parity = 'even'
1114 u.value = 2 # OK
1114 u.value = 2 # OK
1115
1115
1116
1116
1117 class TestLink(TestCase):
1117 class TestLink(TestCase):
1118
1118
1119 def test_connect_same(self):
1119 def test_connect_same(self):
1120 """Verify two traitlets of the same type can be linked together using link."""
1120 """Verify two traitlets of the same type can be linked together using link."""
1121
1121
1122 # Create two simple classes with Int traitlets.
1122 # Create two simple classes with Int traitlets.
1123 class A(HasTraits):
1123 class A(HasTraits):
1124 value = Int()
1124 value = Int()
1125 a = A(value=9)
1125 a = A(value=9)
1126 b = A(value=8)
1126 b = A(value=8)
1127
1127
1128 # Conenct the two classes.
1128 # Conenct the two classes.
1129 c = link((a, 'value'), (b, 'value'))
1129 c = link((a, 'value'), (b, 'value'))
1130
1130
1131 # Make sure the values are the same at the point of linking.
1131 # Make sure the values are the same at the point of linking.
1132 self.assertEqual(a.value, b.value)
1132 self.assertEqual(a.value, b.value)
1133
1133
1134 # Change one of the values to make sure they stay in sync.
1134 # Change one of the values to make sure they stay in sync.
1135 a.value = 5
1135 a.value = 5
1136 self.assertEqual(a.value, b.value)
1136 self.assertEqual(a.value, b.value)
1137 b.value = 6
1137 b.value = 6
1138 self.assertEqual(a.value, b.value)
1138 self.assertEqual(a.value, b.value)
1139
1139
1140 def test_link_different(self):
1140 def test_link_different(self):
1141 """Verify two traitlets of different types can be linked together using link."""
1141 """Verify two traitlets of different types can be linked together using link."""
1142
1142
1143 # Create two simple classes with Int traitlets.
1143 # Create two simple classes with Int traitlets.
1144 class A(HasTraits):
1144 class A(HasTraits):
1145 value = Int()
1145 value = Int()
1146 class B(HasTraits):
1146 class B(HasTraits):
1147 count = Int()
1147 count = Int()
1148 a = A(value=9)
1148 a = A(value=9)
1149 b = B(count=8)
1149 b = B(count=8)
1150
1150
1151 # Conenct the two classes.
1151 # Conenct the two classes.
1152 c = link((a, 'value'), (b, 'count'))
1152 c = link((a, 'value'), (b, 'count'))
1153
1153
1154 # Make sure the values are the same at the point of linking.
1154 # Make sure the values are the same at the point of linking.
1155 self.assertEqual(a.value, b.count)
1155 self.assertEqual(a.value, b.count)
1156
1156
1157 # Change one of the values to make sure they stay in sync.
1157 # Change one of the values to make sure they stay in sync.
1158 a.value = 5
1158 a.value = 5
1159 self.assertEqual(a.value, b.count)
1159 self.assertEqual(a.value, b.count)
1160 b.count = 4
1160 b.count = 4
1161 self.assertEqual(a.value, b.count)
1161 self.assertEqual(a.value, b.count)
1162
1162
1163 def test_unlink(self):
1163 def test_unlink(self):
1164 """Verify two linked traitlets can be unlinked."""
1164 """Verify two linked traitlets can be unlinked."""
1165
1165
1166 # Create two simple classes with Int traitlets.
1166 # Create two simple classes with Int traitlets.
1167 class A(HasTraits):
1167 class A(HasTraits):
1168 value = Int()
1168 value = Int()
1169 a = A(value=9)
1169 a = A(value=9)
1170 b = A(value=8)
1170 b = A(value=8)
1171
1171
1172 # Connect the two classes.
1172 # Connect the two classes.
1173 c = link((a, 'value'), (b, 'value'))
1173 c = link((a, 'value'), (b, 'value'))
1174 a.value = 4
1174 a.value = 4
1175 c.unlink()
1175 c.unlink()
1176
1176
1177 # Change one of the values to make sure they don't stay in sync.
1177 # Change one of the values to make sure they don't stay in sync.
1178 a.value = 5
1178 a.value = 5
1179 self.assertNotEqual(a.value, b.value)
1179 self.assertNotEqual(a.value, b.value)
1180
1180
1181 def test_callbacks(self):
1181 def test_callbacks(self):
1182 """Verify two linked traitlets have their callbacks called once."""
1182 """Verify two linked traitlets have their callbacks called once."""
1183
1183
1184 # Create two simple classes with Int traitlets.
1184 # Create two simple classes with Int traitlets.
1185 class A(HasTraits):
1185 class A(HasTraits):
1186 value = Int()
1186 value = Int()
1187 class B(HasTraits):
1187 class B(HasTraits):
1188 count = Int()
1188 count = Int()
1189 a = A(value=9)
1189 a = A(value=9)
1190 b = B(count=8)
1190 b = B(count=8)
1191
1191
1192 # Register callbacks that count.
1192 # Register callbacks that count.
1193 callback_count = []
1193 callback_count = []
1194 def a_callback(name, old, new):
1194 def a_callback(name, old, new):
1195 callback_count.append('a')
1195 callback_count.append('a')
1196 a.on_trait_change(a_callback, 'value')
1196 a.on_trait_change(a_callback, 'value')
1197 def b_callback(name, old, new):
1197 def b_callback(name, old, new):
1198 callback_count.append('b')
1198 callback_count.append('b')
1199 b.on_trait_change(b_callback, 'count')
1199 b.on_trait_change(b_callback, 'count')
1200
1200
1201 # Connect the two classes.
1201 # Connect the two classes.
1202 c = link((a, 'value'), (b, 'count'))
1202 c = link((a, 'value'), (b, 'count'))
1203
1203
1204 # Make sure b's count was set to a's value once.
1204 # Make sure b's count was set to a's value once.
1205 self.assertEqual(''.join(callback_count), 'b')
1205 self.assertEqual(''.join(callback_count), 'b')
1206 del callback_count[:]
1206 del callback_count[:]
1207
1207
1208 # Make sure a's value was set to b's count once.
1208 # Make sure a's value was set to b's count once.
1209 b.count = 5
1209 b.count = 5
1210 self.assertEqual(''.join(callback_count), 'ba')
1210 self.assertEqual(''.join(callback_count), 'ba')
1211 del callback_count[:]
1211 del callback_count[:]
1212
1212
1213 # Make sure b's count was set to a's value once.
1213 # Make sure b's count was set to a's value once.
1214 a.value = 4
1214 a.value = 4
1215 self.assertEqual(''.join(callback_count), 'ab')
1215 self.assertEqual(''.join(callback_count), 'ab')
1216 del callback_count[:]
1216 del callback_count[:]
1217
1217
1218 class TestDirectionalLink(TestCase):
1218 class TestDirectionalLink(TestCase):
1219 def test_connect_same(self):
1219 def test_connect_same(self):
1220 """Verify two traitlets of the same type can be linked together using directional_link."""
1220 """Verify two traitlets of the same type can be linked together using directional_link."""
1221
1221
1222 # Create two simple classes with Int traitlets.
1222 # Create two simple classes with Int traitlets.
1223 class A(HasTraits):
1223 class A(HasTraits):
1224 value = Int()
1224 value = Int()
1225 a = A(value=9)
1225 a = A(value=9)
1226 b = A(value=8)
1226 b = A(value=8)
1227
1227
1228 # Conenct the two classes.
1228 # Conenct the two classes.
1229 c = directional_link((a, 'value'), (b, 'value'))
1229 c = directional_link((a, 'value'), (b, 'value'))
1230
1230
1231 # Make sure the values are the same at the point of linking.
1231 # Make sure the values are the same at the point of linking.
1232 self.assertEqual(a.value, b.value)
1232 self.assertEqual(a.value, b.value)
1233
1233
1234 # Change one the value of the source and check that it synchronizes the target.
1234 # Change one the value of the source and check that it synchronizes the target.
1235 a.value = 5
1235 a.value = 5
1236 self.assertEqual(b.value, 5)
1236 self.assertEqual(b.value, 5)
1237 # Change one the value of the target and check that it has no impact on the source
1237 # Change one the value of the target and check that it has no impact on the source
1238 b.value = 6
1238 b.value = 6
1239 self.assertEqual(a.value, 5)
1239 self.assertEqual(a.value, 5)
1240
1240
1241 def test_link_different(self):
1241 def test_link_different(self):
1242 """Verify two traitlets of different types can be linked together using link."""
1242 """Verify two traitlets of different types can be linked together using link."""
1243
1243
1244 # Create two simple classes with Int traitlets.
1244 # Create two simple classes with Int traitlets.
1245 class A(HasTraits):
1245 class A(HasTraits):
1246 value = Int()
1246 value = Int()
1247 class B(HasTraits):
1247 class B(HasTraits):
1248 count = Int()
1248 count = Int()
1249 a = A(value=9)
1249 a = A(value=9)
1250 b = B(count=8)
1250 b = B(count=8)
1251
1251
1252 # Conenct the two classes.
1252 # Conenct the two classes.
1253 c = directional_link((a, 'value'), (b, 'count'))
1253 c = directional_link((a, 'value'), (b, 'count'))
1254
1254
1255 # Make sure the values are the same at the point of linking.
1255 # Make sure the values are the same at the point of linking.
1256 self.assertEqual(a.value, b.count)
1256 self.assertEqual(a.value, b.count)
1257
1257
1258 # Change one the value of the source and check that it synchronizes the target.
1258 # Change one the value of the source and check that it synchronizes the target.
1259 a.value = 5
1259 a.value = 5
1260 self.assertEqual(b.count, 5)
1260 self.assertEqual(b.count, 5)
1261 # Change one the value of the target and check that it has no impact on the source
1261 # Change one the value of the target and check that it has no impact on the source
1262 b.value = 6
1262 b.value = 6
1263 self.assertEqual(a.value, 5)
1263 self.assertEqual(a.value, 5)
1264
1264
1265 def test_unlink(self):
1265 def test_unlink(self):
1266 """Verify two linked traitlets can be unlinked."""
1266 """Verify two linked traitlets can be unlinked."""
1267
1267
1268 # Create two simple classes with Int traitlets.
1268 # Create two simple classes with Int traitlets.
1269 class A(HasTraits):
1269 class A(HasTraits):
1270 value = Int()
1270 value = Int()
1271 a = A(value=9)
1271 a = A(value=9)
1272 b = A(value=8)
1272 b = A(value=8)
1273
1273
1274 # Connect the two classes.
1274 # Connect the two classes.
1275 c = directional_link((a, 'value'), (b, 'value'))
1275 c = directional_link((a, 'value'), (b, 'value'))
1276 a.value = 4
1276 a.value = 4
1277 c.unlink()
1277 c.unlink()
1278
1278
1279 # Change one of the values to make sure they don't stay in sync.
1279 # Change one of the values to make sure they don't stay in sync.
1280 a.value = 5
1280 a.value = 5
1281 self.assertNotEqual(a.value, b.value)
1281 self.assertNotEqual(a.value, b.value)
1282
1282
1283 class Pickleable(HasTraits):
1283 class Pickleable(HasTraits):
1284 i = Int()
1284 i = Int()
1285 j = Int()
1285 j = Int()
1286
1286
1287 def _i_default(self):
1287 def _i_default(self):
1288 return 1
1288 return 1
1289
1289
1290 def _i_changed(self, name, old, new):
1290 def _i_changed(self, name, old, new):
1291 self.j = new
1291 self.j = new
1292
1292
1293 def test_pickle_hastraits():
1293 def test_pickle_hastraits():
1294 c = Pickleable()
1294 c = Pickleable()
1295 for protocol in range(pickle.HIGHEST_PROTOCOL + 1):
1295 for protocol in range(pickle.HIGHEST_PROTOCOL + 1):
1296 p = pickle.dumps(c, protocol)
1296 p = pickle.dumps(c, protocol)
1297 c2 = pickle.loads(p)
1297 c2 = pickle.loads(p)
1298 nt.assert_equal(c2.i, c.i)
1298 nt.assert_equal(c2.i, c.i)
1299 nt.assert_equal(c2.j, c.j)
1299 nt.assert_equal(c2.j, c.j)
1300
1300
1301 c.i = 5
1301 c.i = 5
1302 for protocol in range(pickle.HIGHEST_PROTOCOL + 1):
1302 for protocol in range(pickle.HIGHEST_PROTOCOL + 1):
1303 p = pickle.dumps(c, protocol)
1303 p = pickle.dumps(c, protocol)
1304 c2 = pickle.loads(p)
1304 c2 = pickle.loads(p)
1305 nt.assert_equal(c2.i, c.i)
1305 nt.assert_equal(c2.i, c.i)
1306 nt.assert_equal(c2.j, c.j)
1306 nt.assert_equal(c2.j, c.j)
1307
1307
1308 class TestEventful(TestCase):
1308 class TestEventful(TestCase):
1309
1309
1310 def test_list(self):
1310 def test_list(self):
1311 """Does the EventfulList work?"""
1311 """Does the EventfulList work?"""
1312 event_cache = []
1312 event_cache = []
1313
1313
1314 class A(HasTraits):
1314 class A(HasTraits):
1315 x = EventfulList([c for c in 'abc'])
1315 x = EventfulList([c for c in 'abc'])
1316 a = A()
1316 a = A()
1317 a.x.on_events(lambda i, x: event_cache.append('insert'), \
1317 a.x.on_events(lambda i, x: event_cache.append('insert'), \
1318 lambda i, x: event_cache.append('set'), \
1318 lambda i, x: event_cache.append('set'), \
1319 lambda i: event_cache.append('del'), \
1319 lambda i: event_cache.append('del'), \
1320 lambda: event_cache.append('reverse'), \
1320 lambda: event_cache.append('reverse'), \
1321 lambda *p, **k: event_cache.append('sort'))
1321 lambda *p, **k: event_cache.append('sort'))
1322
1322
1323 a.x.remove('c')
1323 a.x.remove('c')
1324 # ab
1324 # ab
1325 a.x.insert(0, 'z')
1325 a.x.insert(0, 'z')
1326 # zab
1326 # zab
1327 del a.x[1]
1327 del a.x[1]
1328 # zb
1328 # zb
1329 a.x.reverse()
1329 a.x.reverse()
1330 # bz
1330 # bz
1331 a.x[1] = 'o'
1331 a.x[1] = 'o'
1332 # bo
1332 # bo
1333 a.x.append('a')
1333 a.x.append('a')
1334 # boa
1334 # boa
1335 a.x.sort()
1335 a.x.sort()
1336 # abo
1336 # abo
1337
1337
1338 # Were the correct events captured?
1338 # Were the correct events captured?
1339 self.assertEqual(event_cache, ['del', 'insert', 'del', 'reverse', 'set', 'set', 'sort'])
1339 self.assertEqual(event_cache, ['del', 'insert', 'del', 'reverse', 'set', 'set', 'sort'])
1340
1340
1341 # Is the output correct?
1341 # Is the output correct?
1342 self.assertEqual(a.x, [c for c in 'abo'])
1342 self.assertEqual(a.x, [c for c in 'abo'])
1343
1343
1344 def test_dict(self):
1344 def test_dict(self):
1345 """Does the EventfulDict work?"""
1345 """Does the EventfulDict work?"""
1346 event_cache = []
1346 event_cache = []
1347
1347
1348 class A(HasTraits):
1348 class A(HasTraits):
1349 x = EventfulDict({c: c for c in 'abc'})
1349 x = EventfulDict({c: c for c in 'abc'})
1350 a = A()
1350 a = A()
1351 a.x.on_events(lambda k, v: event_cache.append('add'), \
1351 a.x.on_events(lambda k, v: event_cache.append('add'), \
1352 lambda k, v: event_cache.append('set'), \
1352 lambda k, v: event_cache.append('set'), \
1353 lambda k: event_cache.append('del'))
1353 lambda k: event_cache.append('del'))
1354
1354
1355 del a.x['c']
1355 del a.x['c']
1356 # ab
1356 # ab
1357 a.x['z'] = 1
1357 a.x['z'] = 1
1358 # abz
1358 # abz
1359 a.x['z'] = 'z'
1359 a.x['z'] = 'z'
1360 # abz
1360 # abz
1361 a.x.pop('a')
1361 a.x.pop('a')
1362 # bz
1362 # bz
1363
1363
1364 # Were the correct events captured?
1364 # Were the correct events captured?
1365 self.assertEqual(event_cache, ['del', 'add', 'set', 'del'])
1365 self.assertEqual(event_cache, ['del', 'add', 'set', 'del'])
1366
1366
1367 # Is the output correct?
1367 # Is the output correct?
1368 self.assertEqual(a.x, {c: c for c in 'bz'})
1368 self.assertEqual(a.x, {c: c for c in 'bz'})
1369
1369
1370 ###
1370 ###
1371 # Traits for Forward Declaration Tests
1371 # Traits for Forward Declaration Tests
1372 ###
1372 ###
1373 class ForwardDeclaredInstanceTrait(HasTraits):
1373 class ForwardDeclaredInstanceTrait(HasTraits):
1374
1374
1375 value = ForwardDeclaredInstance('ForwardDeclaredBar')
1375 value = ForwardDeclaredInstance('ForwardDeclaredBar')
1376
1376
1377 class ForwardDeclaredTypeTrait(HasTraits):
1377 class ForwardDeclaredTypeTrait(HasTraits):
1378
1378
1379 value = ForwardDeclaredType('ForwardDeclaredBar')
1379 value = ForwardDeclaredType('ForwardDeclaredBar')
1380
1380
1381 class ForwardDeclaredInstanceListTrait(HasTraits):
1381 class ForwardDeclaredInstanceListTrait(HasTraits):
1382
1382
1383 value = List(ForwardDeclaredInstance('ForwardDeclaredBar'))
1383 value = List(ForwardDeclaredInstance('ForwardDeclaredBar'))
1384
1384
1385 class ForwardDeclaredTypeListTrait(HasTraits):
1385 class ForwardDeclaredTypeListTrait(HasTraits):
1386
1386
1387 value = List(ForwardDeclaredType('ForwardDeclaredBar'))
1387 value = List(ForwardDeclaredType('ForwardDeclaredBar'))
1388 ###
1388 ###
1389 # End Traits for Forward Declaration Tests
1389 # End Traits for Forward Declaration Tests
1390 ###
1390 ###
1391
1391
1392 ###
1392 ###
1393 # Classes for Forward Declaration Tests
1393 # Classes for Forward Declaration Tests
1394 ###
1394 ###
1395 class ForwardDeclaredBar(object):
1395 class ForwardDeclaredBar(object):
1396 pass
1396 pass
1397
1397
1398 class ForwardDeclaredBarSub(ForwardDeclaredBar):
1398 class ForwardDeclaredBarSub(ForwardDeclaredBar):
1399 pass
1399 pass
1400 ###
1400 ###
1401 # End Classes for Forward Declaration Tests
1401 # End Classes for Forward Declaration Tests
1402 ###
1402 ###
1403
1403
1404 ###
1404 ###
1405 # Forward Declaration Tests
1405 # Forward Declaration Tests
1406 ###
1406 ###
1407 class TestForwardDeclaredInstanceTrait(TraitTestBase):
1407 class TestForwardDeclaredInstanceTrait(TraitTestBase):
1408
1408
1409 obj = ForwardDeclaredInstanceTrait()
1409 obj = ForwardDeclaredInstanceTrait()
1410 _default_value = None
1410 _default_value = None
1411 _good_values = [None, ForwardDeclaredBar(), ForwardDeclaredBarSub()]
1411 _good_values = [None, ForwardDeclaredBar(), ForwardDeclaredBarSub()]
1412 _bad_values = ['foo', 3, ForwardDeclaredBar, ForwardDeclaredBarSub]
1412 _bad_values = ['foo', 3, ForwardDeclaredBar, ForwardDeclaredBarSub]
1413
1413
1414 class TestForwardDeclaredTypeTrait(TraitTestBase):
1414 class TestForwardDeclaredTypeTrait(TraitTestBase):
1415
1415
1416 obj = ForwardDeclaredTypeTrait()
1416 obj = ForwardDeclaredTypeTrait()
1417 _default_value = None
1417 _default_value = None
1418 _good_values = [None, ForwardDeclaredBar, ForwardDeclaredBarSub]
1418 _good_values = [None, ForwardDeclaredBar, ForwardDeclaredBarSub]
1419 _bad_values = ['foo', 3, ForwardDeclaredBar(), ForwardDeclaredBarSub()]
1419 _bad_values = ['foo', 3, ForwardDeclaredBar(), ForwardDeclaredBarSub()]
1420
1420
1421 class TestForwardDeclaredInstanceList(TraitTestBase):
1421 class TestForwardDeclaredInstanceList(TraitTestBase):
1422
1422
1423 obj = ForwardDeclaredInstanceListTrait()
1423 obj = ForwardDeclaredInstanceListTrait()
1424
1424
1425 def test_klass(self):
1425 def test_klass(self):
1426 """Test that the instance klass is properly assigned."""
1426 """Test that the instance klass is properly assigned."""
1427 self.assertIs(self.obj.traits()['value']._trait.klass, ForwardDeclaredBar)
1427 self.assertIs(self.obj.traits()['value']._trait.klass, ForwardDeclaredBar)
1428
1428
1429 _default_value = []
1429 _default_value = []
1430 _good_values = [
1430 _good_values = [
1431 [ForwardDeclaredBar(), ForwardDeclaredBarSub(), None],
1431 [ForwardDeclaredBar(), ForwardDeclaredBarSub(), None],
1432 [None],
1432 [None],
1433 [],
1433 [],
1434 None,
1435 ]
1434 ]
1436 _bad_values = [
1435 _bad_values = [
1437 ForwardDeclaredBar(),
1436 ForwardDeclaredBar(),
1438 [ForwardDeclaredBar(), 3],
1437 [ForwardDeclaredBar(), 3],
1439 '1',
1438 '1',
1440 # Note that this is the type, not an instance.
1439 # Note that this is the type, not an instance.
1441 [ForwardDeclaredBar]
1440 [ForwardDeclaredBar],
1441 None,
1442 ]
1442 ]
1443
1443
1444 class TestForwardDeclaredTypeList(TraitTestBase):
1444 class TestForwardDeclaredTypeList(TraitTestBase):
1445
1445
1446 obj = ForwardDeclaredTypeListTrait()
1446 obj = ForwardDeclaredTypeListTrait()
1447
1447
1448 def test_klass(self):
1448 def test_klass(self):
1449 """Test that the instance klass is properly assigned."""
1449 """Test that the instance klass is properly assigned."""
1450 self.assertIs(self.obj.traits()['value']._trait.klass, ForwardDeclaredBar)
1450 self.assertIs(self.obj.traits()['value']._trait.klass, ForwardDeclaredBar)
1451
1451
1452 _default_value = []
1452 _default_value = []
1453 _good_values = [
1453 _good_values = [
1454 [ForwardDeclaredBar, ForwardDeclaredBarSub, None],
1454 [ForwardDeclaredBar, ForwardDeclaredBarSub, None],
1455 [],
1455 [],
1456 [None],
1456 [None],
1457 None,
1458 ]
1457 ]
1459 _bad_values = [
1458 _bad_values = [
1460 ForwardDeclaredBar,
1459 ForwardDeclaredBar,
1461 [ForwardDeclaredBar, 3],
1460 [ForwardDeclaredBar, 3],
1462 '1',
1461 '1',
1463 # Note that this is an instance, not the type.
1462 # Note that this is an instance, not the type.
1464 [ForwardDeclaredBar()]
1463 [ForwardDeclaredBar()],
1464 None,
1465 ]
1465 ]
1466 ###
1466 ###
1467 # End Forward Declaration Tests
1467 # End Forward Declaration Tests
1468 ###
1468 ###
@@ -1,1734 +1,1732 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """
2 """
3 A lightweight Traits like module.
3 A lightweight Traits like module.
4
4
5 This is designed to provide a lightweight, simple, pure Python version of
5 This is designed to provide a lightweight, simple, pure Python version of
6 many of the capabilities of enthought.traits. This includes:
6 many of the capabilities of enthought.traits. This includes:
7
7
8 * Validation
8 * Validation
9 * Type specification with defaults
9 * Type specification with defaults
10 * Static and dynamic notification
10 * Static and dynamic notification
11 * Basic predefined types
11 * Basic predefined types
12 * An API that is similar to enthought.traits
12 * An API that is similar to enthought.traits
13
13
14 We don't support:
14 We don't support:
15
15
16 * Delegation
16 * Delegation
17 * Automatic GUI generation
17 * Automatic GUI generation
18 * A full set of trait types. Most importantly, we don't provide container
18 * A full set of trait types. Most importantly, we don't provide container
19 traits (list, dict, tuple) that can trigger notifications if their
19 traits (list, dict, tuple) that can trigger notifications if their
20 contents change.
20 contents change.
21 * API compatibility with enthought.traits
21 * API compatibility with enthought.traits
22
22
23 There are also some important difference in our design:
23 There are also some important difference in our design:
24
24
25 * enthought.traits does not validate default values. We do.
25 * enthought.traits does not validate default values. We do.
26
26
27 We choose to create this module because we need these capabilities, but
27 We choose to create this module because we need these capabilities, but
28 we need them to be pure Python so they work in all Python implementations,
28 we need them to be pure Python so they work in all Python implementations,
29 including Jython and IronPython.
29 including Jython and IronPython.
30
30
31 Inheritance diagram:
31 Inheritance diagram:
32
32
33 .. inheritance-diagram:: IPython.utils.traitlets
33 .. inheritance-diagram:: IPython.utils.traitlets
34 :parts: 3
34 :parts: 3
35 """
35 """
36
36
37 # Copyright (c) IPython Development Team.
37 # Copyright (c) IPython Development Team.
38 # Distributed under the terms of the Modified BSD License.
38 # Distributed under the terms of the Modified BSD License.
39 #
39 #
40 # Adapted from enthought.traits, Copyright (c) Enthought, Inc.,
40 # Adapted from enthought.traits, Copyright (c) Enthought, Inc.,
41 # also under the terms of the Modified BSD License.
41 # also under the terms of the Modified BSD License.
42
42
43 import contextlib
43 import contextlib
44 import inspect
44 import inspect
45 import re
45 import re
46 import sys
46 import sys
47 import types
47 import types
48 from types import FunctionType
48 from types import FunctionType
49 try:
49 try:
50 from types import ClassType, InstanceType
50 from types import ClassType, InstanceType
51 ClassTypes = (ClassType, type)
51 ClassTypes = (ClassType, type)
52 except:
52 except:
53 ClassTypes = (type,)
53 ClassTypes = (type,)
54
54
55 from .importstring import import_item
55 from .importstring import import_item
56 from IPython.utils import py3compat
56 from IPython.utils import py3compat
57 from IPython.utils import eventful
57 from IPython.utils import eventful
58 from IPython.utils.py3compat import iteritems, string_types
58 from IPython.utils.py3compat import iteritems, string_types
59 from IPython.testing.skipdoctest import skip_doctest
59 from IPython.testing.skipdoctest import skip_doctest
60
60
61 SequenceTypes = (list, tuple, set, frozenset)
61 SequenceTypes = (list, tuple, set, frozenset)
62
62
63 #-----------------------------------------------------------------------------
63 #-----------------------------------------------------------------------------
64 # Basic classes
64 # Basic classes
65 #-----------------------------------------------------------------------------
65 #-----------------------------------------------------------------------------
66
66
67
67
68 class NoDefaultSpecified ( object ): pass
68 class NoDefaultSpecified ( object ): pass
69 NoDefaultSpecified = NoDefaultSpecified()
69 NoDefaultSpecified = NoDefaultSpecified()
70
70
71
71
72 class Undefined ( object ): pass
72 class Undefined ( object ): pass
73 Undefined = Undefined()
73 Undefined = Undefined()
74
74
75 class TraitError(Exception):
75 class TraitError(Exception):
76 pass
76 pass
77
77
78 #-----------------------------------------------------------------------------
78 #-----------------------------------------------------------------------------
79 # Utilities
79 # Utilities
80 #-----------------------------------------------------------------------------
80 #-----------------------------------------------------------------------------
81
81
82
82
83 def class_of ( object ):
83 def class_of ( object ):
84 """ Returns a string containing the class name of an object with the
84 """ Returns a string containing the class name of an object with the
85 correct indefinite article ('a' or 'an') preceding it (e.g., 'an Image',
85 correct indefinite article ('a' or 'an') preceding it (e.g., 'an Image',
86 'a PlotValue').
86 'a PlotValue').
87 """
87 """
88 if isinstance( object, py3compat.string_types ):
88 if isinstance( object, py3compat.string_types ):
89 return add_article( object )
89 return add_article( object )
90
90
91 return add_article( object.__class__.__name__ )
91 return add_article( object.__class__.__name__ )
92
92
93
93
94 def add_article ( name ):
94 def add_article ( name ):
95 """ Returns a string containing the correct indefinite article ('a' or 'an')
95 """ Returns a string containing the correct indefinite article ('a' or 'an')
96 prefixed to the specified string.
96 prefixed to the specified string.
97 """
97 """
98 if name[:1].lower() in 'aeiou':
98 if name[:1].lower() in 'aeiou':
99 return 'an ' + name
99 return 'an ' + name
100
100
101 return 'a ' + name
101 return 'a ' + name
102
102
103
103
104 def repr_type(obj):
104 def repr_type(obj):
105 """ Return a string representation of a value and its type for readable
105 """ Return a string representation of a value and its type for readable
106 error messages.
106 error messages.
107 """
107 """
108 the_type = type(obj)
108 the_type = type(obj)
109 if (not py3compat.PY3) and the_type is InstanceType:
109 if (not py3compat.PY3) and the_type is InstanceType:
110 # Old-style class.
110 # Old-style class.
111 the_type = obj.__class__
111 the_type = obj.__class__
112 msg = '%r %r' % (obj, the_type)
112 msg = '%r %r' % (obj, the_type)
113 return msg
113 return msg
114
114
115
115
116 def is_trait(t):
116 def is_trait(t):
117 """ Returns whether the given value is an instance or subclass of TraitType.
117 """ Returns whether the given value is an instance or subclass of TraitType.
118 """
118 """
119 return (isinstance(t, TraitType) or
119 return (isinstance(t, TraitType) or
120 (isinstance(t, type) and issubclass(t, TraitType)))
120 (isinstance(t, type) and issubclass(t, TraitType)))
121
121
122
122
123 def parse_notifier_name(name):
123 def parse_notifier_name(name):
124 """Convert the name argument to a list of names.
124 """Convert the name argument to a list of names.
125
125
126 Examples
126 Examples
127 --------
127 --------
128
128
129 >>> parse_notifier_name('a')
129 >>> parse_notifier_name('a')
130 ['a']
130 ['a']
131 >>> parse_notifier_name(['a','b'])
131 >>> parse_notifier_name(['a','b'])
132 ['a', 'b']
132 ['a', 'b']
133 >>> parse_notifier_name(None)
133 >>> parse_notifier_name(None)
134 ['anytrait']
134 ['anytrait']
135 """
135 """
136 if isinstance(name, string_types):
136 if isinstance(name, string_types):
137 return [name]
137 return [name]
138 elif name is None:
138 elif name is None:
139 return ['anytrait']
139 return ['anytrait']
140 elif isinstance(name, (list, tuple)):
140 elif isinstance(name, (list, tuple)):
141 for n in name:
141 for n in name:
142 assert isinstance(n, string_types), "names must be strings"
142 assert isinstance(n, string_types), "names must be strings"
143 return name
143 return name
144
144
145
145
146 class _SimpleTest:
146 class _SimpleTest:
147 def __init__ ( self, value ): self.value = value
147 def __init__ ( self, value ): self.value = value
148 def __call__ ( self, test ):
148 def __call__ ( self, test ):
149 return test == self.value
149 return test == self.value
150 def __repr__(self):
150 def __repr__(self):
151 return "<SimpleTest(%r)" % self.value
151 return "<SimpleTest(%r)" % self.value
152 def __str__(self):
152 def __str__(self):
153 return self.__repr__()
153 return self.__repr__()
154
154
155
155
156 def getmembers(object, predicate=None):
156 def getmembers(object, predicate=None):
157 """A safe version of inspect.getmembers that handles missing attributes.
157 """A safe version of inspect.getmembers that handles missing attributes.
158
158
159 This is useful when there are descriptor based attributes that for
159 This is useful when there are descriptor based attributes that for
160 some reason raise AttributeError even though they exist. This happens
160 some reason raise AttributeError even though they exist. This happens
161 in zope.inteface with the __provides__ attribute.
161 in zope.inteface with the __provides__ attribute.
162 """
162 """
163 results = []
163 results = []
164 for key in dir(object):
164 for key in dir(object):
165 try:
165 try:
166 value = getattr(object, key)
166 value = getattr(object, key)
167 except AttributeError:
167 except AttributeError:
168 pass
168 pass
169 else:
169 else:
170 if not predicate or predicate(value):
170 if not predicate or predicate(value):
171 results.append((key, value))
171 results.append((key, value))
172 results.sort()
172 results.sort()
173 return results
173 return results
174
174
175 def _validate_link(*tuples):
175 def _validate_link(*tuples):
176 """Validate arguments for traitlet link functions"""
176 """Validate arguments for traitlet link functions"""
177 for t in tuples:
177 for t in tuples:
178 if not len(t) == 2:
178 if not len(t) == 2:
179 raise TypeError("Each linked traitlet must be specified as (HasTraits, 'trait_name'), not %r" % t)
179 raise TypeError("Each linked traitlet must be specified as (HasTraits, 'trait_name'), not %r" % t)
180 obj, trait_name = t
180 obj, trait_name = t
181 if not isinstance(obj, HasTraits):
181 if not isinstance(obj, HasTraits):
182 raise TypeError("Each object must be HasTraits, not %r" % type(obj))
182 raise TypeError("Each object must be HasTraits, not %r" % type(obj))
183 if not trait_name in obj.traits():
183 if not trait_name in obj.traits():
184 raise TypeError("%r has no trait %r" % (obj, trait_name))
184 raise TypeError("%r has no trait %r" % (obj, trait_name))
185
185
186 @skip_doctest
186 @skip_doctest
187 class link(object):
187 class link(object):
188 """Link traits from different objects together so they remain in sync.
188 """Link traits from different objects together so they remain in sync.
189
189
190 Parameters
190 Parameters
191 ----------
191 ----------
192 *args : pairs of objects/attributes
192 *args : pairs of objects/attributes
193
193
194 Examples
194 Examples
195 --------
195 --------
196
196
197 >>> c = link((obj1, 'value'), (obj2, 'value'), (obj3, 'value'))
197 >>> c = link((obj1, 'value'), (obj2, 'value'), (obj3, 'value'))
198 >>> obj1.value = 5 # updates other objects as well
198 >>> obj1.value = 5 # updates other objects as well
199 """
199 """
200 updating = False
200 updating = False
201 def __init__(self, *args):
201 def __init__(self, *args):
202 if len(args) < 2:
202 if len(args) < 2:
203 raise TypeError('At least two traitlets must be provided.')
203 raise TypeError('At least two traitlets must be provided.')
204 _validate_link(*args)
204 _validate_link(*args)
205
205
206 self.objects = {}
206 self.objects = {}
207
207
208 initial = getattr(args[0][0], args[0][1])
208 initial = getattr(args[0][0], args[0][1])
209 for obj, attr in args:
209 for obj, attr in args:
210 setattr(obj, attr, initial)
210 setattr(obj, attr, initial)
211
211
212 callback = self._make_closure(obj, attr)
212 callback = self._make_closure(obj, attr)
213 obj.on_trait_change(callback, attr)
213 obj.on_trait_change(callback, attr)
214 self.objects[(obj, attr)] = callback
214 self.objects[(obj, attr)] = callback
215
215
216 @contextlib.contextmanager
216 @contextlib.contextmanager
217 def _busy_updating(self):
217 def _busy_updating(self):
218 self.updating = True
218 self.updating = True
219 try:
219 try:
220 yield
220 yield
221 finally:
221 finally:
222 self.updating = False
222 self.updating = False
223
223
224 def _make_closure(self, sending_obj, sending_attr):
224 def _make_closure(self, sending_obj, sending_attr):
225 def update(name, old, new):
225 def update(name, old, new):
226 self._update(sending_obj, sending_attr, new)
226 self._update(sending_obj, sending_attr, new)
227 return update
227 return update
228
228
229 def _update(self, sending_obj, sending_attr, new):
229 def _update(self, sending_obj, sending_attr, new):
230 if self.updating:
230 if self.updating:
231 return
231 return
232 with self._busy_updating():
232 with self._busy_updating():
233 for obj, attr in self.objects.keys():
233 for obj, attr in self.objects.keys():
234 setattr(obj, attr, new)
234 setattr(obj, attr, new)
235
235
236 def unlink(self):
236 def unlink(self):
237 for key, callback in self.objects.items():
237 for key, callback in self.objects.items():
238 (obj, attr) = key
238 (obj, attr) = key
239 obj.on_trait_change(callback, attr, remove=True)
239 obj.on_trait_change(callback, attr, remove=True)
240
240
241 @skip_doctest
241 @skip_doctest
242 class directional_link(object):
242 class directional_link(object):
243 """Link the trait of a source object with traits of target objects.
243 """Link the trait of a source object with traits of target objects.
244
244
245 Parameters
245 Parameters
246 ----------
246 ----------
247 source : pair of object, name
247 source : pair of object, name
248 targets : pairs of objects/attributes
248 targets : pairs of objects/attributes
249
249
250 Examples
250 Examples
251 --------
251 --------
252
252
253 >>> c = directional_link((src, 'value'), (tgt1, 'value'), (tgt2, 'value'))
253 >>> c = directional_link((src, 'value'), (tgt1, 'value'), (tgt2, 'value'))
254 >>> src.value = 5 # updates target objects
254 >>> src.value = 5 # updates target objects
255 >>> tgt1.value = 6 # does not update other objects
255 >>> tgt1.value = 6 # does not update other objects
256 """
256 """
257 updating = False
257 updating = False
258
258
259 def __init__(self, source, *targets):
259 def __init__(self, source, *targets):
260 if len(targets) < 1:
260 if len(targets) < 1:
261 raise TypeError('At least two traitlets must be provided.')
261 raise TypeError('At least two traitlets must be provided.')
262 _validate_link(source, *targets)
262 _validate_link(source, *targets)
263 self.source = source
263 self.source = source
264 self.targets = targets
264 self.targets = targets
265
265
266 # Update current value
266 # Update current value
267 src_attr_value = getattr(source[0], source[1])
267 src_attr_value = getattr(source[0], source[1])
268 for obj, attr in targets:
268 for obj, attr in targets:
269 setattr(obj, attr, src_attr_value)
269 setattr(obj, attr, src_attr_value)
270
270
271 # Wire
271 # Wire
272 self.source[0].on_trait_change(self._update, self.source[1])
272 self.source[0].on_trait_change(self._update, self.source[1])
273
273
274 @contextlib.contextmanager
274 @contextlib.contextmanager
275 def _busy_updating(self):
275 def _busy_updating(self):
276 self.updating = True
276 self.updating = True
277 try:
277 try:
278 yield
278 yield
279 finally:
279 finally:
280 self.updating = False
280 self.updating = False
281
281
282 def _update(self, name, old, new):
282 def _update(self, name, old, new):
283 if self.updating:
283 if self.updating:
284 return
284 return
285 with self._busy_updating():
285 with self._busy_updating():
286 for obj, attr in self.targets:
286 for obj, attr in self.targets:
287 setattr(obj, attr, new)
287 setattr(obj, attr, new)
288
288
289 def unlink(self):
289 def unlink(self):
290 self.source[0].on_trait_change(self._update, self.source[1], remove=True)
290 self.source[0].on_trait_change(self._update, self.source[1], remove=True)
291 self.source = None
291 self.source = None
292 self.targets = []
292 self.targets = []
293
293
294 dlink = directional_link
294 dlink = directional_link
295
295
296 #-----------------------------------------------------------------------------
296 #-----------------------------------------------------------------------------
297 # Base TraitType for all traits
297 # Base TraitType for all traits
298 #-----------------------------------------------------------------------------
298 #-----------------------------------------------------------------------------
299
299
300
300
301 class TraitType(object):
301 class TraitType(object):
302 """A base class for all trait descriptors.
302 """A base class for all trait descriptors.
303
303
304 Notes
304 Notes
305 -----
305 -----
306 Our implementation of traits is based on Python's descriptor
306 Our implementation of traits is based on Python's descriptor
307 prototol. This class is the base class for all such descriptors. The
307 prototol. This class is the base class for all such descriptors. The
308 only magic we use is a custom metaclass for the main :class:`HasTraits`
308 only magic we use is a custom metaclass for the main :class:`HasTraits`
309 class that does the following:
309 class that does the following:
310
310
311 1. Sets the :attr:`name` attribute of every :class:`TraitType`
311 1. Sets the :attr:`name` attribute of every :class:`TraitType`
312 instance in the class dict to the name of the attribute.
312 instance in the class dict to the name of the attribute.
313 2. Sets the :attr:`this_class` attribute of every :class:`TraitType`
313 2. Sets the :attr:`this_class` attribute of every :class:`TraitType`
314 instance in the class dict to the *class* that declared the trait.
314 instance in the class dict to the *class* that declared the trait.
315 This is used by the :class:`This` trait to allow subclasses to
315 This is used by the :class:`This` trait to allow subclasses to
316 accept superclasses for :class:`This` values.
316 accept superclasses for :class:`This` values.
317 """
317 """
318
318
319
319
320 metadata = {}
320 metadata = {}
321 default_value = Undefined
321 default_value = Undefined
322 allow_none = False
322 allow_none = False
323 info_text = 'any value'
323 info_text = 'any value'
324
324
325 def __init__(self, default_value=NoDefaultSpecified, allow_none=None, **metadata):
325 def __init__(self, default_value=NoDefaultSpecified, allow_none=None, **metadata):
326 """Create a TraitType.
326 """Create a TraitType.
327 """
327 """
328 if default_value is not NoDefaultSpecified:
328 if default_value is not NoDefaultSpecified:
329 self.default_value = default_value
329 self.default_value = default_value
330 if allow_none is not None:
330 if allow_none is not None:
331 self.allow_none = allow_none
331 self.allow_none = allow_none
332
332
333 if len(metadata) > 0:
333 if len(metadata) > 0:
334 if len(self.metadata) > 0:
334 if len(self.metadata) > 0:
335 self._metadata = self.metadata.copy()
335 self._metadata = self.metadata.copy()
336 self._metadata.update(metadata)
336 self._metadata.update(metadata)
337 else:
337 else:
338 self._metadata = metadata
338 self._metadata = metadata
339 else:
339 else:
340 self._metadata = self.metadata
340 self._metadata = self.metadata
341
341
342 self.init()
342 self.init()
343
343
344 def init(self):
344 def init(self):
345 pass
345 pass
346
346
347 def get_default_value(self):
347 def get_default_value(self):
348 """Create a new instance of the default value."""
348 """Create a new instance of the default value."""
349 return self.default_value
349 return self.default_value
350
350
351 def instance_init(self, obj):
351 def instance_init(self, obj):
352 """This is called by :meth:`HasTraits.__new__` to finish init'ing.
352 """This is called by :meth:`HasTraits.__new__` to finish init'ing.
353
353
354 Some stages of initialization must be delayed until the parent
354 Some stages of initialization must be delayed until the parent
355 :class:`HasTraits` instance has been created. This method is
355 :class:`HasTraits` instance has been created. This method is
356 called in :meth:`HasTraits.__new__` after the instance has been
356 called in :meth:`HasTraits.__new__` after the instance has been
357 created.
357 created.
358
358
359 This method trigger the creation and validation of default values
359 This method trigger the creation and validation of default values
360 and also things like the resolution of str given class names in
360 and also things like the resolution of str given class names in
361 :class:`Type` and :class`Instance`.
361 :class:`Type` and :class`Instance`.
362
362
363 Parameters
363 Parameters
364 ----------
364 ----------
365 obj : :class:`HasTraits` instance
365 obj : :class:`HasTraits` instance
366 The parent :class:`HasTraits` instance that has just been
366 The parent :class:`HasTraits` instance that has just been
367 created.
367 created.
368 """
368 """
369 self.set_default_value(obj)
369 self.set_default_value(obj)
370
370
371 def set_default_value(self, obj):
371 def set_default_value(self, obj):
372 """Set the default value on a per instance basis.
372 """Set the default value on a per instance basis.
373
373
374 This method is called by :meth:`instance_init` to create and
374 This method is called by :meth:`instance_init` to create and
375 validate the default value. The creation and validation of
375 validate the default value. The creation and validation of
376 default values must be delayed until the parent :class:`HasTraits`
376 default values must be delayed until the parent :class:`HasTraits`
377 class has been instantiated.
377 class has been instantiated.
378 """
378 """
379 # Check for a deferred initializer defined in the same class as the
379 # Check for a deferred initializer defined in the same class as the
380 # trait declaration or above.
380 # trait declaration or above.
381 mro = type(obj).mro()
381 mro = type(obj).mro()
382 meth_name = '_%s_default' % self.name
382 meth_name = '_%s_default' % self.name
383 for cls in mro[:mro.index(self.this_class)+1]:
383 for cls in mro[:mro.index(self.this_class)+1]:
384 if meth_name in cls.__dict__:
384 if meth_name in cls.__dict__:
385 break
385 break
386 else:
386 else:
387 # We didn't find one. Do static initialization.
387 # We didn't find one. Do static initialization.
388 dv = self.get_default_value()
388 dv = self.get_default_value()
389 newdv = self._validate(obj, dv)
389 newdv = self._validate(obj, dv)
390 obj._trait_values[self.name] = newdv
390 obj._trait_values[self.name] = newdv
391 return
391 return
392 # Complete the dynamic initialization.
392 # Complete the dynamic initialization.
393 obj._trait_dyn_inits[self.name] = meth_name
393 obj._trait_dyn_inits[self.name] = meth_name
394
394
395 def __get__(self, obj, cls=None):
395 def __get__(self, obj, cls=None):
396 """Get the value of the trait by self.name for the instance.
396 """Get the value of the trait by self.name for the instance.
397
397
398 Default values are instantiated when :meth:`HasTraits.__new__`
398 Default values are instantiated when :meth:`HasTraits.__new__`
399 is called. Thus by the time this method gets called either the
399 is called. Thus by the time this method gets called either the
400 default value or a user defined value (they called :meth:`__set__`)
400 default value or a user defined value (they called :meth:`__set__`)
401 is in the :class:`HasTraits` instance.
401 is in the :class:`HasTraits` instance.
402 """
402 """
403 if obj is None:
403 if obj is None:
404 return self
404 return self
405 else:
405 else:
406 try:
406 try:
407 value = obj._trait_values[self.name]
407 value = obj._trait_values[self.name]
408 except KeyError:
408 except KeyError:
409 # Check for a dynamic initializer.
409 # Check for a dynamic initializer.
410 if self.name in obj._trait_dyn_inits:
410 if self.name in obj._trait_dyn_inits:
411 method = getattr(obj, obj._trait_dyn_inits[self.name])
411 method = getattr(obj, obj._trait_dyn_inits[self.name])
412 value = method()
412 value = method()
413 # FIXME: Do we really validate here?
413 # FIXME: Do we really validate here?
414 value = self._validate(obj, value)
414 value = self._validate(obj, value)
415 obj._trait_values[self.name] = value
415 obj._trait_values[self.name] = value
416 return value
416 return value
417 else:
417 else:
418 raise TraitError('Unexpected error in TraitType: '
418 raise TraitError('Unexpected error in TraitType: '
419 'both default value and dynamic initializer are '
419 'both default value and dynamic initializer are '
420 'absent.')
420 'absent.')
421 except Exception:
421 except Exception:
422 # HasTraits should call set_default_value to populate
422 # HasTraits should call set_default_value to populate
423 # this. So this should never be reached.
423 # this. So this should never be reached.
424 raise TraitError('Unexpected error in TraitType: '
424 raise TraitError('Unexpected error in TraitType: '
425 'default value not set properly')
425 'default value not set properly')
426 else:
426 else:
427 return value
427 return value
428
428
429 def __set__(self, obj, value):
429 def __set__(self, obj, value):
430 new_value = self._validate(obj, value)
430 new_value = self._validate(obj, value)
431 try:
431 try:
432 old_value = obj._trait_values[self.name]
432 old_value = obj._trait_values[self.name]
433 except KeyError:
433 except KeyError:
434 old_value = None
434 old_value = None
435
435
436 obj._trait_values[self.name] = new_value
436 obj._trait_values[self.name] = new_value
437 try:
437 try:
438 silent = bool(old_value == new_value)
438 silent = bool(old_value == new_value)
439 except:
439 except:
440 # if there is an error in comparing, default to notify
440 # if there is an error in comparing, default to notify
441 silent = False
441 silent = False
442 if silent is not True:
442 if silent is not True:
443 # we explicitly compare silent to True just in case the equality
443 # we explicitly compare silent to True just in case the equality
444 # comparison above returns something other than True/False
444 # comparison above returns something other than True/False
445 obj._notify_trait(self.name, old_value, new_value)
445 obj._notify_trait(self.name, old_value, new_value)
446
446
447 def _validate(self, obj, value):
447 def _validate(self, obj, value):
448 if value is None and self.allow_none:
448 if value is None and self.allow_none:
449 return value
449 return value
450 if hasattr(self, 'validate'):
450 if hasattr(self, 'validate'):
451 value = self.validate(obj, value)
451 value = self.validate(obj, value)
452 if hasattr(obj, '_%s_validate' % self.name):
452 if hasattr(obj, '_%s_validate' % self.name):
453 value = getattr(obj, '_%s_validate' % self.name)(value, self)
453 value = getattr(obj, '_%s_validate' % self.name)(value, self)
454 return value
454 return value
455
455
456 def __or__(self, other):
456 def __or__(self, other):
457 if isinstance(other, Union):
457 if isinstance(other, Union):
458 return Union([self] + other.trait_types)
458 return Union([self] + other.trait_types)
459 else:
459 else:
460 return Union([self, other])
460 return Union([self, other])
461
461
462 def info(self):
462 def info(self):
463 return self.info_text
463 return self.info_text
464
464
465 def error(self, obj, value):
465 def error(self, obj, value):
466 if obj is not None:
466 if obj is not None:
467 e = "The '%s' trait of %s instance must be %s, but a value of %s was specified." \
467 e = "The '%s' trait of %s instance must be %s, but a value of %s was specified." \
468 % (self.name, class_of(obj),
468 % (self.name, class_of(obj),
469 self.info(), repr_type(value))
469 self.info(), repr_type(value))
470 else:
470 else:
471 e = "The '%s' trait must be %s, but a value of %r was specified." \
471 e = "The '%s' trait must be %s, but a value of %r was specified." \
472 % (self.name, self.info(), repr_type(value))
472 % (self.name, self.info(), repr_type(value))
473 raise TraitError(e)
473 raise TraitError(e)
474
474
475 def get_metadata(self, key, default=None):
475 def get_metadata(self, key, default=None):
476 return getattr(self, '_metadata', {}).get(key, default)
476 return getattr(self, '_metadata', {}).get(key, default)
477
477
478 def set_metadata(self, key, value):
478 def set_metadata(self, key, value):
479 getattr(self, '_metadata', {})[key] = value
479 getattr(self, '_metadata', {})[key] = value
480
480
481
481
482 #-----------------------------------------------------------------------------
482 #-----------------------------------------------------------------------------
483 # The HasTraits implementation
483 # The HasTraits implementation
484 #-----------------------------------------------------------------------------
484 #-----------------------------------------------------------------------------
485
485
486
486
487 class MetaHasTraits(type):
487 class MetaHasTraits(type):
488 """A metaclass for HasTraits.
488 """A metaclass for HasTraits.
489
489
490 This metaclass makes sure that any TraitType class attributes are
490 This metaclass makes sure that any TraitType class attributes are
491 instantiated and sets their name attribute.
491 instantiated and sets their name attribute.
492 """
492 """
493
493
494 def __new__(mcls, name, bases, classdict):
494 def __new__(mcls, name, bases, classdict):
495 """Create the HasTraits class.
495 """Create the HasTraits class.
496
496
497 This instantiates all TraitTypes in the class dict and sets their
497 This instantiates all TraitTypes in the class dict and sets their
498 :attr:`name` attribute.
498 :attr:`name` attribute.
499 """
499 """
500 # print "MetaHasTraitlets (mcls, name): ", mcls, name
500 # print "MetaHasTraitlets (mcls, name): ", mcls, name
501 # print "MetaHasTraitlets (bases): ", bases
501 # print "MetaHasTraitlets (bases): ", bases
502 # print "MetaHasTraitlets (classdict): ", classdict
502 # print "MetaHasTraitlets (classdict): ", classdict
503 for k,v in iteritems(classdict):
503 for k,v in iteritems(classdict):
504 if isinstance(v, TraitType):
504 if isinstance(v, TraitType):
505 v.name = k
505 v.name = k
506 elif inspect.isclass(v):
506 elif inspect.isclass(v):
507 if issubclass(v, TraitType):
507 if issubclass(v, TraitType):
508 vinst = v()
508 vinst = v()
509 vinst.name = k
509 vinst.name = k
510 classdict[k] = vinst
510 classdict[k] = vinst
511 return super(MetaHasTraits, mcls).__new__(mcls, name, bases, classdict)
511 return super(MetaHasTraits, mcls).__new__(mcls, name, bases, classdict)
512
512
513 def __init__(cls, name, bases, classdict):
513 def __init__(cls, name, bases, classdict):
514 """Finish initializing the HasTraits class.
514 """Finish initializing the HasTraits class.
515
515
516 This sets the :attr:`this_class` attribute of each TraitType in the
516 This sets the :attr:`this_class` attribute of each TraitType in the
517 class dict to the newly created class ``cls``.
517 class dict to the newly created class ``cls``.
518 """
518 """
519 for k, v in iteritems(classdict):
519 for k, v in iteritems(classdict):
520 if isinstance(v, TraitType):
520 if isinstance(v, TraitType):
521 v.this_class = cls
521 v.this_class = cls
522 super(MetaHasTraits, cls).__init__(name, bases, classdict)
522 super(MetaHasTraits, cls).__init__(name, bases, classdict)
523
523
524 class HasTraits(py3compat.with_metaclass(MetaHasTraits, object)):
524 class HasTraits(py3compat.with_metaclass(MetaHasTraits, object)):
525
525
526 def __new__(cls, *args, **kw):
526 def __new__(cls, *args, **kw):
527 # This is needed because object.__new__ only accepts
527 # This is needed because object.__new__ only accepts
528 # the cls argument.
528 # the cls argument.
529 new_meth = super(HasTraits, cls).__new__
529 new_meth = super(HasTraits, cls).__new__
530 if new_meth is object.__new__:
530 if new_meth is object.__new__:
531 inst = new_meth(cls)
531 inst = new_meth(cls)
532 else:
532 else:
533 inst = new_meth(cls, **kw)
533 inst = new_meth(cls, **kw)
534 inst._trait_values = {}
534 inst._trait_values = {}
535 inst._trait_notifiers = {}
535 inst._trait_notifiers = {}
536 inst._trait_dyn_inits = {}
536 inst._trait_dyn_inits = {}
537 # Here we tell all the TraitType instances to set their default
537 # Here we tell all the TraitType instances to set their default
538 # values on the instance.
538 # values on the instance.
539 for key in dir(cls):
539 for key in dir(cls):
540 # Some descriptors raise AttributeError like zope.interface's
540 # Some descriptors raise AttributeError like zope.interface's
541 # __provides__ attributes even though they exist. This causes
541 # __provides__ attributes even though they exist. This causes
542 # AttributeErrors even though they are listed in dir(cls).
542 # AttributeErrors even though they are listed in dir(cls).
543 try:
543 try:
544 value = getattr(cls, key)
544 value = getattr(cls, key)
545 except AttributeError:
545 except AttributeError:
546 pass
546 pass
547 else:
547 else:
548 if isinstance(value, TraitType):
548 if isinstance(value, TraitType):
549 value.instance_init(inst)
549 value.instance_init(inst)
550
550
551 return inst
551 return inst
552
552
553 def __init__(self, *args, **kw):
553 def __init__(self, *args, **kw):
554 # Allow trait values to be set using keyword arguments.
554 # Allow trait values to be set using keyword arguments.
555 # We need to use setattr for this to trigger validation and
555 # We need to use setattr for this to trigger validation and
556 # notifications.
556 # notifications.
557 for key, value in iteritems(kw):
557 for key, value in iteritems(kw):
558 setattr(self, key, value)
558 setattr(self, key, value)
559
559
560 def _notify_trait(self, name, old_value, new_value):
560 def _notify_trait(self, name, old_value, new_value):
561
561
562 # First dynamic ones
562 # First dynamic ones
563 callables = []
563 callables = []
564 callables.extend(self._trait_notifiers.get(name,[]))
564 callables.extend(self._trait_notifiers.get(name,[]))
565 callables.extend(self._trait_notifiers.get('anytrait',[]))
565 callables.extend(self._trait_notifiers.get('anytrait',[]))
566
566
567 # Now static ones
567 # Now static ones
568 try:
568 try:
569 cb = getattr(self, '_%s_changed' % name)
569 cb = getattr(self, '_%s_changed' % name)
570 except:
570 except:
571 pass
571 pass
572 else:
572 else:
573 callables.append(cb)
573 callables.append(cb)
574
574
575 # Call them all now
575 # Call them all now
576 for c in callables:
576 for c in callables:
577 # Traits catches and logs errors here. I allow them to raise
577 # Traits catches and logs errors here. I allow them to raise
578 if callable(c):
578 if callable(c):
579 argspec = inspect.getargspec(c)
579 argspec = inspect.getargspec(c)
580 nargs = len(argspec[0])
580 nargs = len(argspec[0])
581 # Bound methods have an additional 'self' argument
581 # Bound methods have an additional 'self' argument
582 # I don't know how to treat unbound methods, but they
582 # I don't know how to treat unbound methods, but they
583 # can't really be used for callbacks.
583 # can't really be used for callbacks.
584 if isinstance(c, types.MethodType):
584 if isinstance(c, types.MethodType):
585 offset = -1
585 offset = -1
586 else:
586 else:
587 offset = 0
587 offset = 0
588 if nargs + offset == 0:
588 if nargs + offset == 0:
589 c()
589 c()
590 elif nargs + offset == 1:
590 elif nargs + offset == 1:
591 c(name)
591 c(name)
592 elif nargs + offset == 2:
592 elif nargs + offset == 2:
593 c(name, new_value)
593 c(name, new_value)
594 elif nargs + offset == 3:
594 elif nargs + offset == 3:
595 c(name, old_value, new_value)
595 c(name, old_value, new_value)
596 else:
596 else:
597 raise TraitError('a trait changed callback '
597 raise TraitError('a trait changed callback '
598 'must have 0-3 arguments.')
598 'must have 0-3 arguments.')
599 else:
599 else:
600 raise TraitError('a trait changed callback '
600 raise TraitError('a trait changed callback '
601 'must be callable.')
601 'must be callable.')
602
602
603
603
604 def _add_notifiers(self, handler, name):
604 def _add_notifiers(self, handler, name):
605 if name not in self._trait_notifiers:
605 if name not in self._trait_notifiers:
606 nlist = []
606 nlist = []
607 self._trait_notifiers[name] = nlist
607 self._trait_notifiers[name] = nlist
608 else:
608 else:
609 nlist = self._trait_notifiers[name]
609 nlist = self._trait_notifiers[name]
610 if handler not in nlist:
610 if handler not in nlist:
611 nlist.append(handler)
611 nlist.append(handler)
612
612
613 def _remove_notifiers(self, handler, name):
613 def _remove_notifiers(self, handler, name):
614 if name in self._trait_notifiers:
614 if name in self._trait_notifiers:
615 nlist = self._trait_notifiers[name]
615 nlist = self._trait_notifiers[name]
616 try:
616 try:
617 index = nlist.index(handler)
617 index = nlist.index(handler)
618 except ValueError:
618 except ValueError:
619 pass
619 pass
620 else:
620 else:
621 del nlist[index]
621 del nlist[index]
622
622
623 def on_trait_change(self, handler, name=None, remove=False):
623 def on_trait_change(self, handler, name=None, remove=False):
624 """Setup a handler to be called when a trait changes.
624 """Setup a handler to be called when a trait changes.
625
625
626 This is used to setup dynamic notifications of trait changes.
626 This is used to setup dynamic notifications of trait changes.
627
627
628 Static handlers can be created by creating methods on a HasTraits
628 Static handlers can be created by creating methods on a HasTraits
629 subclass with the naming convention '_[traitname]_changed'. Thus,
629 subclass with the naming convention '_[traitname]_changed'. Thus,
630 to create static handler for the trait 'a', create the method
630 to create static handler for the trait 'a', create the method
631 _a_changed(self, name, old, new) (fewer arguments can be used, see
631 _a_changed(self, name, old, new) (fewer arguments can be used, see
632 below).
632 below).
633
633
634 Parameters
634 Parameters
635 ----------
635 ----------
636 handler : callable
636 handler : callable
637 A callable that is called when a trait changes. Its
637 A callable that is called when a trait changes. Its
638 signature can be handler(), handler(name), handler(name, new)
638 signature can be handler(), handler(name), handler(name, new)
639 or handler(name, old, new).
639 or handler(name, old, new).
640 name : list, str, None
640 name : list, str, None
641 If None, the handler will apply to all traits. If a list
641 If None, the handler will apply to all traits. If a list
642 of str, handler will apply to all names in the list. If a
642 of str, handler will apply to all names in the list. If a
643 str, the handler will apply just to that name.
643 str, the handler will apply just to that name.
644 remove : bool
644 remove : bool
645 If False (the default), then install the handler. If True
645 If False (the default), then install the handler. If True
646 then unintall it.
646 then unintall it.
647 """
647 """
648 if remove:
648 if remove:
649 names = parse_notifier_name(name)
649 names = parse_notifier_name(name)
650 for n in names:
650 for n in names:
651 self._remove_notifiers(handler, n)
651 self._remove_notifiers(handler, n)
652 else:
652 else:
653 names = parse_notifier_name(name)
653 names = parse_notifier_name(name)
654 for n in names:
654 for n in names:
655 self._add_notifiers(handler, n)
655 self._add_notifiers(handler, n)
656
656
657 @classmethod
657 @classmethod
658 def class_trait_names(cls, **metadata):
658 def class_trait_names(cls, **metadata):
659 """Get a list of all the names of this class' traits.
659 """Get a list of all the names of this class' traits.
660
660
661 This method is just like the :meth:`trait_names` method,
661 This method is just like the :meth:`trait_names` method,
662 but is unbound.
662 but is unbound.
663 """
663 """
664 return cls.class_traits(**metadata).keys()
664 return cls.class_traits(**metadata).keys()
665
665
666 @classmethod
666 @classmethod
667 def class_traits(cls, **metadata):
667 def class_traits(cls, **metadata):
668 """Get a `dict` of all the traits of this class. The dictionary
668 """Get a `dict` of all the traits of this class. The dictionary
669 is keyed on the name and the values are the TraitType objects.
669 is keyed on the name and the values are the TraitType objects.
670
670
671 This method is just like the :meth:`traits` method, but is unbound.
671 This method is just like the :meth:`traits` method, but is unbound.
672
672
673 The TraitTypes returned don't know anything about the values
673 The TraitTypes returned don't know anything about the values
674 that the various HasTrait's instances are holding.
674 that the various HasTrait's instances are holding.
675
675
676 The metadata kwargs allow functions to be passed in which
676 The metadata kwargs allow functions to be passed in which
677 filter traits based on metadata values. The functions should
677 filter traits based on metadata values. The functions should
678 take a single value as an argument and return a boolean. If
678 take a single value as an argument and return a boolean. If
679 any function returns False, then the trait is not included in
679 any function returns False, then the trait is not included in
680 the output. This does not allow for any simple way of
680 the output. This does not allow for any simple way of
681 testing that a metadata name exists and has any
681 testing that a metadata name exists and has any
682 value because get_metadata returns None if a metadata key
682 value because get_metadata returns None if a metadata key
683 doesn't exist.
683 doesn't exist.
684 """
684 """
685 traits = dict([memb for memb in getmembers(cls) if
685 traits = dict([memb for memb in getmembers(cls) if
686 isinstance(memb[1], TraitType)])
686 isinstance(memb[1], TraitType)])
687
687
688 if len(metadata) == 0:
688 if len(metadata) == 0:
689 return traits
689 return traits
690
690
691 for meta_name, meta_eval in metadata.items():
691 for meta_name, meta_eval in metadata.items():
692 if type(meta_eval) is not FunctionType:
692 if type(meta_eval) is not FunctionType:
693 metadata[meta_name] = _SimpleTest(meta_eval)
693 metadata[meta_name] = _SimpleTest(meta_eval)
694
694
695 result = {}
695 result = {}
696 for name, trait in traits.items():
696 for name, trait in traits.items():
697 for meta_name, meta_eval in metadata.items():
697 for meta_name, meta_eval in metadata.items():
698 if not meta_eval(trait.get_metadata(meta_name)):
698 if not meta_eval(trait.get_metadata(meta_name)):
699 break
699 break
700 else:
700 else:
701 result[name] = trait
701 result[name] = trait
702
702
703 return result
703 return result
704
704
705 def trait_names(self, **metadata):
705 def trait_names(self, **metadata):
706 """Get a list of all the names of this class' traits."""
706 """Get a list of all the names of this class' traits."""
707 return self.traits(**metadata).keys()
707 return self.traits(**metadata).keys()
708
708
709 def traits(self, **metadata):
709 def traits(self, **metadata):
710 """Get a `dict` of all the traits of this class. The dictionary
710 """Get a `dict` of all the traits of this class. The dictionary
711 is keyed on the name and the values are the TraitType objects.
711 is keyed on the name and the values are the TraitType objects.
712
712
713 The TraitTypes returned don't know anything about the values
713 The TraitTypes returned don't know anything about the values
714 that the various HasTrait's instances are holding.
714 that the various HasTrait's instances are holding.
715
715
716 The metadata kwargs allow functions to be passed in which
716 The metadata kwargs allow functions to be passed in which
717 filter traits based on metadata values. The functions should
717 filter traits based on metadata values. The functions should
718 take a single value as an argument and return a boolean. If
718 take a single value as an argument and return a boolean. If
719 any function returns False, then the trait is not included in
719 any function returns False, then the trait is not included in
720 the output. This does not allow for any simple way of
720 the output. This does not allow for any simple way of
721 testing that a metadata name exists and has any
721 testing that a metadata name exists and has any
722 value because get_metadata returns None if a metadata key
722 value because get_metadata returns None if a metadata key
723 doesn't exist.
723 doesn't exist.
724 """
724 """
725 traits = dict([memb for memb in getmembers(self.__class__) if
725 traits = dict([memb for memb in getmembers(self.__class__) if
726 isinstance(memb[1], TraitType)])
726 isinstance(memb[1], TraitType)])
727
727
728 if len(metadata) == 0:
728 if len(metadata) == 0:
729 return traits
729 return traits
730
730
731 for meta_name, meta_eval in metadata.items():
731 for meta_name, meta_eval in metadata.items():
732 if type(meta_eval) is not FunctionType:
732 if type(meta_eval) is not FunctionType:
733 metadata[meta_name] = _SimpleTest(meta_eval)
733 metadata[meta_name] = _SimpleTest(meta_eval)
734
734
735 result = {}
735 result = {}
736 for name, trait in traits.items():
736 for name, trait in traits.items():
737 for meta_name, meta_eval in metadata.items():
737 for meta_name, meta_eval in metadata.items():
738 if not meta_eval(trait.get_metadata(meta_name)):
738 if not meta_eval(trait.get_metadata(meta_name)):
739 break
739 break
740 else:
740 else:
741 result[name] = trait
741 result[name] = trait
742
742
743 return result
743 return result
744
744
745 def trait_metadata(self, traitname, key, default=None):
745 def trait_metadata(self, traitname, key, default=None):
746 """Get metadata values for trait by key."""
746 """Get metadata values for trait by key."""
747 try:
747 try:
748 trait = getattr(self.__class__, traitname)
748 trait = getattr(self.__class__, traitname)
749 except AttributeError:
749 except AttributeError:
750 raise TraitError("Class %s does not have a trait named %s" %
750 raise TraitError("Class %s does not have a trait named %s" %
751 (self.__class__.__name__, traitname))
751 (self.__class__.__name__, traitname))
752 else:
752 else:
753 return trait.get_metadata(key, default)
753 return trait.get_metadata(key, default)
754
754
755 #-----------------------------------------------------------------------------
755 #-----------------------------------------------------------------------------
756 # Actual TraitTypes implementations/subclasses
756 # Actual TraitTypes implementations/subclasses
757 #-----------------------------------------------------------------------------
757 #-----------------------------------------------------------------------------
758
758
759 #-----------------------------------------------------------------------------
759 #-----------------------------------------------------------------------------
760 # TraitTypes subclasses for handling classes and instances of classes
760 # TraitTypes subclasses for handling classes and instances of classes
761 #-----------------------------------------------------------------------------
761 #-----------------------------------------------------------------------------
762
762
763
763
764 class ClassBasedTraitType(TraitType):
764 class ClassBasedTraitType(TraitType):
765 """
765 """
766 A trait with error reporting and string -> type resolution for Type,
766 A trait with error reporting and string -> type resolution for Type,
767 Instance and This.
767 Instance and This.
768 """
768 """
769
769
770 def _resolve_string(self, string):
770 def _resolve_string(self, string):
771 """
771 """
772 Resolve a string supplied for a type into an actual object.
772 Resolve a string supplied for a type into an actual object.
773 """
773 """
774 return import_item(string)
774 return import_item(string)
775
775
776 def error(self, obj, value):
776 def error(self, obj, value):
777 kind = type(value)
777 kind = type(value)
778 if (not py3compat.PY3) and kind is InstanceType:
778 if (not py3compat.PY3) and kind is InstanceType:
779 msg = 'class %s' % value.__class__.__name__
779 msg = 'class %s' % value.__class__.__name__
780 else:
780 else:
781 msg = '%s (i.e. %s)' % ( str( kind )[1:-1], repr( value ) )
781 msg = '%s (i.e. %s)' % ( str( kind )[1:-1], repr( value ) )
782
782
783 if obj is not None:
783 if obj is not None:
784 e = "The '%s' trait of %s instance must be %s, but a value of %s was specified." \
784 e = "The '%s' trait of %s instance must be %s, but a value of %s was specified." \
785 % (self.name, class_of(obj),
785 % (self.name, class_of(obj),
786 self.info(), msg)
786 self.info(), msg)
787 else:
787 else:
788 e = "The '%s' trait must be %s, but a value of %r was specified." \
788 e = "The '%s' trait must be %s, but a value of %r was specified." \
789 % (self.name, self.info(), msg)
789 % (self.name, self.info(), msg)
790
790
791 raise TraitError(e)
791 raise TraitError(e)
792
792
793
793
794 class Type(ClassBasedTraitType):
794 class Type(ClassBasedTraitType):
795 """A trait whose value must be a subclass of a specified class."""
795 """A trait whose value must be a subclass of a specified class."""
796
796
797 def __init__ (self, default_value=None, klass=None, allow_none=True, **metadata ):
797 def __init__ (self, default_value=None, klass=None, allow_none=True, **metadata ):
798 """Construct a Type trait
798 """Construct a Type trait
799
799
800 A Type trait specifies that its values must be subclasses of
800 A Type trait specifies that its values must be subclasses of
801 a particular class.
801 a particular class.
802
802
803 If only ``default_value`` is given, it is used for the ``klass`` as
803 If only ``default_value`` is given, it is used for the ``klass`` as
804 well.
804 well.
805
805
806 Parameters
806 Parameters
807 ----------
807 ----------
808 default_value : class, str or None
808 default_value : class, str or None
809 The default value must be a subclass of klass. If an str,
809 The default value must be a subclass of klass. If an str,
810 the str must be a fully specified class name, like 'foo.bar.Bah'.
810 the str must be a fully specified class name, like 'foo.bar.Bah'.
811 The string is resolved into real class, when the parent
811 The string is resolved into real class, when the parent
812 :class:`HasTraits` class is instantiated.
812 :class:`HasTraits` class is instantiated.
813 klass : class, str, None
813 klass : class, str, None
814 Values of this trait must be a subclass of klass. The klass
814 Values of this trait must be a subclass of klass. The klass
815 may be specified in a string like: 'foo.bar.MyClass'.
815 may be specified in a string like: 'foo.bar.MyClass'.
816 The string is resolved into real class, when the parent
816 The string is resolved into real class, when the parent
817 :class:`HasTraits` class is instantiated.
817 :class:`HasTraits` class is instantiated.
818 allow_none : boolean
818 allow_none : bool [ default True ]
819 Indicates whether None is allowed as an assignable value. Even if
819 Indicates whether None is allowed as an assignable value. Even if
820 ``False``, the default value may be ``None``.
820 ``False``, the default value may be ``None``.
821 """
821 """
822 if default_value is None:
822 if default_value is None:
823 if klass is None:
823 if klass is None:
824 klass = object
824 klass = object
825 elif klass is None:
825 elif klass is None:
826 klass = default_value
826 klass = default_value
827
827
828 if not (inspect.isclass(klass) or isinstance(klass, py3compat.string_types)):
828 if not (inspect.isclass(klass) or isinstance(klass, py3compat.string_types)):
829 raise TraitError("A Type trait must specify a class.")
829 raise TraitError("A Type trait must specify a class.")
830
830
831 self.klass = klass
831 self.klass = klass
832
832
833 super(Type, self).__init__(default_value, allow_none=allow_none, **metadata)
833 super(Type, self).__init__(default_value, allow_none=allow_none, **metadata)
834
834
835 def validate(self, obj, value):
835 def validate(self, obj, value):
836 """Validates that the value is a valid object instance."""
836 """Validates that the value is a valid object instance."""
837 if isinstance(value, py3compat.string_types):
837 if isinstance(value, py3compat.string_types):
838 try:
838 try:
839 value = self._resolve_string(value)
839 value = self._resolve_string(value)
840 except ImportError:
840 except ImportError:
841 raise TraitError("The '%s' trait of %s instance must be a type, but "
841 raise TraitError("The '%s' trait of %s instance must be a type, but "
842 "%r could not be imported" % (self.name, obj, value))
842 "%r could not be imported" % (self.name, obj, value))
843 try:
843 try:
844 if issubclass(value, self.klass):
844 if issubclass(value, self.klass):
845 return value
845 return value
846 except:
846 except:
847 pass
847 pass
848
848
849 self.error(obj, value)
849 self.error(obj, value)
850
850
851 def info(self):
851 def info(self):
852 """ Returns a description of the trait."""
852 """ Returns a description of the trait."""
853 if isinstance(self.klass, py3compat.string_types):
853 if isinstance(self.klass, py3compat.string_types):
854 klass = self.klass
854 klass = self.klass
855 else:
855 else:
856 klass = self.klass.__name__
856 klass = self.klass.__name__
857 result = 'a subclass of ' + klass
857 result = 'a subclass of ' + klass
858 if self.allow_none:
858 if self.allow_none:
859 return result + ' or None'
859 return result + ' or None'
860 return result
860 return result
861
861
862 def instance_init(self, obj):
862 def instance_init(self, obj):
863 self._resolve_classes()
863 self._resolve_classes()
864 super(Type, self).instance_init(obj)
864 super(Type, self).instance_init(obj)
865
865
866 def _resolve_classes(self):
866 def _resolve_classes(self):
867 if isinstance(self.klass, py3compat.string_types):
867 if isinstance(self.klass, py3compat.string_types):
868 self.klass = self._resolve_string(self.klass)
868 self.klass = self._resolve_string(self.klass)
869 if isinstance(self.default_value, py3compat.string_types):
869 if isinstance(self.default_value, py3compat.string_types):
870 self.default_value = self._resolve_string(self.default_value)
870 self.default_value = self._resolve_string(self.default_value)
871
871
872 def get_default_value(self):
872 def get_default_value(self):
873 return self.default_value
873 return self.default_value
874
874
875
875
876 class DefaultValueGenerator(object):
876 class DefaultValueGenerator(object):
877 """A class for generating new default value instances."""
877 """A class for generating new default value instances."""
878
878
879 def __init__(self, *args, **kw):
879 def __init__(self, *args, **kw):
880 self.args = args
880 self.args = args
881 self.kw = kw
881 self.kw = kw
882
882
883 def generate(self, klass):
883 def generate(self, klass):
884 return klass(*self.args, **self.kw)
884 return klass(*self.args, **self.kw)
885
885
886
886
887 class Instance(ClassBasedTraitType):
887 class Instance(ClassBasedTraitType):
888 """A trait whose value must be an instance of a specified class.
888 """A trait whose value must be an instance of a specified class.
889
889
890 The value can also be an instance of a subclass of the specified class.
890 The value can also be an instance of a subclass of the specified class.
891
891
892 Subclasses can declare default classes by overriding the klass attribute
892 Subclasses can declare default classes by overriding the klass attribute
893 """
893 """
894
894
895 klass = None
895 klass = None
896
896
897 def __init__(self, klass=None, args=None, kw=None,
897 def __init__(self, klass=None, args=None, kw=None,
898 allow_none=True, **metadata ):
898 allow_none=True, **metadata ):
899 """Construct an Instance trait.
899 """Construct an Instance trait.
900
900
901 This trait allows values that are instances of a particular
901 This trait allows values that are instances of a particular
902 class or its subclasses. Our implementation is quite different
902 class or its subclasses. Our implementation is quite different
903 from that of enthough.traits as we don't allow instances to be used
903 from that of enthough.traits as we don't allow instances to be used
904 for klass and we handle the ``args`` and ``kw`` arguments differently.
904 for klass and we handle the ``args`` and ``kw`` arguments differently.
905
905
906 Parameters
906 Parameters
907 ----------
907 ----------
908 klass : class, str
908 klass : class, str
909 The class that forms the basis for the trait. Class names
909 The class that forms the basis for the trait. Class names
910 can also be specified as strings, like 'foo.bar.Bar'.
910 can also be specified as strings, like 'foo.bar.Bar'.
911 args : tuple
911 args : tuple
912 Positional arguments for generating the default value.
912 Positional arguments for generating the default value.
913 kw : dict
913 kw : dict
914 Keyword arguments for generating the default value.
914 Keyword arguments for generating the default value.
915 allow_none : bool
915 allow_none : bool [default True]
916 Indicates whether None is allowed as a value.
916 Indicates whether None is allowed as a value.
917
917
918 Notes
918 Notes
919 -----
919 -----
920 If both ``args`` and ``kw`` are None, then the default value is None.
920 If both ``args`` and ``kw`` are None, then the default value is None.
921 If ``args`` is a tuple and ``kw`` is a dict, then the default is
921 If ``args`` is a tuple and ``kw`` is a dict, then the default is
922 created as ``klass(*args, **kw)``. If exactly one of ``args`` or ``kw`` is
922 created as ``klass(*args, **kw)``. If exactly one of ``args`` or ``kw`` is
923 None, the None is replaced by ``()`` or ``{}``, respectively.
923 None, the None is replaced by ``()`` or ``{}``, respectively.
924 """
924 """
925 if klass is None:
925 if klass is None:
926 klass = self.klass
926 klass = self.klass
927
927
928 if (klass is not None) and (inspect.isclass(klass) or isinstance(klass, py3compat.string_types)):
928 if (klass is not None) and (inspect.isclass(klass) or isinstance(klass, py3compat.string_types)):
929 self.klass = klass
929 self.klass = klass
930 else:
930 else:
931 raise TraitError('The klass attribute must be a class'
931 raise TraitError('The klass attribute must be a class'
932 ' not: %r' % klass)
932 ' not: %r' % klass)
933
933
934 # self.klass is a class, so handle default_value
934 # self.klass is a class, so handle default_value
935 if args is None and kw is None:
935 if args is None and kw is None:
936 default_value = None
936 default_value = None
937 else:
937 else:
938 if args is None:
938 if args is None:
939 # kw is not None
939 # kw is not None
940 args = ()
940 args = ()
941 elif kw is None:
941 elif kw is None:
942 # args is not None
942 # args is not None
943 kw = {}
943 kw = {}
944
944
945 if not isinstance(kw, dict):
945 if not isinstance(kw, dict):
946 raise TraitError("The 'kw' argument must be a dict or None.")
946 raise TraitError("The 'kw' argument must be a dict or None.")
947 if not isinstance(args, tuple):
947 if not isinstance(args, tuple):
948 raise TraitError("The 'args' argument must be a tuple or None.")
948 raise TraitError("The 'args' argument must be a tuple or None.")
949
949
950 default_value = DefaultValueGenerator(*args, **kw)
950 default_value = DefaultValueGenerator(*args, **kw)
951
951
952 super(Instance, self).__init__(default_value, allow_none=allow_none, **metadata)
952 super(Instance, self).__init__(default_value, allow_none=allow_none, **metadata)
953
953
954 def validate(self, obj, value):
954 def validate(self, obj, value):
955 if isinstance(value, self.klass):
955 if isinstance(value, self.klass):
956 return value
956 return value
957 else:
957 else:
958 self.error(obj, value)
958 self.error(obj, value)
959
959
960 def info(self):
960 def info(self):
961 if isinstance(self.klass, py3compat.string_types):
961 if isinstance(self.klass, py3compat.string_types):
962 klass = self.klass
962 klass = self.klass
963 else:
963 else:
964 klass = self.klass.__name__
964 klass = self.klass.__name__
965 result = class_of(klass)
965 result = class_of(klass)
966 if self.allow_none:
966 if self.allow_none:
967 return result + ' or None'
967 return result + ' or None'
968
968
969 return result
969 return result
970
970
971 def instance_init(self, obj):
971 def instance_init(self, obj):
972 self._resolve_classes()
972 self._resolve_classes()
973 super(Instance, self).instance_init(obj)
973 super(Instance, self).instance_init(obj)
974
974
975 def _resolve_classes(self):
975 def _resolve_classes(self):
976 if isinstance(self.klass, py3compat.string_types):
976 if isinstance(self.klass, py3compat.string_types):
977 self.klass = self._resolve_string(self.klass)
977 self.klass = self._resolve_string(self.klass)
978
978
979 def get_default_value(self):
979 def get_default_value(self):
980 """Instantiate a default value instance.
980 """Instantiate a default value instance.
981
981
982 This is called when the containing HasTraits classes'
982 This is called when the containing HasTraits classes'
983 :meth:`__new__` method is called to ensure that a unique instance
983 :meth:`__new__` method is called to ensure that a unique instance
984 is created for each HasTraits instance.
984 is created for each HasTraits instance.
985 """
985 """
986 dv = self.default_value
986 dv = self.default_value
987 if isinstance(dv, DefaultValueGenerator):
987 if isinstance(dv, DefaultValueGenerator):
988 return dv.generate(self.klass)
988 return dv.generate(self.klass)
989 else:
989 else:
990 return dv
990 return dv
991
991
992
992
993 class ForwardDeclaredMixin(object):
993 class ForwardDeclaredMixin(object):
994 """
994 """
995 Mixin for forward-declared versions of Instance and Type.
995 Mixin for forward-declared versions of Instance and Type.
996 """
996 """
997 def _resolve_string(self, string):
997 def _resolve_string(self, string):
998 """
998 """
999 Find the specified class name by looking for it in the module in which
999 Find the specified class name by looking for it in the module in which
1000 our this_class attribute was defined.
1000 our this_class attribute was defined.
1001 """
1001 """
1002 modname = self.this_class.__module__
1002 modname = self.this_class.__module__
1003 return import_item('.'.join([modname, string]))
1003 return import_item('.'.join([modname, string]))
1004
1004
1005
1005
1006 class ForwardDeclaredType(ForwardDeclaredMixin, Type):
1006 class ForwardDeclaredType(ForwardDeclaredMixin, Type):
1007 """
1007 """
1008 Forward-declared version of Type.
1008 Forward-declared version of Type.
1009 """
1009 """
1010 pass
1010 pass
1011
1011
1012
1012
1013 class ForwardDeclaredInstance(ForwardDeclaredMixin, Instance):
1013 class ForwardDeclaredInstance(ForwardDeclaredMixin, Instance):
1014 """
1014 """
1015 Forward-declared version of Instance.
1015 Forward-declared version of Instance.
1016 """
1016 """
1017 pass
1017 pass
1018
1018
1019
1019
1020 class This(ClassBasedTraitType):
1020 class This(ClassBasedTraitType):
1021 """A trait for instances of the class containing this trait.
1021 """A trait for instances of the class containing this trait.
1022
1022
1023 Because how how and when class bodies are executed, the ``This``
1023 Because how how and when class bodies are executed, the ``This``
1024 trait can only have a default value of None. This, and because we
1024 trait can only have a default value of None. This, and because we
1025 always validate default values, ``allow_none`` is *always* true.
1025 always validate default values, ``allow_none`` is *always* true.
1026 """
1026 """
1027
1027
1028 info_text = 'an instance of the same type as the receiver or None'
1028 info_text = 'an instance of the same type as the receiver or None'
1029
1029
1030 def __init__(self, **metadata):
1030 def __init__(self, **metadata):
1031 super(This, self).__init__(None, **metadata)
1031 super(This, self).__init__(None, **metadata)
1032
1032
1033 def validate(self, obj, value):
1033 def validate(self, obj, value):
1034 # What if value is a superclass of obj.__class__? This is
1034 # What if value is a superclass of obj.__class__? This is
1035 # complicated if it was the superclass that defined the This
1035 # complicated if it was the superclass that defined the This
1036 # trait.
1036 # trait.
1037 if isinstance(value, self.this_class) or (value is None):
1037 if isinstance(value, self.this_class) or (value is None):
1038 return value
1038 return value
1039 else:
1039 else:
1040 self.error(obj, value)
1040 self.error(obj, value)
1041
1041
1042
1042
1043 class Union(TraitType):
1043 class Union(TraitType):
1044 """A trait type representing a Union type."""
1044 """A trait type representing a Union type."""
1045
1045
1046 def __init__(self, trait_types, **metadata):
1046 def __init__(self, trait_types, **metadata):
1047 """Construct a Union trait.
1047 """Construct a Union trait.
1048
1048
1049 This trait allows values that are allowed by at least one of the
1049 This trait allows values that are allowed by at least one of the
1050 specified trait types. A Union traitlet cannot have metadata on
1050 specified trait types. A Union traitlet cannot have metadata on
1051 its own, besides the metadata of the listed types.
1051 its own, besides the metadata of the listed types.
1052
1052
1053 Parameters
1053 Parameters
1054 ----------
1054 ----------
1055 trait_types: sequence
1055 trait_types: sequence
1056 The list of trait types of length at least 1.
1056 The list of trait types of length at least 1.
1057
1057
1058 Notes
1058 Notes
1059 -----
1059 -----
1060 Union([Float(), Bool(), Int()]) attempts to validate the provided values
1060 Union([Float(), Bool(), Int()]) attempts to validate the provided values
1061 with the validation function of Float, then Bool, and finally Int.
1061 with the validation function of Float, then Bool, and finally Int.
1062 """
1062 """
1063 self.trait_types = trait_types
1063 self.trait_types = trait_types
1064 self.info_text = " or ".join([tt.info_text for tt in self.trait_types])
1064 self.info_text = " or ".join([tt.info_text for tt in self.trait_types])
1065 self.default_value = self.trait_types[0].get_default_value()
1065 self.default_value = self.trait_types[0].get_default_value()
1066 super(Union, self).__init__(**metadata)
1066 super(Union, self).__init__(**metadata)
1067
1067
1068 def _resolve_classes(self):
1068 def _resolve_classes(self):
1069 for trait_type in self.trait_types:
1069 for trait_type in self.trait_types:
1070 trait_type.name = self.name
1070 trait_type.name = self.name
1071 trait_type.this_class = self.this_class
1071 trait_type.this_class = self.this_class
1072 if hasattr(trait_type, '_resolve_classes'):
1072 if hasattr(trait_type, '_resolve_classes'):
1073 trait_type._resolve_classes()
1073 trait_type._resolve_classes()
1074
1074
1075 def instance_init(self, obj):
1075 def instance_init(self, obj):
1076 self._resolve_classes()
1076 self._resolve_classes()
1077 super(Union, self).instance_init(obj)
1077 super(Union, self).instance_init(obj)
1078
1078
1079 def validate(self, obj, value):
1079 def validate(self, obj, value):
1080 for trait_type in self.trait_types:
1080 for trait_type in self.trait_types:
1081 try:
1081 try:
1082 v = trait_type._validate(obj, value)
1082 v = trait_type._validate(obj, value)
1083 self._metadata = trait_type._metadata
1083 self._metadata = trait_type._metadata
1084 return v
1084 return v
1085 except TraitError:
1085 except TraitError:
1086 continue
1086 continue
1087 self.error(obj, value)
1087 self.error(obj, value)
1088
1088
1089 def __or__(self, other):
1089 def __or__(self, other):
1090 if isinstance(other, Union):
1090 if isinstance(other, Union):
1091 return Union(self.trait_types + other.trait_types)
1091 return Union(self.trait_types + other.trait_types)
1092 else:
1092 else:
1093 return Union(self.trait_types + [other])
1093 return Union(self.trait_types + [other])
1094
1094
1095 #-----------------------------------------------------------------------------
1095 #-----------------------------------------------------------------------------
1096 # Basic TraitTypes implementations/subclasses
1096 # Basic TraitTypes implementations/subclasses
1097 #-----------------------------------------------------------------------------
1097 #-----------------------------------------------------------------------------
1098
1098
1099
1099
1100 class Any(TraitType):
1100 class Any(TraitType):
1101 default_value = None
1101 default_value = None
1102 info_text = 'any value'
1102 info_text = 'any value'
1103
1103
1104
1104
1105 class Int(TraitType):
1105 class Int(TraitType):
1106 """An int trait."""
1106 """An int trait."""
1107
1107
1108 default_value = 0
1108 default_value = 0
1109 info_text = 'an int'
1109 info_text = 'an int'
1110
1110
1111 def validate(self, obj, value):
1111 def validate(self, obj, value):
1112 if isinstance(value, int):
1112 if isinstance(value, int):
1113 return value
1113 return value
1114 self.error(obj, value)
1114 self.error(obj, value)
1115
1115
1116 class CInt(Int):
1116 class CInt(Int):
1117 """A casting version of the int trait."""
1117 """A casting version of the int trait."""
1118
1118
1119 def validate(self, obj, value):
1119 def validate(self, obj, value):
1120 try:
1120 try:
1121 return int(value)
1121 return int(value)
1122 except:
1122 except:
1123 self.error(obj, value)
1123 self.error(obj, value)
1124
1124
1125 if py3compat.PY3:
1125 if py3compat.PY3:
1126 Long, CLong = Int, CInt
1126 Long, CLong = Int, CInt
1127 Integer = Int
1127 Integer = Int
1128 else:
1128 else:
1129 class Long(TraitType):
1129 class Long(TraitType):
1130 """A long integer trait."""
1130 """A long integer trait."""
1131
1131
1132 default_value = 0
1132 default_value = 0
1133 info_text = 'a long'
1133 info_text = 'a long'
1134
1134
1135 def validate(self, obj, value):
1135 def validate(self, obj, value):
1136 if isinstance(value, long):
1136 if isinstance(value, long):
1137 return value
1137 return value
1138 if isinstance(value, int):
1138 if isinstance(value, int):
1139 return long(value)
1139 return long(value)
1140 self.error(obj, value)
1140 self.error(obj, value)
1141
1141
1142
1142
1143 class CLong(Long):
1143 class CLong(Long):
1144 """A casting version of the long integer trait."""
1144 """A casting version of the long integer trait."""
1145
1145
1146 def validate(self, obj, value):
1146 def validate(self, obj, value):
1147 try:
1147 try:
1148 return long(value)
1148 return long(value)
1149 except:
1149 except:
1150 self.error(obj, value)
1150 self.error(obj, value)
1151
1151
1152 class Integer(TraitType):
1152 class Integer(TraitType):
1153 """An integer trait.
1153 """An integer trait.
1154
1154
1155 Longs that are unnecessary (<= sys.maxint) are cast to ints."""
1155 Longs that are unnecessary (<= sys.maxint) are cast to ints."""
1156
1156
1157 default_value = 0
1157 default_value = 0
1158 info_text = 'an integer'
1158 info_text = 'an integer'
1159
1159
1160 def validate(self, obj, value):
1160 def validate(self, obj, value):
1161 if isinstance(value, int):
1161 if isinstance(value, int):
1162 return value
1162 return value
1163 if isinstance(value, long):
1163 if isinstance(value, long):
1164 # downcast longs that fit in int:
1164 # downcast longs that fit in int:
1165 # note that int(n > sys.maxint) returns a long, so
1165 # note that int(n > sys.maxint) returns a long, so
1166 # we don't need a condition on this cast
1166 # we don't need a condition on this cast
1167 return int(value)
1167 return int(value)
1168 if sys.platform == "cli":
1168 if sys.platform == "cli":
1169 from System import Int64
1169 from System import Int64
1170 if isinstance(value, Int64):
1170 if isinstance(value, Int64):
1171 return int(value)
1171 return int(value)
1172 self.error(obj, value)
1172 self.error(obj, value)
1173
1173
1174
1174
1175 class Float(TraitType):
1175 class Float(TraitType):
1176 """A float trait."""
1176 """A float trait."""
1177
1177
1178 default_value = 0.0
1178 default_value = 0.0
1179 info_text = 'a float'
1179 info_text = 'a float'
1180
1180
1181 def validate(self, obj, value):
1181 def validate(self, obj, value):
1182 if isinstance(value, float):
1182 if isinstance(value, float):
1183 return value
1183 return value
1184 if isinstance(value, int):
1184 if isinstance(value, int):
1185 return float(value)
1185 return float(value)
1186 self.error(obj, value)
1186 self.error(obj, value)
1187
1187
1188
1188
1189 class CFloat(Float):
1189 class CFloat(Float):
1190 """A casting version of the float trait."""
1190 """A casting version of the float trait."""
1191
1191
1192 def validate(self, obj, value):
1192 def validate(self, obj, value):
1193 try:
1193 try:
1194 return float(value)
1194 return float(value)
1195 except:
1195 except:
1196 self.error(obj, value)
1196 self.error(obj, value)
1197
1197
1198 class Complex(TraitType):
1198 class Complex(TraitType):
1199 """A trait for complex numbers."""
1199 """A trait for complex numbers."""
1200
1200
1201 default_value = 0.0 + 0.0j
1201 default_value = 0.0 + 0.0j
1202 info_text = 'a complex number'
1202 info_text = 'a complex number'
1203
1203
1204 def validate(self, obj, value):
1204 def validate(self, obj, value):
1205 if isinstance(value, complex):
1205 if isinstance(value, complex):
1206 return value
1206 return value
1207 if isinstance(value, (float, int)):
1207 if isinstance(value, (float, int)):
1208 return complex(value)
1208 return complex(value)
1209 self.error(obj, value)
1209 self.error(obj, value)
1210
1210
1211
1211
1212 class CComplex(Complex):
1212 class CComplex(Complex):
1213 """A casting version of the complex number trait."""
1213 """A casting version of the complex number trait."""
1214
1214
1215 def validate (self, obj, value):
1215 def validate (self, obj, value):
1216 try:
1216 try:
1217 return complex(value)
1217 return complex(value)
1218 except:
1218 except:
1219 self.error(obj, value)
1219 self.error(obj, value)
1220
1220
1221 # We should always be explicit about whether we're using bytes or unicode, both
1221 # We should always be explicit about whether we're using bytes or unicode, both
1222 # for Python 3 conversion and for reliable unicode behaviour on Python 2. So
1222 # for Python 3 conversion and for reliable unicode behaviour on Python 2. So
1223 # we don't have a Str type.
1223 # we don't have a Str type.
1224 class Bytes(TraitType):
1224 class Bytes(TraitType):
1225 """A trait for byte strings."""
1225 """A trait for byte strings."""
1226
1226
1227 default_value = b''
1227 default_value = b''
1228 info_text = 'a bytes object'
1228 info_text = 'a bytes object'
1229
1229
1230 def validate(self, obj, value):
1230 def validate(self, obj, value):
1231 if isinstance(value, bytes):
1231 if isinstance(value, bytes):
1232 return value
1232 return value
1233 self.error(obj, value)
1233 self.error(obj, value)
1234
1234
1235
1235
1236 class CBytes(Bytes):
1236 class CBytes(Bytes):
1237 """A casting version of the byte string trait."""
1237 """A casting version of the byte string trait."""
1238
1238
1239 def validate(self, obj, value):
1239 def validate(self, obj, value):
1240 try:
1240 try:
1241 return bytes(value)
1241 return bytes(value)
1242 except:
1242 except:
1243 self.error(obj, value)
1243 self.error(obj, value)
1244
1244
1245
1245
1246 class Unicode(TraitType):
1246 class Unicode(TraitType):
1247 """A trait for unicode strings."""
1247 """A trait for unicode strings."""
1248
1248
1249 default_value = u''
1249 default_value = u''
1250 info_text = 'a unicode string'
1250 info_text = 'a unicode string'
1251
1251
1252 def validate(self, obj, value):
1252 def validate(self, obj, value):
1253 if isinstance(value, py3compat.unicode_type):
1253 if isinstance(value, py3compat.unicode_type):
1254 return value
1254 return value
1255 if isinstance(value, bytes):
1255 if isinstance(value, bytes):
1256 try:
1256 try:
1257 return value.decode('ascii', 'strict')
1257 return value.decode('ascii', 'strict')
1258 except UnicodeDecodeError:
1258 except UnicodeDecodeError:
1259 msg = "Could not decode {!r} for unicode trait '{}' of {} instance."
1259 msg = "Could not decode {!r} for unicode trait '{}' of {} instance."
1260 raise TraitError(msg.format(value, self.name, class_of(obj)))
1260 raise TraitError(msg.format(value, self.name, class_of(obj)))
1261 self.error(obj, value)
1261 self.error(obj, value)
1262
1262
1263
1263
1264 class CUnicode(Unicode):
1264 class CUnicode(Unicode):
1265 """A casting version of the unicode trait."""
1265 """A casting version of the unicode trait."""
1266
1266
1267 def validate(self, obj, value):
1267 def validate(self, obj, value):
1268 try:
1268 try:
1269 return py3compat.unicode_type(value)
1269 return py3compat.unicode_type(value)
1270 except:
1270 except:
1271 self.error(obj, value)
1271 self.error(obj, value)
1272
1272
1273
1273
1274 class ObjectName(TraitType):
1274 class ObjectName(TraitType):
1275 """A string holding a valid object name in this version of Python.
1275 """A string holding a valid object name in this version of Python.
1276
1276
1277 This does not check that the name exists in any scope."""
1277 This does not check that the name exists in any scope."""
1278 info_text = "a valid object identifier in Python"
1278 info_text = "a valid object identifier in Python"
1279
1279
1280 if py3compat.PY3:
1280 if py3compat.PY3:
1281 # Python 3:
1281 # Python 3:
1282 coerce_str = staticmethod(lambda _,s: s)
1282 coerce_str = staticmethod(lambda _,s: s)
1283
1283
1284 else:
1284 else:
1285 # Python 2:
1285 # Python 2:
1286 def coerce_str(self, obj, value):
1286 def coerce_str(self, obj, value):
1287 "In Python 2, coerce ascii-only unicode to str"
1287 "In Python 2, coerce ascii-only unicode to str"
1288 if isinstance(value, unicode):
1288 if isinstance(value, unicode):
1289 try:
1289 try:
1290 return str(value)
1290 return str(value)
1291 except UnicodeEncodeError:
1291 except UnicodeEncodeError:
1292 self.error(obj, value)
1292 self.error(obj, value)
1293 return value
1293 return value
1294
1294
1295 def validate(self, obj, value):
1295 def validate(self, obj, value):
1296 value = self.coerce_str(obj, value)
1296 value = self.coerce_str(obj, value)
1297
1297
1298 if isinstance(value, string_types) and py3compat.isidentifier(value):
1298 if isinstance(value, string_types) and py3compat.isidentifier(value):
1299 return value
1299 return value
1300 self.error(obj, value)
1300 self.error(obj, value)
1301
1301
1302 class DottedObjectName(ObjectName):
1302 class DottedObjectName(ObjectName):
1303 """A string holding a valid dotted object name in Python, such as A.b3._c"""
1303 """A string holding a valid dotted object name in Python, such as A.b3._c"""
1304 def validate(self, obj, value):
1304 def validate(self, obj, value):
1305 value = self.coerce_str(obj, value)
1305 value = self.coerce_str(obj, value)
1306
1306
1307 if isinstance(value, string_types) and py3compat.isidentifier(value, dotted=True):
1307 if isinstance(value, string_types) and py3compat.isidentifier(value, dotted=True):
1308 return value
1308 return value
1309 self.error(obj, value)
1309 self.error(obj, value)
1310
1310
1311
1311
1312 class Bool(TraitType):
1312 class Bool(TraitType):
1313 """A boolean (True, False) trait."""
1313 """A boolean (True, False) trait."""
1314
1314
1315 default_value = False
1315 default_value = False
1316 info_text = 'a boolean'
1316 info_text = 'a boolean'
1317
1317
1318 def validate(self, obj, value):
1318 def validate(self, obj, value):
1319 if isinstance(value, bool):
1319 if isinstance(value, bool):
1320 return value
1320 return value
1321 self.error(obj, value)
1321 self.error(obj, value)
1322
1322
1323
1323
1324 class CBool(Bool):
1324 class CBool(Bool):
1325 """A casting version of the boolean trait."""
1325 """A casting version of the boolean trait."""
1326
1326
1327 def validate(self, obj, value):
1327 def validate(self, obj, value):
1328 try:
1328 try:
1329 return bool(value)
1329 return bool(value)
1330 except:
1330 except:
1331 self.error(obj, value)
1331 self.error(obj, value)
1332
1332
1333
1333
1334 class Enum(TraitType):
1334 class Enum(TraitType):
1335 """An enum that whose value must be in a given sequence."""
1335 """An enum that whose value must be in a given sequence."""
1336
1336
1337 def __init__(self, values, default_value=None, allow_none=True, **metadata):
1337 def __init__(self, values, default_value=None, **metadata):
1338 self.values = values
1338 self.values = values
1339 super(Enum, self).__init__(default_value, allow_none=allow_none, **metadata)
1339 super(Enum, self).__init__(default_value, **metadata)
1340
1340
1341 def validate(self, obj, value):
1341 def validate(self, obj, value):
1342 if value in self.values:
1342 if value in self.values:
1343 return value
1343 return value
1344 self.error(obj, value)
1344 self.error(obj, value)
1345
1345
1346 def info(self):
1346 def info(self):
1347 """ Returns a description of the trait."""
1347 """ Returns a description of the trait."""
1348 result = 'any of ' + repr(self.values)
1348 result = 'any of ' + repr(self.values)
1349 if self.allow_none:
1349 if self.allow_none:
1350 return result + ' or None'
1350 return result + ' or None'
1351 return result
1351 return result
1352
1352
1353 class CaselessStrEnum(Enum):
1353 class CaselessStrEnum(Enum):
1354 """An enum of strings that are caseless in validate."""
1354 """An enum of strings that are caseless in validate."""
1355
1355
1356 def validate(self, obj, value):
1356 def validate(self, obj, value):
1357 if not isinstance(value, py3compat.string_types):
1357 if not isinstance(value, py3compat.string_types):
1358 self.error(obj, value)
1358 self.error(obj, value)
1359
1359
1360 for v in self.values:
1360 for v in self.values:
1361 if v.lower() == value.lower():
1361 if v.lower() == value.lower():
1362 return v
1362 return v
1363 self.error(obj, value)
1363 self.error(obj, value)
1364
1364
1365 class Container(Instance):
1365 class Container(Instance):
1366 """An instance of a container (list, set, etc.)
1366 """An instance of a container (list, set, etc.)
1367
1367
1368 To be subclassed by overriding klass.
1368 To be subclassed by overriding klass.
1369 """
1369 """
1370 klass = None
1370 klass = None
1371 _cast_types = ()
1371 _cast_types = ()
1372 _valid_defaults = SequenceTypes
1372 _valid_defaults = SequenceTypes
1373 _trait = None
1373 _trait = None
1374
1374
1375 def __init__(self, trait=None, default_value=None, allow_none=True,
1375 def __init__(self, trait=None, default_value=None, allow_none=False,
1376 **metadata):
1376 **metadata):
1377 """Create a container trait type from a list, set, or tuple.
1377 """Create a container trait type from a list, set, or tuple.
1378
1378
1379 The default value is created by doing ``List(default_value)``,
1379 The default value is created by doing ``List(default_value)``,
1380 which creates a copy of the ``default_value``.
1380 which creates a copy of the ``default_value``.
1381
1381
1382 ``trait`` can be specified, which restricts the type of elements
1382 ``trait`` can be specified, which restricts the type of elements
1383 in the container to that TraitType.
1383 in the container to that TraitType.
1384
1384
1385 If only one arg is given and it is not a Trait, it is taken as
1385 If only one arg is given and it is not a Trait, it is taken as
1386 ``default_value``:
1386 ``default_value``:
1387
1387
1388 ``c = List([1,2,3])``
1388 ``c = List([1,2,3])``
1389
1389
1390 Parameters
1390 Parameters
1391 ----------
1391 ----------
1392
1392
1393 trait : TraitType [ optional ]
1393 trait : TraitType [ optional ]
1394 the type for restricting the contents of the Container. If unspecified,
1394 the type for restricting the contents of the Container. If unspecified,
1395 types are not checked.
1395 types are not checked.
1396
1396
1397 default_value : SequenceType [ optional ]
1397 default_value : SequenceType [ optional ]
1398 The default value for the Trait. Must be list/tuple/set, and
1398 The default value for the Trait. Must be list/tuple/set, and
1399 will be cast to the container type.
1399 will be cast to the container type.
1400
1400
1401 allow_none : Bool [ default True ]
1401 allow_none : bool [ default False ]
1402 Whether to allow the value to be None
1402 Whether to allow the value to be None
1403
1403
1404 **metadata : any
1404 **metadata : any
1405 further keys for extensions to the Trait (e.g. config)
1405 further keys for extensions to the Trait (e.g. config)
1406
1406
1407 """
1407 """
1408 # allow List([values]):
1408 # allow List([values]):
1409 if default_value is None and not is_trait(trait):
1409 if default_value is None and not is_trait(trait):
1410 default_value = trait
1410 default_value = trait
1411 trait = None
1411 trait = None
1412
1412
1413 if default_value is None:
1413 if default_value is None:
1414 args = ()
1414 args = ()
1415 elif isinstance(default_value, self._valid_defaults):
1415 elif isinstance(default_value, self._valid_defaults):
1416 args = (default_value,)
1416 args = (default_value,)
1417 else:
1417 else:
1418 raise TypeError('default value of %s was %s' %(self.__class__.__name__, default_value))
1418 raise TypeError('default value of %s was %s' %(self.__class__.__name__, default_value))
1419
1419
1420 if is_trait(trait):
1420 if is_trait(trait):
1421 self._trait = trait() if isinstance(trait, type) else trait
1421 self._trait = trait() if isinstance(trait, type) else trait
1422 self._trait.name = 'element'
1422 self._trait.name = 'element'
1423 elif trait is not None:
1423 elif trait is not None:
1424 raise TypeError("`trait` must be a Trait or None, got %s"%repr_type(trait))
1424 raise TypeError("`trait` must be a Trait or None, got %s"%repr_type(trait))
1425
1425
1426 super(Container,self).__init__(klass=self.klass, args=args,
1426 super(Container,self).__init__(klass=self.klass, args=args,
1427 allow_none=allow_none, **metadata)
1427 allow_none=allow_none, **metadata)
1428
1428
1429 def element_error(self, obj, element, validator):
1429 def element_error(self, obj, element, validator):
1430 e = "Element of the '%s' trait of %s instance must be %s, but a value of %s was specified." \
1430 e = "Element of the '%s' trait of %s instance must be %s, but a value of %s was specified." \
1431 % (self.name, class_of(obj), validator.info(), repr_type(element))
1431 % (self.name, class_of(obj), validator.info(), repr_type(element))
1432 raise TraitError(e)
1432 raise TraitError(e)
1433
1433
1434 def validate(self, obj, value):
1434 def validate(self, obj, value):
1435 if isinstance(value, self._cast_types):
1435 if isinstance(value, self._cast_types):
1436 value = self.klass(value)
1436 value = self.klass(value)
1437 value = super(Container, self).validate(obj, value)
1437 value = super(Container, self).validate(obj, value)
1438 if value is None:
1438 if value is None:
1439 return value
1439 return value
1440
1440
1441 value = self.validate_elements(obj, value)
1441 value = self.validate_elements(obj, value)
1442
1442
1443 return value
1443 return value
1444
1444
1445 def validate_elements(self, obj, value):
1445 def validate_elements(self, obj, value):
1446 validated = []
1446 validated = []
1447 if self._trait is None or isinstance(self._trait, Any):
1447 if self._trait is None or isinstance(self._trait, Any):
1448 return value
1448 return value
1449 for v in value:
1449 for v in value:
1450 try:
1450 try:
1451 v = self._trait._validate(obj, v)
1451 v = self._trait._validate(obj, v)
1452 except TraitError:
1452 except TraitError:
1453 self.element_error(obj, v, self._trait)
1453 self.element_error(obj, v, self._trait)
1454 else:
1454 else:
1455 validated.append(v)
1455 validated.append(v)
1456 return self.klass(validated)
1456 return self.klass(validated)
1457
1457
1458 def instance_init(self, obj):
1458 def instance_init(self, obj):
1459 if isinstance(self._trait, TraitType):
1459 if isinstance(self._trait, TraitType):
1460 self._trait.this_class = self.this_class
1460 self._trait.this_class = self.this_class
1461 if hasattr(self._trait, '_resolve_classes'):
1461 if hasattr(self._trait, '_resolve_classes'):
1462 self._trait._resolve_classes()
1462 self._trait._resolve_classes()
1463 super(Container, self).instance_init(obj)
1463 super(Container, self).instance_init(obj)
1464
1464
1465
1465
1466 class List(Container):
1466 class List(Container):
1467 """An instance of a Python list."""
1467 """An instance of a Python list."""
1468 klass = list
1468 klass = list
1469 _cast_types = (tuple,)
1469 _cast_types = (tuple,)
1470
1470
1471 def __init__(self, trait=None, default_value=None, minlen=0, maxlen=sys.maxsize,
1471 def __init__(self, trait=None, default_value=None, minlen=0, maxlen=sys.maxsize, **metadata):
1472 allow_none=True, **metadata):
1473 """Create a List trait type from a list, set, or tuple.
1472 """Create a List trait type from a list, set, or tuple.
1474
1473
1475 The default value is created by doing ``List(default_value)``,
1474 The default value is created by doing ``List(default_value)``,
1476 which creates a copy of the ``default_value``.
1475 which creates a copy of the ``default_value``.
1477
1476
1478 ``trait`` can be specified, which restricts the type of elements
1477 ``trait`` can be specified, which restricts the type of elements
1479 in the container to that TraitType.
1478 in the container to that TraitType.
1480
1479
1481 If only one arg is given and it is not a Trait, it is taken as
1480 If only one arg is given and it is not a Trait, it is taken as
1482 ``default_value``:
1481 ``default_value``:
1483
1482
1484 ``c = List([1,2,3])``
1483 ``c = List([1,2,3])``
1485
1484
1486 Parameters
1485 Parameters
1487 ----------
1486 ----------
1488
1487
1489 trait : TraitType [ optional ]
1488 trait : TraitType [ optional ]
1490 the type for restricting the contents of the Container. If unspecified,
1489 the type for restricting the contents of the Container. If unspecified,
1491 types are not checked.
1490 types are not checked.
1492
1491
1493 default_value : SequenceType [ optional ]
1492 default_value : SequenceType [ optional ]
1494 The default value for the Trait. Must be list/tuple/set, and
1493 The default value for the Trait. Must be list/tuple/set, and
1495 will be cast to the container type.
1494 will be cast to the container type.
1496
1495
1497 minlen : Int [ default 0 ]
1496 minlen : Int [ default 0 ]
1498 The minimum length of the input list
1497 The minimum length of the input list
1499
1498
1500 maxlen : Int [ default sys.maxsize ]
1499 maxlen : Int [ default sys.maxsize ]
1501 The maximum length of the input list
1500 The maximum length of the input list
1502
1501
1503 allow_none : Bool [ default True ]
1502 allow_none : bool [ default False ]
1504 Whether to allow the value to be None
1503 Whether to allow the value to be None
1505
1504
1506 **metadata : any
1505 **metadata : any
1507 further keys for extensions to the Trait (e.g. config)
1506 further keys for extensions to the Trait (e.g. config)
1508
1507
1509 """
1508 """
1510 self._minlen = minlen
1509 self._minlen = minlen
1511 self._maxlen = maxlen
1510 self._maxlen = maxlen
1512 super(List, self).__init__(trait=trait, default_value=default_value,
1511 super(List, self).__init__(trait=trait, default_value=default_value,
1513 allow_none=allow_none, **metadata)
1512 **metadata)
1514
1513
1515 def length_error(self, obj, value):
1514 def length_error(self, obj, value):
1516 e = "The '%s' trait of %s instance must be of length %i <= L <= %i, but a value of %s was specified." \
1515 e = "The '%s' trait of %s instance must be of length %i <= L <= %i, but a value of %s was specified." \
1517 % (self.name, class_of(obj), self._minlen, self._maxlen, value)
1516 % (self.name, class_of(obj), self._minlen, self._maxlen, value)
1518 raise TraitError(e)
1517 raise TraitError(e)
1519
1518
1520 def validate_elements(self, obj, value):
1519 def validate_elements(self, obj, value):
1521 length = len(value)
1520 length = len(value)
1522 if length < self._minlen or length > self._maxlen:
1521 if length < self._minlen or length > self._maxlen:
1523 self.length_error(obj, value)
1522 self.length_error(obj, value)
1524
1523
1525 return super(List, self).validate_elements(obj, value)
1524 return super(List, self).validate_elements(obj, value)
1526
1525
1527 def validate(self, obj, value):
1526 def validate(self, obj, value):
1528 value = super(List, self).validate(obj, value)
1527 value = super(List, self).validate(obj, value)
1529
1528
1530 value = self.validate_elements(obj, value)
1529 value = self.validate_elements(obj, value)
1531
1530
1532 return value
1531 return value
1533
1532
1534
1533
1535
1534
1536 class Set(List):
1535 class Set(List):
1537 """An instance of a Python set."""
1536 """An instance of a Python set."""
1538 klass = set
1537 klass = set
1539 _cast_types = (tuple, list)
1538 _cast_types = (tuple, list)
1540
1539
1541 class Tuple(Container):
1540 class Tuple(Container):
1542 """An instance of a Python tuple."""
1541 """An instance of a Python tuple."""
1543 klass = tuple
1542 klass = tuple
1544 _cast_types = (list,)
1543 _cast_types = (list,)
1545
1544
1546 def __init__(self, *traits, **metadata):
1545 def __init__(self, *traits, **metadata):
1547 """Tuple(*traits, default_value=None, allow_none=True, **medatata)
1546 """Tuple(*traits, default_value=None, **medatata)
1548
1547
1549 Create a tuple from a list, set, or tuple.
1548 Create a tuple from a list, set, or tuple.
1550
1549
1551 Create a fixed-type tuple with Traits:
1550 Create a fixed-type tuple with Traits:
1552
1551
1553 ``t = Tuple(Int, Str, CStr)``
1552 ``t = Tuple(Int, Str, CStr)``
1554
1553
1555 would be length 3, with Int,Str,CStr for each element.
1554 would be length 3, with Int,Str,CStr for each element.
1556
1555
1557 If only one arg is given and it is not a Trait, it is taken as
1556 If only one arg is given and it is not a Trait, it is taken as
1558 default_value:
1557 default_value:
1559
1558
1560 ``t = Tuple((1,2,3))``
1559 ``t = Tuple((1,2,3))``
1561
1560
1562 Otherwise, ``default_value`` *must* be specified by keyword.
1561 Otherwise, ``default_value`` *must* be specified by keyword.
1563
1562
1564 Parameters
1563 Parameters
1565 ----------
1564 ----------
1566
1565
1567 *traits : TraitTypes [ optional ]
1566 *traits : TraitTypes [ optional ]
1568 the tsype for restricting the contents of the Tuple. If unspecified,
1567 the tsype for restricting the contents of the Tuple. If unspecified,
1569 types are not checked. If specified, then each positional argument
1568 types are not checked. If specified, then each positional argument
1570 corresponds to an element of the tuple. Tuples defined with traits
1569 corresponds to an element of the tuple. Tuples defined with traits
1571 are of fixed length.
1570 are of fixed length.
1572
1571
1573 default_value : SequenceType [ optional ]
1572 default_value : SequenceType [ optional ]
1574 The default value for the Tuple. Must be list/tuple/set, and
1573 The default value for the Tuple. Must be list/tuple/set, and
1575 will be cast to a tuple. If `traits` are specified, the
1574 will be cast to a tuple. If `traits` are specified, the
1576 `default_value` must conform to the shape and type they specify.
1575 `default_value` must conform to the shape and type they specify.
1577
1576
1578 allow_none : Bool [ default True ]
1577 allow_none : bool [ default False ]
1579 Whether to allow the value to be None
1578 Whether to allow the value to be None
1580
1579
1581 **metadata : any
1580 **metadata : any
1582 further keys for extensions to the Trait (e.g. config)
1581 further keys for extensions to the Trait (e.g. config)
1583
1582
1584 """
1583 """
1585 default_value = metadata.pop('default_value', None)
1584 default_value = metadata.pop('default_value', None)
1586 allow_none = metadata.pop('allow_none', True)
1585 allow_none = metadata.pop('allow_none', True)
1587
1586
1588 # allow Tuple((values,)):
1587 # allow Tuple((values,)):
1589 if len(traits) == 1 and default_value is None and not is_trait(traits[0]):
1588 if len(traits) == 1 and default_value is None and not is_trait(traits[0]):
1590 default_value = traits[0]
1589 default_value = traits[0]
1591 traits = ()
1590 traits = ()
1592
1591
1593 if default_value is None:
1592 if default_value is None:
1594 args = ()
1593 args = ()
1595 elif isinstance(default_value, self._valid_defaults):
1594 elif isinstance(default_value, self._valid_defaults):
1596 args = (default_value,)
1595 args = (default_value,)
1597 else:
1596 else:
1598 raise TypeError('default value of %s was %s' %(self.__class__.__name__, default_value))
1597 raise TypeError('default value of %s was %s' %(self.__class__.__name__, default_value))
1599
1598
1600 self._traits = []
1599 self._traits = []
1601 for trait in traits:
1600 for trait in traits:
1602 t = trait() if isinstance(trait, type) else trait
1601 t = trait() if isinstance(trait, type) else trait
1603 t.name = 'element'
1602 t.name = 'element'
1604 self._traits.append(t)
1603 self._traits.append(t)
1605
1604
1606 if self._traits and default_value is None:
1605 if self._traits and default_value is None:
1607 # don't allow default to be an empty container if length is specified
1606 # don't allow default to be an empty container if length is specified
1608 args = None
1607 args = None
1609 super(Container,self).__init__(klass=self.klass, args=args,
1608 super(Container,self).__init__(klass=self.klass, args=args, **metadata)
1610 allow_none=allow_none, **metadata)
1611
1609
1612 def validate_elements(self, obj, value):
1610 def validate_elements(self, obj, value):
1613 if not self._traits:
1611 if not self._traits:
1614 # nothing to validate
1612 # nothing to validate
1615 return value
1613 return value
1616 if len(value) != len(self._traits):
1614 if len(value) != len(self._traits):
1617 e = "The '%s' trait of %s instance requires %i elements, but a value of %s was specified." \
1615 e = "The '%s' trait of %s instance requires %i elements, but a value of %s was specified." \
1618 % (self.name, class_of(obj), len(self._traits), repr_type(value))
1616 % (self.name, class_of(obj), len(self._traits), repr_type(value))
1619 raise TraitError(e)
1617 raise TraitError(e)
1620
1618
1621 validated = []
1619 validated = []
1622 for t,v in zip(self._traits, value):
1620 for t,v in zip(self._traits, value):
1623 try:
1621 try:
1624 v = t._validate(obj, v)
1622 v = t._validate(obj, v)
1625 except TraitError:
1623 except TraitError:
1626 self.element_error(obj, v, t)
1624 self.element_error(obj, v, t)
1627 else:
1625 else:
1628 validated.append(v)
1626 validated.append(v)
1629 return tuple(validated)
1627 return tuple(validated)
1630
1628
1631 def instance_init(self, obj):
1629 def instance_init(self, obj):
1632 for trait in self._traits:
1630 for trait in self._traits:
1633 if isinstance(trait, TraitType):
1631 if isinstance(trait, TraitType):
1634 trait.this_class = self.this_class
1632 trait.this_class = self.this_class
1635 if hasattr(trait, '_resolve_classes'):
1633 if hasattr(trait, '_resolve_classes'):
1636 trait._resolve_classes()
1634 trait._resolve_classes()
1637 super(Container, self).instance_init(obj)
1635 super(Container, self).instance_init(obj)
1638
1636
1639
1637
1640 class Dict(Instance):
1638 class Dict(Instance):
1641 """An instance of a Python dict."""
1639 """An instance of a Python dict."""
1642
1640
1643 def __init__(self, default_value={}, allow_none=True, **metadata):
1641 def __init__(self, default_value={}, allow_none=False, **metadata):
1644 """Create a dict trait type from a dict.
1642 """Create a dict trait type from a dict.
1645
1643
1646 The default value is created by doing ``dict(default_value)``,
1644 The default value is created by doing ``dict(default_value)``,
1647 which creates a copy of the ``default_value``.
1645 which creates a copy of the ``default_value``.
1648 """
1646 """
1649 if default_value is None:
1647 if default_value is None:
1650 args = None
1648 args = None
1651 elif isinstance(default_value, dict):
1649 elif isinstance(default_value, dict):
1652 args = (default_value,)
1650 args = (default_value,)
1653 elif isinstance(default_value, SequenceTypes):
1651 elif isinstance(default_value, SequenceTypes):
1654 args = (default_value,)
1652 args = (default_value,)
1655 else:
1653 else:
1656 raise TypeError('default value of Dict was %s' % default_value)
1654 raise TypeError('default value of Dict was %s' % default_value)
1657
1655
1658 super(Dict,self).__init__(klass=dict, args=args,
1656 super(Dict,self).__init__(klass=dict, args=args,
1659 allow_none=allow_none, **metadata)
1657 allow_none=allow_none, **metadata)
1660
1658
1661
1659
1662 class EventfulDict(Instance):
1660 class EventfulDict(Instance):
1663 """An instance of an EventfulDict."""
1661 """An instance of an EventfulDict."""
1664
1662
1665 def __init__(self, default_value={}, allow_none=True, **metadata):
1663 def __init__(self, default_value={}, allow_none=False, **metadata):
1666 """Create a EventfulDict trait type from a dict.
1664 """Create a EventfulDict trait type from a dict.
1667
1665
1668 The default value is created by doing
1666 The default value is created by doing
1669 ``eventful.EvenfulDict(default_value)``, which creates a copy of the
1667 ``eventful.EvenfulDict(default_value)``, which creates a copy of the
1670 ``default_value``.
1668 ``default_value``.
1671 """
1669 """
1672 if default_value is None:
1670 if default_value is None:
1673 args = None
1671 args = None
1674 elif isinstance(default_value, dict):
1672 elif isinstance(default_value, dict):
1675 args = (default_value,)
1673 args = (default_value,)
1676 elif isinstance(default_value, SequenceTypes):
1674 elif isinstance(default_value, SequenceTypes):
1677 args = (default_value,)
1675 args = (default_value,)
1678 else:
1676 else:
1679 raise TypeError('default value of EventfulDict was %s' % default_value)
1677 raise TypeError('default value of EventfulDict was %s' % default_value)
1680
1678
1681 super(EventfulDict, self).__init__(klass=eventful.EventfulDict, args=args,
1679 super(EventfulDict, self).__init__(klass=eventful.EventfulDict, args=args,
1682 allow_none=allow_none, **metadata)
1680 allow_none=allow_none, **metadata)
1683
1681
1684
1682
1685 class EventfulList(Instance):
1683 class EventfulList(Instance):
1686 """An instance of an EventfulList."""
1684 """An instance of an EventfulList."""
1687
1685
1688 def __init__(self, default_value=None, allow_none=True, **metadata):
1686 def __init__(self, default_value=None, allow_none=False, **metadata):
1689 """Create a EventfulList trait type from a dict.
1687 """Create a EventfulList trait type from a dict.
1690
1688
1691 The default value is created by doing
1689 The default value is created by doing
1692 ``eventful.EvenfulList(default_value)``, which creates a copy of the
1690 ``eventful.EvenfulList(default_value)``, which creates a copy of the
1693 ``default_value``.
1691 ``default_value``.
1694 """
1692 """
1695 if default_value is None:
1693 if default_value is None:
1696 args = ((),)
1694 args = ((),)
1697 else:
1695 else:
1698 args = (default_value,)
1696 args = (default_value,)
1699
1697
1700 super(EventfulList, self).__init__(klass=eventful.EventfulList, args=args,
1698 super(EventfulList, self).__init__(klass=eventful.EventfulList, args=args,
1701 allow_none=allow_none, **metadata)
1699 allow_none=allow_none, **metadata)
1702
1700
1703
1701
1704 class TCPAddress(TraitType):
1702 class TCPAddress(TraitType):
1705 """A trait for an (ip, port) tuple.
1703 """A trait for an (ip, port) tuple.
1706
1704
1707 This allows for both IPv4 IP addresses as well as hostnames.
1705 This allows for both IPv4 IP addresses as well as hostnames.
1708 """
1706 """
1709
1707
1710 default_value = ('127.0.0.1', 0)
1708 default_value = ('127.0.0.1', 0)
1711 info_text = 'an (ip, port) tuple'
1709 info_text = 'an (ip, port) tuple'
1712
1710
1713 def validate(self, obj, value):
1711 def validate(self, obj, value):
1714 if isinstance(value, tuple):
1712 if isinstance(value, tuple):
1715 if len(value) == 2:
1713 if len(value) == 2:
1716 if isinstance(value[0], py3compat.string_types) and isinstance(value[1], int):
1714 if isinstance(value[0], py3compat.string_types) and isinstance(value[1], int):
1717 port = value[1]
1715 port = value[1]
1718 if port >= 0 and port <= 65535:
1716 if port >= 0 and port <= 65535:
1719 return value
1717 return value
1720 self.error(obj, value)
1718 self.error(obj, value)
1721
1719
1722 class CRegExp(TraitType):
1720 class CRegExp(TraitType):
1723 """A casting compiled regular expression trait.
1721 """A casting compiled regular expression trait.
1724
1722
1725 Accepts both strings and compiled regular expressions. The resulting
1723 Accepts both strings and compiled regular expressions. The resulting
1726 attribute will be a compiled regular expression."""
1724 attribute will be a compiled regular expression."""
1727
1725
1728 info_text = 'a regular expression'
1726 info_text = 'a regular expression'
1729
1727
1730 def validate(self, obj, value):
1728 def validate(self, obj, value):
1731 try:
1729 try:
1732 return re.compile(value)
1730 return re.compile(value)
1733 except:
1731 except:
1734 self.error(obj, value)
1732 self.error(obj, value)
General Comments 0
You need to be logged in to leave comments. Login now