##// END OF EJS Templates
Merge pull request #7708 from SylvainCorlay/allow_none...
Min RK -
r20763:454aa2cb merge
parent child Browse files
Show More
@@ -1,432 +1,432 b''
1 1 # encoding: utf-8
2 2 """
3 3 A mixin for :class:`~IPython.core.application.Application` classes that
4 4 launch InteractiveShell instances, load extensions, etc.
5 5 """
6 6
7 7 # Copyright (c) IPython Development Team.
8 8 # Distributed under the terms of the Modified BSD License.
9 9
10 10 from __future__ import absolute_import
11 11 from __future__ import print_function
12 12
13 13 import glob
14 14 import os
15 15 import sys
16 16
17 17 from IPython.config.application import boolean_flag
18 18 from IPython.config.configurable import Configurable
19 19 from IPython.config.loader import Config
20 20 from IPython.core import pylabtools
21 21 from IPython.utils import py3compat
22 22 from IPython.utils.contexts import preserve_keys
23 23 from IPython.utils.path import filefind
24 24 from IPython.utils.traitlets import (
25 25 Unicode, Instance, List, Bool, CaselessStrEnum
26 26 )
27 27 from IPython.lib.inputhook import guis
28 28
29 29 #-----------------------------------------------------------------------------
30 30 # Aliases and Flags
31 31 #-----------------------------------------------------------------------------
32 32
33 33 gui_keys = tuple(sorted([ key for key in guis if key is not None ]))
34 34
35 35 backend_keys = sorted(pylabtools.backends.keys())
36 36 backend_keys.insert(0, 'auto')
37 37
38 38 shell_flags = {}
39 39
40 40 addflag = lambda *args: shell_flags.update(boolean_flag(*args))
41 41 addflag('autoindent', 'InteractiveShell.autoindent',
42 42 'Turn on autoindenting.', 'Turn off autoindenting.'
43 43 )
44 44 addflag('automagic', 'InteractiveShell.automagic',
45 45 """Turn on the auto calling of magic commands. Type %%magic at the
46 46 IPython prompt for more information.""",
47 47 'Turn off the auto calling of magic commands.'
48 48 )
49 49 addflag('pdb', 'InteractiveShell.pdb',
50 50 "Enable auto calling the pdb debugger after every exception.",
51 51 "Disable auto calling the pdb debugger after every exception."
52 52 )
53 53 # pydb flag doesn't do any config, as core.debugger switches on import,
54 54 # which is before parsing. This just allows the flag to be passed.
55 55 shell_flags.update(dict(
56 56 pydb = ({},
57 57 """Use the third party 'pydb' package as debugger, instead of pdb.
58 58 Requires that pydb is installed."""
59 59 )
60 60 ))
61 61 addflag('pprint', 'PlainTextFormatter.pprint',
62 62 "Enable auto pretty printing of results.",
63 63 "Disable auto pretty printing of results."
64 64 )
65 65 addflag('color-info', 'InteractiveShell.color_info',
66 66 """IPython can display information about objects via a set of functions,
67 67 and optionally can use colors for this, syntax highlighting
68 68 source code and various other elements. This is on by default, but can cause
69 69 problems with some pagers. If you see such problems, you can disable the
70 70 colours.""",
71 71 "Disable using colors for info related things."
72 72 )
73 73 addflag('deep-reload', 'InteractiveShell.deep_reload',
74 74 """Enable deep (recursive) reloading by default. IPython can use the
75 75 deep_reload module which reloads changes in modules recursively (it
76 76 replaces the reload() function, so you don't need to change anything to
77 77 use it). deep_reload() forces a full reload of modules whose code may
78 78 have changed, which the default reload() function does not. When
79 79 deep_reload is off, IPython will use the normal reload(), but
80 80 deep_reload will still be available as dreload(). This feature is off
81 81 by default [which means that you have both normal reload() and
82 82 dreload()].""",
83 83 "Disable deep (recursive) reloading by default."
84 84 )
85 85 nosep_config = Config()
86 86 nosep_config.InteractiveShell.separate_in = ''
87 87 nosep_config.InteractiveShell.separate_out = ''
88 88 nosep_config.InteractiveShell.separate_out2 = ''
89 89
90 90 shell_flags['nosep']=(nosep_config, "Eliminate all spacing between prompts.")
91 91 shell_flags['pylab'] = (
92 92 {'InteractiveShellApp' : {'pylab' : 'auto'}},
93 93 """Pre-load matplotlib and numpy for interactive use with
94 94 the default matplotlib backend."""
95 95 )
96 96 shell_flags['matplotlib'] = (
97 97 {'InteractiveShellApp' : {'matplotlib' : 'auto'}},
98 98 """Configure matplotlib for interactive use with
99 99 the default matplotlib backend."""
100 100 )
101 101
102 102 # it's possible we don't want short aliases for *all* of these:
103 103 shell_aliases = dict(
104 104 autocall='InteractiveShell.autocall',
105 105 colors='InteractiveShell.colors',
106 106 logfile='InteractiveShell.logfile',
107 107 logappend='InteractiveShell.logappend',
108 108 c='InteractiveShellApp.code_to_run',
109 109 m='InteractiveShellApp.module_to_run',
110 110 ext='InteractiveShellApp.extra_extension',
111 111 gui='InteractiveShellApp.gui',
112 112 pylab='InteractiveShellApp.pylab',
113 113 matplotlib='InteractiveShellApp.matplotlib',
114 114 )
115 115 shell_aliases['cache-size'] = 'InteractiveShell.cache_size'
116 116
117 117 #-----------------------------------------------------------------------------
118 118 # Main classes and functions
119 119 #-----------------------------------------------------------------------------
120 120
121 121 class InteractiveShellApp(Configurable):
122 122 """A Mixin for applications that start InteractiveShell instances.
123 123
124 124 Provides configurables for loading extensions and executing files
125 125 as part of configuring a Shell environment.
126 126
127 127 The following methods should be called by the :meth:`initialize` method
128 128 of the subclass:
129 129
130 130 - :meth:`init_path`
131 131 - :meth:`init_shell` (to be implemented by the subclass)
132 132 - :meth:`init_gui_pylab`
133 133 - :meth:`init_extensions`
134 134 - :meth:`init_code`
135 135 """
136 136 extensions = List(Unicode, config=True,
137 137 help="A list of dotted module names of IPython extensions to load."
138 138 )
139 139 extra_extension = Unicode('', config=True,
140 140 help="dotted module name of an IPython extension to load."
141 141 )
142 142
143 143 reraise_ipython_extension_failures = Bool(
144 144 False,
145 145 config=True,
146 146 help="Reraise exceptions encountered loading IPython extensions?",
147 147 )
148 148
149 149 # Extensions that are always loaded (not configurable)
150 150 default_extensions = List(Unicode, [u'storemagic'], config=False)
151 151
152 152 hide_initial_ns = Bool(True, config=True,
153 153 help="""Should variables loaded at startup (by startup files, exec_lines, etc.)
154 154 be hidden from tools like %who?"""
155 155 )
156 156
157 157 exec_files = List(Unicode, config=True,
158 158 help="""List of files to run at IPython startup."""
159 159 )
160 160 exec_PYTHONSTARTUP = Bool(True, config=True,
161 161 help="""Run the file referenced by the PYTHONSTARTUP environment
162 162 variable at IPython startup."""
163 163 )
164 164 file_to_run = Unicode('', config=True,
165 165 help="""A file to be run""")
166 166
167 167 exec_lines = List(Unicode, config=True,
168 168 help="""lines of code to run at IPython startup."""
169 169 )
170 170 code_to_run = Unicode('', config=True,
171 171 help="Execute the given command string."
172 172 )
173 173 module_to_run = Unicode('', config=True,
174 174 help="Run the module as a script."
175 175 )
176 gui = CaselessStrEnum(gui_keys, config=True,
176 gui = CaselessStrEnum(gui_keys, config=True, allow_none=True,
177 177 help="Enable GUI event loop integration with any of {0}.".format(gui_keys)
178 178 )
179 matplotlib = CaselessStrEnum(backend_keys,
179 matplotlib = CaselessStrEnum(backend_keys, allow_none=True,
180 180 config=True,
181 181 help="""Configure matplotlib for interactive use with
182 182 the default matplotlib backend."""
183 183 )
184 pylab = CaselessStrEnum(backend_keys,
184 pylab = CaselessStrEnum(backend_keys, allow_none=True,
185 185 config=True,
186 186 help="""Pre-load matplotlib and numpy for interactive use,
187 187 selecting a particular matplotlib backend and loop integration.
188 188 """
189 189 )
190 190 pylab_import_all = Bool(True, config=True,
191 191 help="""If true, IPython will populate the user namespace with numpy, pylab, etc.
192 192 and an ``import *`` is done from numpy and pylab, when using pylab mode.
193 193
194 194 When False, pylab mode should not import any names into the user namespace.
195 195 """
196 196 )
197 197 shell = Instance('IPython.core.interactiveshell.InteractiveShellABC')
198 198
199 199 user_ns = Instance(dict, args=None, allow_none=True)
200 200 def _user_ns_changed(self, name, old, new):
201 201 if self.shell is not None:
202 202 self.shell.user_ns = new
203 203 self.shell.init_user_ns()
204 204
205 205 def init_path(self):
206 206 """Add current working directory, '', to sys.path"""
207 207 if sys.path[0] != '':
208 208 sys.path.insert(0, '')
209 209
210 210 def init_shell(self):
211 211 raise NotImplementedError("Override in subclasses")
212 212
213 213 def init_gui_pylab(self):
214 214 """Enable GUI event loop integration, taking pylab into account."""
215 215 enable = False
216 216 shell = self.shell
217 217 if self.pylab:
218 218 enable = lambda key: shell.enable_pylab(key, import_all=self.pylab_import_all)
219 219 key = self.pylab
220 220 elif self.matplotlib:
221 221 enable = shell.enable_matplotlib
222 222 key = self.matplotlib
223 223 elif self.gui:
224 224 enable = shell.enable_gui
225 225 key = self.gui
226 226
227 227 if not enable:
228 228 return
229 229
230 230 try:
231 231 r = enable(key)
232 232 except ImportError:
233 233 self.log.warn("Eventloop or matplotlib integration failed. Is matplotlib installed?")
234 234 self.shell.showtraceback()
235 235 return
236 236 except Exception:
237 237 self.log.warn("GUI event loop or pylab initialization failed")
238 238 self.shell.showtraceback()
239 239 return
240 240
241 241 if isinstance(r, tuple):
242 242 gui, backend = r[:2]
243 243 self.log.info("Enabling GUI event loop integration, "
244 244 "eventloop=%s, matplotlib=%s", gui, backend)
245 245 if key == "auto":
246 246 print("Using matplotlib backend: %s" % backend)
247 247 else:
248 248 gui = r
249 249 self.log.info("Enabling GUI event loop integration, "
250 250 "eventloop=%s", gui)
251 251
252 252 def init_extensions(self):
253 253 """Load all IPython extensions in IPythonApp.extensions.
254 254
255 255 This uses the :meth:`ExtensionManager.load_extensions` to load all
256 256 the extensions listed in ``self.extensions``.
257 257 """
258 258 try:
259 259 self.log.debug("Loading IPython extensions...")
260 260 extensions = self.default_extensions + self.extensions
261 261 if self.extra_extension:
262 262 extensions.append(self.extra_extension)
263 263 for ext in extensions:
264 264 try:
265 265 self.log.info("Loading IPython extension: %s" % ext)
266 266 self.shell.extension_manager.load_extension(ext)
267 267 except:
268 268 if self.reraise_ipython_extension_failures:
269 269 raise
270 270 msg = ("Error in loading extension: {ext}\n"
271 271 "Check your config files in {location}".format(
272 272 ext=ext,
273 273 location=self.profile_dir.location
274 274 ))
275 275 self.log.warn(msg, exc_info=True)
276 276 except:
277 277 if self.reraise_ipython_extension_failures:
278 278 raise
279 279 self.log.warn("Unknown error in loading extensions:", exc_info=True)
280 280
281 281 def init_code(self):
282 282 """run the pre-flight code, specified via exec_lines"""
283 283 self._run_startup_files()
284 284 self._run_exec_lines()
285 285 self._run_exec_files()
286 286
287 287 # Hide variables defined here from %who etc.
288 288 if self.hide_initial_ns:
289 289 self.shell.user_ns_hidden.update(self.shell.user_ns)
290 290
291 291 # command-line execution (ipython -i script.py, ipython -m module)
292 292 # should *not* be excluded from %whos
293 293 self._run_cmd_line_code()
294 294 self._run_module()
295 295
296 296 # flush output, so itwon't be attached to the first cell
297 297 sys.stdout.flush()
298 298 sys.stderr.flush()
299 299
300 300 def _run_exec_lines(self):
301 301 """Run lines of code in IPythonApp.exec_lines in the user's namespace."""
302 302 if not self.exec_lines:
303 303 return
304 304 try:
305 305 self.log.debug("Running code from IPythonApp.exec_lines...")
306 306 for line in self.exec_lines:
307 307 try:
308 308 self.log.info("Running code in user namespace: %s" %
309 309 line)
310 310 self.shell.run_cell(line, store_history=False)
311 311 except:
312 312 self.log.warn("Error in executing line in user "
313 313 "namespace: %s" % line)
314 314 self.shell.showtraceback()
315 315 except:
316 316 self.log.warn("Unknown error in handling IPythonApp.exec_lines:")
317 317 self.shell.showtraceback()
318 318
319 319 def _exec_file(self, fname, shell_futures=False):
320 320 try:
321 321 full_filename = filefind(fname, [u'.', self.ipython_dir])
322 322 except IOError as e:
323 323 self.log.warn("File not found: %r"%fname)
324 324 return
325 325 # Make sure that the running script gets a proper sys.argv as if it
326 326 # were run from a system shell.
327 327 save_argv = sys.argv
328 328 sys.argv = [full_filename] + self.extra_args[1:]
329 329 # protect sys.argv from potential unicode strings on Python 2:
330 330 if not py3compat.PY3:
331 331 sys.argv = [ py3compat.cast_bytes(a) for a in sys.argv ]
332 332 try:
333 333 if os.path.isfile(full_filename):
334 334 self.log.info("Running file in user namespace: %s" %
335 335 full_filename)
336 336 # Ensure that __file__ is always defined to match Python
337 337 # behavior.
338 338 with preserve_keys(self.shell.user_ns, '__file__'):
339 339 self.shell.user_ns['__file__'] = fname
340 340 if full_filename.endswith('.ipy'):
341 341 self.shell.safe_execfile_ipy(full_filename,
342 342 shell_futures=shell_futures)
343 343 else:
344 344 # default to python, even without extension
345 345 self.shell.safe_execfile(full_filename,
346 346 self.shell.user_ns,
347 347 shell_futures=shell_futures)
348 348 finally:
349 349 sys.argv = save_argv
350 350
351 351 def _run_startup_files(self):
352 352 """Run files from profile startup directory"""
353 353 startup_dir = self.profile_dir.startup_dir
354 354 startup_files = []
355 355
356 356 if self.exec_PYTHONSTARTUP and os.environ.get('PYTHONSTARTUP', False) and \
357 357 not (self.file_to_run or self.code_to_run or self.module_to_run):
358 358 python_startup = os.environ['PYTHONSTARTUP']
359 359 self.log.debug("Running PYTHONSTARTUP file %s...", python_startup)
360 360 try:
361 361 self._exec_file(python_startup)
362 362 except:
363 363 self.log.warn("Unknown error in handling PYTHONSTARTUP file %s:", python_startup)
364 364 self.shell.showtraceback()
365 365 finally:
366 366 # Many PYTHONSTARTUP files set up the readline completions,
367 367 # but this is often at odds with IPython's own completions.
368 368 # Do not allow PYTHONSTARTUP to set up readline.
369 369 if self.shell.has_readline:
370 370 self.shell.set_readline_completer()
371 371
372 372 startup_files += glob.glob(os.path.join(startup_dir, '*.py'))
373 373 startup_files += glob.glob(os.path.join(startup_dir, '*.ipy'))
374 374 if not startup_files:
375 375 return
376 376
377 377 self.log.debug("Running startup files from %s...", startup_dir)
378 378 try:
379 379 for fname in sorted(startup_files):
380 380 self._exec_file(fname)
381 381 except:
382 382 self.log.warn("Unknown error in handling startup files:")
383 383 self.shell.showtraceback()
384 384
385 385 def _run_exec_files(self):
386 386 """Run files from IPythonApp.exec_files"""
387 387 if not self.exec_files:
388 388 return
389 389
390 390 self.log.debug("Running files in IPythonApp.exec_files...")
391 391 try:
392 392 for fname in self.exec_files:
393 393 self._exec_file(fname)
394 394 except:
395 395 self.log.warn("Unknown error in handling IPythonApp.exec_files:")
396 396 self.shell.showtraceback()
397 397
398 398 def _run_cmd_line_code(self):
399 399 """Run code or file specified at the command-line"""
400 400 if self.code_to_run:
401 401 line = self.code_to_run
402 402 try:
403 403 self.log.info("Running code given at command line (c=): %s" %
404 404 line)
405 405 self.shell.run_cell(line, store_history=False)
406 406 except:
407 407 self.log.warn("Error in executing line in user namespace: %s" %
408 408 line)
409 409 self.shell.showtraceback()
410 410
411 411 # Like Python itself, ignore the second if the first of these is present
412 412 elif self.file_to_run:
413 413 fname = self.file_to_run
414 414 try:
415 415 self._exec_file(fname, shell_futures=True)
416 416 except:
417 417 self.log.warn("Error in executing file in user namespace: %s" %
418 418 fname)
419 419 self.shell.showtraceback()
420 420
421 421 def _run_module(self):
422 422 """Run module specified at the command-line."""
423 423 if self.module_to_run:
424 424 # Make sure that the module gets a proper sys.argv as if it were
425 425 # run using `python -m`.
426 426 save_argv = sys.argv
427 427 sys.argv = [sys.executable] + self.extra_args
428 428 try:
429 429 self.shell.safe_run_module(self.module_to_run,
430 430 self.shell.user_ns)
431 431 finally:
432 432 sys.argv = save_argv
@@ -1,468 +1,468 b''
1 1 """A base class for contents managers."""
2 2
3 3 # Copyright (c) IPython Development Team.
4 4 # Distributed under the terms of the Modified BSD License.
5 5
6 6 from fnmatch import fnmatch
7 7 import itertools
8 8 import json
9 9 import os
10 10 import re
11 11
12 12 from tornado.web import HTTPError
13 13
14 14 from .checkpoints import Checkpoints
15 15 from IPython.config.configurable import LoggingConfigurable
16 16 from IPython.nbformat import sign, validate, ValidationError
17 17 from IPython.nbformat.v4 import new_notebook
18 18 from IPython.utils.importstring import import_item
19 19 from IPython.utils.traitlets import (
20 20 Any,
21 21 Dict,
22 22 Instance,
23 23 List,
24 24 TraitError,
25 25 Type,
26 26 Unicode,
27 27 )
28 28 from IPython.utils.py3compat import string_types
29 29
30 30 copy_pat = re.compile(r'\-Copy\d*\.')
31 31
32 32
33 33 class ContentsManager(LoggingConfigurable):
34 34 """Base class for serving files and directories.
35 35
36 36 This serves any text or binary file,
37 37 as well as directories,
38 38 with special handling for JSON notebook documents.
39 39
40 40 Most APIs take a path argument,
41 41 which is always an API-style unicode path,
42 42 and always refers to a directory.
43 43
44 44 - unicode, not url-escaped
45 45 - '/'-separated
46 46 - leading and trailing '/' will be stripped
47 47 - if unspecified, path defaults to '',
48 48 indicating the root path.
49 49
50 50 """
51 51
52 52 notary = Instance(sign.NotebookNotary)
53 53 def _notary_default(self):
54 54 return sign.NotebookNotary(parent=self)
55 55
56 56 hide_globs = List(Unicode, [
57 57 u'__pycache__', '*.pyc', '*.pyo',
58 58 '.DS_Store', '*.so', '*.dylib', '*~',
59 59 ], config=True, help="""
60 60 Glob patterns to hide in file and directory listings.
61 61 """)
62 62
63 63 untitled_notebook = Unicode("Untitled", config=True,
64 64 help="The base name used when creating untitled notebooks."
65 65 )
66 66
67 67 untitled_file = Unicode("untitled", config=True,
68 68 help="The base name used when creating untitled files."
69 69 )
70 70
71 71 untitled_directory = Unicode("Untitled Folder", config=True,
72 72 help="The base name used when creating untitled directories."
73 73 )
74 74
75 75 pre_save_hook = Any(None, config=True,
76 76 help="""Python callable or importstring thereof
77 77
78 78 To be called on a contents model prior to save.
79 79
80 80 This can be used to process the structure,
81 81 such as removing notebook outputs or other side effects that
82 82 should not be saved.
83 83
84 84 It will be called as (all arguments passed by keyword)::
85 85
86 86 hook(path=path, model=model, contents_manager=self)
87 87
88 88 - model: the model to be saved. Includes file contents.
89 89 Modifying this dict will affect the file that is stored.
90 90 - path: the API path of the save destination
91 91 - contents_manager: this ContentsManager instance
92 92 """
93 93 )
94 94 def _pre_save_hook_changed(self, name, old, new):
95 95 if new and isinstance(new, string_types):
96 96 self.pre_save_hook = import_item(self.pre_save_hook)
97 97 elif new:
98 98 if not callable(new):
99 99 raise TraitError("pre_save_hook must be callable")
100 100
101 101 def run_pre_save_hook(self, model, path, **kwargs):
102 102 """Run the pre-save hook if defined, and log errors"""
103 103 if self.pre_save_hook:
104 104 try:
105 105 self.log.debug("Running pre-save hook on %s", path)
106 106 self.pre_save_hook(model=model, path=path, contents_manager=self, **kwargs)
107 107 except Exception:
108 108 self.log.error("Pre-save hook failed on %s", path, exc_info=True)
109 109
110 110 checkpoints_class = Type(Checkpoints, config=True)
111 111 checkpoints = Instance(Checkpoints, config=True)
112 checkpoints_kwargs = Dict(allow_none=False, config=True)
112 checkpoints_kwargs = Dict(config=True)
113 113
114 114 def _checkpoints_default(self):
115 115 return self.checkpoints_class(**self.checkpoints_kwargs)
116 116
117 117 def _checkpoints_kwargs_default(self):
118 118 return dict(
119 119 parent=self,
120 120 log=self.log,
121 121 )
122 122
123 123 # ContentsManager API part 1: methods that must be
124 124 # implemented in subclasses.
125 125
126 126 def dir_exists(self, path):
127 127 """Does the API-style path (directory) actually exist?
128 128
129 129 Like os.path.isdir
130 130
131 131 Override this method in subclasses.
132 132
133 133 Parameters
134 134 ----------
135 135 path : string
136 136 The path to check
137 137
138 138 Returns
139 139 -------
140 140 exists : bool
141 141 Whether the path does indeed exist.
142 142 """
143 143 raise NotImplementedError
144 144
145 145 def is_hidden(self, path):
146 146 """Does the API style path correspond to a hidden directory or file?
147 147
148 148 Parameters
149 149 ----------
150 150 path : string
151 151 The path to check. This is an API path (`/` separated,
152 152 relative to root dir).
153 153
154 154 Returns
155 155 -------
156 156 hidden : bool
157 157 Whether the path is hidden.
158 158
159 159 """
160 160 raise NotImplementedError
161 161
162 162 def file_exists(self, path=''):
163 163 """Does a file exist at the given path?
164 164
165 165 Like os.path.isfile
166 166
167 167 Override this method in subclasses.
168 168
169 169 Parameters
170 170 ----------
171 171 name : string
172 172 The name of the file you are checking.
173 173 path : string
174 174 The relative path to the file's directory (with '/' as separator)
175 175
176 176 Returns
177 177 -------
178 178 exists : bool
179 179 Whether the file exists.
180 180 """
181 181 raise NotImplementedError('must be implemented in a subclass')
182 182
183 183 def exists(self, path):
184 184 """Does a file or directory exist at the given path?
185 185
186 186 Like os.path.exists
187 187
188 188 Parameters
189 189 ----------
190 190 path : string
191 191 The relative path to the file's directory (with '/' as separator)
192 192
193 193 Returns
194 194 -------
195 195 exists : bool
196 196 Whether the target exists.
197 197 """
198 198 return self.file_exists(path) or self.dir_exists(path)
199 199
200 200 def get(self, path, content=True, type=None, format=None):
201 201 """Get the model of a file or directory with or without content."""
202 202 raise NotImplementedError('must be implemented in a subclass')
203 203
204 204 def save(self, model, path):
205 205 """Save the file or directory and return the model with no content.
206 206
207 207 Save implementations should call self.run_pre_save_hook(model=model, path=path)
208 208 prior to writing any data.
209 209 """
210 210 raise NotImplementedError('must be implemented in a subclass')
211 211
212 212 def delete_file(self, path):
213 213 """Delete file or directory by path."""
214 214 raise NotImplementedError('must be implemented in a subclass')
215 215
216 216 def rename_file(self, old_path, new_path):
217 217 """Rename a file."""
218 218 raise NotImplementedError('must be implemented in a subclass')
219 219
220 220 # ContentsManager API part 2: methods that have useable default
221 221 # implementations, but can be overridden in subclasses.
222 222
223 223 def delete(self, path):
224 224 """Delete a file/directory and any associated checkpoints."""
225 225 self.delete_file(path)
226 226 self.checkpoints.delete_all_checkpoints(path)
227 227
228 228 def rename(self, old_path, new_path):
229 229 """Rename a file and any checkpoints associated with that file."""
230 230 self.rename_file(old_path, new_path)
231 231 self.checkpoints.rename_all_checkpoints(old_path, new_path)
232 232
233 233 def update(self, model, path):
234 234 """Update the file's path
235 235
236 236 For use in PATCH requests, to enable renaming a file without
237 237 re-uploading its contents. Only used for renaming at the moment.
238 238 """
239 239 path = path.strip('/')
240 240 new_path = model.get('path', path).strip('/')
241 241 if path != new_path:
242 242 self.rename(path, new_path)
243 243 model = self.get(new_path, content=False)
244 244 return model
245 245
246 246 def info_string(self):
247 247 return "Serving contents"
248 248
249 249 def get_kernel_path(self, path, model=None):
250 250 """Return the API path for the kernel
251 251
252 252 KernelManagers can turn this value into a filesystem path,
253 253 or ignore it altogether.
254 254
255 255 The default value here will start kernels in the directory of the
256 256 notebook server. FileContentsManager overrides this to use the
257 257 directory containing the notebook.
258 258 """
259 259 return ''
260 260
261 261 def increment_filename(self, filename, path='', insert=''):
262 262 """Increment a filename until it is unique.
263 263
264 264 Parameters
265 265 ----------
266 266 filename : unicode
267 267 The name of a file, including extension
268 268 path : unicode
269 269 The API path of the target's directory
270 270
271 271 Returns
272 272 -------
273 273 name : unicode
274 274 A filename that is unique, based on the input filename.
275 275 """
276 276 path = path.strip('/')
277 277 basename, ext = os.path.splitext(filename)
278 278 for i in itertools.count():
279 279 if i:
280 280 insert_i = '{}{}'.format(insert, i)
281 281 else:
282 282 insert_i = ''
283 283 name = u'{basename}{insert}{ext}'.format(basename=basename,
284 284 insert=insert_i, ext=ext)
285 285 if not self.exists(u'{}/{}'.format(path, name)):
286 286 break
287 287 return name
288 288
289 289 def validate_notebook_model(self, model):
290 290 """Add failed-validation message to model"""
291 291 try:
292 292 validate(model['content'])
293 293 except ValidationError as e:
294 294 model['message'] = u'Notebook Validation failed: {}:\n{}'.format(
295 295 e.message, json.dumps(e.instance, indent=1, default=lambda obj: '<UNKNOWN>'),
296 296 )
297 297 return model
298 298
299 299 def new_untitled(self, path='', type='', ext=''):
300 300 """Create a new untitled file or directory in path
301 301
302 302 path must be a directory
303 303
304 304 File extension can be specified.
305 305
306 306 Use `new` to create files with a fully specified path (including filename).
307 307 """
308 308 path = path.strip('/')
309 309 if not self.dir_exists(path):
310 310 raise HTTPError(404, 'No such directory: %s' % path)
311 311
312 312 model = {}
313 313 if type:
314 314 model['type'] = type
315 315
316 316 if ext == '.ipynb':
317 317 model.setdefault('type', 'notebook')
318 318 else:
319 319 model.setdefault('type', 'file')
320 320
321 321 insert = ''
322 322 if model['type'] == 'directory':
323 323 untitled = self.untitled_directory
324 324 insert = ' '
325 325 elif model['type'] == 'notebook':
326 326 untitled = self.untitled_notebook
327 327 ext = '.ipynb'
328 328 elif model['type'] == 'file':
329 329 untitled = self.untitled_file
330 330 else:
331 331 raise HTTPError(400, "Unexpected model type: %r" % model['type'])
332 332
333 333 name = self.increment_filename(untitled + ext, path, insert=insert)
334 334 path = u'{0}/{1}'.format(path, name)
335 335 return self.new(model, path)
336 336
337 337 def new(self, model=None, path=''):
338 338 """Create a new file or directory and return its model with no content.
339 339
340 340 To create a new untitled entity in a directory, use `new_untitled`.
341 341 """
342 342 path = path.strip('/')
343 343 if model is None:
344 344 model = {}
345 345
346 346 if path.endswith('.ipynb'):
347 347 model.setdefault('type', 'notebook')
348 348 else:
349 349 model.setdefault('type', 'file')
350 350
351 351 # no content, not a directory, so fill out new-file model
352 352 if 'content' not in model and model['type'] != 'directory':
353 353 if model['type'] == 'notebook':
354 354 model['content'] = new_notebook()
355 355 model['format'] = 'json'
356 356 else:
357 357 model['content'] = ''
358 358 model['type'] = 'file'
359 359 model['format'] = 'text'
360 360
361 361 model = self.save(model, path)
362 362 return model
363 363
364 364 def copy(self, from_path, to_path=None):
365 365 """Copy an existing file and return its new model.
366 366
367 367 If to_path not specified, it will be the parent directory of from_path.
368 368 If to_path is a directory, filename will increment `from_path-Copy#.ext`.
369 369
370 370 from_path must be a full path to a file.
371 371 """
372 372 path = from_path.strip('/')
373 373 if to_path is not None:
374 374 to_path = to_path.strip('/')
375 375
376 376 if '/' in path:
377 377 from_dir, from_name = path.rsplit('/', 1)
378 378 else:
379 379 from_dir = ''
380 380 from_name = path
381 381
382 382 model = self.get(path)
383 383 model.pop('path', None)
384 384 model.pop('name', None)
385 385 if model['type'] == 'directory':
386 386 raise HTTPError(400, "Can't copy directories")
387 387
388 388 if to_path is None:
389 389 to_path = from_dir
390 390 if self.dir_exists(to_path):
391 391 name = copy_pat.sub(u'.', from_name)
392 392 to_name = self.increment_filename(name, to_path, insert='-Copy')
393 393 to_path = u'{0}/{1}'.format(to_path, to_name)
394 394
395 395 model = self.save(model, to_path)
396 396 return model
397 397
398 398 def log_info(self):
399 399 self.log.info(self.info_string())
400 400
401 401 def trust_notebook(self, path):
402 402 """Explicitly trust a notebook
403 403
404 404 Parameters
405 405 ----------
406 406 path : string
407 407 The path of a notebook
408 408 """
409 409 model = self.get(path)
410 410 nb = model['content']
411 411 self.log.warn("Trusting notebook %s", path)
412 412 self.notary.mark_cells(nb, True)
413 413 self.save(model, path)
414 414
415 415 def check_and_sign(self, nb, path=''):
416 416 """Check for trusted cells, and sign the notebook.
417 417
418 418 Called as a part of saving notebooks.
419 419
420 420 Parameters
421 421 ----------
422 422 nb : dict
423 423 The notebook dict
424 424 path : string
425 425 The notebook's path (for logging)
426 426 """
427 427 if self.notary.check_cells(nb):
428 428 self.notary.sign(nb)
429 429 else:
430 430 self.log.warn("Saving untrusted notebook %s", path)
431 431
432 432 def mark_trusted_cells(self, nb, path=''):
433 433 """Mark cells as trusted if the notebook signature matches.
434 434
435 435 Called as a part of loading notebooks.
436 436
437 437 Parameters
438 438 ----------
439 439 nb : dict
440 440 The notebook object (in current nbformat)
441 441 path : string
442 442 The notebook's path (for logging)
443 443 """
444 444 trusted = self.notary.check_signature(nb)
445 445 if not trusted:
446 446 self.log.warn("Notebook %s is not trusted", path)
447 447 self.notary.mark_cells(nb, trusted)
448 448
449 449 def should_list(self, name):
450 450 """Should this file/directory name be displayed in a listing?"""
451 451 return not any(fnmatch(name, glob) for glob in self.hide_globs)
452 452
453 453 # Part 3: Checkpoints API
454 454 def create_checkpoint(self, path):
455 455 """Create a checkpoint."""
456 456 return self.checkpoints.create_checkpoint(self, path)
457 457
458 458 def restore_checkpoint(self, checkpoint_id, path):
459 459 """
460 460 Restore a checkpoint.
461 461 """
462 462 self.checkpoints.restore_checkpoint(self, checkpoint_id, path)
463 463
464 464 def list_checkpoints(self, path):
465 465 return self.checkpoints.list_checkpoints(path)
466 466
467 467 def delete_checkpoint(self, checkpoint_id, path):
468 468 return self.checkpoints.delete_checkpoint(checkpoint_id, path)
@@ -1,489 +1,489 b''
1 1 """Base Widget class. Allows user to create widgets in the back-end that render
2 2 in the IPython notebook front-end.
3 3 """
4 4 #-----------------------------------------------------------------------------
5 5 # Copyright (c) 2013, the IPython Development Team.
6 6 #
7 7 # Distributed under the terms of the Modified BSD License.
8 8 #
9 9 # The full license is in the file COPYING.txt, distributed with this software.
10 10 #-----------------------------------------------------------------------------
11 11
12 12 #-----------------------------------------------------------------------------
13 13 # Imports
14 14 #-----------------------------------------------------------------------------
15 15 from contextlib import contextmanager
16 16 import collections
17 17
18 18 from IPython.core.getipython import get_ipython
19 19 from IPython.kernel.comm import Comm
20 20 from IPython.config import LoggingConfigurable
21 21 from IPython.utils.importstring import import_item
22 22 from IPython.utils.traitlets import Unicode, Dict, Instance, Bool, List, \
23 23 CaselessStrEnum, Tuple, CUnicode, Int, Set
24 24 from IPython.utils.py3compat import string_types
25 25
26 26 #-----------------------------------------------------------------------------
27 27 # Classes
28 28 #-----------------------------------------------------------------------------
29 29 class CallbackDispatcher(LoggingConfigurable):
30 30 """A structure for registering and running callbacks"""
31 31 callbacks = List()
32 32
33 33 def __call__(self, *args, **kwargs):
34 34 """Call all of the registered callbacks."""
35 35 value = None
36 36 for callback in self.callbacks:
37 37 try:
38 38 local_value = callback(*args, **kwargs)
39 39 except Exception as e:
40 40 ip = get_ipython()
41 41 if ip is None:
42 42 self.log.warn("Exception in callback %s: %s", callback, e, exc_info=True)
43 43 else:
44 44 ip.showtraceback()
45 45 else:
46 46 value = local_value if local_value is not None else value
47 47 return value
48 48
49 49 def register_callback(self, callback, remove=False):
50 50 """(Un)Register a callback
51 51
52 52 Parameters
53 53 ----------
54 54 callback: method handle
55 55 Method to be registered or unregistered.
56 56 remove=False: bool
57 57 Whether to unregister the callback."""
58 58
59 59 # (Un)Register the callback.
60 60 if remove and callback in self.callbacks:
61 61 self.callbacks.remove(callback)
62 62 elif not remove and callback not in self.callbacks:
63 63 self.callbacks.append(callback)
64 64
65 65 def _show_traceback(method):
66 66 """decorator for showing tracebacks in IPython"""
67 67 def m(self, *args, **kwargs):
68 68 try:
69 69 return(method(self, *args, **kwargs))
70 70 except Exception as e:
71 71 ip = get_ipython()
72 72 if ip is None:
73 73 self.log.warn("Exception in widget method %s: %s", method, e, exc_info=True)
74 74 else:
75 75 ip.showtraceback()
76 76 return m
77 77
78 78
79 79 def register(key=None):
80 80 """Returns a decorator registering a widget class in the widget registry.
81 81 If no key is provided, the class name is used as a key. A key is
82 82 provided for each core IPython widget so that the frontend can use
83 83 this key regardless of the language of the kernel"""
84 84 def wrap(widget):
85 85 l = key if key is not None else widget.__module__ + widget.__name__
86 86 Widget.widget_types[l] = widget
87 87 return widget
88 88 return wrap
89 89
90 90
91 91 class Widget(LoggingConfigurable):
92 92 #-------------------------------------------------------------------------
93 93 # Class attributes
94 94 #-------------------------------------------------------------------------
95 95 _widget_construction_callback = None
96 96 widgets = {}
97 97 widget_types = {}
98 98
99 99 @staticmethod
100 100 def on_widget_constructed(callback):
101 101 """Registers a callback to be called when a widget is constructed.
102 102
103 103 The callback must have the following signature:
104 104 callback(widget)"""
105 105 Widget._widget_construction_callback = callback
106 106
107 107 @staticmethod
108 108 def _call_widget_constructed(widget):
109 109 """Static method, called when a widget is constructed."""
110 110 if Widget._widget_construction_callback is not None and callable(Widget._widget_construction_callback):
111 111 Widget._widget_construction_callback(widget)
112 112
113 113 @staticmethod
114 114 def handle_comm_opened(comm, msg):
115 115 """Static method, called when a widget is constructed."""
116 116 widget_class = import_item(msg['content']['data']['widget_class'])
117 117 widget = widget_class(comm=comm)
118 118
119 119
120 120 #-------------------------------------------------------------------------
121 121 # Traits
122 122 #-------------------------------------------------------------------------
123 123 _model_module = Unicode(None, allow_none=True, help="""A requirejs module name
124 124 in which to find _model_name. If empty, look in the global registry.""")
125 125 _model_name = Unicode('WidgetModel', help="""Name of the backbone model
126 126 registered in the front-end to create and sync this widget with.""")
127 127 _view_module = Unicode(help="""A requirejs module in which to find _view_name.
128 128 If empty, look in the global registry.""", sync=True)
129 129 _view_name = Unicode(None, allow_none=True, help="""Default view registered in the front-end
130 130 to use to represent the widget.""", sync=True)
131 131 comm = Instance('IPython.kernel.comm.Comm')
132 132
133 133 msg_throttle = Int(3, sync=True, help="""Maximum number of msgs the
134 134 front-end can send before receiving an idle msg from the back-end.""")
135 135
136 136 version = Int(0, sync=True, help="""Widget's version""")
137 137 keys = List()
138 138 def _keys_default(self):
139 139 return [name for name in self.traits(sync=True)]
140 140
141 141 _property_lock = Tuple((None, None))
142 142 _send_state_lock = Int(0)
143 _states_to_send = Set(allow_none=False)
143 _states_to_send = Set()
144 144 _display_callbacks = Instance(CallbackDispatcher, ())
145 145 _msg_callbacks = Instance(CallbackDispatcher, ())
146 146
147 147 #-------------------------------------------------------------------------
148 148 # (Con/de)structor
149 149 #-------------------------------------------------------------------------
150 150 def __init__(self, **kwargs):
151 151 """Public constructor"""
152 152 self._model_id = kwargs.pop('model_id', None)
153 153 super(Widget, self).__init__(**kwargs)
154 154
155 155 Widget._call_widget_constructed(self)
156 156 self.open()
157 157
158 158 def __del__(self):
159 159 """Object disposal"""
160 160 self.close()
161 161
162 162 #-------------------------------------------------------------------------
163 163 # Properties
164 164 #-------------------------------------------------------------------------
165 165
166 166 def open(self):
167 167 """Open a comm to the frontend if one isn't already open."""
168 168 if self.comm is None:
169 169 args = dict(target_name='ipython.widget',
170 170 data={'model_name': self._model_name,
171 171 'model_module': self._model_module})
172 172 if self._model_id is not None:
173 173 args['comm_id'] = self._model_id
174 174 self.comm = Comm(**args)
175 175
176 176 def _comm_changed(self, name, new):
177 177 """Called when the comm is changed."""
178 178 if new is None:
179 179 return
180 180 self._model_id = self.model_id
181 181
182 182 self.comm.on_msg(self._handle_msg)
183 183 Widget.widgets[self.model_id] = self
184 184
185 185 # first update
186 186 self.send_state()
187 187
188 188 @property
189 189 def model_id(self):
190 190 """Gets the model id of this widget.
191 191
192 192 If a Comm doesn't exist yet, a Comm will be created automagically."""
193 193 return self.comm.comm_id
194 194
195 195 #-------------------------------------------------------------------------
196 196 # Methods
197 197 #-------------------------------------------------------------------------
198 198
199 199 def close(self):
200 200 """Close method.
201 201
202 202 Closes the underlying comm.
203 203 When the comm is closed, all of the widget views are automatically
204 204 removed from the front-end."""
205 205 if self.comm is not None:
206 206 Widget.widgets.pop(self.model_id, None)
207 207 self.comm.close()
208 208 self.comm = None
209 209
210 210 def send_state(self, key=None):
211 211 """Sends the widget state, or a piece of it, to the front-end.
212 212
213 213 Parameters
214 214 ----------
215 215 key : unicode, or iterable (optional)
216 216 A single property's name or iterable of property names to sync with the front-end.
217 217 """
218 218 self._send({
219 219 "method" : "update",
220 220 "state" : self.get_state(key=key)
221 221 })
222 222
223 223 def get_state(self, key=None):
224 224 """Gets the widget state, or a piece of it.
225 225
226 226 Parameters
227 227 ----------
228 228 key : unicode or iterable (optional)
229 229 A single property's name or iterable of property names to get.
230 230 """
231 231 if key is None:
232 232 keys = self.keys
233 233 elif isinstance(key, string_types):
234 234 keys = [key]
235 235 elif isinstance(key, collections.Iterable):
236 236 keys = key
237 237 else:
238 238 raise ValueError("key must be a string, an iterable of keys, or None")
239 239 state = {}
240 240 for k in keys:
241 241 f = self.trait_metadata(k, 'to_json', self._trait_to_json)
242 242 value = getattr(self, k)
243 243 state[k] = f(value)
244 244 return state
245 245
246 246 def set_state(self, sync_data):
247 247 """Called when a state is received from the front-end."""
248 248 for name in self.keys:
249 249 if name in sync_data:
250 250 json_value = sync_data[name]
251 251 from_json = self.trait_metadata(name, 'from_json', self._trait_from_json)
252 252 with self._lock_property(name, json_value):
253 253 setattr(self, name, from_json(json_value))
254 254
255 255 def send(self, content):
256 256 """Sends a custom msg to the widget model in the front-end.
257 257
258 258 Parameters
259 259 ----------
260 260 content : dict
261 261 Content of the message to send.
262 262 """
263 263 self._send({"method": "custom", "content": content})
264 264
265 265 def on_msg(self, callback, remove=False):
266 266 """(Un)Register a custom msg receive callback.
267 267
268 268 Parameters
269 269 ----------
270 270 callback: callable
271 271 callback will be passed two arguments when a message arrives::
272 272
273 273 callback(widget, content)
274 274
275 275 remove: bool
276 276 True if the callback should be unregistered."""
277 277 self._msg_callbacks.register_callback(callback, remove=remove)
278 278
279 279 def on_displayed(self, callback, remove=False):
280 280 """(Un)Register a widget displayed callback.
281 281
282 282 Parameters
283 283 ----------
284 284 callback: method handler
285 285 Must have a signature of::
286 286
287 287 callback(widget, **kwargs)
288 288
289 289 kwargs from display are passed through without modification.
290 290 remove: bool
291 291 True if the callback should be unregistered."""
292 292 self._display_callbacks.register_callback(callback, remove=remove)
293 293
294 294 #-------------------------------------------------------------------------
295 295 # Support methods
296 296 #-------------------------------------------------------------------------
297 297 @contextmanager
298 298 def _lock_property(self, key, value):
299 299 """Lock a property-value pair.
300 300
301 301 The value should be the JSON state of the property.
302 302
303 303 NOTE: This, in addition to the single lock for all state changes, is
304 304 flawed. In the future we may want to look into buffering state changes
305 305 back to the front-end."""
306 306 self._property_lock = (key, value)
307 307 try:
308 308 yield
309 309 finally:
310 310 self._property_lock = (None, None)
311 311
312 312 @contextmanager
313 313 def hold_sync(self):
314 314 """Hold syncing any state until the context manager is released"""
315 315 # We increment a value so that this can be nested. Syncing will happen when
316 316 # all levels have been released.
317 317 self._send_state_lock += 1
318 318 try:
319 319 yield
320 320 finally:
321 321 self._send_state_lock -=1
322 322 if self._send_state_lock == 0:
323 323 self.send_state(self._states_to_send)
324 324 self._states_to_send.clear()
325 325
326 326 def _should_send_property(self, key, value):
327 327 """Check the property lock (property_lock)"""
328 328 to_json = self.trait_metadata(key, 'to_json', self._trait_to_json)
329 329 if (key == self._property_lock[0]
330 330 and to_json(value) == self._property_lock[1]):
331 331 return False
332 332 elif self._send_state_lock > 0:
333 333 self._states_to_send.add(key)
334 334 return False
335 335 else:
336 336 return True
337 337
338 338 # Event handlers
339 339 @_show_traceback
340 340 def _handle_msg(self, msg):
341 341 """Called when a msg is received from the front-end"""
342 342 data = msg['content']['data']
343 343 method = data['method']
344 344
345 345 # Handle backbone sync methods CREATE, PATCH, and UPDATE all in one.
346 346 if method == 'backbone':
347 347 if 'sync_data' in data:
348 348 sync_data = data['sync_data']
349 349 self.set_state(sync_data) # handles all methods
350 350
351 351 # Handle a state request.
352 352 elif method == 'request_state':
353 353 self.send_state()
354 354
355 355 # Handle a custom msg from the front-end.
356 356 elif method == 'custom':
357 357 if 'content' in data:
358 358 self._handle_custom_msg(data['content'])
359 359
360 360 # Catch remainder.
361 361 else:
362 362 self.log.error('Unknown front-end to back-end widget msg with method "%s"' % method)
363 363
364 364 def _handle_custom_msg(self, content):
365 365 """Called when a custom msg is received."""
366 366 self._msg_callbacks(self, content)
367 367
368 368 def _notify_trait(self, name, old_value, new_value):
369 369 """Called when a property has been changed."""
370 370 # Trigger default traitlet callback machinery. This allows any user
371 371 # registered validation to be processed prior to allowing the widget
372 372 # machinery to handle the state.
373 373 LoggingConfigurable._notify_trait(self, name, old_value, new_value)
374 374
375 375 # Send the state after the user registered callbacks for trait changes
376 376 # have all fired (allows for user to validate values).
377 377 if self.comm is not None and name in self.keys:
378 378 # Make sure this isn't information that the front-end just sent us.
379 379 if self._should_send_property(name, new_value):
380 380 # Send new state to front-end
381 381 self.send_state(key=name)
382 382
383 383 def _handle_displayed(self, **kwargs):
384 384 """Called when a view has been displayed for this widget instance"""
385 385 self._display_callbacks(self, **kwargs)
386 386
387 387 def _trait_to_json(self, x):
388 388 """Convert a trait value to json
389 389
390 390 Traverse lists/tuples and dicts and serialize their values as well.
391 391 Replace any widgets with their model_id
392 392 """
393 393 if isinstance(x, dict):
394 394 return {k: self._trait_to_json(v) for k, v in x.items()}
395 395 elif isinstance(x, (list, tuple)):
396 396 return [self._trait_to_json(v) for v in x]
397 397 elif isinstance(x, Widget):
398 398 return "IPY_MODEL_" + x.model_id
399 399 else:
400 400 return x # Value must be JSON-able
401 401
402 402 def _trait_from_json(self, x):
403 403 """Convert json values to objects
404 404
405 405 Replace any strings representing valid model id values to Widget references.
406 406 """
407 407 if isinstance(x, dict):
408 408 return {k: self._trait_from_json(v) for k, v in x.items()}
409 409 elif isinstance(x, (list, tuple)):
410 410 return [self._trait_from_json(v) for v in x]
411 411 elif isinstance(x, string_types) and x.startswith('IPY_MODEL_') and x[10:] in Widget.widgets:
412 412 # we want to support having child widgets at any level in a hierarchy
413 413 # trusting that a widget UUID will not appear out in the wild
414 414 return Widget.widgets[x[10:]]
415 415 else:
416 416 return x
417 417
418 418 def _ipython_display_(self, **kwargs):
419 419 """Called when `IPython.display.display` is called on the widget."""
420 420 # Show view.
421 421 if self._view_name is not None:
422 422 self._send({"method": "display"})
423 423 self._handle_displayed(**kwargs)
424 424
425 425 def _send(self, msg):
426 426 """Sends a message to the model in the front-end."""
427 427 self.comm.send(msg)
428 428
429 429
430 430 class DOMWidget(Widget):
431 431 visible = Bool(True, allow_none=True, help="Whether the widget is visible. False collapses the empty space, while None preserves the empty space.", sync=True)
432 432 _css = Tuple(sync=True, help="CSS property list: (selector, key, value)")
433 433 _dom_classes = Tuple(sync=True, help="DOM classes applied to widget.$el.")
434 434
435 435 width = CUnicode(sync=True)
436 436 height = CUnicode(sync=True)
437 437 # A default padding of 2.5 px makes the widgets look nice when displayed inline.
438 438 padding = CUnicode(sync=True)
439 439 margin = CUnicode(sync=True)
440 440
441 441 color = Unicode(sync=True)
442 442 background_color = Unicode(sync=True)
443 443 border_color = Unicode(sync=True)
444 444
445 445 border_width = CUnicode(sync=True)
446 446 border_radius = CUnicode(sync=True)
447 447 border_style = CaselessStrEnum(values=[ # http://www.w3schools.com/cssref/pr_border-style.asp
448 448 'none',
449 449 'hidden',
450 450 'dotted',
451 451 'dashed',
452 452 'solid',
453 453 'double',
454 454 'groove',
455 455 'ridge',
456 456 'inset',
457 457 'outset',
458 458 'initial',
459 459 'inherit', ''],
460 460 default_value='', sync=True)
461 461
462 462 font_style = CaselessStrEnum(values=[ # http://www.w3schools.com/cssref/pr_font_font-style.asp
463 463 'normal',
464 464 'italic',
465 465 'oblique',
466 466 'initial',
467 467 'inherit', ''],
468 468 default_value='', sync=True)
469 469 font_weight = CaselessStrEnum(values=[ # http://www.w3schools.com/cssref/pr_font_weight.asp
470 470 'normal',
471 471 'bold',
472 472 'bolder',
473 473 'lighter',
474 474 'initial',
475 475 'inherit', ''] + list(map(str, range(100,1000,100))),
476 476 default_value='', sync=True)
477 477 font_size = CUnicode(sync=True)
478 478 font_family = Unicode(sync=True)
479 479
480 480 def __init__(self, *pargs, **kwargs):
481 481 super(DOMWidget, self).__init__(*pargs, **kwargs)
482 482
483 483 def _validate_border(name, old, new):
484 484 if new is not None and new != '':
485 485 if name != 'border_width' and not self.border_width:
486 486 self.border_width = 1
487 487 if name != 'border_style' and self.border_style == '':
488 488 self.border_style = 'solid'
489 489 self.on_trait_change(_validate_border, ['border_width', 'border_style', 'border_color'])
@@ -1,80 +1,80 b''
1 1 """Box class.
2 2
3 3 Represents a container that can be used to group other widgets.
4 4 """
5 5
6 6 # Copyright (c) IPython Development Team.
7 7 # Distributed under the terms of the Modified BSD License.
8 8
9 9 from .widget import DOMWidget, register
10 10 from IPython.utils.traitlets import Unicode, Tuple, TraitError, Int, CaselessStrEnum
11 11 from IPython.utils.warn import DeprecatedClass
12 12
13 13 @register('IPython.Box')
14 14 class Box(DOMWidget):
15 15 """Displays multiple widgets in a group."""
16 16 _view_name = Unicode('BoxView', sync=True)
17 17
18 18 # Child widgets in the container.
19 19 # Using a tuple here to force reassignment to update the list.
20 20 # When a proper notifying-list trait exists, that is what should be used here.
21 children = Tuple(sync=True, allow_none=False)
21 children = Tuple(sync=True)
22 22
23 23 _overflow_values = ['visible', 'hidden', 'scroll', 'auto', 'initial', 'inherit', '']
24 24 overflow_x = CaselessStrEnum(
25 25 values=_overflow_values,
26 default_value='', allow_none=False, sync=True, help="""Specifies what
26 default_value='', sync=True, help="""Specifies what
27 27 happens to content that is too large for the rendered region.""")
28 28 overflow_y = CaselessStrEnum(
29 29 values=_overflow_values,
30 default_value='', allow_none=False, sync=True, help="""Specifies what
30 default_value='', sync=True, help="""Specifies what
31 31 happens to content that is too large for the rendered region.""")
32 32
33 33 box_style = CaselessStrEnum(
34 34 values=['success', 'info', 'warning', 'danger', ''],
35 35 default_value='', allow_none=True, sync=True, help="""Use a
36 36 predefined styling for the box.""")
37 37
38 38 def __init__(self, children = (), **kwargs):
39 39 kwargs['children'] = children
40 40 super(Box, self).__init__(**kwargs)
41 41 self.on_displayed(Box._fire_children_displayed)
42 42
43 43 def _fire_children_displayed(self):
44 44 for child in self.children:
45 45 child._handle_displayed()
46 46
47 47
48 48 @register('IPython.FlexBox')
49 49 class FlexBox(Box):
50 50 """Displays multiple widgets using the flexible box model."""
51 51 _view_name = Unicode('FlexBoxView', sync=True)
52 52 orientation = CaselessStrEnum(values=['vertical', 'horizontal'], default_value='vertical', sync=True)
53 53 flex = Int(0, sync=True, help="""Specify the flexible-ness of the model.""")
54 54 def _flex_changed(self, name, old, new):
55 55 new = min(max(0, new), 2)
56 56 if self.flex != new:
57 57 self.flex = new
58 58
59 59 _locations = ['start', 'center', 'end', 'baseline', 'stretch']
60 60 pack = CaselessStrEnum(
61 61 values=_locations,
62 default_value='start', allow_none=False, sync=True)
62 default_value='start', sync=True)
63 63 align = CaselessStrEnum(
64 64 values=_locations,
65 default_value='start', allow_none=False, sync=True)
65 default_value='start', sync=True)
66 66
67 67
68 68 def VBox(*pargs, **kwargs):
69 69 """Displays multiple widgets vertically using the flexible box model."""
70 70 kwargs['orientation'] = 'vertical'
71 71 return FlexBox(*pargs, **kwargs)
72 72
73 73 def HBox(*pargs, **kwargs):
74 74 """Displays multiple widgets horizontally using the flexible box model."""
75 75 kwargs['orientation'] = 'horizontal'
76 76 return FlexBox(*pargs, **kwargs)
77 77
78 78
79 79 # Remove in IPython 4.0
80 80 ContainerWidget = DeprecatedClass(Box, 'ContainerWidget')
@@ -1,298 +1,296 b''
1 1 """Float class.
2 2
3 3 Represents an unbounded float using a widget.
4 4 """
5 5 #-----------------------------------------------------------------------------
6 6 # Copyright (c) 2013, the IPython Development Team.
7 7 #
8 8 # Distributed under the terms of the Modified BSD License.
9 9 #
10 10 # The full license is in the file COPYING.txt, distributed with this software.
11 11 #-----------------------------------------------------------------------------
12 12
13 13 #-----------------------------------------------------------------------------
14 14 # Imports
15 15 #-----------------------------------------------------------------------------
16 16 from .widget import DOMWidget, register
17 17 from IPython.utils.traitlets import Unicode, CFloat, Bool, CaselessStrEnum, Tuple
18 18 from IPython.utils.warn import DeprecatedClass
19 19
20 20 #-----------------------------------------------------------------------------
21 21 # Classes
22 22 #-----------------------------------------------------------------------------
23 23 class _Float(DOMWidget):
24 24 value = CFloat(0.0, help="Float value", sync=True)
25 25 disabled = Bool(False, help="Enable or disable user changes", sync=True)
26 26 description = Unicode(help="Description of the value this widget represents", sync=True)
27 27
28 28 def __init__(self, value=None, **kwargs):
29 29 if value is not None:
30 30 kwargs['value'] = value
31 31 super(_Float, self).__init__(**kwargs)
32 32
33 33 class _BoundedFloat(_Float):
34 34 max = CFloat(100.0, help="Max value", sync=True)
35 35 min = CFloat(0.0, help="Min value", sync=True)
36 36 step = CFloat(0.1, help="Minimum step that the value can take (ignored by some views)", sync=True)
37 37
38 38 def __init__(self, *pargs, **kwargs):
39 39 """Constructor"""
40 40 super(_BoundedFloat, self).__init__(*pargs, **kwargs)
41 41 self._handle_value_changed('value', None, self.value)
42 42 self._handle_max_changed('max', None, self.max)
43 43 self._handle_min_changed('min', None, self.min)
44 44 self.on_trait_change(self._handle_value_changed, 'value')
45 45 self.on_trait_change(self._handle_max_changed, 'max')
46 46 self.on_trait_change(self._handle_min_changed, 'min')
47 47
48 48 def _handle_value_changed(self, name, old, new):
49 49 """Validate value."""
50 50 if self.min > new or new > self.max:
51 51 self.value = min(max(new, self.min), self.max)
52 52
53 53 def _handle_max_changed(self, name, old, new):
54 54 """Make sure the min is always <= the max."""
55 55 if new < self.min:
56 56 raise ValueError("setting max < min")
57 57 if new < self.value:
58 58 self.value = new
59 59
60 60 def _handle_min_changed(self, name, old, new):
61 61 """Make sure the max is always >= the min."""
62 62 if new > self.max:
63 63 raise ValueError("setting min > max")
64 64 if new > self.value:
65 65 self.value = new
66 66
67 67
68 68 @register('IPython.FloatText')
69 69 class FloatText(_Float):
70 70 """ Displays a float value within a textbox. For a textbox in
71 71 which the value must be within a specific range, use BoundedFloatText.
72 72
73 73 Parameters
74 74 ----------
75 75 value : float
76 76 value displayed
77 77 description : str
78 78 description displayed next to the textbox
79 79 color : str Unicode color code (eg. '#C13535'), optional
80 80 color of the value displayed
81 81 """
82 82 _view_name = Unicode('FloatTextView', sync=True)
83 83
84 84
85 85 @register('IPython.BoundedFloatText')
86 86 class BoundedFloatText(_BoundedFloat):
87 87 """ Displays a float value within a textbox. Value must be within the range specified.
88 88 For a textbox in which the value doesn't need to be within a specific range, use FloatText.
89 89
90 90 Parameters
91 91 ----------
92 92 value : float
93 93 value displayed
94 94 min : float
95 95 minimal value of the range of possible values displayed
96 96 max : float
97 97 maximal value of the range of possible values displayed
98 98 description : str
99 99 description displayed next to the textbox
100 100 color : str Unicode color code (eg. '#C13535'), optional
101 101 color of the value displayed
102 102 """
103 103 _view_name = Unicode('FloatTextView', sync=True)
104 104
105 105
106 106 @register('IPython.FloatSlider')
107 107 class FloatSlider(_BoundedFloat):
108 108 """ Slider/trackbar of floating values with the specified range.
109 109
110 110 Parameters
111 111 ----------
112 112 value : float
113 113 position of the slider
114 114 min : float
115 115 minimal position of the slider
116 116 max : float
117 117 maximal position of the slider
118 118 step : float
119 119 step of the trackbar
120 120 description : str
121 121 name of the slider
122 122 orientation : {'vertical', 'horizontal}, optional
123 123 default is horizontal
124 124 readout : {True, False}, optional
125 125 default is True, display the current value of the slider next to it
126 126 slider_color : str Unicode color code (eg. '#C13535'), optional
127 127 color of the slider
128 128 color : str Unicode color code (eg. '#C13535'), optional
129 129 color of the value displayed (if readout == True)
130 130 """
131 131 _view_name = Unicode('FloatSliderView', sync=True)
132 132 orientation = CaselessStrEnum(values=['horizontal', 'vertical'],
133 default_value='horizontal',
134 help="Vertical or horizontal.", allow_none=False, sync=True)
133 default_value='horizontal', help="Vertical or horizontal.", sync=True)
135 134 _range = Bool(False, help="Display a range selector", sync=True)
136 135 readout = Bool(True, help="Display the current value of the slider next to it.", sync=True)
137 136 slider_color = Unicode(sync=True)
138 137
139 138
140 139 @register('IPython.FloatProgress')
141 140 class FloatProgress(_BoundedFloat):
142 141 """ Displays a progress bar.
143 142
144 143 Parameters
145 144 -----------
146 145 value : float
147 146 position within the range of the progress bar
148 147 min : float
149 148 minimal position of the slider
150 149 max : float
151 150 maximal position of the slider
152 151 step : float
153 152 step of the progress bar
154 153 description : str
155 154 name of the progress bar
156 155 bar_style: {'success', 'info', 'warning', 'danger', ''}, optional
157 156 color of the progress bar, default is '' (blue)
158 157 colors are: 'success'-green, 'info'-light blue, 'warning'-orange, 'danger'-red
159 158 """
160 159 _view_name = Unicode('ProgressView', sync=True)
161 160
162 161 bar_style = CaselessStrEnum(
163 162 values=['success', 'info', 'warning', 'danger', ''],
164 163 default_value='', allow_none=True, sync=True, help="""Use a
165 164 predefined styling for the progess bar.""")
166 165
167 166 class _FloatRange(_Float):
168 167 value = Tuple(CFloat, CFloat, default_value=(0.0, 1.0), help="Tuple of (lower, upper) bounds", sync=True)
169 168 lower = CFloat(0.0, help="Lower bound", sync=False)
170 169 upper = CFloat(1.0, help="Upper bound", sync=False)
171 170
172 171 def __init__(self, *pargs, **kwargs):
173 172 value_given = 'value' in kwargs
174 173 lower_given = 'lower' in kwargs
175 174 upper_given = 'upper' in kwargs
176 175 if value_given and (lower_given or upper_given):
177 176 raise ValueError("Cannot specify both 'value' and 'lower'/'upper' for range widget")
178 177 if lower_given != upper_given:
179 178 raise ValueError("Must specify both 'lower' and 'upper' for range widget")
180 179
181 180 DOMWidget.__init__(self, *pargs, **kwargs)
182 181
183 182 # ensure the traits match, preferring whichever (if any) was given in kwargs
184 183 if value_given:
185 184 self.lower, self.upper = self.value
186 185 else:
187 186 self.value = (self.lower, self.upper)
188 187
189 188 self.on_trait_change(self._validate, ['value', 'upper', 'lower'])
190 189
191 190 def _validate(self, name, old, new):
192 191 if name == 'value':
193 192 self.lower, self.upper = min(new), max(new)
194 193 elif name == 'lower':
195 194 self.value = (new, self.value[1])
196 195 elif name == 'upper':
197 196 self.value = (self.value[0], new)
198 197
199 198 class _BoundedFloatRange(_FloatRange):
200 199 step = CFloat(1.0, help="Minimum step that the value can take (ignored by some views)", sync=True)
201 200 max = CFloat(100.0, help="Max value", sync=True)
202 201 min = CFloat(0.0, help="Min value", sync=True)
203 202
204 203 def __init__(self, *pargs, **kwargs):
205 204 any_value_given = 'value' in kwargs or 'upper' in kwargs or 'lower' in kwargs
206 205 _FloatRange.__init__(self, *pargs, **kwargs)
207 206
208 207 # ensure a minimal amount of sanity
209 208 if self.min > self.max:
210 209 raise ValueError("min must be <= max")
211 210
212 211 if any_value_given:
213 212 # if a value was given, clamp it within (min, max)
214 213 self._validate("value", None, self.value)
215 214 else:
216 215 # otherwise, set it to 25-75% to avoid the handles overlapping
217 216 self.value = (0.75*self.min + 0.25*self.max,
218 217 0.25*self.min + 0.75*self.max)
219 218 # callback already set for 'value', 'lower', 'upper'
220 219 self.on_trait_change(self._validate, ['min', 'max'])
221 220
222 221
223 222 def _validate(self, name, old, new):
224 223 if name == "min":
225 224 if new > self.max:
226 225 raise ValueError("setting min > max")
227 226 self.min = new
228 227 elif name == "max":
229 228 if new < self.min:
230 229 raise ValueError("setting max < min")
231 230 self.max = new
232 231
233 232 low, high = self.value
234 233 if name == "value":
235 234 low, high = min(new), max(new)
236 235 elif name == "upper":
237 236 if new < self.lower:
238 237 raise ValueError("setting upper < lower")
239 238 high = new
240 239 elif name == "lower":
241 240 if new > self.upper:
242 241 raise ValueError("setting lower > upper")
243 242 low = new
244 243
245 244 low = max(self.min, min(low, self.max))
246 245 high = min(self.max, max(high, self.min))
247 246
248 247 # determine the order in which we should update the
249 248 # lower, upper traits to avoid a temporary inverted overlap
250 249 lower_first = high < self.lower
251 250
252 251 self.value = (low, high)
253 252 if lower_first:
254 253 self.lower = low
255 254 self.upper = high
256 255 else:
257 256 self.upper = high
258 257 self.lower = low
259 258
260 259
261 260 @register('IPython.FloatRangeSlider')
262 261 class FloatRangeSlider(_BoundedFloatRange):
263 262 """ Slider/trackbar for displaying a floating value range (within the specified range of values).
264 263
265 264 Parameters
266 265 ----------
267 266 value : float tuple
268 267 range of the slider displayed
269 268 min : float
270 269 minimal position of the slider
271 270 max : float
272 271 maximal position of the slider
273 272 step : float
274 273 step of the trackbar
275 274 description : str
276 275 name of the slider
277 276 orientation : {'vertical', 'horizontal}, optional
278 277 default is horizontal
279 278 readout : {True, False}, optional
280 279 default is True, display the current value of the slider next to it
281 280 slider_color : str Unicode color code (eg. '#C13535'), optional
282 281 color of the slider
283 282 color : str Unicode color code (eg. '#C13535'), optional
284 283 color of the value displayed (if readout == True)
285 284 """
286 285 _view_name = Unicode('FloatSliderView', sync=True)
287 286 orientation = CaselessStrEnum(values=['horizontal', 'vertical'],
288 default_value='horizontal', allow_none=False,
289 help="Vertical or horizontal.", sync=True)
287 default_value='horizontal', help="Vertical or horizontal.", sync=True)
290 288 _range = Bool(True, help="Display a range selector", sync=True)
291 289 readout = Bool(True, help="Display the current value of the slider next to it.", sync=True)
292 290 slider_color = Unicode(sync=True)
293 291
294 292 # Remove in IPython 4.0
295 293 FloatTextWidget = DeprecatedClass(FloatText, 'FloatTextWidget')
296 294 BoundedFloatTextWidget = DeprecatedClass(BoundedFloatText, 'BoundedFloatTextWidget')
297 295 FloatSliderWidget = DeprecatedClass(FloatSlider, 'FloatSliderWidget')
298 296 FloatProgressWidget = DeprecatedClass(FloatProgress, 'FloatProgressWidget')
@@ -1,209 +1,207 b''
1 1 """Int class.
2 2
3 3 Represents an unbounded int using a widget.
4 4 """
5 5 #-----------------------------------------------------------------------------
6 6 # Copyright (c) 2013, the IPython Development Team.
7 7 #
8 8 # Distributed under the terms of the Modified BSD License.
9 9 #
10 10 # The full license is in the file COPYING.txt, distributed with this software.
11 11 #-----------------------------------------------------------------------------
12 12
13 13 #-----------------------------------------------------------------------------
14 14 # Imports
15 15 #-----------------------------------------------------------------------------
16 16 from .widget import DOMWidget, register
17 17 from IPython.utils.traitlets import Unicode, CInt, Bool, CaselessStrEnum, Tuple
18 18 from IPython.utils.warn import DeprecatedClass
19 19
20 20 #-----------------------------------------------------------------------------
21 21 # Classes
22 22 #-----------------------------------------------------------------------------
23 23 class _Int(DOMWidget):
24 24 """Base class used to create widgets that represent an int."""
25 25 value = CInt(0, help="Int value", sync=True)
26 26 disabled = Bool(False, help="Enable or disable user changes", sync=True)
27 27 description = Unicode(help="Description of the value this widget represents", sync=True)
28 28
29 29 def __init__(self, value=None, **kwargs):
30 30 if value is not None:
31 31 kwargs['value'] = value
32 32 super(_Int, self).__init__(**kwargs)
33 33
34 34 class _BoundedInt(_Int):
35 35 """Base class used to create widgets that represent a int that is bounded
36 36 by a minium and maximum."""
37 37 step = CInt(1, help="Minimum step that the value can take (ignored by some views)", sync=True)
38 38 max = CInt(100, help="Max value", sync=True)
39 39 min = CInt(0, help="Min value", sync=True)
40 40
41 41 def __init__(self, *pargs, **kwargs):
42 42 """Constructor"""
43 43 super(_BoundedInt, self).__init__(*pargs, **kwargs)
44 44 self._handle_value_changed('value', None, self.value)
45 45 self._handle_max_changed('max', None, self.max)
46 46 self._handle_min_changed('min', None, self.min)
47 47 self.on_trait_change(self._handle_value_changed, 'value')
48 48 self.on_trait_change(self._handle_max_changed, 'max')
49 49 self.on_trait_change(self._handle_min_changed, 'min')
50 50
51 51 def _handle_value_changed(self, name, old, new):
52 52 """Validate value."""
53 53 if self.min > new or new > self.max:
54 54 self.value = min(max(new, self.min), self.max)
55 55
56 56 def _handle_max_changed(self, name, old, new):
57 57 """Make sure the min is always <= the max."""
58 58 if new < self.min:
59 59 raise ValueError("setting max < min")
60 60 if new < self.value:
61 61 self.value = new
62 62
63 63 def _handle_min_changed(self, name, old, new):
64 64 """Make sure the max is always >= the min."""
65 65 if new > self.max:
66 66 raise ValueError("setting min > max")
67 67 if new > self.value:
68 68 self.value = new
69 69
70 70 @register('IPython.IntText')
71 71 class IntText(_Int):
72 72 """Textbox widget that represents a int."""
73 73 _view_name = Unicode('IntTextView', sync=True)
74 74
75 75
76 76 @register('IPython.BoundedIntText')
77 77 class BoundedIntText(_BoundedInt):
78 78 """Textbox widget that represents a int bounded by a minimum and maximum value."""
79 79 _view_name = Unicode('IntTextView', sync=True)
80 80
81 81
82 82 @register('IPython.IntSlider')
83 83 class IntSlider(_BoundedInt):
84 84 """Slider widget that represents a int bounded by a minimum and maximum value."""
85 85 _view_name = Unicode('IntSliderView', sync=True)
86 86 orientation = CaselessStrEnum(values=['horizontal', 'vertical'],
87 default_value='horizontal', allow_none=False,
88 help="Vertical or horizontal.", sync=True)
87 default_value='horizontal', help="Vertical or horizontal.", sync=True)
89 88 _range = Bool(False, help="Display a range selector", sync=True)
90 89 readout = Bool(True, help="Display the current value of the slider next to it.", sync=True)
91 90 slider_color = Unicode(sync=True)
92 91
93 92
94 93 @register('IPython.IntProgress')
95 94 class IntProgress(_BoundedInt):
96 95 """Progress bar that represents a int bounded by a minimum and maximum value."""
97 96 _view_name = Unicode('ProgressView', sync=True)
98 97
99 98 bar_style = CaselessStrEnum(
100 99 values=['success', 'info', 'warning', 'danger', ''],
101 100 default_value='', allow_none=True, sync=True, help="""Use a
102 101 predefined styling for the progess bar.""")
103 102
104 103 class _IntRange(_Int):
105 104 value = Tuple(CInt, CInt, default_value=(0, 1), help="Tuple of (lower, upper) bounds", sync=True)
106 105 lower = CInt(0, help="Lower bound", sync=False)
107 106 upper = CInt(1, help="Upper bound", sync=False)
108 107
109 108 def __init__(self, *pargs, **kwargs):
110 109 value_given = 'value' in kwargs
111 110 lower_given = 'lower' in kwargs
112 111 upper_given = 'upper' in kwargs
113 112 if value_given and (lower_given or upper_given):
114 113 raise ValueError("Cannot specify both 'value' and 'lower'/'upper' for range widget")
115 114 if lower_given != upper_given:
116 115 raise ValueError("Must specify both 'lower' and 'upper' for range widget")
117 116
118 117 super(_IntRange, self).__init__(*pargs, **kwargs)
119 118
120 119 # ensure the traits match, preferring whichever (if any) was given in kwargs
121 120 if value_given:
122 121 self.lower, self.upper = self.value
123 122 else:
124 123 self.value = (self.lower, self.upper)
125 124
126 125 self.on_trait_change(self._validate, ['value', 'upper', 'lower'])
127 126
128 127 def _validate(self, name, old, new):
129 128 if name == 'value':
130 129 self.lower, self.upper = min(new), max(new)
131 130 elif name == 'lower':
132 131 self.value = (new, self.value[1])
133 132 elif name == 'upper':
134 133 self.value = (self.value[0], new)
135 134
136 135 class _BoundedIntRange(_IntRange):
137 136 step = CInt(1, help="Minimum step that the value can take (ignored by some views)", sync=True)
138 137 max = CInt(100, help="Max value", sync=True)
139 138 min = CInt(0, help="Min value", sync=True)
140 139
141 140 def __init__(self, *pargs, **kwargs):
142 141 any_value_given = 'value' in kwargs or 'upper' in kwargs or 'lower' in kwargs
143 142 _IntRange.__init__(self, *pargs, **kwargs)
144 143
145 144 # ensure a minimal amount of sanity
146 145 if self.min > self.max:
147 146 raise ValueError("min must be <= max")
148 147
149 148 if any_value_given:
150 149 # if a value was given, clamp it within (min, max)
151 150 self._validate("value", None, self.value)
152 151 else:
153 152 # otherwise, set it to 25-75% to avoid the handles overlapping
154 153 self.value = (0.75*self.min + 0.25*self.max,
155 154 0.25*self.min + 0.75*self.max)
156 155 # callback already set for 'value', 'lower', 'upper'
157 156 self.on_trait_change(self._validate, ['min', 'max'])
158 157
159 158 def _validate(self, name, old, new):
160 159 if name == "min":
161 160 if new > self.max:
162 161 raise ValueError("setting min > max")
163 162 elif name == "max":
164 163 if new < self.min:
165 164 raise ValueError("setting max < min")
166 165
167 166 low, high = self.value
168 167 if name == "value":
169 168 low, high = min(new), max(new)
170 169 elif name == "upper":
171 170 if new < self.lower:
172 171 raise ValueError("setting upper < lower")
173 172 high = new
174 173 elif name == "lower":
175 174 if new > self.upper:
176 175 raise ValueError("setting lower > upper")
177 176 low = new
178 177
179 178 low = max(self.min, min(low, self.max))
180 179 high = min(self.max, max(high, self.min))
181 180
182 181 # determine the order in which we should update the
183 182 # lower, upper traits to avoid a temporary inverted overlap
184 183 lower_first = high < self.lower
185 184
186 185 self.value = (low, high)
187 186 if lower_first:
188 187 self.lower = low
189 188 self.upper = high
190 189 else:
191 190 self.upper = high
192 191 self.lower = low
193 192
194 193 @register('IPython.IntRangeSlider')
195 194 class IntRangeSlider(_BoundedIntRange):
196 195 """Slider widget that represents a pair of ints between a minimum and maximum value."""
197 196 _view_name = Unicode('IntSliderView', sync=True)
198 197 orientation = CaselessStrEnum(values=['horizontal', 'vertical'],
199 default_value='horizontal', allow_none=False,
200 help="Vertical or horizontal.", sync=True)
198 default_value='horizontal', help="Vertical or horizontal.", sync=True)
201 199 _range = Bool(True, help="Display a range selector", sync=True)
202 200 readout = Bool(True, help="Display the current value of the slider next to it.", sync=True)
203 201 slider_color = Unicode(sync=True)
204 202
205 203 # Remove in IPython 4.0
206 204 IntTextWidget = DeprecatedClass(IntText, 'IntTextWidget')
207 205 BoundedIntTextWidget = DeprecatedClass(BoundedIntText, 'BoundedIntTextWidget')
208 206 IntSliderWidget = DeprecatedClass(IntSlider, 'IntSliderWidget')
209 207 IntProgressWidget = DeprecatedClass(IntProgress, 'IntProgressWidget')
@@ -1,496 +1,496 b''
1 1 """Test suite for our zeromq-based message specification."""
2 2
3 3 # Copyright (c) IPython Development Team.
4 4 # Distributed under the terms of the Modified BSD License.
5 5
6 6 import re
7 7 import sys
8 8 from distutils.version import LooseVersion as V
9 9 try:
10 10 from queue import Empty # Py 3
11 11 except ImportError:
12 12 from Queue import Empty # Py 2
13 13
14 14 import nose.tools as nt
15 15
16 16 from IPython.utils.traitlets import (
17 17 HasTraits, TraitError, Bool, Unicode, Dict, Integer, List, Enum,
18 18 )
19 19 from IPython.utils.py3compat import string_types, iteritems
20 20
21 21 from .utils import TIMEOUT, start_global_kernel, flush_channels, execute
22 22
23 23 #-----------------------------------------------------------------------------
24 24 # Globals
25 25 #-----------------------------------------------------------------------------
26 26 KC = None
27 27
28 28 def setup():
29 29 global KC
30 30 KC = start_global_kernel()
31 31
32 32 #-----------------------------------------------------------------------------
33 33 # Message Spec References
34 34 #-----------------------------------------------------------------------------
35 35
36 36 class Reference(HasTraits):
37 37
38 38 """
39 39 Base class for message spec specification testing.
40 40
41 41 This class is the core of the message specification test. The
42 42 idea is that child classes implement trait attributes for each
43 43 message keys, so that message keys can be tested against these
44 44 traits using :meth:`check` method.
45 45
46 46 """
47 47
48 48 def check(self, d):
49 49 """validate a dict against our traits"""
50 50 for key in self.trait_names():
51 51 nt.assert_in(key, d)
52 52 # FIXME: always allow None, probably not a good idea
53 53 if d[key] is None:
54 54 continue
55 55 try:
56 56 setattr(self, key, d[key])
57 57 except TraitError as e:
58 58 assert False, str(e)
59 59
60 60
61 61 class Version(Unicode):
62 62 def __init__(self, *args, **kwargs):
63 63 self.min = kwargs.pop('min', None)
64 64 self.max = kwargs.pop('max', None)
65 65 kwargs['default_value'] = self.min
66 66 super(Version, self).__init__(*args, **kwargs)
67 67
68 68 def validate(self, obj, value):
69 69 if self.min and V(value) < V(self.min):
70 70 raise TraitError("bad version: %s < %s" % (value, self.min))
71 71 if self.max and (V(value) > V(self.max)):
72 72 raise TraitError("bad version: %s > %s" % (value, self.max))
73 73
74 74
75 75 class RMessage(Reference):
76 76 msg_id = Unicode()
77 77 msg_type = Unicode()
78 78 header = Dict()
79 79 parent_header = Dict()
80 80 content = Dict()
81 81
82 82 def check(self, d):
83 83 super(RMessage, self).check(d)
84 84 RHeader().check(self.header)
85 85 if self.parent_header:
86 86 RHeader().check(self.parent_header)
87 87
88 88 class RHeader(Reference):
89 89 msg_id = Unicode()
90 90 msg_type = Unicode()
91 91 session = Unicode()
92 92 username = Unicode()
93 93 version = Version(min='5.0')
94 94
95 95 mime_pat = re.compile(r'^[\w\-\+\.]+/[\w\-\+\.]+$')
96 96
97 97 class MimeBundle(Reference):
98 98 metadata = Dict()
99 99 data = Dict()
100 100 def _data_changed(self, name, old, new):
101 101 for k,v in iteritems(new):
102 102 assert mime_pat.match(k)
103 103 nt.assert_is_instance(v, string_types)
104 104
105 105 # shell replies
106 106
107 107 class ExecuteReply(Reference):
108 108 execution_count = Integer()
109 status = Enum((u'ok', u'error'))
109 status = Enum((u'ok', u'error'), default_value=u'ok')
110 110
111 111 def check(self, d):
112 112 Reference.check(self, d)
113 113 if d['status'] == 'ok':
114 114 ExecuteReplyOkay().check(d)
115 115 elif d['status'] == 'error':
116 116 ExecuteReplyError().check(d)
117 117
118 118
119 119 class ExecuteReplyOkay(Reference):
120 120 payload = List(Dict)
121 121 user_expressions = Dict()
122 122
123 123
124 124 class ExecuteReplyError(Reference):
125 125 ename = Unicode()
126 126 evalue = Unicode()
127 127 traceback = List(Unicode)
128 128
129 129
130 130 class InspectReply(MimeBundle):
131 131 found = Bool()
132 132
133 133
134 134 class ArgSpec(Reference):
135 135 args = List(Unicode)
136 136 varargs = Unicode()
137 137 varkw = Unicode()
138 138 defaults = List()
139 139
140 140
141 141 class Status(Reference):
142 execution_state = Enum((u'busy', u'idle', u'starting'))
142 execution_state = Enum((u'busy', u'idle', u'starting'), default_value=u'busy')
143 143
144 144
145 145 class CompleteReply(Reference):
146 146 matches = List(Unicode)
147 147 cursor_start = Integer()
148 148 cursor_end = Integer()
149 149 status = Unicode()
150 150
151 151 class LanguageInfo(Reference):
152 152 name = Unicode('python')
153 153 version = Unicode(sys.version.split()[0])
154 154
155 155 class KernelInfoReply(Reference):
156 156 protocol_version = Version(min='5.0')
157 157 implementation = Unicode('ipython')
158 158 implementation_version = Version(min='2.1')
159 159 language_info = Dict()
160 160 banner = Unicode()
161 161
162 162 def check(self, d):
163 163 Reference.check(self, d)
164 164 LanguageInfo().check(d['language_info'])
165 165
166 166
167 167 class IsCompleteReply(Reference):
168 status = Enum((u'complete', u'incomplete', u'invalid', u'unknown'))
168 status = Enum((u'complete', u'incomplete', u'invalid', u'unknown'), default_value=u'complete')
169 169
170 170 def check(self, d):
171 171 Reference.check(self, d)
172 172 if d['status'] == 'incomplete':
173 173 IsCompleteReplyIncomplete().check(d)
174 174
175 175 class IsCompleteReplyIncomplete(Reference):
176 176 indent = Unicode()
177 177
178 178
179 179 # IOPub messages
180 180
181 181 class ExecuteInput(Reference):
182 182 code = Unicode()
183 183 execution_count = Integer()
184 184
185 185
186 186 Error = ExecuteReplyError
187 187
188 188
189 189 class Stream(Reference):
190 name = Enum((u'stdout', u'stderr'))
190 name = Enum((u'stdout', u'stderr'), default_value=u'stdout')
191 191 text = Unicode()
192 192
193 193
194 194 class DisplayData(MimeBundle):
195 195 pass
196 196
197 197
198 198 class ExecuteResult(MimeBundle):
199 199 execution_count = Integer()
200 200
201 201 class HistoryReply(Reference):
202 202 history = List(List())
203 203
204 204
205 205 references = {
206 206 'execute_reply' : ExecuteReply(),
207 207 'inspect_reply' : InspectReply(),
208 208 'status' : Status(),
209 209 'complete_reply' : CompleteReply(),
210 210 'kernel_info_reply': KernelInfoReply(),
211 211 'is_complete_reply': IsCompleteReply(),
212 212 'execute_input' : ExecuteInput(),
213 213 'execute_result' : ExecuteResult(),
214 214 'history_reply' : HistoryReply(),
215 215 'error' : Error(),
216 216 'stream' : Stream(),
217 217 'display_data' : DisplayData(),
218 218 'header' : RHeader(),
219 219 }
220 220 """
221 221 Specifications of `content` part of the reply messages.
222 222 """
223 223
224 224
225 225 def validate_message(msg, msg_type=None, parent=None):
226 226 """validate a message
227 227
228 228 This is a generator, and must be iterated through to actually
229 229 trigger each test.
230 230
231 231 If msg_type and/or parent are given, the msg_type and/or parent msg_id
232 232 are compared with the given values.
233 233 """
234 234 RMessage().check(msg)
235 235 if msg_type:
236 236 nt.assert_equal(msg['msg_type'], msg_type)
237 237 if parent:
238 238 nt.assert_equal(msg['parent_header']['msg_id'], parent)
239 239 content = msg['content']
240 240 ref = references[msg['msg_type']]
241 241 ref.check(content)
242 242
243 243
244 244 #-----------------------------------------------------------------------------
245 245 # Tests
246 246 #-----------------------------------------------------------------------------
247 247
248 248 # Shell channel
249 249
250 250 def test_execute():
251 251 flush_channels()
252 252
253 253 msg_id = KC.execute(code='x=1')
254 254 reply = KC.get_shell_msg(timeout=TIMEOUT)
255 255 validate_message(reply, 'execute_reply', msg_id)
256 256
257 257
258 258 def test_execute_silent():
259 259 flush_channels()
260 260 msg_id, reply = execute(code='x=1', silent=True)
261 261
262 262 # flush status=idle
263 263 status = KC.iopub_channel.get_msg(timeout=TIMEOUT)
264 264 validate_message(status, 'status', msg_id)
265 265 nt.assert_equal(status['content']['execution_state'], 'idle')
266 266
267 267 nt.assert_raises(Empty, KC.iopub_channel.get_msg, timeout=0.1)
268 268 count = reply['execution_count']
269 269
270 270 msg_id, reply = execute(code='x=2', silent=True)
271 271
272 272 # flush status=idle
273 273 status = KC.iopub_channel.get_msg(timeout=TIMEOUT)
274 274 validate_message(status, 'status', msg_id)
275 275 nt.assert_equal(status['content']['execution_state'], 'idle')
276 276
277 277 nt.assert_raises(Empty, KC.iopub_channel.get_msg, timeout=0.1)
278 278 count_2 = reply['execution_count']
279 279 nt.assert_equal(count_2, count)
280 280
281 281
282 282 def test_execute_error():
283 283 flush_channels()
284 284
285 285 msg_id, reply = execute(code='1/0')
286 286 nt.assert_equal(reply['status'], 'error')
287 287 nt.assert_equal(reply['ename'], 'ZeroDivisionError')
288 288
289 289 error = KC.iopub_channel.get_msg(timeout=TIMEOUT)
290 290 validate_message(error, 'error', msg_id)
291 291
292 292
293 293 def test_execute_inc():
294 294 """execute request should increment execution_count"""
295 295 flush_channels()
296 296
297 297 msg_id, reply = execute(code='x=1')
298 298 count = reply['execution_count']
299 299
300 300 flush_channels()
301 301
302 302 msg_id, reply = execute(code='x=2')
303 303 count_2 = reply['execution_count']
304 304 nt.assert_equal(count_2, count+1)
305 305
306 306 def test_execute_stop_on_error():
307 307 """execute request should not abort execution queue with stop_on_error False"""
308 308 flush_channels()
309 309
310 310 fail = '\n'.join([
311 311 # sleep to ensure subsequent message is waiting in the queue to be aborted
312 312 'import time',
313 313 'time.sleep(0.5)',
314 314 'raise ValueError',
315 315 ])
316 316 KC.execute(code=fail)
317 317 msg_id = KC.execute(code='print("Hello")')
318 318 KC.get_shell_msg(timeout=TIMEOUT)
319 319 reply = KC.get_shell_msg(timeout=TIMEOUT)
320 320 nt.assert_equal(reply['content']['status'], 'aborted')
321 321
322 322 flush_channels()
323 323
324 324 KC.execute(code=fail, stop_on_error=False)
325 325 msg_id = KC.execute(code='print("Hello")')
326 326 KC.get_shell_msg(timeout=TIMEOUT)
327 327 reply = KC.get_shell_msg(timeout=TIMEOUT)
328 328 nt.assert_equal(reply['content']['status'], 'ok')
329 329
330 330
331 331 def test_user_expressions():
332 332 flush_channels()
333 333
334 334 msg_id, reply = execute(code='x=1', user_expressions=dict(foo='x+1'))
335 335 user_expressions = reply['user_expressions']
336 336 nt.assert_equal(user_expressions, {u'foo': {
337 337 u'status': u'ok',
338 338 u'data': {u'text/plain': u'2'},
339 339 u'metadata': {},
340 340 }})
341 341
342 342
343 343 def test_user_expressions_fail():
344 344 flush_channels()
345 345
346 346 msg_id, reply = execute(code='x=0', user_expressions=dict(foo='nosuchname'))
347 347 user_expressions = reply['user_expressions']
348 348 foo = user_expressions['foo']
349 349 nt.assert_equal(foo['status'], 'error')
350 350 nt.assert_equal(foo['ename'], 'NameError')
351 351
352 352
353 353 def test_oinfo():
354 354 flush_channels()
355 355
356 356 msg_id = KC.inspect('a')
357 357 reply = KC.get_shell_msg(timeout=TIMEOUT)
358 358 validate_message(reply, 'inspect_reply', msg_id)
359 359
360 360
361 361 def test_oinfo_found():
362 362 flush_channels()
363 363
364 364 msg_id, reply = execute(code='a=5')
365 365
366 366 msg_id = KC.inspect('a')
367 367 reply = KC.get_shell_msg(timeout=TIMEOUT)
368 368 validate_message(reply, 'inspect_reply', msg_id)
369 369 content = reply['content']
370 370 assert content['found']
371 371 text = content['data']['text/plain']
372 372 nt.assert_in('Type:', text)
373 373 nt.assert_in('Docstring:', text)
374 374
375 375
376 376 def test_oinfo_detail():
377 377 flush_channels()
378 378
379 379 msg_id, reply = execute(code='ip=get_ipython()')
380 380
381 381 msg_id = KC.inspect('ip.object_inspect', cursor_pos=10, detail_level=1)
382 382 reply = KC.get_shell_msg(timeout=TIMEOUT)
383 383 validate_message(reply, 'inspect_reply', msg_id)
384 384 content = reply['content']
385 385 assert content['found']
386 386 text = content['data']['text/plain']
387 387 nt.assert_in('Signature:', text)
388 388 nt.assert_in('Source:', text)
389 389
390 390
391 391 def test_oinfo_not_found():
392 392 flush_channels()
393 393
394 394 msg_id = KC.inspect('dne')
395 395 reply = KC.get_shell_msg(timeout=TIMEOUT)
396 396 validate_message(reply, 'inspect_reply', msg_id)
397 397 content = reply['content']
398 398 nt.assert_false(content['found'])
399 399
400 400
401 401 def test_complete():
402 402 flush_channels()
403 403
404 404 msg_id, reply = execute(code="alpha = albert = 5")
405 405
406 406 msg_id = KC.complete('al', 2)
407 407 reply = KC.get_shell_msg(timeout=TIMEOUT)
408 408 validate_message(reply, 'complete_reply', msg_id)
409 409 matches = reply['content']['matches']
410 410 for name in ('alpha', 'albert'):
411 411 nt.assert_in(name, matches)
412 412
413 413
414 414 def test_kernel_info_request():
415 415 flush_channels()
416 416
417 417 msg_id = KC.kernel_info()
418 418 reply = KC.get_shell_msg(timeout=TIMEOUT)
419 419 validate_message(reply, 'kernel_info_reply', msg_id)
420 420
421 421
422 422 def test_single_payload():
423 423 flush_channels()
424 424 msg_id, reply = execute(code="for i in range(3):\n"+
425 425 " x=range?\n")
426 426 payload = reply['payload']
427 427 next_input_pls = [pl for pl in payload if pl["source"] == "set_next_input"]
428 428 nt.assert_equal(len(next_input_pls), 1)
429 429
430 430 def test_is_complete():
431 431 flush_channels()
432 432
433 433 msg_id = KC.is_complete("a = 1")
434 434 reply = KC.get_shell_msg(timeout=TIMEOUT)
435 435 validate_message(reply, 'is_complete_reply', msg_id)
436 436
437 437 def test_history_range():
438 438 flush_channels()
439 439
440 440 msg_id_exec = KC.execute(code='x=1', store_history = True)
441 441 reply_exec = KC.get_shell_msg(timeout=TIMEOUT)
442 442
443 443 msg_id = KC.history(hist_access_type = 'range', raw = True, output = True, start = 1, stop = 2, session = 0)
444 444 reply = KC.get_shell_msg(timeout=TIMEOUT)
445 445 validate_message(reply, 'history_reply', msg_id)
446 446 content = reply['content']
447 447 nt.assert_equal(len(content['history']), 1)
448 448
449 449 def test_history_tail():
450 450 flush_channels()
451 451
452 452 msg_id_exec = KC.execute(code='x=1', store_history = True)
453 453 reply_exec = KC.get_shell_msg(timeout=TIMEOUT)
454 454
455 455 msg_id = KC.history(hist_access_type = 'tail', raw = True, output = True, n = 1, session = 0)
456 456 reply = KC.get_shell_msg(timeout=TIMEOUT)
457 457 validate_message(reply, 'history_reply', msg_id)
458 458 content = reply['content']
459 459 nt.assert_equal(len(content['history']), 1)
460 460
461 461 def test_history_search():
462 462 flush_channels()
463 463
464 464 msg_id_exec = KC.execute(code='x=1', store_history = True)
465 465 reply_exec = KC.get_shell_msg(timeout=TIMEOUT)
466 466
467 467 msg_id = KC.history(hist_access_type = 'search', raw = True, output = True, n = 1, pattern = '*', session = 0)
468 468 reply = KC.get_shell_msg(timeout=TIMEOUT)
469 469 validate_message(reply, 'history_reply', msg_id)
470 470 content = reply['content']
471 471 nt.assert_equal(len(content['history']), 1)
472 472
473 473 # IOPub channel
474 474
475 475
476 476 def test_stream():
477 477 flush_channels()
478 478
479 479 msg_id, reply = execute("print('hi')")
480 480
481 481 stdout = KC.iopub_channel.get_msg(timeout=TIMEOUT)
482 482 validate_message(stdout, 'stream', msg_id)
483 483 content = stdout['content']
484 484 nt.assert_equal(content['text'], u'hi\n')
485 485
486 486
487 487 def test_display_data():
488 488 flush_channels()
489 489
490 490 msg_id, reply = execute("from IPython.core.display import display; display(1)")
491 491
492 492 display = KC.iopub_channel.get_msg(timeout=TIMEOUT)
493 493 validate_message(display, 'display_data', parent=msg_id)
494 494 data = display['content']['data']
495 495 nt.assert_equal(data['text/plain'], u'1')
496 496
@@ -1,849 +1,849 b''
1 1 """The Python scheduler for rich scheduling.
2 2
3 3 The Pure ZMQ scheduler does not allow routing schemes other than LRU,
4 4 nor does it check msg_id DAG dependencies. For those, a slightly slower
5 5 Python Scheduler exists.
6 6 """
7 7
8 8 # Copyright (c) IPython Development Team.
9 9 # Distributed under the terms of the Modified BSD License.
10 10
11 11 import logging
12 12 import sys
13 13 import time
14 14
15 15 from collections import deque
16 16 from datetime import datetime
17 17 from random import randint, random
18 18 from types import FunctionType
19 19
20 20 try:
21 21 import numpy
22 22 except ImportError:
23 23 numpy = None
24 24
25 25 import zmq
26 26 from zmq.eventloop import ioloop, zmqstream
27 27
28 28 # local imports
29 29 from IPython.external.decorator import decorator
30 30 from IPython.config.application import Application
31 31 from IPython.config.loader import Config
32 32 from IPython.utils.traitlets import Instance, Dict, List, Set, Integer, Enum, CBytes
33 33 from IPython.utils.py3compat import cast_bytes
34 34
35 35 from IPython.parallel import error, util
36 36 from IPython.parallel.factory import SessionFactory
37 37 from IPython.parallel.util import connect_logger, local_logger
38 38
39 39 from .dependency import Dependency
40 40
41 41 @decorator
42 42 def logged(f,self,*args,**kwargs):
43 43 # print ("#--------------------")
44 44 self.log.debug("scheduler::%s(*%s,**%s)", f.__name__, args, kwargs)
45 45 # print ("#--")
46 46 return f(self,*args, **kwargs)
47 47
48 48 #----------------------------------------------------------------------
49 49 # Chooser functions
50 50 #----------------------------------------------------------------------
51 51
52 52 def plainrandom(loads):
53 53 """Plain random pick."""
54 54 n = len(loads)
55 55 return randint(0,n-1)
56 56
57 57 def lru(loads):
58 58 """Always pick the front of the line.
59 59
60 60 The content of `loads` is ignored.
61 61
62 62 Assumes LRU ordering of loads, with oldest first.
63 63 """
64 64 return 0
65 65
66 66 def twobin(loads):
67 67 """Pick two at random, use the LRU of the two.
68 68
69 69 The content of loads is ignored.
70 70
71 71 Assumes LRU ordering of loads, with oldest first.
72 72 """
73 73 n = len(loads)
74 74 a = randint(0,n-1)
75 75 b = randint(0,n-1)
76 76 return min(a,b)
77 77
78 78 def weighted(loads):
79 79 """Pick two at random using inverse load as weight.
80 80
81 81 Return the less loaded of the two.
82 82 """
83 83 # weight 0 a million times more than 1:
84 84 weights = 1./(1e-6+numpy.array(loads))
85 85 sums = weights.cumsum()
86 86 t = sums[-1]
87 87 x = random()*t
88 88 y = random()*t
89 89 idx = 0
90 90 idy = 0
91 91 while sums[idx] < x:
92 92 idx += 1
93 93 while sums[idy] < y:
94 94 idy += 1
95 95 if weights[idy] > weights[idx]:
96 96 return idy
97 97 else:
98 98 return idx
99 99
100 100 def leastload(loads):
101 101 """Always choose the lowest load.
102 102
103 103 If the lowest load occurs more than once, the first
104 104 occurance will be used. If loads has LRU ordering, this means
105 105 the LRU of those with the lowest load is chosen.
106 106 """
107 107 return loads.index(min(loads))
108 108
109 109 #---------------------------------------------------------------------
110 110 # Classes
111 111 #---------------------------------------------------------------------
112 112
113 113
114 114 # store empty default dependency:
115 115 MET = Dependency([])
116 116
117 117
118 118 class Job(object):
119 119 """Simple container for a job"""
120 120 def __init__(self, msg_id, raw_msg, idents, msg, header, metadata,
121 121 targets, after, follow, timeout):
122 122 self.msg_id = msg_id
123 123 self.raw_msg = raw_msg
124 124 self.idents = idents
125 125 self.msg = msg
126 126 self.header = header
127 127 self.metadata = metadata
128 128 self.targets = targets
129 129 self.after = after
130 130 self.follow = follow
131 131 self.timeout = timeout
132 132
133 133 self.removed = False # used for lazy-delete from sorted queue
134 134 self.timestamp = time.time()
135 135 self.timeout_id = 0
136 136 self.blacklist = set()
137 137
138 138 def __lt__(self, other):
139 139 return self.timestamp < other.timestamp
140 140
141 141 def __cmp__(self, other):
142 142 return cmp(self.timestamp, other.timestamp)
143 143
144 144 @property
145 145 def dependents(self):
146 146 return self.follow.union(self.after)
147 147
148 148
149 149 class TaskScheduler(SessionFactory):
150 150 """Python TaskScheduler object.
151 151
152 152 This is the simplest object that supports msg_id based
153 153 DAG dependencies. *Only* task msg_ids are checked, not
154 154 msg_ids of jobs submitted via the MUX queue.
155 155
156 156 """
157 157
158 158 hwm = Integer(1, config=True,
159 159 help="""specify the High Water Mark (HWM) for the downstream
160 160 socket in the Task scheduler. This is the maximum number
161 161 of allowed outstanding tasks on each engine.
162 162
163 163 The default (1) means that only one task can be outstanding on each
164 164 engine. Setting TaskScheduler.hwm=0 means there is no limit, and the
165 165 engines continue to be assigned tasks while they are working,
166 166 effectively hiding network latency behind computation, but can result
167 167 in an imbalance of work when submitting many heterogenous tasks all at
168 168 once. Any positive value greater than one is a compromise between the
169 169 two.
170 170
171 171 """
172 172 )
173 173 scheme_name = Enum(('leastload', 'pure', 'lru', 'plainrandom', 'weighted', 'twobin'),
174 'leastload', config=True, allow_none=False,
174 'leastload', config=True,
175 175 help="""select the task scheduler scheme [default: Python LRU]
176 176 Options are: 'pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'"""
177 177 )
178 178 def _scheme_name_changed(self, old, new):
179 179 self.log.debug("Using scheme %r"%new)
180 180 self.scheme = globals()[new]
181 181
182 182 # input arguments:
183 183 scheme = Instance(FunctionType) # function for determining the destination
184 184 def _scheme_default(self):
185 185 return leastload
186 186 client_stream = Instance(zmqstream.ZMQStream) # client-facing stream
187 187 engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream
188 188 notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream
189 189 mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream
190 190 query_stream = Instance(zmqstream.ZMQStream) # hub-facing DEALER stream
191 191
192 192 # internals:
193 193 queue = Instance(deque) # sorted list of Jobs
194 194 def _queue_default(self):
195 195 return deque()
196 196 queue_map = Dict() # dict by msg_id of Jobs (for O(1) access to the Queue)
197 197 graph = Dict() # dict by msg_id of [ msg_ids that depend on key ]
198 198 retries = Dict() # dict by msg_id of retries remaining (non-neg ints)
199 199 # waiting = List() # list of msg_ids ready to run, but haven't due to HWM
200 200 pending = Dict() # dict by engine_uuid of submitted tasks
201 201 completed = Dict() # dict by engine_uuid of completed tasks
202 202 failed = Dict() # dict by engine_uuid of failed tasks
203 203 destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
204 204 clients = Dict() # dict by msg_id for who submitted the task
205 205 targets = List() # list of target IDENTs
206 206 loads = List() # list of engine loads
207 207 # full = Set() # set of IDENTs that have HWM outstanding tasks
208 208 all_completed = Set() # set of all completed tasks
209 209 all_failed = Set() # set of all failed tasks
210 210 all_done = Set() # set of all finished tasks=union(completed,failed)
211 211 all_ids = Set() # set of all submitted task IDs
212 212
213 213 ident = CBytes() # ZMQ identity. This should just be self.session.session
214 214 # but ensure Bytes
215 215 def _ident_default(self):
216 216 return self.session.bsession
217 217
218 218 def start(self):
219 219 self.query_stream.on_recv(self.dispatch_query_reply)
220 220 self.session.send(self.query_stream, "connection_request", {})
221 221
222 222 self.engine_stream.on_recv(self.dispatch_result, copy=False)
223 223 self.client_stream.on_recv(self.dispatch_submission, copy=False)
224 224
225 225 self._notification_handlers = dict(
226 226 registration_notification = self._register_engine,
227 227 unregistration_notification = self._unregister_engine
228 228 )
229 229 self.notifier_stream.on_recv(self.dispatch_notification)
230 230 self.log.info("Scheduler started [%s]" % self.scheme_name)
231 231
232 232 def resume_receiving(self):
233 233 """Resume accepting jobs."""
234 234 self.client_stream.on_recv(self.dispatch_submission, copy=False)
235 235
236 236 def stop_receiving(self):
237 237 """Stop accepting jobs while there are no engines.
238 238 Leave them in the ZMQ queue."""
239 239 self.client_stream.on_recv(None)
240 240
241 241 #-----------------------------------------------------------------------
242 242 # [Un]Registration Handling
243 243 #-----------------------------------------------------------------------
244 244
245 245
246 246 def dispatch_query_reply(self, msg):
247 247 """handle reply to our initial connection request"""
248 248 try:
249 249 idents,msg = self.session.feed_identities(msg)
250 250 except ValueError:
251 251 self.log.warn("task::Invalid Message: %r",msg)
252 252 return
253 253 try:
254 254 msg = self.session.deserialize(msg)
255 255 except ValueError:
256 256 self.log.warn("task::Unauthorized message from: %r"%idents)
257 257 return
258 258
259 259 content = msg['content']
260 260 for uuid in content.get('engines', {}).values():
261 261 self._register_engine(cast_bytes(uuid))
262 262
263 263
264 264 @util.log_errors
265 265 def dispatch_notification(self, msg):
266 266 """dispatch register/unregister events."""
267 267 try:
268 268 idents,msg = self.session.feed_identities(msg)
269 269 except ValueError:
270 270 self.log.warn("task::Invalid Message: %r",msg)
271 271 return
272 272 try:
273 273 msg = self.session.deserialize(msg)
274 274 except ValueError:
275 275 self.log.warn("task::Unauthorized message from: %r"%idents)
276 276 return
277 277
278 278 msg_type = msg['header']['msg_type']
279 279
280 280 handler = self._notification_handlers.get(msg_type, None)
281 281 if handler is None:
282 282 self.log.error("Unhandled message type: %r"%msg_type)
283 283 else:
284 284 try:
285 285 handler(cast_bytes(msg['content']['uuid']))
286 286 except Exception:
287 287 self.log.error("task::Invalid notification msg: %r", msg, exc_info=True)
288 288
289 289 def _register_engine(self, uid):
290 290 """New engine with ident `uid` became available."""
291 291 # head of the line:
292 292 self.targets.insert(0,uid)
293 293 self.loads.insert(0,0)
294 294
295 295 # initialize sets
296 296 self.completed[uid] = set()
297 297 self.failed[uid] = set()
298 298 self.pending[uid] = {}
299 299
300 300 # rescan the graph:
301 301 self.update_graph(None)
302 302
303 303 def _unregister_engine(self, uid):
304 304 """Existing engine with ident `uid` became unavailable."""
305 305 if len(self.targets) == 1:
306 306 # this was our only engine
307 307 pass
308 308
309 309 # handle any potentially finished tasks:
310 310 self.engine_stream.flush()
311 311
312 312 # don't pop destinations, because they might be used later
313 313 # map(self.destinations.pop, self.completed.pop(uid))
314 314 # map(self.destinations.pop, self.failed.pop(uid))
315 315
316 316 # prevent this engine from receiving work
317 317 idx = self.targets.index(uid)
318 318 self.targets.pop(idx)
319 319 self.loads.pop(idx)
320 320
321 321 # wait 5 seconds before cleaning up pending jobs, since the results might
322 322 # still be incoming
323 323 if self.pending[uid]:
324 324 self.loop.add_timeout(self.loop.time() + 5,
325 325 lambda : self.handle_stranded_tasks(uid),
326 326 )
327 327 else:
328 328 self.completed.pop(uid)
329 329 self.failed.pop(uid)
330 330
331 331
332 332 def handle_stranded_tasks(self, engine):
333 333 """Deal with jobs resident in an engine that died."""
334 334 lost = self.pending[engine]
335 335 for msg_id in lost.keys():
336 336 if msg_id not in self.pending[engine]:
337 337 # prevent double-handling of messages
338 338 continue
339 339
340 340 raw_msg = lost[msg_id].raw_msg
341 341 idents,msg = self.session.feed_identities(raw_msg, copy=False)
342 342 parent = self.session.unpack(msg[1].bytes)
343 343 idents = [engine, idents[0]]
344 344
345 345 # build fake error reply
346 346 try:
347 347 raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id))
348 348 except:
349 349 content = error.wrap_exception()
350 350 # build fake metadata
351 351 md = dict(
352 352 status=u'error',
353 353 engine=engine.decode('ascii'),
354 354 date=datetime.now(),
355 355 )
356 356 msg = self.session.msg('apply_reply', content, parent=parent, metadata=md)
357 357 raw_reply = list(map(zmq.Message, self.session.serialize(msg, ident=idents)))
358 358 # and dispatch it
359 359 self.dispatch_result(raw_reply)
360 360
361 361 # finally scrub completed/failed lists
362 362 self.completed.pop(engine)
363 363 self.failed.pop(engine)
364 364
365 365
366 366 #-----------------------------------------------------------------------
367 367 # Job Submission
368 368 #-----------------------------------------------------------------------
369 369
370 370
371 371 @util.log_errors
372 372 def dispatch_submission(self, raw_msg):
373 373 """Dispatch job submission to appropriate handlers."""
374 374 # ensure targets up to date:
375 375 self.notifier_stream.flush()
376 376 try:
377 377 idents, msg = self.session.feed_identities(raw_msg, copy=False)
378 378 msg = self.session.deserialize(msg, content=False, copy=False)
379 379 except Exception:
380 380 self.log.error("task::Invaid task msg: %r"%raw_msg, exc_info=True)
381 381 return
382 382
383 383
384 384 # send to monitor
385 385 self.mon_stream.send_multipart([b'intask']+raw_msg, copy=False)
386 386
387 387 header = msg['header']
388 388 md = msg['metadata']
389 389 msg_id = header['msg_id']
390 390 self.all_ids.add(msg_id)
391 391
392 392 # get targets as a set of bytes objects
393 393 # from a list of unicode objects
394 394 targets = md.get('targets', [])
395 395 targets = set(map(cast_bytes, targets))
396 396
397 397 retries = md.get('retries', 0)
398 398 self.retries[msg_id] = retries
399 399
400 400 # time dependencies
401 401 after = md.get('after', None)
402 402 if after:
403 403 after = Dependency(after)
404 404 if after.all:
405 405 if after.success:
406 406 after = Dependency(after.difference(self.all_completed),
407 407 success=after.success,
408 408 failure=after.failure,
409 409 all=after.all,
410 410 )
411 411 if after.failure:
412 412 after = Dependency(after.difference(self.all_failed),
413 413 success=after.success,
414 414 failure=after.failure,
415 415 all=after.all,
416 416 )
417 417 if after.check(self.all_completed, self.all_failed):
418 418 # recast as empty set, if `after` already met,
419 419 # to prevent unnecessary set comparisons
420 420 after = MET
421 421 else:
422 422 after = MET
423 423
424 424 # location dependencies
425 425 follow = Dependency(md.get('follow', []))
426 426
427 427 timeout = md.get('timeout', None)
428 428 if timeout:
429 429 timeout = float(timeout)
430 430
431 431 job = Job(msg_id=msg_id, raw_msg=raw_msg, idents=idents, msg=msg,
432 432 header=header, targets=targets, after=after, follow=follow,
433 433 timeout=timeout, metadata=md,
434 434 )
435 435 # validate and reduce dependencies:
436 436 for dep in after,follow:
437 437 if not dep: # empty dependency
438 438 continue
439 439 # check valid:
440 440 if msg_id in dep or dep.difference(self.all_ids):
441 441 self.queue_map[msg_id] = job
442 442 return self.fail_unreachable(msg_id, error.InvalidDependency)
443 443 # check if unreachable:
444 444 if dep.unreachable(self.all_completed, self.all_failed):
445 445 self.queue_map[msg_id] = job
446 446 return self.fail_unreachable(msg_id)
447 447
448 448 if after.check(self.all_completed, self.all_failed):
449 449 # time deps already met, try to run
450 450 if not self.maybe_run(job):
451 451 # can't run yet
452 452 if msg_id not in self.all_failed:
453 453 # could have failed as unreachable
454 454 self.save_unmet(job)
455 455 else:
456 456 self.save_unmet(job)
457 457
458 458 def job_timeout(self, job, timeout_id):
459 459 """callback for a job's timeout.
460 460
461 461 The job may or may not have been run at this point.
462 462 """
463 463 if job.timeout_id != timeout_id:
464 464 # not the most recent call
465 465 return
466 466 now = time.time()
467 467 if job.timeout >= (now + 1):
468 468 self.log.warn("task %s timeout fired prematurely: %s > %s",
469 469 job.msg_id, job.timeout, now
470 470 )
471 471 if job.msg_id in self.queue_map:
472 472 # still waiting, but ran out of time
473 473 self.log.info("task %r timed out", job.msg_id)
474 474 self.fail_unreachable(job.msg_id, error.TaskTimeout)
475 475
476 476 def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
477 477 """a task has become unreachable, send a reply with an ImpossibleDependency
478 478 error."""
479 479 if msg_id not in self.queue_map:
480 480 self.log.error("task %r already failed!", msg_id)
481 481 return
482 482 job = self.queue_map.pop(msg_id)
483 483 # lazy-delete from the queue
484 484 job.removed = True
485 485 for mid in job.dependents:
486 486 if mid in self.graph:
487 487 self.graph[mid].remove(msg_id)
488 488
489 489 try:
490 490 raise why()
491 491 except:
492 492 content = error.wrap_exception()
493 493 self.log.debug("task %r failing as unreachable with: %s", msg_id, content['ename'])
494 494
495 495 self.all_done.add(msg_id)
496 496 self.all_failed.add(msg_id)
497 497
498 498 msg = self.session.send(self.client_stream, 'apply_reply', content,
499 499 parent=job.header, ident=job.idents)
500 500 self.session.send(self.mon_stream, msg, ident=[b'outtask']+job.idents)
501 501
502 502 self.update_graph(msg_id, success=False)
503 503
504 504 def available_engines(self):
505 505 """return a list of available engine indices based on HWM"""
506 506 if not self.hwm:
507 507 return list(range(len(self.targets)))
508 508 available = []
509 509 for idx in range(len(self.targets)):
510 510 if self.loads[idx] < self.hwm:
511 511 available.append(idx)
512 512 return available
513 513
514 514 def maybe_run(self, job):
515 515 """check location dependencies, and run if they are met."""
516 516 msg_id = job.msg_id
517 517 self.log.debug("Attempting to assign task %s", msg_id)
518 518 available = self.available_engines()
519 519 if not available:
520 520 # no engines, definitely can't run
521 521 return False
522 522
523 523 if job.follow or job.targets or job.blacklist or self.hwm:
524 524 # we need a can_run filter
525 525 def can_run(idx):
526 526 # check hwm
527 527 if self.hwm and self.loads[idx] == self.hwm:
528 528 return False
529 529 target = self.targets[idx]
530 530 # check blacklist
531 531 if target in job.blacklist:
532 532 return False
533 533 # check targets
534 534 if job.targets and target not in job.targets:
535 535 return False
536 536 # check follow
537 537 return job.follow.check(self.completed[target], self.failed[target])
538 538
539 539 indices = list(filter(can_run, available))
540 540
541 541 if not indices:
542 542 # couldn't run
543 543 if job.follow.all:
544 544 # check follow for impossibility
545 545 dests = set()
546 546 relevant = set()
547 547 if job.follow.success:
548 548 relevant = self.all_completed
549 549 if job.follow.failure:
550 550 relevant = relevant.union(self.all_failed)
551 551 for m in job.follow.intersection(relevant):
552 552 dests.add(self.destinations[m])
553 553 if len(dests) > 1:
554 554 self.queue_map[msg_id] = job
555 555 self.fail_unreachable(msg_id)
556 556 return False
557 557 if job.targets:
558 558 # check blacklist+targets for impossibility
559 559 job.targets.difference_update(job.blacklist)
560 560 if not job.targets or not job.targets.intersection(self.targets):
561 561 self.queue_map[msg_id] = job
562 562 self.fail_unreachable(msg_id)
563 563 return False
564 564 return False
565 565 else:
566 566 indices = None
567 567
568 568 self.submit_task(job, indices)
569 569 return True
570 570
571 571 def save_unmet(self, job):
572 572 """Save a message for later submission when its dependencies are met."""
573 573 msg_id = job.msg_id
574 574 self.log.debug("Adding task %s to the queue", msg_id)
575 575 self.queue_map[msg_id] = job
576 576 self.queue.append(job)
577 577 # track the ids in follow or after, but not those already finished
578 578 for dep_id in job.after.union(job.follow).difference(self.all_done):
579 579 if dep_id not in self.graph:
580 580 self.graph[dep_id] = set()
581 581 self.graph[dep_id].add(msg_id)
582 582
583 583 # schedule timeout callback
584 584 if job.timeout:
585 585 timeout_id = job.timeout_id = job.timeout_id + 1
586 586 self.loop.add_timeout(time.time() + job.timeout,
587 587 lambda : self.job_timeout(job, timeout_id)
588 588 )
589 589
590 590
591 591 def submit_task(self, job, indices=None):
592 592 """Submit a task to any of a subset of our targets."""
593 593 if indices:
594 594 loads = [self.loads[i] for i in indices]
595 595 else:
596 596 loads = self.loads
597 597 idx = self.scheme(loads)
598 598 if indices:
599 599 idx = indices[idx]
600 600 target = self.targets[idx]
601 601 # print (target, map(str, msg[:3]))
602 602 # send job to the engine
603 603 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
604 604 self.engine_stream.send_multipart(job.raw_msg, copy=False)
605 605 # update load
606 606 self.add_job(idx)
607 607 self.pending[target][job.msg_id] = job
608 608 # notify Hub
609 609 content = dict(msg_id=job.msg_id, engine_id=target.decode('ascii'))
610 610 self.session.send(self.mon_stream, 'task_destination', content=content,
611 611 ident=[b'tracktask',self.ident])
612 612
613 613
614 614 #-----------------------------------------------------------------------
615 615 # Result Handling
616 616 #-----------------------------------------------------------------------
617 617
618 618
619 619 @util.log_errors
620 620 def dispatch_result(self, raw_msg):
621 621 """dispatch method for result replies"""
622 622 try:
623 623 idents,msg = self.session.feed_identities(raw_msg, copy=False)
624 624 msg = self.session.deserialize(msg, content=False, copy=False)
625 625 engine = idents[0]
626 626 try:
627 627 idx = self.targets.index(engine)
628 628 except ValueError:
629 629 pass # skip load-update for dead engines
630 630 else:
631 631 self.finish_job(idx)
632 632 except Exception:
633 633 self.log.error("task::Invalid result: %r", raw_msg, exc_info=True)
634 634 return
635 635
636 636 md = msg['metadata']
637 637 parent = msg['parent_header']
638 638 if md.get('dependencies_met', True):
639 639 success = (md['status'] == 'ok')
640 640 msg_id = parent['msg_id']
641 641 retries = self.retries[msg_id]
642 642 if not success and retries > 0:
643 643 # failed
644 644 self.retries[msg_id] = retries - 1
645 645 self.handle_unmet_dependency(idents, parent)
646 646 else:
647 647 del self.retries[msg_id]
648 648 # relay to client and update graph
649 649 self.handle_result(idents, parent, raw_msg, success)
650 650 # send to Hub monitor
651 651 self.mon_stream.send_multipart([b'outtask']+raw_msg, copy=False)
652 652 else:
653 653 self.handle_unmet_dependency(idents, parent)
654 654
655 655 def handle_result(self, idents, parent, raw_msg, success=True):
656 656 """handle a real task result, either success or failure"""
657 657 # first, relay result to client
658 658 engine = idents[0]
659 659 client = idents[1]
660 660 # swap_ids for ROUTER-ROUTER mirror
661 661 raw_msg[:2] = [client,engine]
662 662 # print (map(str, raw_msg[:4]))
663 663 self.client_stream.send_multipart(raw_msg, copy=False)
664 664 # now, update our data structures
665 665 msg_id = parent['msg_id']
666 666 self.pending[engine].pop(msg_id)
667 667 if success:
668 668 self.completed[engine].add(msg_id)
669 669 self.all_completed.add(msg_id)
670 670 else:
671 671 self.failed[engine].add(msg_id)
672 672 self.all_failed.add(msg_id)
673 673 self.all_done.add(msg_id)
674 674 self.destinations[msg_id] = engine
675 675
676 676 self.update_graph(msg_id, success)
677 677
678 678 def handle_unmet_dependency(self, idents, parent):
679 679 """handle an unmet dependency"""
680 680 engine = idents[0]
681 681 msg_id = parent['msg_id']
682 682
683 683 job = self.pending[engine].pop(msg_id)
684 684 job.blacklist.add(engine)
685 685
686 686 if job.blacklist == job.targets:
687 687 self.queue_map[msg_id] = job
688 688 self.fail_unreachable(msg_id)
689 689 elif not self.maybe_run(job):
690 690 # resubmit failed
691 691 if msg_id not in self.all_failed:
692 692 # put it back in our dependency tree
693 693 self.save_unmet(job)
694 694
695 695 if self.hwm:
696 696 try:
697 697 idx = self.targets.index(engine)
698 698 except ValueError:
699 699 pass # skip load-update for dead engines
700 700 else:
701 701 if self.loads[idx] == self.hwm-1:
702 702 self.update_graph(None)
703 703
704 704 def update_graph(self, dep_id=None, success=True):
705 705 """dep_id just finished. Update our dependency
706 706 graph and submit any jobs that just became runnable.
707 707
708 708 Called with dep_id=None to update entire graph for hwm, but without finishing a task.
709 709 """
710 710 # print ("\n\n***********")
711 711 # pprint (dep_id)
712 712 # pprint (self.graph)
713 713 # pprint (self.queue_map)
714 714 # pprint (self.all_completed)
715 715 # pprint (self.all_failed)
716 716 # print ("\n\n***********\n\n")
717 717 # update any jobs that depended on the dependency
718 718 msg_ids = self.graph.pop(dep_id, [])
719 719
720 720 # recheck *all* jobs if
721 721 # a) we have HWM and an engine just become no longer full
722 722 # or b) dep_id was given as None
723 723
724 724 if dep_id is None or self.hwm and any( [ load==self.hwm-1 for load in self.loads ]):
725 725 jobs = self.queue
726 726 using_queue = True
727 727 else:
728 728 using_queue = False
729 729 jobs = deque(sorted( self.queue_map[msg_id] for msg_id in msg_ids ))
730 730
731 731 to_restore = []
732 732 while jobs:
733 733 job = jobs.popleft()
734 734 if job.removed:
735 735 continue
736 736 msg_id = job.msg_id
737 737
738 738 put_it_back = True
739 739
740 740 if job.after.unreachable(self.all_completed, self.all_failed)\
741 741 or job.follow.unreachable(self.all_completed, self.all_failed):
742 742 self.fail_unreachable(msg_id)
743 743 put_it_back = False
744 744
745 745 elif job.after.check(self.all_completed, self.all_failed): # time deps met, maybe run
746 746 if self.maybe_run(job):
747 747 put_it_back = False
748 748 self.queue_map.pop(msg_id)
749 749 for mid in job.dependents:
750 750 if mid in self.graph:
751 751 self.graph[mid].remove(msg_id)
752 752
753 753 # abort the loop if we just filled up all of our engines.
754 754 # avoids an O(N) operation in situation of full queue,
755 755 # where graph update is triggered as soon as an engine becomes
756 756 # non-full, and all tasks after the first are checked,
757 757 # even though they can't run.
758 758 if not self.available_engines():
759 759 break
760 760
761 761 if using_queue and put_it_back:
762 762 # popped a job from the queue but it neither ran nor failed,
763 763 # so we need to put it back when we are done
764 764 # make sure to_restore preserves the same ordering
765 765 to_restore.append(job)
766 766
767 767 # put back any tasks we popped but didn't run
768 768 if using_queue:
769 769 self.queue.extendleft(to_restore)
770 770
771 771 #----------------------------------------------------------------------
772 772 # methods to be overridden by subclasses
773 773 #----------------------------------------------------------------------
774 774
775 775 def add_job(self, idx):
776 776 """Called after self.targets[idx] just got the job with header.
777 777 Override with subclasses. The default ordering is simple LRU.
778 778 The default loads are the number of outstanding jobs."""
779 779 self.loads[idx] += 1
780 780 for lis in (self.targets, self.loads):
781 781 lis.append(lis.pop(idx))
782 782
783 783
784 784 def finish_job(self, idx):
785 785 """Called after self.targets[idx] just finished a job.
786 786 Override with subclasses."""
787 787 self.loads[idx] -= 1
788 788
789 789
790 790
791 791 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, reg_addr, config=None,
792 792 logname='root', log_url=None, loglevel=logging.DEBUG,
793 793 identity=b'task', in_thread=False):
794 794
795 795 ZMQStream = zmqstream.ZMQStream
796 796
797 797 if config:
798 798 # unwrap dict back into Config
799 799 config = Config(config)
800 800
801 801 if in_thread:
802 802 # use instance() to get the same Context/Loop as our parent
803 803 ctx = zmq.Context.instance()
804 804 loop = ioloop.IOLoop.instance()
805 805 else:
806 806 # in a process, don't use instance()
807 807 # for safety with multiprocessing
808 808 ctx = zmq.Context()
809 809 loop = ioloop.IOLoop()
810 810 ins = ZMQStream(ctx.socket(zmq.ROUTER),loop)
811 811 util.set_hwm(ins, 0)
812 812 ins.setsockopt(zmq.IDENTITY, identity + b'_in')
813 813 ins.bind(in_addr)
814 814
815 815 outs = ZMQStream(ctx.socket(zmq.ROUTER),loop)
816 816 util.set_hwm(outs, 0)
817 817 outs.setsockopt(zmq.IDENTITY, identity + b'_out')
818 818 outs.bind(out_addr)
819 819 mons = zmqstream.ZMQStream(ctx.socket(zmq.PUB),loop)
820 820 util.set_hwm(mons, 0)
821 821 mons.connect(mon_addr)
822 822 nots = zmqstream.ZMQStream(ctx.socket(zmq.SUB),loop)
823 823 nots.setsockopt(zmq.SUBSCRIBE, b'')
824 824 nots.connect(not_addr)
825 825
826 826 querys = ZMQStream(ctx.socket(zmq.DEALER),loop)
827 827 querys.connect(reg_addr)
828 828
829 829 # setup logging.
830 830 if in_thread:
831 831 log = Application.instance().log
832 832 else:
833 833 if log_url:
834 834 log = connect_logger(logname, ctx, log_url, root="scheduler", loglevel=loglevel)
835 835 else:
836 836 log = local_logger(logname, loglevel)
837 837
838 838 scheduler = TaskScheduler(client_stream=ins, engine_stream=outs,
839 839 mon_stream=mons, notifier_stream=nots,
840 840 query_stream=querys,
841 841 loop=loop, log=log,
842 842 config=config)
843 843 scheduler.start()
844 844 if not in_thread:
845 845 try:
846 846 loop.start()
847 847 except KeyboardInterrupt:
848 848 scheduler.log.critical("Interrupted, exiting...")
849 849
@@ -1,578 +1,578 b''
1 1 # -*- coding: utf-8 -*-
2 2 """terminal client to the IPython kernel"""
3 3
4 4 # Copyright (c) IPython Development Team.
5 5 # Distributed under the terms of the Modified BSD License.
6 6
7 7 from __future__ import print_function
8 8
9 9 import base64
10 10 import bdb
11 11 import signal
12 12 import os
13 13 import sys
14 14 import time
15 15 import subprocess
16 16 from getpass import getpass
17 17 from io import BytesIO
18 18
19 19 try:
20 20 from queue import Empty # Py 3
21 21 except ImportError:
22 22 from Queue import Empty # Py 2
23 23
24 24 from IPython.core import page
25 25 from IPython.core import release
26 26 from IPython.terminal.console.zmqhistory import ZMQHistoryManager
27 27 from IPython.utils.warn import warn, error
28 28 from IPython.utils import io
29 29 from IPython.utils.py3compat import string_types, input
30 30 from IPython.utils.traitlets import List, Enum, Any, Instance, Unicode, Float, Bool
31 31 from IPython.utils.tempdir import NamedFileInTemporaryDirectory
32 32
33 33 from IPython.terminal.interactiveshell import TerminalInteractiveShell
34 34 from IPython.terminal.console.completer import ZMQCompleter
35 35
36 36 class ZMQTerminalInteractiveShell(TerminalInteractiveShell):
37 37 """A subclass of TerminalInteractiveShell that uses the 0MQ kernel"""
38 38 _executing = False
39 39 _execution_state = Unicode('')
40 40 _pending_clearoutput = False
41 41 kernel_banner = Unicode('')
42 42 kernel_timeout = Float(60, config=True,
43 43 help="""Timeout for giving up on a kernel (in seconds).
44 44
45 45 On first connect and restart, the console tests whether the
46 46 kernel is running and responsive by sending kernel_info_requests.
47 47 This sets the timeout in seconds for how long the kernel can take
48 48 before being presumed dead.
49 49 """
50 50 )
51 51
52 52 image_handler = Enum(('PIL', 'stream', 'tempfile', 'callable'),
53 config=True, help=
53 config=True, allow_none=True, help=
54 54 """
55 55 Handler for image type output. This is useful, for example,
56 56 when connecting to the kernel in which pylab inline backend is
57 57 activated. There are four handlers defined. 'PIL': Use
58 58 Python Imaging Library to popup image; 'stream': Use an
59 59 external program to show the image. Image will be fed into
60 60 the STDIN of the program. You will need to configure
61 61 `stream_image_handler`; 'tempfile': Use an external program to
62 62 show the image. Image will be saved in a temporally file and
63 63 the program is called with the temporally file. You will need
64 64 to configure `tempfile_image_handler`; 'callable': You can set
65 65 any Python callable which is called with the image data. You
66 66 will need to configure `callable_image_handler`.
67 67 """
68 68 )
69 69
70 70 stream_image_handler = List(config=True, help=
71 71 """
72 72 Command to invoke an image viewer program when you are using
73 73 'stream' image handler. This option is a list of string where
74 74 the first element is the command itself and reminders are the
75 75 options for the command. Raw image data is given as STDIN to
76 76 the program.
77 77 """
78 78 )
79 79
80 80 tempfile_image_handler = List(config=True, help=
81 81 """
82 82 Command to invoke an image viewer program when you are using
83 83 'tempfile' image handler. This option is a list of string
84 84 where the first element is the command itself and reminders
85 85 are the options for the command. You can use {file} and
86 86 {format} in the string to represent the location of the
87 87 generated image file and image format.
88 88 """
89 89 )
90 90
91 91 callable_image_handler = Any(config=True, help=
92 92 """
93 93 Callable object called via 'callable' image handler with one
94 94 argument, `data`, which is `msg["content"]["data"]` where
95 95 `msg` is the message from iopub channel. For exmaple, you can
96 96 find base64 encoded PNG data as `data['image/png']`.
97 97 """
98 98 )
99 99
100 100 mime_preference = List(
101 101 default_value=['image/png', 'image/jpeg', 'image/svg+xml'],
102 config=True, allow_none=False, help=
102 config=True, help=
103 103 """
104 104 Preferred object representation MIME type in order. First
105 105 matched MIME type will be used.
106 106 """
107 107 )
108 108
109 109 manager = Instance('IPython.kernel.KernelManager')
110 110 client = Instance('IPython.kernel.KernelClient')
111 111 def _client_changed(self, name, old, new):
112 112 self.session_id = new.session.session
113 113 session_id = Unicode()
114 114
115 115 def init_completer(self):
116 116 """Initialize the completion machinery.
117 117
118 118 This creates completion machinery that can be used by client code,
119 119 either interactively in-process (typically triggered by the readline
120 120 library), programmatically (such as in test suites) or out-of-process
121 121 (typically over the network by remote frontends).
122 122 """
123 123 from IPython.core.completerlib import (module_completer,
124 124 magic_run_completer, cd_completer)
125 125
126 126 self.Completer = ZMQCompleter(self, self.client, config=self.config)
127 127
128 128
129 129 self.set_hook('complete_command', module_completer, str_key = 'import')
130 130 self.set_hook('complete_command', module_completer, str_key = 'from')
131 131 self.set_hook('complete_command', magic_run_completer, str_key = '%run')
132 132 self.set_hook('complete_command', cd_completer, str_key = '%cd')
133 133
134 134 # Only configure readline if we truly are using readline. IPython can
135 135 # do tab-completion over the network, in GUIs, etc, where readline
136 136 # itself may be absent
137 137 if self.has_readline:
138 138 self.set_readline_completer()
139 139
140 140 def run_cell(self, cell, store_history=True):
141 141 """Run a complete IPython cell.
142 142
143 143 Parameters
144 144 ----------
145 145 cell : str
146 146 The code (including IPython code such as %magic functions) to run.
147 147 store_history : bool
148 148 If True, the raw and translated cell will be stored in IPython's
149 149 history. For user code calling back into IPython's machinery, this
150 150 should be set to False.
151 151 """
152 152 if (not cell) or cell.isspace():
153 153 # pressing enter flushes any pending display
154 154 self.handle_iopub()
155 155 return
156 156
157 157 # flush stale replies, which could have been ignored, due to missed heartbeats
158 158 while self.client.shell_channel.msg_ready():
159 159 self.client.shell_channel.get_msg()
160 160 # execute takes 'hidden', which is the inverse of store_hist
161 161 msg_id = self.client.execute(cell, not store_history)
162 162
163 163 # first thing is wait for any side effects (output, stdin, etc.)
164 164 self._executing = True
165 165 self._execution_state = "busy"
166 166 while self._execution_state != 'idle' and self.client.is_alive():
167 167 try:
168 168 self.handle_input_request(msg_id, timeout=0.05)
169 169 except Empty:
170 170 # display intermediate print statements, etc.
171 171 self.handle_iopub(msg_id)
172 172
173 173 # after all of that is done, wait for the execute reply
174 174 while self.client.is_alive():
175 175 try:
176 176 self.handle_execute_reply(msg_id, timeout=0.05)
177 177 except Empty:
178 178 pass
179 179 else:
180 180 break
181 181 self._executing = False
182 182
183 183 #-----------------
184 184 # message handlers
185 185 #-----------------
186 186
187 187 def handle_execute_reply(self, msg_id, timeout=None):
188 188 msg = self.client.shell_channel.get_msg(block=False, timeout=timeout)
189 189 if msg["parent_header"].get("msg_id", None) == msg_id:
190 190
191 191 self.handle_iopub(msg_id)
192 192
193 193 content = msg["content"]
194 194 status = content['status']
195 195
196 196 if status == 'aborted':
197 197 self.write('Aborted\n')
198 198 return
199 199 elif status == 'ok':
200 200 # handle payloads
201 201 for item in content["payload"]:
202 202 source = item['source']
203 203 if source == 'page':
204 204 page.page(item['data']['text/plain'])
205 205 elif source == 'set_next_input':
206 206 self.set_next_input(item['text'])
207 207 elif source == 'ask_exit':
208 208 self.ask_exit()
209 209
210 210 elif status == 'error':
211 211 for frame in content["traceback"]:
212 212 print(frame, file=io.stderr)
213 213
214 214 self.execution_count = int(content["execution_count"] + 1)
215 215
216 216 include_other_output = Bool(False, config=True,
217 217 help="""Whether to include output from clients
218 218 other than this one sharing the same kernel.
219 219
220 220 Outputs are not displayed until enter is pressed.
221 221 """
222 222 )
223 223 other_output_prefix = Unicode("[remote] ", config=True,
224 224 help="""Prefix to add to outputs coming from clients other than this one.
225 225
226 226 Only relevant if include_other_output is True.
227 227 """
228 228 )
229 229
230 230 def from_here(self, msg):
231 231 """Return whether a message is from this session"""
232 232 return msg['parent_header'].get("session", self.session_id) == self.session_id
233 233
234 234 def include_output(self, msg):
235 235 """Return whether we should include a given output message"""
236 236 from_here = self.from_here(msg)
237 237 if msg['msg_type'] == 'execute_input':
238 238 # only echo inputs not from here
239 239 return self.include_other_output and not from_here
240 240
241 241 if self.include_other_output:
242 242 return True
243 243 else:
244 244 return from_here
245 245
246 246 def handle_iopub(self, msg_id=''):
247 247 """Process messages on the IOPub channel
248 248
249 249 This method consumes and processes messages on the IOPub channel,
250 250 such as stdout, stderr, execute_result and status.
251 251
252 252 It only displays output that is caused by this session.
253 253 """
254 254 while self.client.iopub_channel.msg_ready():
255 255 sub_msg = self.client.iopub_channel.get_msg()
256 256 msg_type = sub_msg['header']['msg_type']
257 257 parent = sub_msg["parent_header"]
258 258
259 259 if self.include_output(sub_msg):
260 260 if msg_type == 'status':
261 261 self._execution_state = sub_msg["content"]["execution_state"]
262 262 elif msg_type == 'stream':
263 263 if sub_msg["content"]["name"] == "stdout":
264 264 if self._pending_clearoutput:
265 265 print("\r", file=io.stdout, end="")
266 266 self._pending_clearoutput = False
267 267 print(sub_msg["content"]["text"], file=io.stdout, end="")
268 268 io.stdout.flush()
269 269 elif sub_msg["content"]["name"] == "stderr":
270 270 if self._pending_clearoutput:
271 271 print("\r", file=io.stderr, end="")
272 272 self._pending_clearoutput = False
273 273 print(sub_msg["content"]["text"], file=io.stderr, end="")
274 274 io.stderr.flush()
275 275
276 276 elif msg_type == 'execute_result':
277 277 if self._pending_clearoutput:
278 278 print("\r", file=io.stdout, end="")
279 279 self._pending_clearoutput = False
280 280 self.execution_count = int(sub_msg["content"]["execution_count"])
281 281 if not self.from_here(sub_msg):
282 282 sys.stdout.write(self.other_output_prefix)
283 283 format_dict = sub_msg["content"]["data"]
284 284 self.handle_rich_data(format_dict)
285 285
286 286 # taken from DisplayHook.__call__:
287 287 hook = self.displayhook
288 288 hook.start_displayhook()
289 289 hook.write_output_prompt()
290 290 hook.write_format_data(format_dict)
291 291 hook.log_output(format_dict)
292 292 hook.finish_displayhook()
293 293
294 294 elif msg_type == 'display_data':
295 295 data = sub_msg["content"]["data"]
296 296 handled = self.handle_rich_data(data)
297 297 if not handled:
298 298 if not self.from_here(sub_msg):
299 299 sys.stdout.write(self.other_output_prefix)
300 300 # if it was an image, we handled it by now
301 301 if 'text/plain' in data:
302 302 print(data['text/plain'])
303 303
304 304 elif msg_type == 'execute_input':
305 305 content = sub_msg['content']
306 306 self.execution_count = content['execution_count']
307 307 if not self.from_here(sub_msg):
308 308 sys.stdout.write(self.other_output_prefix)
309 309 sys.stdout.write(self.prompt_manager.render('in'))
310 310 sys.stdout.write(content['code'])
311 311
312 312 elif msg_type == 'clear_output':
313 313 if sub_msg["content"]["wait"]:
314 314 self._pending_clearoutput = True
315 315 else:
316 316 print("\r", file=io.stdout, end="")
317 317
318 318 _imagemime = {
319 319 'image/png': 'png',
320 320 'image/jpeg': 'jpeg',
321 321 'image/svg+xml': 'svg',
322 322 }
323 323
324 324 def handle_rich_data(self, data):
325 325 for mime in self.mime_preference:
326 326 if mime in data and mime in self._imagemime:
327 327 self.handle_image(data, mime)
328 328 return True
329 329
330 330 def handle_image(self, data, mime):
331 331 handler = getattr(
332 332 self, 'handle_image_{0}'.format(self.image_handler), None)
333 333 if handler:
334 334 handler(data, mime)
335 335
336 336 def handle_image_PIL(self, data, mime):
337 337 if mime not in ('image/png', 'image/jpeg'):
338 338 return
339 339 import PIL.Image
340 340 raw = base64.decodestring(data[mime].encode('ascii'))
341 341 img = PIL.Image.open(BytesIO(raw))
342 342 img.show()
343 343
344 344 def handle_image_stream(self, data, mime):
345 345 raw = base64.decodestring(data[mime].encode('ascii'))
346 346 imageformat = self._imagemime[mime]
347 347 fmt = dict(format=imageformat)
348 348 args = [s.format(**fmt) for s in self.stream_image_handler]
349 349 with open(os.devnull, 'w') as devnull:
350 350 proc = subprocess.Popen(
351 351 args, stdin=subprocess.PIPE,
352 352 stdout=devnull, stderr=devnull)
353 353 proc.communicate(raw)
354 354
355 355 def handle_image_tempfile(self, data, mime):
356 356 raw = base64.decodestring(data[mime].encode('ascii'))
357 357 imageformat = self._imagemime[mime]
358 358 filename = 'tmp.{0}'.format(imageformat)
359 359 with NamedFileInTemporaryDirectory(filename) as f, \
360 360 open(os.devnull, 'w') as devnull:
361 361 f.write(raw)
362 362 f.flush()
363 363 fmt = dict(file=f.name, format=imageformat)
364 364 args = [s.format(**fmt) for s in self.tempfile_image_handler]
365 365 subprocess.call(args, stdout=devnull, stderr=devnull)
366 366
367 367 def handle_image_callable(self, data, mime):
368 368 self.callable_image_handler(data)
369 369
370 370 def handle_input_request(self, msg_id, timeout=0.1):
371 371 """ Method to capture raw_input
372 372 """
373 373 req = self.client.stdin_channel.get_msg(timeout=timeout)
374 374 # in case any iopub came while we were waiting:
375 375 self.handle_iopub(msg_id)
376 376 if msg_id == req["parent_header"].get("msg_id"):
377 377 # wrap SIGINT handler
378 378 real_handler = signal.getsignal(signal.SIGINT)
379 379 def double_int(sig,frame):
380 380 # call real handler (forwards sigint to kernel),
381 381 # then raise local interrupt, stopping local raw_input
382 382 real_handler(sig,frame)
383 383 raise KeyboardInterrupt
384 384 signal.signal(signal.SIGINT, double_int)
385 385 content = req['content']
386 386 read = getpass if content.get('password', False) else input
387 387 try:
388 388 raw_data = read(content["prompt"])
389 389 except EOFError:
390 390 # turn EOFError into EOF character
391 391 raw_data = '\x04'
392 392 except KeyboardInterrupt:
393 393 sys.stdout.write('\n')
394 394 return
395 395 finally:
396 396 # restore SIGINT handler
397 397 signal.signal(signal.SIGINT, real_handler)
398 398
399 399 # only send stdin reply if there *was not* another request
400 400 # or execution finished while we were reading.
401 401 if not (self.client.stdin_channel.msg_ready() or self.client.shell_channel.msg_ready()):
402 402 self.client.input(raw_data)
403 403
404 404 def mainloop(self, display_banner=False):
405 405 while True:
406 406 try:
407 407 self.interact(display_banner=display_banner)
408 408 #self.interact_with_readline()
409 409 # XXX for testing of a readline-decoupled repl loop, call
410 410 # interact_with_readline above
411 411 break
412 412 except KeyboardInterrupt:
413 413 # this should not be necessary, but KeyboardInterrupt
414 414 # handling seems rather unpredictable...
415 415 self.write("\nKeyboardInterrupt in interact()\n")
416 416
417 417 self.client.shutdown()
418 418
419 419 def _banner1_default(self):
420 420 return "IPython Console {version}\n".format(version=release.version)
421 421
422 422 def compute_banner(self):
423 423 super(ZMQTerminalInteractiveShell, self).compute_banner()
424 424 if self.client and not self.kernel_banner:
425 425 msg_id = self.client.kernel_info()
426 426 while True:
427 427 try:
428 428 reply = self.client.get_shell_msg(timeout=1)
429 429 except Empty:
430 430 break
431 431 else:
432 432 if reply['parent_header'].get('msg_id') == msg_id:
433 433 self.kernel_banner = reply['content'].get('banner', '')
434 434 break
435 435 self.banner += self.kernel_banner
436 436
437 437 def wait_for_kernel(self, timeout=None):
438 438 """method to wait for a kernel to be ready"""
439 439 tic = time.time()
440 440 self.client.hb_channel.unpause()
441 441 while True:
442 442 msg_id = self.client.kernel_info()
443 443 reply = None
444 444 while True:
445 445 try:
446 446 reply = self.client.get_shell_msg(timeout=1)
447 447 except Empty:
448 448 break
449 449 else:
450 450 if reply['parent_header'].get('msg_id') == msg_id:
451 451 return True
452 452 if timeout is not None \
453 453 and (time.time() - tic) > timeout \
454 454 and not self.client.hb_channel.is_beating():
455 455 # heart failed
456 456 return False
457 457 return True
458 458
459 459 def interact(self, display_banner=None):
460 460 """Closely emulate the interactive Python console."""
461 461
462 462 # batch run -> do not interact
463 463 if self.exit_now:
464 464 return
465 465
466 466 if display_banner is None:
467 467 display_banner = self.display_banner
468 468
469 469 if isinstance(display_banner, string_types):
470 470 self.show_banner(display_banner)
471 471 elif display_banner:
472 472 self.show_banner()
473 473
474 474 more = False
475 475
476 476 # run a non-empty no-op, so that we don't get a prompt until
477 477 # we know the kernel is ready. This keeps the connection
478 478 # message above the first prompt.
479 479 if not self.wait_for_kernel(self.kernel_timeout):
480 480 error("Kernel did not respond\n")
481 481 return
482 482
483 483 if self.has_readline:
484 484 self.readline_startup_hook(self.pre_readline)
485 485 hlen_b4_cell = self.readline.get_current_history_length()
486 486 else:
487 487 hlen_b4_cell = 0
488 488 # exit_now is set by a call to %Exit or %Quit, through the
489 489 # ask_exit callback.
490 490
491 491 while not self.exit_now:
492 492 if not self.client.is_alive():
493 493 # kernel died, prompt for action or exit
494 494
495 495 action = "restart" if self.manager else "wait for restart"
496 496 ans = self.ask_yes_no("kernel died, %s ([y]/n)?" % action, default='y')
497 497 if ans:
498 498 if self.manager:
499 499 self.manager.restart_kernel(True)
500 500 self.wait_for_kernel(self.kernel_timeout)
501 501 else:
502 502 self.exit_now = True
503 503 continue
504 504 try:
505 505 # protect prompt block from KeyboardInterrupt
506 506 # when sitting on ctrl-C
507 507 self.hooks.pre_prompt_hook()
508 508 if more:
509 509 try:
510 510 prompt = self.prompt_manager.render('in2')
511 511 except Exception:
512 512 self.showtraceback()
513 513 if self.autoindent:
514 514 self.rl_do_indent = True
515 515
516 516 else:
517 517 try:
518 518 prompt = self.separate_in + self.prompt_manager.render('in')
519 519 except Exception:
520 520 self.showtraceback()
521 521
522 522 line = self.raw_input(prompt)
523 523 if self.exit_now:
524 524 # quick exit on sys.std[in|out] close
525 525 break
526 526 if self.autoindent:
527 527 self.rl_do_indent = False
528 528
529 529 except KeyboardInterrupt:
530 530 #double-guard against keyboardinterrupts during kbdint handling
531 531 try:
532 532 self.write('\n' + self.get_exception_only())
533 533 source_raw = self.input_splitter.raw_reset()
534 534 hlen_b4_cell = self._replace_rlhist_multiline(source_raw, hlen_b4_cell)
535 535 more = False
536 536 except KeyboardInterrupt:
537 537 pass
538 538 except EOFError:
539 539 if self.autoindent:
540 540 self.rl_do_indent = False
541 541 if self.has_readline:
542 542 self.readline_startup_hook(None)
543 543 self.write('\n')
544 544 self.exit()
545 545 except bdb.BdbQuit:
546 546 warn('The Python debugger has exited with a BdbQuit exception.\n'
547 547 'Because of how pdb handles the stack, it is impossible\n'
548 548 'for IPython to properly format this particular exception.\n'
549 549 'IPython will resume normal operation.')
550 550 except:
551 551 # exceptions here are VERY RARE, but they can be triggered
552 552 # asynchronously by signal handlers, for example.
553 553 self.showtraceback()
554 554 else:
555 555 try:
556 556 self.input_splitter.push(line)
557 557 more = self.input_splitter.push_accepts_more()
558 558 except SyntaxError:
559 559 # Run the code directly - run_cell takes care of displaying
560 560 # the exception.
561 561 more = False
562 562 if (self.SyntaxTB.last_syntax_error and
563 563 self.autoedit_syntax):
564 564 self.edit_syntax_error()
565 565 if not more:
566 566 source_raw = self.input_splitter.raw_reset()
567 567 hlen_b4_cell = self._replace_rlhist_multiline(source_raw, hlen_b4_cell)
568 568 self.run_cell(source_raw)
569 569
570 570
571 571 # Turn off the exit flag, so the mainloop can be restarted if desired
572 572 self.exit_now = False
573 573
574 574 def init_history(self):
575 575 """Sets up the command history. """
576 576 self.history_manager = ZMQHistoryManager(client=self.client)
577 577 self.configurables.append(self.history_manager)
578 578
@@ -1,1468 +1,1468 b''
1 1 # encoding: utf-8
2 2 """Tests for IPython.utils.traitlets."""
3 3
4 4 # Copyright (c) IPython Development Team.
5 5 # Distributed under the terms of the Modified BSD License.
6 6 #
7 7 # Adapted from enthought.traits, Copyright (c) Enthought, Inc.,
8 8 # also under the terms of the Modified BSD License.
9 9
10 10 import pickle
11 11 import re
12 12 import sys
13 13 from unittest import TestCase
14 14
15 15 import nose.tools as nt
16 16 from nose import SkipTest
17 17
18 18 from IPython.utils.traitlets import (
19 19 HasTraits, MetaHasTraits, TraitType, Any, Bool, CBytes, Dict, Enum,
20 20 Int, Long, Integer, Float, Complex, Bytes, Unicode, TraitError,
21 21 Union, Undefined, Type, This, Instance, TCPAddress, List, Tuple,
22 22 ObjectName, DottedObjectName, CRegExp, link, directional_link,
23 23 EventfulList, EventfulDict, ForwardDeclaredType, ForwardDeclaredInstance,
24 24 )
25 25 from IPython.utils import py3compat
26 26 from IPython.testing.decorators import skipif
27 27
28 28 #-----------------------------------------------------------------------------
29 29 # Helper classes for testing
30 30 #-----------------------------------------------------------------------------
31 31
32 32
33 33 class HasTraitsStub(HasTraits):
34 34
35 35 def _notify_trait(self, name, old, new):
36 36 self._notify_name = name
37 37 self._notify_old = old
38 38 self._notify_new = new
39 39
40 40
41 41 #-----------------------------------------------------------------------------
42 42 # Test classes
43 43 #-----------------------------------------------------------------------------
44 44
45 45
46 46 class TestTraitType(TestCase):
47 47
48 48 def test_get_undefined(self):
49 49 class A(HasTraits):
50 50 a = TraitType
51 51 a = A()
52 52 self.assertEqual(a.a, Undefined)
53 53
54 54 def test_set(self):
55 55 class A(HasTraitsStub):
56 56 a = TraitType
57 57
58 58 a = A()
59 59 a.a = 10
60 60 self.assertEqual(a.a, 10)
61 61 self.assertEqual(a._notify_name, 'a')
62 62 self.assertEqual(a._notify_old, Undefined)
63 63 self.assertEqual(a._notify_new, 10)
64 64
65 65 def test_validate(self):
66 66 class MyTT(TraitType):
67 67 def validate(self, inst, value):
68 68 return -1
69 69 class A(HasTraitsStub):
70 70 tt = MyTT
71 71
72 72 a = A()
73 73 a.tt = 10
74 74 self.assertEqual(a.tt, -1)
75 75
76 76 def test_default_validate(self):
77 77 class MyIntTT(TraitType):
78 78 def validate(self, obj, value):
79 79 if isinstance(value, int):
80 80 return value
81 81 self.error(obj, value)
82 82 class A(HasTraits):
83 83 tt = MyIntTT(10)
84 84 a = A()
85 85 self.assertEqual(a.tt, 10)
86 86
87 87 # Defaults are validated when the HasTraits is instantiated
88 88 class B(HasTraits):
89 89 tt = MyIntTT('bad default')
90 90 self.assertRaises(TraitError, B)
91 91
92 92 def test_info(self):
93 93 class A(HasTraits):
94 94 tt = TraitType
95 95 a = A()
96 96 self.assertEqual(A.tt.info(), 'any value')
97 97
98 98 def test_error(self):
99 99 class A(HasTraits):
100 100 tt = TraitType
101 101 a = A()
102 102 self.assertRaises(TraitError, A.tt.error, a, 10)
103 103
104 104 def test_dynamic_initializer(self):
105 105 class A(HasTraits):
106 106 x = Int(10)
107 107 def _x_default(self):
108 108 return 11
109 109 class B(A):
110 110 x = Int(20)
111 111 class C(A):
112 112 def _x_default(self):
113 113 return 21
114 114
115 115 a = A()
116 116 self.assertEqual(a._trait_values, {})
117 117 self.assertEqual(list(a._trait_dyn_inits.keys()), ['x'])
118 118 self.assertEqual(a.x, 11)
119 119 self.assertEqual(a._trait_values, {'x': 11})
120 120 b = B()
121 121 self.assertEqual(b._trait_values, {'x': 20})
122 122 self.assertEqual(list(a._trait_dyn_inits.keys()), ['x'])
123 123 self.assertEqual(b.x, 20)
124 124 c = C()
125 125 self.assertEqual(c._trait_values, {})
126 126 self.assertEqual(list(a._trait_dyn_inits.keys()), ['x'])
127 127 self.assertEqual(c.x, 21)
128 128 self.assertEqual(c._trait_values, {'x': 21})
129 129 # Ensure that the base class remains unmolested when the _default
130 130 # initializer gets overridden in a subclass.
131 131 a = A()
132 132 c = C()
133 133 self.assertEqual(a._trait_values, {})
134 134 self.assertEqual(list(a._trait_dyn_inits.keys()), ['x'])
135 135 self.assertEqual(a.x, 11)
136 136 self.assertEqual(a._trait_values, {'x': 11})
137 137
138 138
139 139
140 140 class TestHasTraitsMeta(TestCase):
141 141
142 142 def test_metaclass(self):
143 143 self.assertEqual(type(HasTraits), MetaHasTraits)
144 144
145 145 class A(HasTraits):
146 146 a = Int
147 147
148 148 a = A()
149 149 self.assertEqual(type(a.__class__), MetaHasTraits)
150 150 self.assertEqual(a.a,0)
151 151 a.a = 10
152 152 self.assertEqual(a.a,10)
153 153
154 154 class B(HasTraits):
155 155 b = Int()
156 156
157 157 b = B()
158 158 self.assertEqual(b.b,0)
159 159 b.b = 10
160 160 self.assertEqual(b.b,10)
161 161
162 162 class C(HasTraits):
163 163 c = Int(30)
164 164
165 165 c = C()
166 166 self.assertEqual(c.c,30)
167 167 c.c = 10
168 168 self.assertEqual(c.c,10)
169 169
170 170 def test_this_class(self):
171 171 class A(HasTraits):
172 172 t = This()
173 173 tt = This()
174 174 class B(A):
175 175 tt = This()
176 176 ttt = This()
177 177 self.assertEqual(A.t.this_class, A)
178 178 self.assertEqual(B.t.this_class, A)
179 179 self.assertEqual(B.tt.this_class, B)
180 180 self.assertEqual(B.ttt.this_class, B)
181 181
182 182 class TestHasTraitsNotify(TestCase):
183 183
184 184 def setUp(self):
185 185 self._notify1 = []
186 186 self._notify2 = []
187 187
188 188 def notify1(self, name, old, new):
189 189 self._notify1.append((name, old, new))
190 190
191 191 def notify2(self, name, old, new):
192 192 self._notify2.append((name, old, new))
193 193
194 194 def test_notify_all(self):
195 195
196 196 class A(HasTraits):
197 197 a = Int
198 198 b = Float
199 199
200 200 a = A()
201 201 a.on_trait_change(self.notify1)
202 202 a.a = 0
203 203 self.assertEqual(len(self._notify1),0)
204 204 a.b = 0.0
205 205 self.assertEqual(len(self._notify1),0)
206 206 a.a = 10
207 207 self.assertTrue(('a',0,10) in self._notify1)
208 208 a.b = 10.0
209 209 self.assertTrue(('b',0.0,10.0) in self._notify1)
210 210 self.assertRaises(TraitError,setattr,a,'a','bad string')
211 211 self.assertRaises(TraitError,setattr,a,'b','bad string')
212 212 self._notify1 = []
213 213 a.on_trait_change(self.notify1,remove=True)
214 214 a.a = 20
215 215 a.b = 20.0
216 216 self.assertEqual(len(self._notify1),0)
217 217
218 218 def test_notify_one(self):
219 219
220 220 class A(HasTraits):
221 221 a = Int
222 222 b = Float
223 223
224 224 a = A()
225 225 a.on_trait_change(self.notify1, 'a')
226 226 a.a = 0
227 227 self.assertEqual(len(self._notify1),0)
228 228 a.a = 10
229 229 self.assertTrue(('a',0,10) in self._notify1)
230 230 self.assertRaises(TraitError,setattr,a,'a','bad string')
231 231
232 232 def test_subclass(self):
233 233
234 234 class A(HasTraits):
235 235 a = Int
236 236
237 237 class B(A):
238 238 b = Float
239 239
240 240 b = B()
241 241 self.assertEqual(b.a,0)
242 242 self.assertEqual(b.b,0.0)
243 243 b.a = 100
244 244 b.b = 100.0
245 245 self.assertEqual(b.a,100)
246 246 self.assertEqual(b.b,100.0)
247 247
248 248 def test_notify_subclass(self):
249 249
250 250 class A(HasTraits):
251 251 a = Int
252 252
253 253 class B(A):
254 254 b = Float
255 255
256 256 b = B()
257 257 b.on_trait_change(self.notify1, 'a')
258 258 b.on_trait_change(self.notify2, 'b')
259 259 b.a = 0
260 260 b.b = 0.0
261 261 self.assertEqual(len(self._notify1),0)
262 262 self.assertEqual(len(self._notify2),0)
263 263 b.a = 10
264 264 b.b = 10.0
265 265 self.assertTrue(('a',0,10) in self._notify1)
266 266 self.assertTrue(('b',0.0,10.0) in self._notify2)
267 267
268 268 def test_static_notify(self):
269 269
270 270 class A(HasTraits):
271 271 a = Int
272 272 _notify1 = []
273 273 def _a_changed(self, name, old, new):
274 274 self._notify1.append((name, old, new))
275 275
276 276 a = A()
277 277 a.a = 0
278 278 # This is broken!!!
279 279 self.assertEqual(len(a._notify1),0)
280 280 a.a = 10
281 281 self.assertTrue(('a',0,10) in a._notify1)
282 282
283 283 class B(A):
284 284 b = Float
285 285 _notify2 = []
286 286 def _b_changed(self, name, old, new):
287 287 self._notify2.append((name, old, new))
288 288
289 289 b = B()
290 290 b.a = 10
291 291 b.b = 10.0
292 292 self.assertTrue(('a',0,10) in b._notify1)
293 293 self.assertTrue(('b',0.0,10.0) in b._notify2)
294 294
295 295 def test_notify_args(self):
296 296
297 297 def callback0():
298 298 self.cb = ()
299 299 def callback1(name):
300 300 self.cb = (name,)
301 301 def callback2(name, new):
302 302 self.cb = (name, new)
303 303 def callback3(name, old, new):
304 304 self.cb = (name, old, new)
305 305
306 306 class A(HasTraits):
307 307 a = Int
308 308
309 309 a = A()
310 310 a.on_trait_change(callback0, 'a')
311 311 a.a = 10
312 312 self.assertEqual(self.cb,())
313 313 a.on_trait_change(callback0, 'a', remove=True)
314 314
315 315 a.on_trait_change(callback1, 'a')
316 316 a.a = 100
317 317 self.assertEqual(self.cb,('a',))
318 318 a.on_trait_change(callback1, 'a', remove=True)
319 319
320 320 a.on_trait_change(callback2, 'a')
321 321 a.a = 1000
322 322 self.assertEqual(self.cb,('a',1000))
323 323 a.on_trait_change(callback2, 'a', remove=True)
324 324
325 325 a.on_trait_change(callback3, 'a')
326 326 a.a = 10000
327 327 self.assertEqual(self.cb,('a',1000,10000))
328 328 a.on_trait_change(callback3, 'a', remove=True)
329 329
330 330 self.assertEqual(len(a._trait_notifiers['a']),0)
331 331
332 332 def test_notify_only_once(self):
333 333
334 334 class A(HasTraits):
335 335 listen_to = ['a']
336 336
337 337 a = Int(0)
338 338 b = 0
339 339
340 340 def __init__(self, **kwargs):
341 341 super(A, self).__init__(**kwargs)
342 342 self.on_trait_change(self.listener1, ['a'])
343 343
344 344 def listener1(self, name, old, new):
345 345 self.b += 1
346 346
347 347 class B(A):
348 348
349 349 c = 0
350 350 d = 0
351 351
352 352 def __init__(self, **kwargs):
353 353 super(B, self).__init__(**kwargs)
354 354 self.on_trait_change(self.listener2)
355 355
356 356 def listener2(self, name, old, new):
357 357 self.c += 1
358 358
359 359 def _a_changed(self, name, old, new):
360 360 self.d += 1
361 361
362 362 b = B()
363 363 b.a += 1
364 364 self.assertEqual(b.b, b.c)
365 365 self.assertEqual(b.b, b.d)
366 366 b.a += 1
367 367 self.assertEqual(b.b, b.c)
368 368 self.assertEqual(b.b, b.d)
369 369
370 370
371 371 class TestHasTraits(TestCase):
372 372
373 373 def test_trait_names(self):
374 374 class A(HasTraits):
375 375 i = Int
376 376 f = Float
377 377 a = A()
378 378 self.assertEqual(sorted(a.trait_names()),['f','i'])
379 379 self.assertEqual(sorted(A.class_trait_names()),['f','i'])
380 380
381 381 def test_trait_metadata(self):
382 382 class A(HasTraits):
383 383 i = Int(config_key='MY_VALUE')
384 384 a = A()
385 385 self.assertEqual(a.trait_metadata('i','config_key'), 'MY_VALUE')
386 386
387 387 def test_trait_metadata_default(self):
388 388 class A(HasTraits):
389 389 i = Int()
390 390 a = A()
391 391 self.assertEqual(a.trait_metadata('i', 'config_key'), None)
392 392 self.assertEqual(a.trait_metadata('i', 'config_key', 'default'), 'default')
393 393
394 394 def test_traits(self):
395 395 class A(HasTraits):
396 396 i = Int
397 397 f = Float
398 398 a = A()
399 399 self.assertEqual(a.traits(), dict(i=A.i, f=A.f))
400 400 self.assertEqual(A.class_traits(), dict(i=A.i, f=A.f))
401 401
402 402 def test_traits_metadata(self):
403 403 class A(HasTraits):
404 404 i = Int(config_key='VALUE1', other_thing='VALUE2')
405 405 f = Float(config_key='VALUE3', other_thing='VALUE2')
406 406 j = Int(0)
407 407 a = A()
408 408 self.assertEqual(a.traits(), dict(i=A.i, f=A.f, j=A.j))
409 409 traits = a.traits(config_key='VALUE1', other_thing='VALUE2')
410 410 self.assertEqual(traits, dict(i=A.i))
411 411
412 412 # This passes, but it shouldn't because I am replicating a bug in
413 413 # traits.
414 414 traits = a.traits(config_key=lambda v: True)
415 415 self.assertEqual(traits, dict(i=A.i, f=A.f, j=A.j))
416 416
417 417 def test_init(self):
418 418 class A(HasTraits):
419 419 i = Int()
420 420 x = Float()
421 421 a = A(i=1, x=10.0)
422 422 self.assertEqual(a.i, 1)
423 423 self.assertEqual(a.x, 10.0)
424 424
425 425 def test_positional_args(self):
426 426 class A(HasTraits):
427 427 i = Int(0)
428 428 def __init__(self, i):
429 429 super(A, self).__init__()
430 430 self.i = i
431 431
432 432 a = A(5)
433 433 self.assertEqual(a.i, 5)
434 434 # should raise TypeError if no positional arg given
435 435 self.assertRaises(TypeError, A)
436 436
437 437 #-----------------------------------------------------------------------------
438 438 # Tests for specific trait types
439 439 #-----------------------------------------------------------------------------
440 440
441 441
442 442 class TestType(TestCase):
443 443
444 444 def test_default(self):
445 445
446 446 class B(object): pass
447 447 class A(HasTraits):
448 448 klass = Type
449 449
450 450 a = A()
451 451 self.assertEqual(a.klass, None)
452 452
453 453 a.klass = B
454 454 self.assertEqual(a.klass, B)
455 455 self.assertRaises(TraitError, setattr, a, 'klass', 10)
456 456
457 457 def test_value(self):
458 458
459 459 class B(object): pass
460 460 class C(object): pass
461 461 class A(HasTraits):
462 462 klass = Type(B)
463 463
464 464 a = A()
465 465 self.assertEqual(a.klass, B)
466 466 self.assertRaises(TraitError, setattr, a, 'klass', C)
467 467 self.assertRaises(TraitError, setattr, a, 'klass', object)
468 468 a.klass = B
469 469
470 470 def test_allow_none(self):
471 471
472 472 class B(object): pass
473 473 class C(B): pass
474 474 class A(HasTraits):
475 475 klass = Type(B, allow_none=False)
476 476
477 477 a = A()
478 478 self.assertEqual(a.klass, B)
479 479 self.assertRaises(TraitError, setattr, a, 'klass', None)
480 480 a.klass = C
481 481 self.assertEqual(a.klass, C)
482 482
483 483 def test_validate_klass(self):
484 484
485 485 class A(HasTraits):
486 486 klass = Type('no strings allowed')
487 487
488 488 self.assertRaises(ImportError, A)
489 489
490 490 class A(HasTraits):
491 491 klass = Type('rub.adub.Duck')
492 492
493 493 self.assertRaises(ImportError, A)
494 494
495 495 def test_validate_default(self):
496 496
497 497 class B(object): pass
498 498 class A(HasTraits):
499 499 klass = Type('bad default', B)
500 500
501 501 self.assertRaises(ImportError, A)
502 502
503 503 class C(HasTraits):
504 504 klass = Type(None, B, allow_none=False)
505 505
506 506 self.assertRaises(TraitError, C)
507 507
508 508 def test_str_klass(self):
509 509
510 510 class A(HasTraits):
511 511 klass = Type('IPython.utils.ipstruct.Struct')
512 512
513 513 from IPython.utils.ipstruct import Struct
514 514 a = A()
515 515 a.klass = Struct
516 516 self.assertEqual(a.klass, Struct)
517 517
518 518 self.assertRaises(TraitError, setattr, a, 'klass', 10)
519 519
520 520 def test_set_str_klass(self):
521 521
522 522 class A(HasTraits):
523 523 klass = Type()
524 524
525 525 a = A(klass='IPython.utils.ipstruct.Struct')
526 526 from IPython.utils.ipstruct import Struct
527 527 self.assertEqual(a.klass, Struct)
528 528
529 529 class TestInstance(TestCase):
530 530
531 531 def test_basic(self):
532 532 class Foo(object): pass
533 533 class Bar(Foo): pass
534 534 class Bah(object): pass
535 535
536 536 class A(HasTraits):
537 537 inst = Instance(Foo)
538 538
539 539 a = A()
540 540 self.assertTrue(a.inst is None)
541 541 a.inst = Foo()
542 542 self.assertTrue(isinstance(a.inst, Foo))
543 543 a.inst = Bar()
544 544 self.assertTrue(isinstance(a.inst, Foo))
545 545 self.assertRaises(TraitError, setattr, a, 'inst', Foo)
546 546 self.assertRaises(TraitError, setattr, a, 'inst', Bar)
547 547 self.assertRaises(TraitError, setattr, a, 'inst', Bah())
548 548
549 549 def test_default_klass(self):
550 550 class Foo(object): pass
551 551 class Bar(Foo): pass
552 552 class Bah(object): pass
553 553
554 554 class FooInstance(Instance):
555 555 klass = Foo
556 556
557 557 class A(HasTraits):
558 558 inst = FooInstance()
559 559
560 560 a = A()
561 561 self.assertTrue(a.inst is None)
562 562 a.inst = Foo()
563 563 self.assertTrue(isinstance(a.inst, Foo))
564 564 a.inst = Bar()
565 565 self.assertTrue(isinstance(a.inst, Foo))
566 566 self.assertRaises(TraitError, setattr, a, 'inst', Foo)
567 567 self.assertRaises(TraitError, setattr, a, 'inst', Bar)
568 568 self.assertRaises(TraitError, setattr, a, 'inst', Bah())
569 569
570 570 def test_unique_default_value(self):
571 571 class Foo(object): pass
572 572 class A(HasTraits):
573 573 inst = Instance(Foo,(),{})
574 574
575 575 a = A()
576 576 b = A()
577 577 self.assertTrue(a.inst is not b.inst)
578 578
579 579 def test_args_kw(self):
580 580 class Foo(object):
581 581 def __init__(self, c): self.c = c
582 582 class Bar(object): pass
583 583 class Bah(object):
584 584 def __init__(self, c, d):
585 585 self.c = c; self.d = d
586 586
587 587 class A(HasTraits):
588 588 inst = Instance(Foo, (10,))
589 589 a = A()
590 590 self.assertEqual(a.inst.c, 10)
591 591
592 592 class B(HasTraits):
593 593 inst = Instance(Bah, args=(10,), kw=dict(d=20))
594 594 b = B()
595 595 self.assertEqual(b.inst.c, 10)
596 596 self.assertEqual(b.inst.d, 20)
597 597
598 598 class C(HasTraits):
599 599 inst = Instance(Foo)
600 600 c = C()
601 601 self.assertTrue(c.inst is None)
602 602
603 603 def test_bad_default(self):
604 604 class Foo(object): pass
605 605
606 606 class A(HasTraits):
607 607 inst = Instance(Foo, allow_none=False)
608 608
609 609 self.assertRaises(TraitError, A)
610 610
611 611 def test_instance(self):
612 612 class Foo(object): pass
613 613
614 614 def inner():
615 615 class A(HasTraits):
616 616 inst = Instance(Foo())
617 617
618 618 self.assertRaises(TraitError, inner)
619 619
620 620
621 621 class TestThis(TestCase):
622 622
623 623 def test_this_class(self):
624 624 class Foo(HasTraits):
625 625 this = This
626 626
627 627 f = Foo()
628 628 self.assertEqual(f.this, None)
629 629 g = Foo()
630 630 f.this = g
631 631 self.assertEqual(f.this, g)
632 632 self.assertRaises(TraitError, setattr, f, 'this', 10)
633 633
634 634 def test_this_inst(self):
635 635 class Foo(HasTraits):
636 636 this = This()
637 637
638 638 f = Foo()
639 639 f.this = Foo()
640 640 self.assertTrue(isinstance(f.this, Foo))
641 641
642 642 def test_subclass(self):
643 643 class Foo(HasTraits):
644 644 t = This()
645 645 class Bar(Foo):
646 646 pass
647 647 f = Foo()
648 648 b = Bar()
649 649 f.t = b
650 650 b.t = f
651 651 self.assertEqual(f.t, b)
652 652 self.assertEqual(b.t, f)
653 653
654 654 def test_subclass_override(self):
655 655 class Foo(HasTraits):
656 656 t = This()
657 657 class Bar(Foo):
658 658 t = This()
659 659 f = Foo()
660 660 b = Bar()
661 661 f.t = b
662 662 self.assertEqual(f.t, b)
663 663 self.assertRaises(TraitError, setattr, b, 't', f)
664 664
665 665 def test_this_in_container(self):
666 666
667 667 class Tree(HasTraits):
668 668 value = Unicode()
669 669 leaves = List(This())
670 670
671 671 tree = Tree(
672 672 value='foo',
673 673 leaves=[Tree('bar'), Tree('buzz')]
674 674 )
675 675
676 676 with self.assertRaises(TraitError):
677 677 tree.leaves = [1, 2]
678 678
679 679 class TraitTestBase(TestCase):
680 680 """A best testing class for basic trait types."""
681 681
682 682 def assign(self, value):
683 683 self.obj.value = value
684 684
685 685 def coerce(self, value):
686 686 return value
687 687
688 688 def test_good_values(self):
689 689 if hasattr(self, '_good_values'):
690 690 for value in self._good_values:
691 691 self.assign(value)
692 692 self.assertEqual(self.obj.value, self.coerce(value))
693 693
694 694 def test_bad_values(self):
695 695 if hasattr(self, '_bad_values'):
696 696 for value in self._bad_values:
697 697 try:
698 698 self.assertRaises(TraitError, self.assign, value)
699 699 except AssertionError:
700 700 assert False, value
701 701
702 702 def test_default_value(self):
703 703 if hasattr(self, '_default_value'):
704 704 self.assertEqual(self._default_value, self.obj.value)
705 705
706 706 def test_allow_none(self):
707 707 if (hasattr(self, '_bad_values') and hasattr(self, '_good_values') and
708 708 None in self._bad_values):
709 709 trait=self.obj.traits()['value']
710 710 try:
711 711 trait.allow_none = True
712 712 self._bad_values.remove(None)
713 713 #skip coerce. Allow None casts None to None.
714 714 self.assign(None)
715 715 self.assertEqual(self.obj.value,None)
716 716 self.test_good_values()
717 717 self.test_bad_values()
718 718 finally:
719 719 #tear down
720 720 trait.allow_none = False
721 721 self._bad_values.append(None)
722 722
723 723 def tearDown(self):
724 724 # restore default value after tests, if set
725 725 if hasattr(self, '_default_value'):
726 726 self.obj.value = self._default_value
727 727
728 728
729 729 class AnyTrait(HasTraits):
730 730
731 731 value = Any
732 732
733 733 class AnyTraitTest(TraitTestBase):
734 734
735 735 obj = AnyTrait()
736 736
737 737 _default_value = None
738 738 _good_values = [10.0, 'ten', u'ten', [10], {'ten': 10},(10,), None, 1j]
739 739 _bad_values = []
740 740
741 741 class UnionTrait(HasTraits):
742 742
743 743 value = Union([Type(), Bool()])
744 744
745 745 class UnionTraitTest(TraitTestBase):
746 746
747 747 obj = UnionTrait(value='IPython.utils.ipstruct.Struct')
748 748 _good_values = [int, float, True]
749 749 _bad_values = [[], (0,), 1j]
750 750
751 751 class OrTrait(HasTraits):
752 752
753 753 value = Bool() | Unicode()
754 754
755 755 class OrTraitTest(TraitTestBase):
756 756
757 757 obj = OrTrait()
758 758 _good_values = [True, False, 'ten']
759 759 _bad_values = [[], (0,), 1j]
760 760
761 761 class IntTrait(HasTraits):
762 762
763 763 value = Int(99)
764 764
765 765 class TestInt(TraitTestBase):
766 766
767 767 obj = IntTrait()
768 768 _default_value = 99
769 769 _good_values = [10, -10]
770 770 _bad_values = ['ten', u'ten', [10], {'ten': 10},(10,), None, 1j,
771 771 10.1, -10.1, '10L', '-10L', '10.1', '-10.1', u'10L',
772 772 u'-10L', u'10.1', u'-10.1', '10', '-10', u'10', u'-10']
773 773 if not py3compat.PY3:
774 774 _bad_values.extend([long(10), long(-10), 10*sys.maxint, -10*sys.maxint])
775 775
776 776
777 777 class LongTrait(HasTraits):
778 778
779 779 value = Long(99 if py3compat.PY3 else long(99))
780 780
781 781 class TestLong(TraitTestBase):
782 782
783 783 obj = LongTrait()
784 784
785 785 _default_value = 99 if py3compat.PY3 else long(99)
786 786 _good_values = [10, -10]
787 787 _bad_values = ['ten', u'ten', [10], {'ten': 10},(10,),
788 788 None, 1j, 10.1, -10.1, '10', '-10', '10L', '-10L', '10.1',
789 789 '-10.1', u'10', u'-10', u'10L', u'-10L', u'10.1',
790 790 u'-10.1']
791 791 if not py3compat.PY3:
792 792 # maxint undefined on py3, because int == long
793 793 _good_values.extend([long(10), long(-10), 10*sys.maxint, -10*sys.maxint])
794 794 _bad_values.extend([[long(10)], (long(10),)])
795 795
796 796 @skipif(py3compat.PY3, "not relevant on py3")
797 797 def test_cast_small(self):
798 798 """Long casts ints to long"""
799 799 self.obj.value = 10
800 800 self.assertEqual(type(self.obj.value), long)
801 801
802 802
803 803 class IntegerTrait(HasTraits):
804 804 value = Integer(1)
805 805
806 806 class TestInteger(TestLong):
807 807 obj = IntegerTrait()
808 808 _default_value = 1
809 809
810 810 def coerce(self, n):
811 811 return int(n)
812 812
813 813 @skipif(py3compat.PY3, "not relevant on py3")
814 814 def test_cast_small(self):
815 815 """Integer casts small longs to int"""
816 816 if py3compat.PY3:
817 817 raise SkipTest("not relevant on py3")
818 818
819 819 self.obj.value = long(100)
820 820 self.assertEqual(type(self.obj.value), int)
821 821
822 822
823 823 class FloatTrait(HasTraits):
824 824
825 825 value = Float(99.0)
826 826
827 827 class TestFloat(TraitTestBase):
828 828
829 829 obj = FloatTrait()
830 830
831 831 _default_value = 99.0
832 832 _good_values = [10, -10, 10.1, -10.1]
833 833 _bad_values = ['ten', u'ten', [10], {'ten': 10},(10,), None,
834 834 1j, '10', '-10', '10L', '-10L', '10.1', '-10.1', u'10',
835 835 u'-10', u'10L', u'-10L', u'10.1', u'-10.1']
836 836 if not py3compat.PY3:
837 837 _bad_values.extend([long(10), long(-10)])
838 838
839 839
840 840 class ComplexTrait(HasTraits):
841 841
842 842 value = Complex(99.0-99.0j)
843 843
844 844 class TestComplex(TraitTestBase):
845 845
846 846 obj = ComplexTrait()
847 847
848 848 _default_value = 99.0-99.0j
849 849 _good_values = [10, -10, 10.1, -10.1, 10j, 10+10j, 10-10j,
850 850 10.1j, 10.1+10.1j, 10.1-10.1j]
851 851 _bad_values = [u'10L', u'-10L', 'ten', [10], {'ten': 10},(10,), None]
852 852 if not py3compat.PY3:
853 853 _bad_values.extend([long(10), long(-10)])
854 854
855 855
856 856 class BytesTrait(HasTraits):
857 857
858 858 value = Bytes(b'string')
859 859
860 860 class TestBytes(TraitTestBase):
861 861
862 862 obj = BytesTrait()
863 863
864 864 _default_value = b'string'
865 865 _good_values = [b'10', b'-10', b'10L',
866 866 b'-10L', b'10.1', b'-10.1', b'string']
867 867 _bad_values = [10, -10, 10.1, -10.1, 1j, [10],
868 868 ['ten'],{'ten': 10},(10,), None, u'string']
869 869 if not py3compat.PY3:
870 870 _bad_values.extend([long(10), long(-10)])
871 871
872 872
873 873 class UnicodeTrait(HasTraits):
874 874
875 875 value = Unicode(u'unicode')
876 876
877 877 class TestUnicode(TraitTestBase):
878 878
879 879 obj = UnicodeTrait()
880 880
881 881 _default_value = u'unicode'
882 882 _good_values = ['10', '-10', '10L', '-10L', '10.1',
883 883 '-10.1', '', u'', 'string', u'string', u"€"]
884 884 _bad_values = [10, -10, 10.1, -10.1, 1j,
885 885 [10], ['ten'], [u'ten'], {'ten': 10},(10,), None]
886 886 if not py3compat.PY3:
887 887 _bad_values.extend([long(10), long(-10)])
888 888
889 889
890 890 class ObjectNameTrait(HasTraits):
891 891 value = ObjectName("abc")
892 892
893 893 class TestObjectName(TraitTestBase):
894 894 obj = ObjectNameTrait()
895 895
896 896 _default_value = "abc"
897 897 _good_values = ["a", "gh", "g9", "g_", "_G", u"a345_"]
898 898 _bad_values = [1, "", u"€", "9g", "!", "#abc", "aj@", "a.b", "a()", "a[0]",
899 899 None, object(), object]
900 900 if sys.version_info[0] < 3:
901 901 _bad_values.append(u"ΓΎ")
902 902 else:
903 903 _good_values.append(u"ΓΎ") # ΓΎ=1 is valid in Python 3 (PEP 3131).
904 904
905 905
906 906 class DottedObjectNameTrait(HasTraits):
907 907 value = DottedObjectName("a.b")
908 908
909 909 class TestDottedObjectName(TraitTestBase):
910 910 obj = DottedObjectNameTrait()
911 911
912 912 _default_value = "a.b"
913 913 _good_values = ["A", "y.t", "y765.__repr__", "os.path.join", u"os.path.join"]
914 914 _bad_values = [1, u"abc.€", "_.@", ".", ".abc", "abc.", ".abc.", None]
915 915 if sys.version_info[0] < 3:
916 916 _bad_values.append(u"t.ΓΎ")
917 917 else:
918 918 _good_values.append(u"t.ΓΎ")
919 919
920 920
921 921 class TCPAddressTrait(HasTraits):
922 922
923 923 value = TCPAddress()
924 924
925 925 class TestTCPAddress(TraitTestBase):
926 926
927 927 obj = TCPAddressTrait()
928 928
929 929 _default_value = ('127.0.0.1',0)
930 930 _good_values = [('localhost',0),('192.168.0.1',1000),('www.google.com',80)]
931 931 _bad_values = [(0,0),('localhost',10.0),('localhost',-1), None]
932 932
933 933 class ListTrait(HasTraits):
934 934
935 935 value = List(Int)
936 936
937 937 class TestList(TraitTestBase):
938 938
939 939 obj = ListTrait()
940 940
941 941 _default_value = []
942 942 _good_values = [[], [1], list(range(10)), (1,2)]
943 943 _bad_values = [10, [1,'a'], 'a']
944 944
945 945 def coerce(self, value):
946 946 if value is not None:
947 947 value = list(value)
948 948 return value
949 949
950 950 class Foo(object):
951 951 pass
952 952
953 953 class NoneInstanceListTrait(HasTraits):
954 954
955 955 value = List(Instance(Foo, allow_none=False))
956 956
957 957 class TestNoneInstanceList(TraitTestBase):
958 958
959 959 obj = NoneInstanceListTrait()
960 960
961 961 _default_value = []
962 962 _good_values = [[Foo(), Foo()], []]
963 963 _bad_values = [[None], [Foo(), None]]
964 964
965 965
966 966 class InstanceListTrait(HasTraits):
967 967
968 968 value = List(Instance(__name__+'.Foo'))
969 969
970 970 class TestInstanceList(TraitTestBase):
971 971
972 972 obj = InstanceListTrait()
973 973
974 974 def test_klass(self):
975 975 """Test that the instance klass is properly assigned."""
976 976 self.assertIs(self.obj.traits()['value']._trait.klass, Foo)
977 977
978 978 _default_value = []
979 _good_values = [[Foo(), Foo(), None], None]
980 _bad_values = [['1', 2,], '1', [Foo]]
979 _good_values = [[Foo(), Foo(), None], []]
980 _bad_values = [['1', 2,], '1', [Foo], None]
981 981
982 982 class LenListTrait(HasTraits):
983 983
984 984 value = List(Int, [0], minlen=1, maxlen=2)
985 985
986 986 class TestLenList(TraitTestBase):
987 987
988 988 obj = LenListTrait()
989 989
990 990 _default_value = [0]
991 991 _good_values = [[1], [1,2], (1,2)]
992 992 _bad_values = [10, [1,'a'], 'a', [], list(range(3))]
993 993
994 994 def coerce(self, value):
995 995 if value is not None:
996 996 value = list(value)
997 997 return value
998 998
999 999 class TupleTrait(HasTraits):
1000 1000
1001 1001 value = Tuple(Int(allow_none=True))
1002 1002
1003 1003 class TestTupleTrait(TraitTestBase):
1004 1004
1005 1005 obj = TupleTrait()
1006 1006
1007 1007 _default_value = None
1008 1008 _good_values = [(1,), None, (0,), [1], (None,)]
1009 1009 _bad_values = [10, (1,2), ('a'), ()]
1010 1010
1011 1011 def coerce(self, value):
1012 1012 if value is not None:
1013 1013 value = tuple(value)
1014 1014 return value
1015 1015
1016 1016 def test_invalid_args(self):
1017 1017 self.assertRaises(TypeError, Tuple, 5)
1018 1018 self.assertRaises(TypeError, Tuple, default_value='hello')
1019 1019 t = Tuple(Int, CBytes, default_value=(1,5))
1020 1020
1021 1021 class LooseTupleTrait(HasTraits):
1022 1022
1023 1023 value = Tuple((1,2,3))
1024 1024
1025 1025 class TestLooseTupleTrait(TraitTestBase):
1026 1026
1027 1027 obj = LooseTupleTrait()
1028 1028
1029 1029 _default_value = (1,2,3)
1030 1030 _good_values = [(1,), None, [1], (0,), tuple(range(5)), tuple('hello'), ('a',5), ()]
1031 1031 _bad_values = [10, 'hello', {}]
1032 1032
1033 1033 def coerce(self, value):
1034 1034 if value is not None:
1035 1035 value = tuple(value)
1036 1036 return value
1037 1037
1038 1038 def test_invalid_args(self):
1039 1039 self.assertRaises(TypeError, Tuple, 5)
1040 1040 self.assertRaises(TypeError, Tuple, default_value='hello')
1041 1041 t = Tuple(Int, CBytes, default_value=(1,5))
1042 1042
1043 1043
1044 1044 class MultiTupleTrait(HasTraits):
1045 1045
1046 1046 value = Tuple(Int, Bytes, default_value=[99,b'bottles'])
1047 1047
1048 1048 class TestMultiTuple(TraitTestBase):
1049 1049
1050 1050 obj = MultiTupleTrait()
1051 1051
1052 1052 _default_value = (99,b'bottles')
1053 1053 _good_values = [(1,b'a'), (2,b'b')]
1054 1054 _bad_values = ((),10, b'a', (1,b'a',3), (b'a',1), (1, u'a'))
1055 1055
1056 1056 class CRegExpTrait(HasTraits):
1057 1057
1058 1058 value = CRegExp(r'')
1059 1059
1060 1060 class TestCRegExp(TraitTestBase):
1061 1061
1062 1062 def coerce(self, value):
1063 1063 return re.compile(value)
1064 1064
1065 1065 obj = CRegExpTrait()
1066 1066
1067 1067 _default_value = re.compile(r'')
1068 1068 _good_values = [r'\d+', re.compile(r'\d+')]
1069 1069 _bad_values = ['(', None, ()]
1070 1070
1071 1071 class DictTrait(HasTraits):
1072 1072 value = Dict()
1073 1073
1074 1074 def test_dict_assignment():
1075 1075 d = dict()
1076 1076 c = DictTrait()
1077 1077 c.value = d
1078 1078 d['a'] = 5
1079 1079 nt.assert_equal(d, c.value)
1080 1080 nt.assert_true(c.value is d)
1081 1081
1082 1082 def test_dict_default_value():
1083 1083 """Check that the `{}` default value of the Dict traitlet constructor is
1084 1084 actually copied."""
1085 1085
1086 1086 d1, d2 = Dict(), Dict()
1087 1087 nt.assert_false(d1.get_default_value() is d2.get_default_value())
1088 1088
1089 1089
1090 1090 class TestValidationHook(TestCase):
1091 1091
1092 1092 def test_parity_trait(self):
1093 1093 """Verify that the early validation hook is effective"""
1094 1094
1095 1095 class Parity(HasTraits):
1096 1096
1097 1097 value = Int(0)
1098 1098 parity = Enum(['odd', 'even'], default_value='even', allow_none=False)
1099 1099
1100 1100 def _value_validate(self, value, trait):
1101 1101 if self.parity == 'even' and value % 2:
1102 1102 raise TraitError('Expected an even number')
1103 1103 if self.parity == 'odd' and (value % 2 == 0):
1104 1104 raise TraitError('Expected an odd number')
1105 1105 return value
1106 1106
1107 1107 u = Parity()
1108 1108 u.parity = 'odd'
1109 1109 u.value = 1 # OK
1110 1110 with self.assertRaises(TraitError):
1111 1111 u.value = 2 # Trait Error
1112 1112
1113 1113 u.parity = 'even'
1114 1114 u.value = 2 # OK
1115 1115
1116 1116
1117 1117 class TestLink(TestCase):
1118 1118
1119 1119 def test_connect_same(self):
1120 1120 """Verify two traitlets of the same type can be linked together using link."""
1121 1121
1122 1122 # Create two simple classes with Int traitlets.
1123 1123 class A(HasTraits):
1124 1124 value = Int()
1125 1125 a = A(value=9)
1126 1126 b = A(value=8)
1127 1127
1128 1128 # Conenct the two classes.
1129 1129 c = link((a, 'value'), (b, 'value'))
1130 1130
1131 1131 # Make sure the values are the same at the point of linking.
1132 1132 self.assertEqual(a.value, b.value)
1133 1133
1134 1134 # Change one of the values to make sure they stay in sync.
1135 1135 a.value = 5
1136 1136 self.assertEqual(a.value, b.value)
1137 1137 b.value = 6
1138 1138 self.assertEqual(a.value, b.value)
1139 1139
1140 1140 def test_link_different(self):
1141 1141 """Verify two traitlets of different types can be linked together using link."""
1142 1142
1143 1143 # Create two simple classes with Int traitlets.
1144 1144 class A(HasTraits):
1145 1145 value = Int()
1146 1146 class B(HasTraits):
1147 1147 count = Int()
1148 1148 a = A(value=9)
1149 1149 b = B(count=8)
1150 1150
1151 1151 # Conenct the two classes.
1152 1152 c = link((a, 'value'), (b, 'count'))
1153 1153
1154 1154 # Make sure the values are the same at the point of linking.
1155 1155 self.assertEqual(a.value, b.count)
1156 1156
1157 1157 # Change one of the values to make sure they stay in sync.
1158 1158 a.value = 5
1159 1159 self.assertEqual(a.value, b.count)
1160 1160 b.count = 4
1161 1161 self.assertEqual(a.value, b.count)
1162 1162
1163 1163 def test_unlink(self):
1164 1164 """Verify two linked traitlets can be unlinked."""
1165 1165
1166 1166 # Create two simple classes with Int traitlets.
1167 1167 class A(HasTraits):
1168 1168 value = Int()
1169 1169 a = A(value=9)
1170 1170 b = A(value=8)
1171 1171
1172 1172 # Connect the two classes.
1173 1173 c = link((a, 'value'), (b, 'value'))
1174 1174 a.value = 4
1175 1175 c.unlink()
1176 1176
1177 1177 # Change one of the values to make sure they don't stay in sync.
1178 1178 a.value = 5
1179 1179 self.assertNotEqual(a.value, b.value)
1180 1180
1181 1181 def test_callbacks(self):
1182 1182 """Verify two linked traitlets have their callbacks called once."""
1183 1183
1184 1184 # Create two simple classes with Int traitlets.
1185 1185 class A(HasTraits):
1186 1186 value = Int()
1187 1187 class B(HasTraits):
1188 1188 count = Int()
1189 1189 a = A(value=9)
1190 1190 b = B(count=8)
1191 1191
1192 1192 # Register callbacks that count.
1193 1193 callback_count = []
1194 1194 def a_callback(name, old, new):
1195 1195 callback_count.append('a')
1196 1196 a.on_trait_change(a_callback, 'value')
1197 1197 def b_callback(name, old, new):
1198 1198 callback_count.append('b')
1199 1199 b.on_trait_change(b_callback, 'count')
1200 1200
1201 1201 # Connect the two classes.
1202 1202 c = link((a, 'value'), (b, 'count'))
1203 1203
1204 1204 # Make sure b's count was set to a's value once.
1205 1205 self.assertEqual(''.join(callback_count), 'b')
1206 1206 del callback_count[:]
1207 1207
1208 1208 # Make sure a's value was set to b's count once.
1209 1209 b.count = 5
1210 1210 self.assertEqual(''.join(callback_count), 'ba')
1211 1211 del callback_count[:]
1212 1212
1213 1213 # Make sure b's count was set to a's value once.
1214 1214 a.value = 4
1215 1215 self.assertEqual(''.join(callback_count), 'ab')
1216 1216 del callback_count[:]
1217 1217
1218 1218 class TestDirectionalLink(TestCase):
1219 1219 def test_connect_same(self):
1220 1220 """Verify two traitlets of the same type can be linked together using directional_link."""
1221 1221
1222 1222 # Create two simple classes with Int traitlets.
1223 1223 class A(HasTraits):
1224 1224 value = Int()
1225 1225 a = A(value=9)
1226 1226 b = A(value=8)
1227 1227
1228 1228 # Conenct the two classes.
1229 1229 c = directional_link((a, 'value'), (b, 'value'))
1230 1230
1231 1231 # Make sure the values are the same at the point of linking.
1232 1232 self.assertEqual(a.value, b.value)
1233 1233
1234 1234 # Change one the value of the source and check that it synchronizes the target.
1235 1235 a.value = 5
1236 1236 self.assertEqual(b.value, 5)
1237 1237 # Change one the value of the target and check that it has no impact on the source
1238 1238 b.value = 6
1239 1239 self.assertEqual(a.value, 5)
1240 1240
1241 1241 def test_link_different(self):
1242 1242 """Verify two traitlets of different types can be linked together using link."""
1243 1243
1244 1244 # Create two simple classes with Int traitlets.
1245 1245 class A(HasTraits):
1246 1246 value = Int()
1247 1247 class B(HasTraits):
1248 1248 count = Int()
1249 1249 a = A(value=9)
1250 1250 b = B(count=8)
1251 1251
1252 1252 # Conenct the two classes.
1253 1253 c = directional_link((a, 'value'), (b, 'count'))
1254 1254
1255 1255 # Make sure the values are the same at the point of linking.
1256 1256 self.assertEqual(a.value, b.count)
1257 1257
1258 1258 # Change one the value of the source and check that it synchronizes the target.
1259 1259 a.value = 5
1260 1260 self.assertEqual(b.count, 5)
1261 1261 # Change one the value of the target and check that it has no impact on the source
1262 1262 b.value = 6
1263 1263 self.assertEqual(a.value, 5)
1264 1264
1265 1265 def test_unlink(self):
1266 1266 """Verify two linked traitlets can be unlinked."""
1267 1267
1268 1268 # Create two simple classes with Int traitlets.
1269 1269 class A(HasTraits):
1270 1270 value = Int()
1271 1271 a = A(value=9)
1272 1272 b = A(value=8)
1273 1273
1274 1274 # Connect the two classes.
1275 1275 c = directional_link((a, 'value'), (b, 'value'))
1276 1276 a.value = 4
1277 1277 c.unlink()
1278 1278
1279 1279 # Change one of the values to make sure they don't stay in sync.
1280 1280 a.value = 5
1281 1281 self.assertNotEqual(a.value, b.value)
1282 1282
1283 1283 class Pickleable(HasTraits):
1284 1284 i = Int()
1285 1285 j = Int()
1286 1286
1287 1287 def _i_default(self):
1288 1288 return 1
1289 1289
1290 1290 def _i_changed(self, name, old, new):
1291 1291 self.j = new
1292 1292
1293 1293 def test_pickle_hastraits():
1294 1294 c = Pickleable()
1295 1295 for protocol in range(pickle.HIGHEST_PROTOCOL + 1):
1296 1296 p = pickle.dumps(c, protocol)
1297 1297 c2 = pickle.loads(p)
1298 1298 nt.assert_equal(c2.i, c.i)
1299 1299 nt.assert_equal(c2.j, c.j)
1300 1300
1301 1301 c.i = 5
1302 1302 for protocol in range(pickle.HIGHEST_PROTOCOL + 1):
1303 1303 p = pickle.dumps(c, protocol)
1304 1304 c2 = pickle.loads(p)
1305 1305 nt.assert_equal(c2.i, c.i)
1306 1306 nt.assert_equal(c2.j, c.j)
1307 1307
1308 1308 class TestEventful(TestCase):
1309 1309
1310 1310 def test_list(self):
1311 1311 """Does the EventfulList work?"""
1312 1312 event_cache = []
1313 1313
1314 1314 class A(HasTraits):
1315 1315 x = EventfulList([c for c in 'abc'])
1316 1316 a = A()
1317 1317 a.x.on_events(lambda i, x: event_cache.append('insert'), \
1318 1318 lambda i, x: event_cache.append('set'), \
1319 1319 lambda i: event_cache.append('del'), \
1320 1320 lambda: event_cache.append('reverse'), \
1321 1321 lambda *p, **k: event_cache.append('sort'))
1322 1322
1323 1323 a.x.remove('c')
1324 1324 # ab
1325 1325 a.x.insert(0, 'z')
1326 1326 # zab
1327 1327 del a.x[1]
1328 1328 # zb
1329 1329 a.x.reverse()
1330 1330 # bz
1331 1331 a.x[1] = 'o'
1332 1332 # bo
1333 1333 a.x.append('a')
1334 1334 # boa
1335 1335 a.x.sort()
1336 1336 # abo
1337 1337
1338 1338 # Were the correct events captured?
1339 1339 self.assertEqual(event_cache, ['del', 'insert', 'del', 'reverse', 'set', 'set', 'sort'])
1340 1340
1341 1341 # Is the output correct?
1342 1342 self.assertEqual(a.x, [c for c in 'abo'])
1343 1343
1344 1344 def test_dict(self):
1345 1345 """Does the EventfulDict work?"""
1346 1346 event_cache = []
1347 1347
1348 1348 class A(HasTraits):
1349 1349 x = EventfulDict({c: c for c in 'abc'})
1350 1350 a = A()
1351 1351 a.x.on_events(lambda k, v: event_cache.append('add'), \
1352 1352 lambda k, v: event_cache.append('set'), \
1353 1353 lambda k: event_cache.append('del'))
1354 1354
1355 1355 del a.x['c']
1356 1356 # ab
1357 1357 a.x['z'] = 1
1358 1358 # abz
1359 1359 a.x['z'] = 'z'
1360 1360 # abz
1361 1361 a.x.pop('a')
1362 1362 # bz
1363 1363
1364 1364 # Were the correct events captured?
1365 1365 self.assertEqual(event_cache, ['del', 'add', 'set', 'del'])
1366 1366
1367 1367 # Is the output correct?
1368 1368 self.assertEqual(a.x, {c: c for c in 'bz'})
1369 1369
1370 1370 ###
1371 1371 # Traits for Forward Declaration Tests
1372 1372 ###
1373 1373 class ForwardDeclaredInstanceTrait(HasTraits):
1374 1374
1375 1375 value = ForwardDeclaredInstance('ForwardDeclaredBar')
1376 1376
1377 1377 class ForwardDeclaredTypeTrait(HasTraits):
1378 1378
1379 1379 value = ForwardDeclaredType('ForwardDeclaredBar')
1380 1380
1381 1381 class ForwardDeclaredInstanceListTrait(HasTraits):
1382 1382
1383 1383 value = List(ForwardDeclaredInstance('ForwardDeclaredBar'))
1384 1384
1385 1385 class ForwardDeclaredTypeListTrait(HasTraits):
1386 1386
1387 1387 value = List(ForwardDeclaredType('ForwardDeclaredBar'))
1388 1388 ###
1389 1389 # End Traits for Forward Declaration Tests
1390 1390 ###
1391 1391
1392 1392 ###
1393 1393 # Classes for Forward Declaration Tests
1394 1394 ###
1395 1395 class ForwardDeclaredBar(object):
1396 1396 pass
1397 1397
1398 1398 class ForwardDeclaredBarSub(ForwardDeclaredBar):
1399 1399 pass
1400 1400 ###
1401 1401 # End Classes for Forward Declaration Tests
1402 1402 ###
1403 1403
1404 1404 ###
1405 1405 # Forward Declaration Tests
1406 1406 ###
1407 1407 class TestForwardDeclaredInstanceTrait(TraitTestBase):
1408 1408
1409 1409 obj = ForwardDeclaredInstanceTrait()
1410 1410 _default_value = None
1411 1411 _good_values = [None, ForwardDeclaredBar(), ForwardDeclaredBarSub()]
1412 1412 _bad_values = ['foo', 3, ForwardDeclaredBar, ForwardDeclaredBarSub]
1413 1413
1414 1414 class TestForwardDeclaredTypeTrait(TraitTestBase):
1415 1415
1416 1416 obj = ForwardDeclaredTypeTrait()
1417 1417 _default_value = None
1418 1418 _good_values = [None, ForwardDeclaredBar, ForwardDeclaredBarSub]
1419 1419 _bad_values = ['foo', 3, ForwardDeclaredBar(), ForwardDeclaredBarSub()]
1420 1420
1421 1421 class TestForwardDeclaredInstanceList(TraitTestBase):
1422 1422
1423 1423 obj = ForwardDeclaredInstanceListTrait()
1424 1424
1425 1425 def test_klass(self):
1426 1426 """Test that the instance klass is properly assigned."""
1427 1427 self.assertIs(self.obj.traits()['value']._trait.klass, ForwardDeclaredBar)
1428 1428
1429 1429 _default_value = []
1430 1430 _good_values = [
1431 1431 [ForwardDeclaredBar(), ForwardDeclaredBarSub(), None],
1432 1432 [None],
1433 1433 [],
1434 None,
1435 1434 ]
1436 1435 _bad_values = [
1437 1436 ForwardDeclaredBar(),
1438 1437 [ForwardDeclaredBar(), 3],
1439 1438 '1',
1440 1439 # Note that this is the type, not an instance.
1441 [ForwardDeclaredBar]
1440 [ForwardDeclaredBar],
1441 None,
1442 1442 ]
1443 1443
1444 1444 class TestForwardDeclaredTypeList(TraitTestBase):
1445 1445
1446 1446 obj = ForwardDeclaredTypeListTrait()
1447 1447
1448 1448 def test_klass(self):
1449 1449 """Test that the instance klass is properly assigned."""
1450 1450 self.assertIs(self.obj.traits()['value']._trait.klass, ForwardDeclaredBar)
1451 1451
1452 1452 _default_value = []
1453 1453 _good_values = [
1454 1454 [ForwardDeclaredBar, ForwardDeclaredBarSub, None],
1455 1455 [],
1456 1456 [None],
1457 None,
1458 1457 ]
1459 1458 _bad_values = [
1460 1459 ForwardDeclaredBar,
1461 1460 [ForwardDeclaredBar, 3],
1462 1461 '1',
1463 1462 # Note that this is an instance, not the type.
1464 [ForwardDeclaredBar()]
1463 [ForwardDeclaredBar()],
1464 None,
1465 1465 ]
1466 1466 ###
1467 1467 # End Forward Declaration Tests
1468 1468 ###
@@ -1,1734 +1,1732 b''
1 1 # encoding: utf-8
2 2 """
3 3 A lightweight Traits like module.
4 4
5 5 This is designed to provide a lightweight, simple, pure Python version of
6 6 many of the capabilities of enthought.traits. This includes:
7 7
8 8 * Validation
9 9 * Type specification with defaults
10 10 * Static and dynamic notification
11 11 * Basic predefined types
12 12 * An API that is similar to enthought.traits
13 13
14 14 We don't support:
15 15
16 16 * Delegation
17 17 * Automatic GUI generation
18 18 * A full set of trait types. Most importantly, we don't provide container
19 19 traits (list, dict, tuple) that can trigger notifications if their
20 20 contents change.
21 21 * API compatibility with enthought.traits
22 22
23 23 There are also some important difference in our design:
24 24
25 25 * enthought.traits does not validate default values. We do.
26 26
27 27 We choose to create this module because we need these capabilities, but
28 28 we need them to be pure Python so they work in all Python implementations,
29 29 including Jython and IronPython.
30 30
31 31 Inheritance diagram:
32 32
33 33 .. inheritance-diagram:: IPython.utils.traitlets
34 34 :parts: 3
35 35 """
36 36
37 37 # Copyright (c) IPython Development Team.
38 38 # Distributed under the terms of the Modified BSD License.
39 39 #
40 40 # Adapted from enthought.traits, Copyright (c) Enthought, Inc.,
41 41 # also under the terms of the Modified BSD License.
42 42
43 43 import contextlib
44 44 import inspect
45 45 import re
46 46 import sys
47 47 import types
48 48 from types import FunctionType
49 49 try:
50 50 from types import ClassType, InstanceType
51 51 ClassTypes = (ClassType, type)
52 52 except:
53 53 ClassTypes = (type,)
54 54
55 55 from .importstring import import_item
56 56 from IPython.utils import py3compat
57 57 from IPython.utils import eventful
58 58 from IPython.utils.py3compat import iteritems, string_types
59 59 from IPython.testing.skipdoctest import skip_doctest
60 60
61 61 SequenceTypes = (list, tuple, set, frozenset)
62 62
63 63 #-----------------------------------------------------------------------------
64 64 # Basic classes
65 65 #-----------------------------------------------------------------------------
66 66
67 67
68 68 class NoDefaultSpecified ( object ): pass
69 69 NoDefaultSpecified = NoDefaultSpecified()
70 70
71 71
72 72 class Undefined ( object ): pass
73 73 Undefined = Undefined()
74 74
75 75 class TraitError(Exception):
76 76 pass
77 77
78 78 #-----------------------------------------------------------------------------
79 79 # Utilities
80 80 #-----------------------------------------------------------------------------
81 81
82 82
83 83 def class_of ( object ):
84 84 """ Returns a string containing the class name of an object with the
85 85 correct indefinite article ('a' or 'an') preceding it (e.g., 'an Image',
86 86 'a PlotValue').
87 87 """
88 88 if isinstance( object, py3compat.string_types ):
89 89 return add_article( object )
90 90
91 91 return add_article( object.__class__.__name__ )
92 92
93 93
94 94 def add_article ( name ):
95 95 """ Returns a string containing the correct indefinite article ('a' or 'an')
96 96 prefixed to the specified string.
97 97 """
98 98 if name[:1].lower() in 'aeiou':
99 99 return 'an ' + name
100 100
101 101 return 'a ' + name
102 102
103 103
104 104 def repr_type(obj):
105 105 """ Return a string representation of a value and its type for readable
106 106 error messages.
107 107 """
108 108 the_type = type(obj)
109 109 if (not py3compat.PY3) and the_type is InstanceType:
110 110 # Old-style class.
111 111 the_type = obj.__class__
112 112 msg = '%r %r' % (obj, the_type)
113 113 return msg
114 114
115 115
116 116 def is_trait(t):
117 117 """ Returns whether the given value is an instance or subclass of TraitType.
118 118 """
119 119 return (isinstance(t, TraitType) or
120 120 (isinstance(t, type) and issubclass(t, TraitType)))
121 121
122 122
123 123 def parse_notifier_name(name):
124 124 """Convert the name argument to a list of names.
125 125
126 126 Examples
127 127 --------
128 128
129 129 >>> parse_notifier_name('a')
130 130 ['a']
131 131 >>> parse_notifier_name(['a','b'])
132 132 ['a', 'b']
133 133 >>> parse_notifier_name(None)
134 134 ['anytrait']
135 135 """
136 136 if isinstance(name, string_types):
137 137 return [name]
138 138 elif name is None:
139 139 return ['anytrait']
140 140 elif isinstance(name, (list, tuple)):
141 141 for n in name:
142 142 assert isinstance(n, string_types), "names must be strings"
143 143 return name
144 144
145 145
146 146 class _SimpleTest:
147 147 def __init__ ( self, value ): self.value = value
148 148 def __call__ ( self, test ):
149 149 return test == self.value
150 150 def __repr__(self):
151 151 return "<SimpleTest(%r)" % self.value
152 152 def __str__(self):
153 153 return self.__repr__()
154 154
155 155
156 156 def getmembers(object, predicate=None):
157 157 """A safe version of inspect.getmembers that handles missing attributes.
158 158
159 159 This is useful when there are descriptor based attributes that for
160 160 some reason raise AttributeError even though they exist. This happens
161 161 in zope.inteface with the __provides__ attribute.
162 162 """
163 163 results = []
164 164 for key in dir(object):
165 165 try:
166 166 value = getattr(object, key)
167 167 except AttributeError:
168 168 pass
169 169 else:
170 170 if not predicate or predicate(value):
171 171 results.append((key, value))
172 172 results.sort()
173 173 return results
174 174
175 175 def _validate_link(*tuples):
176 176 """Validate arguments for traitlet link functions"""
177 177 for t in tuples:
178 178 if not len(t) == 2:
179 179 raise TypeError("Each linked traitlet must be specified as (HasTraits, 'trait_name'), not %r" % t)
180 180 obj, trait_name = t
181 181 if not isinstance(obj, HasTraits):
182 182 raise TypeError("Each object must be HasTraits, not %r" % type(obj))
183 183 if not trait_name in obj.traits():
184 184 raise TypeError("%r has no trait %r" % (obj, trait_name))
185 185
186 186 @skip_doctest
187 187 class link(object):
188 188 """Link traits from different objects together so they remain in sync.
189 189
190 190 Parameters
191 191 ----------
192 192 *args : pairs of objects/attributes
193 193
194 194 Examples
195 195 --------
196 196
197 197 >>> c = link((obj1, 'value'), (obj2, 'value'), (obj3, 'value'))
198 198 >>> obj1.value = 5 # updates other objects as well
199 199 """
200 200 updating = False
201 201 def __init__(self, *args):
202 202 if len(args) < 2:
203 203 raise TypeError('At least two traitlets must be provided.')
204 204 _validate_link(*args)
205 205
206 206 self.objects = {}
207 207
208 208 initial = getattr(args[0][0], args[0][1])
209 209 for obj, attr in args:
210 210 setattr(obj, attr, initial)
211 211
212 212 callback = self._make_closure(obj, attr)
213 213 obj.on_trait_change(callback, attr)
214 214 self.objects[(obj, attr)] = callback
215 215
216 216 @contextlib.contextmanager
217 217 def _busy_updating(self):
218 218 self.updating = True
219 219 try:
220 220 yield
221 221 finally:
222 222 self.updating = False
223 223
224 224 def _make_closure(self, sending_obj, sending_attr):
225 225 def update(name, old, new):
226 226 self._update(sending_obj, sending_attr, new)
227 227 return update
228 228
229 229 def _update(self, sending_obj, sending_attr, new):
230 230 if self.updating:
231 231 return
232 232 with self._busy_updating():
233 233 for obj, attr in self.objects.keys():
234 234 setattr(obj, attr, new)
235 235
236 236 def unlink(self):
237 237 for key, callback in self.objects.items():
238 238 (obj, attr) = key
239 239 obj.on_trait_change(callback, attr, remove=True)
240 240
241 241 @skip_doctest
242 242 class directional_link(object):
243 243 """Link the trait of a source object with traits of target objects.
244 244
245 245 Parameters
246 246 ----------
247 247 source : pair of object, name
248 248 targets : pairs of objects/attributes
249 249
250 250 Examples
251 251 --------
252 252
253 253 >>> c = directional_link((src, 'value'), (tgt1, 'value'), (tgt2, 'value'))
254 254 >>> src.value = 5 # updates target objects
255 255 >>> tgt1.value = 6 # does not update other objects
256 256 """
257 257 updating = False
258 258
259 259 def __init__(self, source, *targets):
260 260 if len(targets) < 1:
261 261 raise TypeError('At least two traitlets must be provided.')
262 262 _validate_link(source, *targets)
263 263 self.source = source
264 264 self.targets = targets
265 265
266 266 # Update current value
267 267 src_attr_value = getattr(source[0], source[1])
268 268 for obj, attr in targets:
269 269 setattr(obj, attr, src_attr_value)
270 270
271 271 # Wire
272 272 self.source[0].on_trait_change(self._update, self.source[1])
273 273
274 274 @contextlib.contextmanager
275 275 def _busy_updating(self):
276 276 self.updating = True
277 277 try:
278 278 yield
279 279 finally:
280 280 self.updating = False
281 281
282 282 def _update(self, name, old, new):
283 283 if self.updating:
284 284 return
285 285 with self._busy_updating():
286 286 for obj, attr in self.targets:
287 287 setattr(obj, attr, new)
288 288
289 289 def unlink(self):
290 290 self.source[0].on_trait_change(self._update, self.source[1], remove=True)
291 291 self.source = None
292 292 self.targets = []
293 293
294 294 dlink = directional_link
295 295
296 296 #-----------------------------------------------------------------------------
297 297 # Base TraitType for all traits
298 298 #-----------------------------------------------------------------------------
299 299
300 300
301 301 class TraitType(object):
302 302 """A base class for all trait descriptors.
303 303
304 304 Notes
305 305 -----
306 306 Our implementation of traits is based on Python's descriptor
307 307 prototol. This class is the base class for all such descriptors. The
308 308 only magic we use is a custom metaclass for the main :class:`HasTraits`
309 309 class that does the following:
310 310
311 311 1. Sets the :attr:`name` attribute of every :class:`TraitType`
312 312 instance in the class dict to the name of the attribute.
313 313 2. Sets the :attr:`this_class` attribute of every :class:`TraitType`
314 314 instance in the class dict to the *class* that declared the trait.
315 315 This is used by the :class:`This` trait to allow subclasses to
316 316 accept superclasses for :class:`This` values.
317 317 """
318 318
319 319
320 320 metadata = {}
321 321 default_value = Undefined
322 322 allow_none = False
323 323 info_text = 'any value'
324 324
325 325 def __init__(self, default_value=NoDefaultSpecified, allow_none=None, **metadata):
326 326 """Create a TraitType.
327 327 """
328 328 if default_value is not NoDefaultSpecified:
329 329 self.default_value = default_value
330 330 if allow_none is not None:
331 331 self.allow_none = allow_none
332 332
333 333 if len(metadata) > 0:
334 334 if len(self.metadata) > 0:
335 335 self._metadata = self.metadata.copy()
336 336 self._metadata.update(metadata)
337 337 else:
338 338 self._metadata = metadata
339 339 else:
340 340 self._metadata = self.metadata
341 341
342 342 self.init()
343 343
344 344 def init(self):
345 345 pass
346 346
347 347 def get_default_value(self):
348 348 """Create a new instance of the default value."""
349 349 return self.default_value
350 350
351 351 def instance_init(self, obj):
352 352 """This is called by :meth:`HasTraits.__new__` to finish init'ing.
353 353
354 354 Some stages of initialization must be delayed until the parent
355 355 :class:`HasTraits` instance has been created. This method is
356 356 called in :meth:`HasTraits.__new__` after the instance has been
357 357 created.
358 358
359 359 This method trigger the creation and validation of default values
360 360 and also things like the resolution of str given class names in
361 361 :class:`Type` and :class`Instance`.
362 362
363 363 Parameters
364 364 ----------
365 365 obj : :class:`HasTraits` instance
366 366 The parent :class:`HasTraits` instance that has just been
367 367 created.
368 368 """
369 369 self.set_default_value(obj)
370 370
371 371 def set_default_value(self, obj):
372 372 """Set the default value on a per instance basis.
373 373
374 374 This method is called by :meth:`instance_init` to create and
375 375 validate the default value. The creation and validation of
376 376 default values must be delayed until the parent :class:`HasTraits`
377 377 class has been instantiated.
378 378 """
379 379 # Check for a deferred initializer defined in the same class as the
380 380 # trait declaration or above.
381 381 mro = type(obj).mro()
382 382 meth_name = '_%s_default' % self.name
383 383 for cls in mro[:mro.index(self.this_class)+1]:
384 384 if meth_name in cls.__dict__:
385 385 break
386 386 else:
387 387 # We didn't find one. Do static initialization.
388 388 dv = self.get_default_value()
389 389 newdv = self._validate(obj, dv)
390 390 obj._trait_values[self.name] = newdv
391 391 return
392 392 # Complete the dynamic initialization.
393 393 obj._trait_dyn_inits[self.name] = meth_name
394 394
395 395 def __get__(self, obj, cls=None):
396 396 """Get the value of the trait by self.name for the instance.
397 397
398 398 Default values are instantiated when :meth:`HasTraits.__new__`
399 399 is called. Thus by the time this method gets called either the
400 400 default value or a user defined value (they called :meth:`__set__`)
401 401 is in the :class:`HasTraits` instance.
402 402 """
403 403 if obj is None:
404 404 return self
405 405 else:
406 406 try:
407 407 value = obj._trait_values[self.name]
408 408 except KeyError:
409 409 # Check for a dynamic initializer.
410 410 if self.name in obj._trait_dyn_inits:
411 411 method = getattr(obj, obj._trait_dyn_inits[self.name])
412 412 value = method()
413 413 # FIXME: Do we really validate here?
414 414 value = self._validate(obj, value)
415 415 obj._trait_values[self.name] = value
416 416 return value
417 417 else:
418 418 raise TraitError('Unexpected error in TraitType: '
419 419 'both default value and dynamic initializer are '
420 420 'absent.')
421 421 except Exception:
422 422 # HasTraits should call set_default_value to populate
423 423 # this. So this should never be reached.
424 424 raise TraitError('Unexpected error in TraitType: '
425 425 'default value not set properly')
426 426 else:
427 427 return value
428 428
429 429 def __set__(self, obj, value):
430 430 new_value = self._validate(obj, value)
431 431 try:
432 432 old_value = obj._trait_values[self.name]
433 433 except KeyError:
434 434 old_value = None
435 435
436 436 obj._trait_values[self.name] = new_value
437 437 try:
438 438 silent = bool(old_value == new_value)
439 439 except:
440 440 # if there is an error in comparing, default to notify
441 441 silent = False
442 442 if silent is not True:
443 443 # we explicitly compare silent to True just in case the equality
444 444 # comparison above returns something other than True/False
445 445 obj._notify_trait(self.name, old_value, new_value)
446 446
447 447 def _validate(self, obj, value):
448 448 if value is None and self.allow_none:
449 449 return value
450 450 if hasattr(self, 'validate'):
451 451 value = self.validate(obj, value)
452 452 if hasattr(obj, '_%s_validate' % self.name):
453 453 value = getattr(obj, '_%s_validate' % self.name)(value, self)
454 454 return value
455 455
456 456 def __or__(self, other):
457 457 if isinstance(other, Union):
458 458 return Union([self] + other.trait_types)
459 459 else:
460 460 return Union([self, other])
461 461
462 462 def info(self):
463 463 return self.info_text
464 464
465 465 def error(self, obj, value):
466 466 if obj is not None:
467 467 e = "The '%s' trait of %s instance must be %s, but a value of %s was specified." \
468 468 % (self.name, class_of(obj),
469 469 self.info(), repr_type(value))
470 470 else:
471 471 e = "The '%s' trait must be %s, but a value of %r was specified." \
472 472 % (self.name, self.info(), repr_type(value))
473 473 raise TraitError(e)
474 474
475 475 def get_metadata(self, key, default=None):
476 476 return getattr(self, '_metadata', {}).get(key, default)
477 477
478 478 def set_metadata(self, key, value):
479 479 getattr(self, '_metadata', {})[key] = value
480 480
481 481
482 482 #-----------------------------------------------------------------------------
483 483 # The HasTraits implementation
484 484 #-----------------------------------------------------------------------------
485 485
486 486
487 487 class MetaHasTraits(type):
488 488 """A metaclass for HasTraits.
489 489
490 490 This metaclass makes sure that any TraitType class attributes are
491 491 instantiated and sets their name attribute.
492 492 """
493 493
494 494 def __new__(mcls, name, bases, classdict):
495 495 """Create the HasTraits class.
496 496
497 497 This instantiates all TraitTypes in the class dict and sets their
498 498 :attr:`name` attribute.
499 499 """
500 500 # print "MetaHasTraitlets (mcls, name): ", mcls, name
501 501 # print "MetaHasTraitlets (bases): ", bases
502 502 # print "MetaHasTraitlets (classdict): ", classdict
503 503 for k,v in iteritems(classdict):
504 504 if isinstance(v, TraitType):
505 505 v.name = k
506 506 elif inspect.isclass(v):
507 507 if issubclass(v, TraitType):
508 508 vinst = v()
509 509 vinst.name = k
510 510 classdict[k] = vinst
511 511 return super(MetaHasTraits, mcls).__new__(mcls, name, bases, classdict)
512 512
513 513 def __init__(cls, name, bases, classdict):
514 514 """Finish initializing the HasTraits class.
515 515
516 516 This sets the :attr:`this_class` attribute of each TraitType in the
517 517 class dict to the newly created class ``cls``.
518 518 """
519 519 for k, v in iteritems(classdict):
520 520 if isinstance(v, TraitType):
521 521 v.this_class = cls
522 522 super(MetaHasTraits, cls).__init__(name, bases, classdict)
523 523
524 524 class HasTraits(py3compat.with_metaclass(MetaHasTraits, object)):
525 525
526 526 def __new__(cls, *args, **kw):
527 527 # This is needed because object.__new__ only accepts
528 528 # the cls argument.
529 529 new_meth = super(HasTraits, cls).__new__
530 530 if new_meth is object.__new__:
531 531 inst = new_meth(cls)
532 532 else:
533 533 inst = new_meth(cls, **kw)
534 534 inst._trait_values = {}
535 535 inst._trait_notifiers = {}
536 536 inst._trait_dyn_inits = {}
537 537 # Here we tell all the TraitType instances to set their default
538 538 # values on the instance.
539 539 for key in dir(cls):
540 540 # Some descriptors raise AttributeError like zope.interface's
541 541 # __provides__ attributes even though they exist. This causes
542 542 # AttributeErrors even though they are listed in dir(cls).
543 543 try:
544 544 value = getattr(cls, key)
545 545 except AttributeError:
546 546 pass
547 547 else:
548 548 if isinstance(value, TraitType):
549 549 value.instance_init(inst)
550 550
551 551 return inst
552 552
553 553 def __init__(self, *args, **kw):
554 554 # Allow trait values to be set using keyword arguments.
555 555 # We need to use setattr for this to trigger validation and
556 556 # notifications.
557 557 for key, value in iteritems(kw):
558 558 setattr(self, key, value)
559 559
560 560 def _notify_trait(self, name, old_value, new_value):
561 561
562 562 # First dynamic ones
563 563 callables = []
564 564 callables.extend(self._trait_notifiers.get(name,[]))
565 565 callables.extend(self._trait_notifiers.get('anytrait',[]))
566 566
567 567 # Now static ones
568 568 try:
569 569 cb = getattr(self, '_%s_changed' % name)
570 570 except:
571 571 pass
572 572 else:
573 573 callables.append(cb)
574 574
575 575 # Call them all now
576 576 for c in callables:
577 577 # Traits catches and logs errors here. I allow them to raise
578 578 if callable(c):
579 579 argspec = inspect.getargspec(c)
580 580 nargs = len(argspec[0])
581 581 # Bound methods have an additional 'self' argument
582 582 # I don't know how to treat unbound methods, but they
583 583 # can't really be used for callbacks.
584 584 if isinstance(c, types.MethodType):
585 585 offset = -1
586 586 else:
587 587 offset = 0
588 588 if nargs + offset == 0:
589 589 c()
590 590 elif nargs + offset == 1:
591 591 c(name)
592 592 elif nargs + offset == 2:
593 593 c(name, new_value)
594 594 elif nargs + offset == 3:
595 595 c(name, old_value, new_value)
596 596 else:
597 597 raise TraitError('a trait changed callback '
598 598 'must have 0-3 arguments.')
599 599 else:
600 600 raise TraitError('a trait changed callback '
601 601 'must be callable.')
602 602
603 603
604 604 def _add_notifiers(self, handler, name):
605 605 if name not in self._trait_notifiers:
606 606 nlist = []
607 607 self._trait_notifiers[name] = nlist
608 608 else:
609 609 nlist = self._trait_notifiers[name]
610 610 if handler not in nlist:
611 611 nlist.append(handler)
612 612
613 613 def _remove_notifiers(self, handler, name):
614 614 if name in self._trait_notifiers:
615 615 nlist = self._trait_notifiers[name]
616 616 try:
617 617 index = nlist.index(handler)
618 618 except ValueError:
619 619 pass
620 620 else:
621 621 del nlist[index]
622 622
623 623 def on_trait_change(self, handler, name=None, remove=False):
624 624 """Setup a handler to be called when a trait changes.
625 625
626 626 This is used to setup dynamic notifications of trait changes.
627 627
628 628 Static handlers can be created by creating methods on a HasTraits
629 629 subclass with the naming convention '_[traitname]_changed'. Thus,
630 630 to create static handler for the trait 'a', create the method
631 631 _a_changed(self, name, old, new) (fewer arguments can be used, see
632 632 below).
633 633
634 634 Parameters
635 635 ----------
636 636 handler : callable
637 637 A callable that is called when a trait changes. Its
638 638 signature can be handler(), handler(name), handler(name, new)
639 639 or handler(name, old, new).
640 640 name : list, str, None
641 641 If None, the handler will apply to all traits. If a list
642 642 of str, handler will apply to all names in the list. If a
643 643 str, the handler will apply just to that name.
644 644 remove : bool
645 645 If False (the default), then install the handler. If True
646 646 then unintall it.
647 647 """
648 648 if remove:
649 649 names = parse_notifier_name(name)
650 650 for n in names:
651 651 self._remove_notifiers(handler, n)
652 652 else:
653 653 names = parse_notifier_name(name)
654 654 for n in names:
655 655 self._add_notifiers(handler, n)
656 656
657 657 @classmethod
658 658 def class_trait_names(cls, **metadata):
659 659 """Get a list of all the names of this class' traits.
660 660
661 661 This method is just like the :meth:`trait_names` method,
662 662 but is unbound.
663 663 """
664 664 return cls.class_traits(**metadata).keys()
665 665
666 666 @classmethod
667 667 def class_traits(cls, **metadata):
668 668 """Get a `dict` of all the traits of this class. The dictionary
669 669 is keyed on the name and the values are the TraitType objects.
670 670
671 671 This method is just like the :meth:`traits` method, but is unbound.
672 672
673 673 The TraitTypes returned don't know anything about the values
674 674 that the various HasTrait's instances are holding.
675 675
676 676 The metadata kwargs allow functions to be passed in which
677 677 filter traits based on metadata values. The functions should
678 678 take a single value as an argument and return a boolean. If
679 679 any function returns False, then the trait is not included in
680 680 the output. This does not allow for any simple way of
681 681 testing that a metadata name exists and has any
682 682 value because get_metadata returns None if a metadata key
683 683 doesn't exist.
684 684 """
685 685 traits = dict([memb for memb in getmembers(cls) if
686 686 isinstance(memb[1], TraitType)])
687 687
688 688 if len(metadata) == 0:
689 689 return traits
690 690
691 691 for meta_name, meta_eval in metadata.items():
692 692 if type(meta_eval) is not FunctionType:
693 693 metadata[meta_name] = _SimpleTest(meta_eval)
694 694
695 695 result = {}
696 696 for name, trait in traits.items():
697 697 for meta_name, meta_eval in metadata.items():
698 698 if not meta_eval(trait.get_metadata(meta_name)):
699 699 break
700 700 else:
701 701 result[name] = trait
702 702
703 703 return result
704 704
705 705 def trait_names(self, **metadata):
706 706 """Get a list of all the names of this class' traits."""
707 707 return self.traits(**metadata).keys()
708 708
709 709 def traits(self, **metadata):
710 710 """Get a `dict` of all the traits of this class. The dictionary
711 711 is keyed on the name and the values are the TraitType objects.
712 712
713 713 The TraitTypes returned don't know anything about the values
714 714 that the various HasTrait's instances are holding.
715 715
716 716 The metadata kwargs allow functions to be passed in which
717 717 filter traits based on metadata values. The functions should
718 718 take a single value as an argument and return a boolean. If
719 719 any function returns False, then the trait is not included in
720 720 the output. This does not allow for any simple way of
721 721 testing that a metadata name exists and has any
722 722 value because get_metadata returns None if a metadata key
723 723 doesn't exist.
724 724 """
725 725 traits = dict([memb for memb in getmembers(self.__class__) if
726 726 isinstance(memb[1], TraitType)])
727 727
728 728 if len(metadata) == 0:
729 729 return traits
730 730
731 731 for meta_name, meta_eval in metadata.items():
732 732 if type(meta_eval) is not FunctionType:
733 733 metadata[meta_name] = _SimpleTest(meta_eval)
734 734
735 735 result = {}
736 736 for name, trait in traits.items():
737 737 for meta_name, meta_eval in metadata.items():
738 738 if not meta_eval(trait.get_metadata(meta_name)):
739 739 break
740 740 else:
741 741 result[name] = trait
742 742
743 743 return result
744 744
745 745 def trait_metadata(self, traitname, key, default=None):
746 746 """Get metadata values for trait by key."""
747 747 try:
748 748 trait = getattr(self.__class__, traitname)
749 749 except AttributeError:
750 750 raise TraitError("Class %s does not have a trait named %s" %
751 751 (self.__class__.__name__, traitname))
752 752 else:
753 753 return trait.get_metadata(key, default)
754 754
755 755 #-----------------------------------------------------------------------------
756 756 # Actual TraitTypes implementations/subclasses
757 757 #-----------------------------------------------------------------------------
758 758
759 759 #-----------------------------------------------------------------------------
760 760 # TraitTypes subclasses for handling classes and instances of classes
761 761 #-----------------------------------------------------------------------------
762 762
763 763
764 764 class ClassBasedTraitType(TraitType):
765 765 """
766 766 A trait with error reporting and string -> type resolution for Type,
767 767 Instance and This.
768 768 """
769 769
770 770 def _resolve_string(self, string):
771 771 """
772 772 Resolve a string supplied for a type into an actual object.
773 773 """
774 774 return import_item(string)
775 775
776 776 def error(self, obj, value):
777 777 kind = type(value)
778 778 if (not py3compat.PY3) and kind is InstanceType:
779 779 msg = 'class %s' % value.__class__.__name__
780 780 else:
781 781 msg = '%s (i.e. %s)' % ( str( kind )[1:-1], repr( value ) )
782 782
783 783 if obj is not None:
784 784 e = "The '%s' trait of %s instance must be %s, but a value of %s was specified." \
785 785 % (self.name, class_of(obj),
786 786 self.info(), msg)
787 787 else:
788 788 e = "The '%s' trait must be %s, but a value of %r was specified." \
789 789 % (self.name, self.info(), msg)
790 790
791 791 raise TraitError(e)
792 792
793 793
794 794 class Type(ClassBasedTraitType):
795 795 """A trait whose value must be a subclass of a specified class."""
796 796
797 797 def __init__ (self, default_value=None, klass=None, allow_none=True, **metadata ):
798 798 """Construct a Type trait
799 799
800 800 A Type trait specifies that its values must be subclasses of
801 801 a particular class.
802 802
803 803 If only ``default_value`` is given, it is used for the ``klass`` as
804 804 well.
805 805
806 806 Parameters
807 807 ----------
808 808 default_value : class, str or None
809 809 The default value must be a subclass of klass. If an str,
810 810 the str must be a fully specified class name, like 'foo.bar.Bah'.
811 811 The string is resolved into real class, when the parent
812 812 :class:`HasTraits` class is instantiated.
813 813 klass : class, str, None
814 814 Values of this trait must be a subclass of klass. The klass
815 815 may be specified in a string like: 'foo.bar.MyClass'.
816 816 The string is resolved into real class, when the parent
817 817 :class:`HasTraits` class is instantiated.
818 allow_none : boolean
818 allow_none : bool [ default True ]
819 819 Indicates whether None is allowed as an assignable value. Even if
820 820 ``False``, the default value may be ``None``.
821 821 """
822 822 if default_value is None:
823 823 if klass is None:
824 824 klass = object
825 825 elif klass is None:
826 826 klass = default_value
827 827
828 828 if not (inspect.isclass(klass) or isinstance(klass, py3compat.string_types)):
829 829 raise TraitError("A Type trait must specify a class.")
830 830
831 831 self.klass = klass
832 832
833 833 super(Type, self).__init__(default_value, allow_none=allow_none, **metadata)
834 834
835 835 def validate(self, obj, value):
836 836 """Validates that the value is a valid object instance."""
837 837 if isinstance(value, py3compat.string_types):
838 838 try:
839 839 value = self._resolve_string(value)
840 840 except ImportError:
841 841 raise TraitError("The '%s' trait of %s instance must be a type, but "
842 842 "%r could not be imported" % (self.name, obj, value))
843 843 try:
844 844 if issubclass(value, self.klass):
845 845 return value
846 846 except:
847 847 pass
848 848
849 849 self.error(obj, value)
850 850
851 851 def info(self):
852 852 """ Returns a description of the trait."""
853 853 if isinstance(self.klass, py3compat.string_types):
854 854 klass = self.klass
855 855 else:
856 856 klass = self.klass.__name__
857 857 result = 'a subclass of ' + klass
858 858 if self.allow_none:
859 859 return result + ' or None'
860 860 return result
861 861
862 862 def instance_init(self, obj):
863 863 self._resolve_classes()
864 864 super(Type, self).instance_init(obj)
865 865
866 866 def _resolve_classes(self):
867 867 if isinstance(self.klass, py3compat.string_types):
868 868 self.klass = self._resolve_string(self.klass)
869 869 if isinstance(self.default_value, py3compat.string_types):
870 870 self.default_value = self._resolve_string(self.default_value)
871 871
872 872 def get_default_value(self):
873 873 return self.default_value
874 874
875 875
876 876 class DefaultValueGenerator(object):
877 877 """A class for generating new default value instances."""
878 878
879 879 def __init__(self, *args, **kw):
880 880 self.args = args
881 881 self.kw = kw
882 882
883 883 def generate(self, klass):
884 884 return klass(*self.args, **self.kw)
885 885
886 886
887 887 class Instance(ClassBasedTraitType):
888 888 """A trait whose value must be an instance of a specified class.
889 889
890 890 The value can also be an instance of a subclass of the specified class.
891 891
892 892 Subclasses can declare default classes by overriding the klass attribute
893 893 """
894 894
895 895 klass = None
896 896
897 897 def __init__(self, klass=None, args=None, kw=None,
898 898 allow_none=True, **metadata ):
899 899 """Construct an Instance trait.
900 900
901 901 This trait allows values that are instances of a particular
902 902 class or its subclasses. Our implementation is quite different
903 903 from that of enthough.traits as we don't allow instances to be used
904 904 for klass and we handle the ``args`` and ``kw`` arguments differently.
905 905
906 906 Parameters
907 907 ----------
908 908 klass : class, str
909 909 The class that forms the basis for the trait. Class names
910 910 can also be specified as strings, like 'foo.bar.Bar'.
911 911 args : tuple
912 912 Positional arguments for generating the default value.
913 913 kw : dict
914 914 Keyword arguments for generating the default value.
915 allow_none : bool
915 allow_none : bool [default True]
916 916 Indicates whether None is allowed as a value.
917 917
918 918 Notes
919 919 -----
920 920 If both ``args`` and ``kw`` are None, then the default value is None.
921 921 If ``args`` is a tuple and ``kw`` is a dict, then the default is
922 922 created as ``klass(*args, **kw)``. If exactly one of ``args`` or ``kw`` is
923 923 None, the None is replaced by ``()`` or ``{}``, respectively.
924 924 """
925 925 if klass is None:
926 926 klass = self.klass
927 927
928 928 if (klass is not None) and (inspect.isclass(klass) or isinstance(klass, py3compat.string_types)):
929 929 self.klass = klass
930 930 else:
931 931 raise TraitError('The klass attribute must be a class'
932 932 ' not: %r' % klass)
933 933
934 934 # self.klass is a class, so handle default_value
935 935 if args is None and kw is None:
936 936 default_value = None
937 937 else:
938 938 if args is None:
939 939 # kw is not None
940 940 args = ()
941 941 elif kw is None:
942 942 # args is not None
943 943 kw = {}
944 944
945 945 if not isinstance(kw, dict):
946 946 raise TraitError("The 'kw' argument must be a dict or None.")
947 947 if not isinstance(args, tuple):
948 948 raise TraitError("The 'args' argument must be a tuple or None.")
949 949
950 950 default_value = DefaultValueGenerator(*args, **kw)
951 951
952 952 super(Instance, self).__init__(default_value, allow_none=allow_none, **metadata)
953 953
954 954 def validate(self, obj, value):
955 955 if isinstance(value, self.klass):
956 956 return value
957 957 else:
958 958 self.error(obj, value)
959 959
960 960 def info(self):
961 961 if isinstance(self.klass, py3compat.string_types):
962 962 klass = self.klass
963 963 else:
964 964 klass = self.klass.__name__
965 965 result = class_of(klass)
966 966 if self.allow_none:
967 967 return result + ' or None'
968 968
969 969 return result
970 970
971 971 def instance_init(self, obj):
972 972 self._resolve_classes()
973 973 super(Instance, self).instance_init(obj)
974 974
975 975 def _resolve_classes(self):
976 976 if isinstance(self.klass, py3compat.string_types):
977 977 self.klass = self._resolve_string(self.klass)
978 978
979 979 def get_default_value(self):
980 980 """Instantiate a default value instance.
981 981
982 982 This is called when the containing HasTraits classes'
983 983 :meth:`__new__` method is called to ensure that a unique instance
984 984 is created for each HasTraits instance.
985 985 """
986 986 dv = self.default_value
987 987 if isinstance(dv, DefaultValueGenerator):
988 988 return dv.generate(self.klass)
989 989 else:
990 990 return dv
991 991
992 992
993 993 class ForwardDeclaredMixin(object):
994 994 """
995 995 Mixin for forward-declared versions of Instance and Type.
996 996 """
997 997 def _resolve_string(self, string):
998 998 """
999 999 Find the specified class name by looking for it in the module in which
1000 1000 our this_class attribute was defined.
1001 1001 """
1002 1002 modname = self.this_class.__module__
1003 1003 return import_item('.'.join([modname, string]))
1004 1004
1005 1005
1006 1006 class ForwardDeclaredType(ForwardDeclaredMixin, Type):
1007 1007 """
1008 1008 Forward-declared version of Type.
1009 1009 """
1010 1010 pass
1011 1011
1012 1012
1013 1013 class ForwardDeclaredInstance(ForwardDeclaredMixin, Instance):
1014 1014 """
1015 1015 Forward-declared version of Instance.
1016 1016 """
1017 1017 pass
1018 1018
1019 1019
1020 1020 class This(ClassBasedTraitType):
1021 1021 """A trait for instances of the class containing this trait.
1022 1022
1023 1023 Because how how and when class bodies are executed, the ``This``
1024 1024 trait can only have a default value of None. This, and because we
1025 1025 always validate default values, ``allow_none`` is *always* true.
1026 1026 """
1027 1027
1028 1028 info_text = 'an instance of the same type as the receiver or None'
1029 1029
1030 1030 def __init__(self, **metadata):
1031 1031 super(This, self).__init__(None, **metadata)
1032 1032
1033 1033 def validate(self, obj, value):
1034 1034 # What if value is a superclass of obj.__class__? This is
1035 1035 # complicated if it was the superclass that defined the This
1036 1036 # trait.
1037 1037 if isinstance(value, self.this_class) or (value is None):
1038 1038 return value
1039 1039 else:
1040 1040 self.error(obj, value)
1041 1041
1042 1042
1043 1043 class Union(TraitType):
1044 1044 """A trait type representing a Union type."""
1045 1045
1046 1046 def __init__(self, trait_types, **metadata):
1047 1047 """Construct a Union trait.
1048 1048
1049 1049 This trait allows values that are allowed by at least one of the
1050 1050 specified trait types. A Union traitlet cannot have metadata on
1051 1051 its own, besides the metadata of the listed types.
1052 1052
1053 1053 Parameters
1054 1054 ----------
1055 1055 trait_types: sequence
1056 1056 The list of trait types of length at least 1.
1057 1057
1058 1058 Notes
1059 1059 -----
1060 1060 Union([Float(), Bool(), Int()]) attempts to validate the provided values
1061 1061 with the validation function of Float, then Bool, and finally Int.
1062 1062 """
1063 1063 self.trait_types = trait_types
1064 1064 self.info_text = " or ".join([tt.info_text for tt in self.trait_types])
1065 1065 self.default_value = self.trait_types[0].get_default_value()
1066 1066 super(Union, self).__init__(**metadata)
1067 1067
1068 1068 def _resolve_classes(self):
1069 1069 for trait_type in self.trait_types:
1070 1070 trait_type.name = self.name
1071 1071 trait_type.this_class = self.this_class
1072 1072 if hasattr(trait_type, '_resolve_classes'):
1073 1073 trait_type._resolve_classes()
1074 1074
1075 1075 def instance_init(self, obj):
1076 1076 self._resolve_classes()
1077 1077 super(Union, self).instance_init(obj)
1078 1078
1079 1079 def validate(self, obj, value):
1080 1080 for trait_type in self.trait_types:
1081 1081 try:
1082 1082 v = trait_type._validate(obj, value)
1083 1083 self._metadata = trait_type._metadata
1084 1084 return v
1085 1085 except TraitError:
1086 1086 continue
1087 1087 self.error(obj, value)
1088 1088
1089 1089 def __or__(self, other):
1090 1090 if isinstance(other, Union):
1091 1091 return Union(self.trait_types + other.trait_types)
1092 1092 else:
1093 1093 return Union(self.trait_types + [other])
1094 1094
1095 1095 #-----------------------------------------------------------------------------
1096 1096 # Basic TraitTypes implementations/subclasses
1097 1097 #-----------------------------------------------------------------------------
1098 1098
1099 1099
1100 1100 class Any(TraitType):
1101 1101 default_value = None
1102 1102 info_text = 'any value'
1103 1103
1104 1104
1105 1105 class Int(TraitType):
1106 1106 """An int trait."""
1107 1107
1108 1108 default_value = 0
1109 1109 info_text = 'an int'
1110 1110
1111 1111 def validate(self, obj, value):
1112 1112 if isinstance(value, int):
1113 1113 return value
1114 1114 self.error(obj, value)
1115 1115
1116 1116 class CInt(Int):
1117 1117 """A casting version of the int trait."""
1118 1118
1119 1119 def validate(self, obj, value):
1120 1120 try:
1121 1121 return int(value)
1122 1122 except:
1123 1123 self.error(obj, value)
1124 1124
1125 1125 if py3compat.PY3:
1126 1126 Long, CLong = Int, CInt
1127 1127 Integer = Int
1128 1128 else:
1129 1129 class Long(TraitType):
1130 1130 """A long integer trait."""
1131 1131
1132 1132 default_value = 0
1133 1133 info_text = 'a long'
1134 1134
1135 1135 def validate(self, obj, value):
1136 1136 if isinstance(value, long):
1137 1137 return value
1138 1138 if isinstance(value, int):
1139 1139 return long(value)
1140 1140 self.error(obj, value)
1141 1141
1142 1142
1143 1143 class CLong(Long):
1144 1144 """A casting version of the long integer trait."""
1145 1145
1146 1146 def validate(self, obj, value):
1147 1147 try:
1148 1148 return long(value)
1149 1149 except:
1150 1150 self.error(obj, value)
1151 1151
1152 1152 class Integer(TraitType):
1153 1153 """An integer trait.
1154 1154
1155 1155 Longs that are unnecessary (<= sys.maxint) are cast to ints."""
1156 1156
1157 1157 default_value = 0
1158 1158 info_text = 'an integer'
1159 1159
1160 1160 def validate(self, obj, value):
1161 1161 if isinstance(value, int):
1162 1162 return value
1163 1163 if isinstance(value, long):
1164 1164 # downcast longs that fit in int:
1165 1165 # note that int(n > sys.maxint) returns a long, so
1166 1166 # we don't need a condition on this cast
1167 1167 return int(value)
1168 1168 if sys.platform == "cli":
1169 1169 from System import Int64
1170 1170 if isinstance(value, Int64):
1171 1171 return int(value)
1172 1172 self.error(obj, value)
1173 1173
1174 1174
1175 1175 class Float(TraitType):
1176 1176 """A float trait."""
1177 1177
1178 1178 default_value = 0.0
1179 1179 info_text = 'a float'
1180 1180
1181 1181 def validate(self, obj, value):
1182 1182 if isinstance(value, float):
1183 1183 return value
1184 1184 if isinstance(value, int):
1185 1185 return float(value)
1186 1186 self.error(obj, value)
1187 1187
1188 1188
1189 1189 class CFloat(Float):
1190 1190 """A casting version of the float trait."""
1191 1191
1192 1192 def validate(self, obj, value):
1193 1193 try:
1194 1194 return float(value)
1195 1195 except:
1196 1196 self.error(obj, value)
1197 1197
1198 1198 class Complex(TraitType):
1199 1199 """A trait for complex numbers."""
1200 1200
1201 1201 default_value = 0.0 + 0.0j
1202 1202 info_text = 'a complex number'
1203 1203
1204 1204 def validate(self, obj, value):
1205 1205 if isinstance(value, complex):
1206 1206 return value
1207 1207 if isinstance(value, (float, int)):
1208 1208 return complex(value)
1209 1209 self.error(obj, value)
1210 1210
1211 1211
1212 1212 class CComplex(Complex):
1213 1213 """A casting version of the complex number trait."""
1214 1214
1215 1215 def validate (self, obj, value):
1216 1216 try:
1217 1217 return complex(value)
1218 1218 except:
1219 1219 self.error(obj, value)
1220 1220
1221 1221 # We should always be explicit about whether we're using bytes or unicode, both
1222 1222 # for Python 3 conversion and for reliable unicode behaviour on Python 2. So
1223 1223 # we don't have a Str type.
1224 1224 class Bytes(TraitType):
1225 1225 """A trait for byte strings."""
1226 1226
1227 1227 default_value = b''
1228 1228 info_text = 'a bytes object'
1229 1229
1230 1230 def validate(self, obj, value):
1231 1231 if isinstance(value, bytes):
1232 1232 return value
1233 1233 self.error(obj, value)
1234 1234
1235 1235
1236 1236 class CBytes(Bytes):
1237 1237 """A casting version of the byte string trait."""
1238 1238
1239 1239 def validate(self, obj, value):
1240 1240 try:
1241 1241 return bytes(value)
1242 1242 except:
1243 1243 self.error(obj, value)
1244 1244
1245 1245
1246 1246 class Unicode(TraitType):
1247 1247 """A trait for unicode strings."""
1248 1248
1249 1249 default_value = u''
1250 1250 info_text = 'a unicode string'
1251 1251
1252 1252 def validate(self, obj, value):
1253 1253 if isinstance(value, py3compat.unicode_type):
1254 1254 return value
1255 1255 if isinstance(value, bytes):
1256 1256 try:
1257 1257 return value.decode('ascii', 'strict')
1258 1258 except UnicodeDecodeError:
1259 1259 msg = "Could not decode {!r} for unicode trait '{}' of {} instance."
1260 1260 raise TraitError(msg.format(value, self.name, class_of(obj)))
1261 1261 self.error(obj, value)
1262 1262
1263 1263
1264 1264 class CUnicode(Unicode):
1265 1265 """A casting version of the unicode trait."""
1266 1266
1267 1267 def validate(self, obj, value):
1268 1268 try:
1269 1269 return py3compat.unicode_type(value)
1270 1270 except:
1271 1271 self.error(obj, value)
1272 1272
1273 1273
1274 1274 class ObjectName(TraitType):
1275 1275 """A string holding a valid object name in this version of Python.
1276 1276
1277 1277 This does not check that the name exists in any scope."""
1278 1278 info_text = "a valid object identifier in Python"
1279 1279
1280 1280 if py3compat.PY3:
1281 1281 # Python 3:
1282 1282 coerce_str = staticmethod(lambda _,s: s)
1283 1283
1284 1284 else:
1285 1285 # Python 2:
1286 1286 def coerce_str(self, obj, value):
1287 1287 "In Python 2, coerce ascii-only unicode to str"
1288 1288 if isinstance(value, unicode):
1289 1289 try:
1290 1290 return str(value)
1291 1291 except UnicodeEncodeError:
1292 1292 self.error(obj, value)
1293 1293 return value
1294 1294
1295 1295 def validate(self, obj, value):
1296 1296 value = self.coerce_str(obj, value)
1297 1297
1298 1298 if isinstance(value, string_types) and py3compat.isidentifier(value):
1299 1299 return value
1300 1300 self.error(obj, value)
1301 1301
1302 1302 class DottedObjectName(ObjectName):
1303 1303 """A string holding a valid dotted object name in Python, such as A.b3._c"""
1304 1304 def validate(self, obj, value):
1305 1305 value = self.coerce_str(obj, value)
1306 1306
1307 1307 if isinstance(value, string_types) and py3compat.isidentifier(value, dotted=True):
1308 1308 return value
1309 1309 self.error(obj, value)
1310 1310
1311 1311
1312 1312 class Bool(TraitType):
1313 1313 """A boolean (True, False) trait."""
1314 1314
1315 1315 default_value = False
1316 1316 info_text = 'a boolean'
1317 1317
1318 1318 def validate(self, obj, value):
1319 1319 if isinstance(value, bool):
1320 1320 return value
1321 1321 self.error(obj, value)
1322 1322
1323 1323
1324 1324 class CBool(Bool):
1325 1325 """A casting version of the boolean trait."""
1326 1326
1327 1327 def validate(self, obj, value):
1328 1328 try:
1329 1329 return bool(value)
1330 1330 except:
1331 1331 self.error(obj, value)
1332 1332
1333 1333
1334 1334 class Enum(TraitType):
1335 1335 """An enum that whose value must be in a given sequence."""
1336 1336
1337 def __init__(self, values, default_value=None, allow_none=True, **metadata):
1337 def __init__(self, values, default_value=None, **metadata):
1338 1338 self.values = values
1339 super(Enum, self).__init__(default_value, allow_none=allow_none, **metadata)
1339 super(Enum, self).__init__(default_value, **metadata)
1340 1340
1341 1341 def validate(self, obj, value):
1342 1342 if value in self.values:
1343 1343 return value
1344 1344 self.error(obj, value)
1345 1345
1346 1346 def info(self):
1347 1347 """ Returns a description of the trait."""
1348 1348 result = 'any of ' + repr(self.values)
1349 1349 if self.allow_none:
1350 1350 return result + ' or None'
1351 1351 return result
1352 1352
1353 1353 class CaselessStrEnum(Enum):
1354 1354 """An enum of strings that are caseless in validate."""
1355 1355
1356 1356 def validate(self, obj, value):
1357 1357 if not isinstance(value, py3compat.string_types):
1358 1358 self.error(obj, value)
1359 1359
1360 1360 for v in self.values:
1361 1361 if v.lower() == value.lower():
1362 1362 return v
1363 1363 self.error(obj, value)
1364 1364
1365 1365 class Container(Instance):
1366 1366 """An instance of a container (list, set, etc.)
1367 1367
1368 1368 To be subclassed by overriding klass.
1369 1369 """
1370 1370 klass = None
1371 1371 _cast_types = ()
1372 1372 _valid_defaults = SequenceTypes
1373 1373 _trait = None
1374 1374
1375 def __init__(self, trait=None, default_value=None, allow_none=True,
1375 def __init__(self, trait=None, default_value=None, allow_none=False,
1376 1376 **metadata):
1377 1377 """Create a container trait type from a list, set, or tuple.
1378 1378
1379 1379 The default value is created by doing ``List(default_value)``,
1380 1380 which creates a copy of the ``default_value``.
1381 1381
1382 1382 ``trait`` can be specified, which restricts the type of elements
1383 1383 in the container to that TraitType.
1384 1384
1385 1385 If only one arg is given and it is not a Trait, it is taken as
1386 1386 ``default_value``:
1387 1387
1388 1388 ``c = List([1,2,3])``
1389 1389
1390 1390 Parameters
1391 1391 ----------
1392 1392
1393 1393 trait : TraitType [ optional ]
1394 1394 the type for restricting the contents of the Container. If unspecified,
1395 1395 types are not checked.
1396 1396
1397 1397 default_value : SequenceType [ optional ]
1398 1398 The default value for the Trait. Must be list/tuple/set, and
1399 1399 will be cast to the container type.
1400 1400
1401 allow_none : Bool [ default True ]
1401 allow_none : bool [ default False ]
1402 1402 Whether to allow the value to be None
1403 1403
1404 1404 **metadata : any
1405 1405 further keys for extensions to the Trait (e.g. config)
1406 1406
1407 1407 """
1408 1408 # allow List([values]):
1409 1409 if default_value is None and not is_trait(trait):
1410 1410 default_value = trait
1411 1411 trait = None
1412 1412
1413 1413 if default_value is None:
1414 1414 args = ()
1415 1415 elif isinstance(default_value, self._valid_defaults):
1416 1416 args = (default_value,)
1417 1417 else:
1418 1418 raise TypeError('default value of %s was %s' %(self.__class__.__name__, default_value))
1419 1419
1420 1420 if is_trait(trait):
1421 1421 self._trait = trait() if isinstance(trait, type) else trait
1422 1422 self._trait.name = 'element'
1423 1423 elif trait is not None:
1424 1424 raise TypeError("`trait` must be a Trait or None, got %s"%repr_type(trait))
1425 1425
1426 1426 super(Container,self).__init__(klass=self.klass, args=args,
1427 1427 allow_none=allow_none, **metadata)
1428 1428
1429 1429 def element_error(self, obj, element, validator):
1430 1430 e = "Element of the '%s' trait of %s instance must be %s, but a value of %s was specified." \
1431 1431 % (self.name, class_of(obj), validator.info(), repr_type(element))
1432 1432 raise TraitError(e)
1433 1433
1434 1434 def validate(self, obj, value):
1435 1435 if isinstance(value, self._cast_types):
1436 1436 value = self.klass(value)
1437 1437 value = super(Container, self).validate(obj, value)
1438 1438 if value is None:
1439 1439 return value
1440 1440
1441 1441 value = self.validate_elements(obj, value)
1442 1442
1443 1443 return value
1444 1444
1445 1445 def validate_elements(self, obj, value):
1446 1446 validated = []
1447 1447 if self._trait is None or isinstance(self._trait, Any):
1448 1448 return value
1449 1449 for v in value:
1450 1450 try:
1451 1451 v = self._trait._validate(obj, v)
1452 1452 except TraitError:
1453 1453 self.element_error(obj, v, self._trait)
1454 1454 else:
1455 1455 validated.append(v)
1456 1456 return self.klass(validated)
1457 1457
1458 1458 def instance_init(self, obj):
1459 1459 if isinstance(self._trait, TraitType):
1460 1460 self._trait.this_class = self.this_class
1461 1461 if hasattr(self._trait, '_resolve_classes'):
1462 1462 self._trait._resolve_classes()
1463 1463 super(Container, self).instance_init(obj)
1464 1464
1465 1465
1466 1466 class List(Container):
1467 1467 """An instance of a Python list."""
1468 1468 klass = list
1469 1469 _cast_types = (tuple,)
1470 1470
1471 def __init__(self, trait=None, default_value=None, minlen=0, maxlen=sys.maxsize,
1472 allow_none=True, **metadata):
1471 def __init__(self, trait=None, default_value=None, minlen=0, maxlen=sys.maxsize, **metadata):
1473 1472 """Create a List trait type from a list, set, or tuple.
1474 1473
1475 1474 The default value is created by doing ``List(default_value)``,
1476 1475 which creates a copy of the ``default_value``.
1477 1476
1478 1477 ``trait`` can be specified, which restricts the type of elements
1479 1478 in the container to that TraitType.
1480 1479
1481 1480 If only one arg is given and it is not a Trait, it is taken as
1482 1481 ``default_value``:
1483 1482
1484 1483 ``c = List([1,2,3])``
1485 1484
1486 1485 Parameters
1487 1486 ----------
1488 1487
1489 1488 trait : TraitType [ optional ]
1490 1489 the type for restricting the contents of the Container. If unspecified,
1491 1490 types are not checked.
1492 1491
1493 1492 default_value : SequenceType [ optional ]
1494 1493 The default value for the Trait. Must be list/tuple/set, and
1495 1494 will be cast to the container type.
1496 1495
1497 1496 minlen : Int [ default 0 ]
1498 1497 The minimum length of the input list
1499 1498
1500 1499 maxlen : Int [ default sys.maxsize ]
1501 1500 The maximum length of the input list
1502 1501
1503 allow_none : Bool [ default True ]
1502 allow_none : bool [ default False ]
1504 1503 Whether to allow the value to be None
1505 1504
1506 1505 **metadata : any
1507 1506 further keys for extensions to the Trait (e.g. config)
1508 1507
1509 1508 """
1510 1509 self._minlen = minlen
1511 1510 self._maxlen = maxlen
1512 1511 super(List, self).__init__(trait=trait, default_value=default_value,
1513 allow_none=allow_none, **metadata)
1512 **metadata)
1514 1513
1515 1514 def length_error(self, obj, value):
1516 1515 e = "The '%s' trait of %s instance must be of length %i <= L <= %i, but a value of %s was specified." \
1517 1516 % (self.name, class_of(obj), self._minlen, self._maxlen, value)
1518 1517 raise TraitError(e)
1519 1518
1520 1519 def validate_elements(self, obj, value):
1521 1520 length = len(value)
1522 1521 if length < self._minlen or length > self._maxlen:
1523 1522 self.length_error(obj, value)
1524 1523
1525 1524 return super(List, self).validate_elements(obj, value)
1526 1525
1527 1526 def validate(self, obj, value):
1528 1527 value = super(List, self).validate(obj, value)
1529 1528
1530 1529 value = self.validate_elements(obj, value)
1531 1530
1532 1531 return value
1533 1532
1534 1533
1535 1534
1536 1535 class Set(List):
1537 1536 """An instance of a Python set."""
1538 1537 klass = set
1539 1538 _cast_types = (tuple, list)
1540 1539
1541 1540 class Tuple(Container):
1542 1541 """An instance of a Python tuple."""
1543 1542 klass = tuple
1544 1543 _cast_types = (list,)
1545 1544
1546 1545 def __init__(self, *traits, **metadata):
1547 """Tuple(*traits, default_value=None, allow_none=True, **medatata)
1546 """Tuple(*traits, default_value=None, **medatata)
1548 1547
1549 1548 Create a tuple from a list, set, or tuple.
1550 1549
1551 1550 Create a fixed-type tuple with Traits:
1552 1551
1553 1552 ``t = Tuple(Int, Str, CStr)``
1554 1553
1555 1554 would be length 3, with Int,Str,CStr for each element.
1556 1555
1557 1556 If only one arg is given and it is not a Trait, it is taken as
1558 1557 default_value:
1559 1558
1560 1559 ``t = Tuple((1,2,3))``
1561 1560
1562 1561 Otherwise, ``default_value`` *must* be specified by keyword.
1563 1562
1564 1563 Parameters
1565 1564 ----------
1566 1565
1567 1566 *traits : TraitTypes [ optional ]
1568 1567 the tsype for restricting the contents of the Tuple. If unspecified,
1569 1568 types are not checked. If specified, then each positional argument
1570 1569 corresponds to an element of the tuple. Tuples defined with traits
1571 1570 are of fixed length.
1572 1571
1573 1572 default_value : SequenceType [ optional ]
1574 1573 The default value for the Tuple. Must be list/tuple/set, and
1575 1574 will be cast to a tuple. If `traits` are specified, the
1576 1575 `default_value` must conform to the shape and type they specify.
1577 1576
1578 allow_none : Bool [ default True ]
1577 allow_none : bool [ default False ]
1579 1578 Whether to allow the value to be None
1580 1579
1581 1580 **metadata : any
1582 1581 further keys for extensions to the Trait (e.g. config)
1583 1582
1584 1583 """
1585 1584 default_value = metadata.pop('default_value', None)
1586 1585 allow_none = metadata.pop('allow_none', True)
1587 1586
1588 1587 # allow Tuple((values,)):
1589 1588 if len(traits) == 1 and default_value is None and not is_trait(traits[0]):
1590 1589 default_value = traits[0]
1591 1590 traits = ()
1592 1591
1593 1592 if default_value is None:
1594 1593 args = ()
1595 1594 elif isinstance(default_value, self._valid_defaults):
1596 1595 args = (default_value,)
1597 1596 else:
1598 1597 raise TypeError('default value of %s was %s' %(self.__class__.__name__, default_value))
1599 1598
1600 1599 self._traits = []
1601 1600 for trait in traits:
1602 1601 t = trait() if isinstance(trait, type) else trait
1603 1602 t.name = 'element'
1604 1603 self._traits.append(t)
1605 1604
1606 1605 if self._traits and default_value is None:
1607 1606 # don't allow default to be an empty container if length is specified
1608 1607 args = None
1609 super(Container,self).__init__(klass=self.klass, args=args,
1610 allow_none=allow_none, **metadata)
1608 super(Container,self).__init__(klass=self.klass, args=args, **metadata)
1611 1609
1612 1610 def validate_elements(self, obj, value):
1613 1611 if not self._traits:
1614 1612 # nothing to validate
1615 1613 return value
1616 1614 if len(value) != len(self._traits):
1617 1615 e = "The '%s' trait of %s instance requires %i elements, but a value of %s was specified." \
1618 1616 % (self.name, class_of(obj), len(self._traits), repr_type(value))
1619 1617 raise TraitError(e)
1620 1618
1621 1619 validated = []
1622 1620 for t,v in zip(self._traits, value):
1623 1621 try:
1624 1622 v = t._validate(obj, v)
1625 1623 except TraitError:
1626 1624 self.element_error(obj, v, t)
1627 1625 else:
1628 1626 validated.append(v)
1629 1627 return tuple(validated)
1630 1628
1631 1629 def instance_init(self, obj):
1632 1630 for trait in self._traits:
1633 1631 if isinstance(trait, TraitType):
1634 1632 trait.this_class = self.this_class
1635 1633 if hasattr(trait, '_resolve_classes'):
1636 1634 trait._resolve_classes()
1637 1635 super(Container, self).instance_init(obj)
1638 1636
1639 1637
1640 1638 class Dict(Instance):
1641 1639 """An instance of a Python dict."""
1642 1640
1643 def __init__(self, default_value={}, allow_none=True, **metadata):
1641 def __init__(self, default_value={}, allow_none=False, **metadata):
1644 1642 """Create a dict trait type from a dict.
1645 1643
1646 1644 The default value is created by doing ``dict(default_value)``,
1647 1645 which creates a copy of the ``default_value``.
1648 1646 """
1649 1647 if default_value is None:
1650 1648 args = None
1651 1649 elif isinstance(default_value, dict):
1652 1650 args = (default_value,)
1653 1651 elif isinstance(default_value, SequenceTypes):
1654 1652 args = (default_value,)
1655 1653 else:
1656 1654 raise TypeError('default value of Dict was %s' % default_value)
1657 1655
1658 1656 super(Dict,self).__init__(klass=dict, args=args,
1659 1657 allow_none=allow_none, **metadata)
1660 1658
1661 1659
1662 1660 class EventfulDict(Instance):
1663 1661 """An instance of an EventfulDict."""
1664 1662
1665 def __init__(self, default_value={}, allow_none=True, **metadata):
1663 def __init__(self, default_value={}, allow_none=False, **metadata):
1666 1664 """Create a EventfulDict trait type from a dict.
1667 1665
1668 1666 The default value is created by doing
1669 1667 ``eventful.EvenfulDict(default_value)``, which creates a copy of the
1670 1668 ``default_value``.
1671 1669 """
1672 1670 if default_value is None:
1673 1671 args = None
1674 1672 elif isinstance(default_value, dict):
1675 1673 args = (default_value,)
1676 1674 elif isinstance(default_value, SequenceTypes):
1677 1675 args = (default_value,)
1678 1676 else:
1679 1677 raise TypeError('default value of EventfulDict was %s' % default_value)
1680 1678
1681 1679 super(EventfulDict, self).__init__(klass=eventful.EventfulDict, args=args,
1682 1680 allow_none=allow_none, **metadata)
1683 1681
1684 1682
1685 1683 class EventfulList(Instance):
1686 1684 """An instance of an EventfulList."""
1687 1685
1688 def __init__(self, default_value=None, allow_none=True, **metadata):
1686 def __init__(self, default_value=None, allow_none=False, **metadata):
1689 1687 """Create a EventfulList trait type from a dict.
1690 1688
1691 1689 The default value is created by doing
1692 1690 ``eventful.EvenfulList(default_value)``, which creates a copy of the
1693 1691 ``default_value``.
1694 1692 """
1695 1693 if default_value is None:
1696 1694 args = ((),)
1697 1695 else:
1698 1696 args = (default_value,)
1699 1697
1700 1698 super(EventfulList, self).__init__(klass=eventful.EventfulList, args=args,
1701 1699 allow_none=allow_none, **metadata)
1702 1700
1703 1701
1704 1702 class TCPAddress(TraitType):
1705 1703 """A trait for an (ip, port) tuple.
1706 1704
1707 1705 This allows for both IPv4 IP addresses as well as hostnames.
1708 1706 """
1709 1707
1710 1708 default_value = ('127.0.0.1', 0)
1711 1709 info_text = 'an (ip, port) tuple'
1712 1710
1713 1711 def validate(self, obj, value):
1714 1712 if isinstance(value, tuple):
1715 1713 if len(value) == 2:
1716 1714 if isinstance(value[0], py3compat.string_types) and isinstance(value[1], int):
1717 1715 port = value[1]
1718 1716 if port >= 0 and port <= 65535:
1719 1717 return value
1720 1718 self.error(obj, value)
1721 1719
1722 1720 class CRegExp(TraitType):
1723 1721 """A casting compiled regular expression trait.
1724 1722
1725 1723 Accepts both strings and compiled regular expressions. The resulting
1726 1724 attribute will be a compiled regular expression."""
1727 1725
1728 1726 info_text = 'a regular expression'
1729 1727
1730 1728 def validate(self, obj, value):
1731 1729 try:
1732 1730 return re.compile(value)
1733 1731 except:
1734 1732 self.error(obj, value)
General Comments 0
You need to be logged in to leave comments. Login now