##// END OF EJS Templates
Various Python 3 fixes in IPython.utils
Thomas Kluyver -
Show More
@@ -1,161 +1,164 b''
1 1 """Utilities to manipulate JSON objects.
2 2 """
3 3 #-----------------------------------------------------------------------------
4 4 # Copyright (C) 2010 The IPython Development Team
5 5 #
6 6 # Distributed under the terms of the BSD License. The full license is in
7 7 # the file COPYING.txt, distributed as part of this software.
8 8 #-----------------------------------------------------------------------------
9 9
10 10 #-----------------------------------------------------------------------------
11 11 # Imports
12 12 #-----------------------------------------------------------------------------
13 13 # stdlib
14 14 import re
15 15 import sys
16 16 import types
17 17 from datetime import datetime
18 18
19 from IPython.utils import py3compat
20 next_attr_name = '__next__' if py3compat.PY3 else 'next'
21
19 22 #-----------------------------------------------------------------------------
20 23 # Globals and constants
21 24 #-----------------------------------------------------------------------------
22 25
23 26 # timestamp formats
24 27 ISO8601="%Y-%m-%dT%H:%M:%S.%f"
25 28 ISO8601_PAT=re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d+$")
26 29
27 30 #-----------------------------------------------------------------------------
28 31 # Classes and functions
29 32 #-----------------------------------------------------------------------------
30 33
31 34 def rekey(dikt):
32 35 """Rekey a dict that has been forced to use str keys where there should be
33 36 ints by json."""
34 37 for k in dikt.iterkeys():
35 38 if isinstance(k, basestring):
36 39 ik=fk=None
37 40 try:
38 41 ik = int(k)
39 42 except ValueError:
40 43 try:
41 44 fk = float(k)
42 45 except ValueError:
43 46 continue
44 47 if ik is not None:
45 48 nk = ik
46 49 else:
47 50 nk = fk
48 51 if nk in dikt:
49 52 raise KeyError("already have key %r"%nk)
50 53 dikt[nk] = dikt.pop(k)
51 54 return dikt
52 55
53 56
54 57 def extract_dates(obj):
55 58 """extract ISO8601 dates from unpacked JSON"""
56 59 if isinstance(obj, dict):
57 60 obj = dict(obj) # don't clobber
58 61 for k,v in obj.iteritems():
59 62 obj[k] = extract_dates(v)
60 63 elif isinstance(obj, (list, tuple)):
61 64 obj = [ extract_dates(o) for o in obj ]
62 65 elif isinstance(obj, basestring):
63 66 if ISO8601_PAT.match(obj):
64 67 obj = datetime.strptime(obj, ISO8601)
65 68 return obj
66 69
67 70 def squash_dates(obj):
68 71 """squash datetime objects into ISO8601 strings"""
69 72 if isinstance(obj, dict):
70 73 obj = dict(obj) # don't clobber
71 74 for k,v in obj.iteritems():
72 75 obj[k] = squash_dates(v)
73 76 elif isinstance(obj, (list, tuple)):
74 77 obj = [ squash_dates(o) for o in obj ]
75 78 elif isinstance(obj, datetime):
76 79 obj = obj.strftime(ISO8601)
77 80 return obj
78 81
79 82 def date_default(obj):
80 83 """default function for packing datetime objects in JSON."""
81 84 if isinstance(obj, datetime):
82 85 return obj.strftime(ISO8601)
83 86 else:
84 87 raise TypeError("%r is not JSON serializable"%obj)
85 88
86 89
87 90
88 91 def json_clean(obj):
89 92 """Clean an object to ensure it's safe to encode in JSON.
90 93
91 94 Atomic, immutable objects are returned unmodified. Sets and tuples are
92 95 converted to lists, lists are copied and dicts are also copied.
93 96
94 97 Note: dicts whose keys could cause collisions upon encoding (such as a dict
95 98 with both the number 1 and the string '1' as keys) will cause a ValueError
96 99 to be raised.
97 100
98 101 Parameters
99 102 ----------
100 103 obj : any python object
101 104
102 105 Returns
103 106 -------
104 107 out : object
105 108
106 109 A version of the input which will not cause an encoding error when
107 110 encoded as JSON. Note that this function does not *encode* its inputs,
108 111 it simply sanitizes it so that there will be no encoding errors later.
109 112
110 113 Examples
111 114 --------
112 115 >>> json_clean(4)
113 116 4
114 117 >>> json_clean(range(10))
115 118 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
116 119 >>> json_clean(dict(x=1, y=2))
117 120 {'y': 2, 'x': 1}
118 121 >>> json_clean(dict(x=1, y=2, z=[1,2,3]))
119 122 {'y': 2, 'x': 1, 'z': [1, 2, 3]}
120 123 >>> json_clean(True)
121 124 True
122 125 """
123 126 # types that are 'atomic' and ok in json as-is. bool doesn't need to be
124 127 # listed explicitly because bools pass as int instances
125 128 atomic_ok = (unicode, int, float, types.NoneType)
126 129
127 130 # containers that we need to convert into lists
128 131 container_to_list = (tuple, set, types.GeneratorType)
129 132
130 133 if isinstance(obj, atomic_ok):
131 134 return obj
132 135
133 136 if isinstance(obj, bytes):
134 137 return obj.decode(sys.getdefaultencoding(), 'replace')
135 138
136 139 if isinstance(obj, container_to_list) or (
137 hasattr(obj, '__iter__') and hasattr(obj, 'next')):
140 hasattr(obj, '__iter__') and hasattr(obj, next_attr_name)):
138 141 obj = list(obj)
139 142
140 143 if isinstance(obj, list):
141 144 return [json_clean(x) for x in obj]
142 145
143 146 if isinstance(obj, dict):
144 147 # First, validate that the dict won't lose data in conversion due to
145 148 # key collisions after stringification. This can happen with keys like
146 149 # True and 'true' or 1 and '1', which collide in JSON.
147 150 nkeys = len(obj)
148 151 nkeys_collapsed = len(set(map(str, obj)))
149 152 if nkeys != nkeys_collapsed:
150 153 raise ValueError('dict can not be safely converted to JSON: '
151 154 'key collision would lead to dropped values')
152 155 # If all OK, proceed by making the new dict that will be json-safe
153 156 out = {}
154 157 for k,v in obj.iteritems():
155 158 out[str(k)] = json_clean(v)
156 159 return out
157 160
158 161 # If we get here, we don't know how to handle the object, so we just get
159 162 # its repr and return that. This will catch lambdas, open sockets, class
160 163 # objects, and any other complicated contraption that json can't encode
161 164 return repr(obj)
@@ -1,147 +1,147 b''
1 1 # encoding: utf-8
2 2 """
3 3 Utilities for working with external processes.
4 4 """
5 5
6 6 #-----------------------------------------------------------------------------
7 7 # Copyright (C) 2008-2009 The IPython Development Team
8 8 #
9 9 # Distributed under the terms of the BSD License. The full license is in
10 10 # the file COPYING, distributed as part of this software.
11 11 #-----------------------------------------------------------------------------
12 12
13 13 #-----------------------------------------------------------------------------
14 14 # Imports
15 15 #-----------------------------------------------------------------------------
16 16 from __future__ import print_function
17 17
18 18 # Stdlib
19 19 import os
20 20 import sys
21 21 import shlex
22 22
23 23 # Our own
24 24 if sys.platform == 'win32':
25 25 from ._process_win32 import _find_cmd, system, getoutput, AvoidUNCPath
26 26 else:
27 27 from ._process_posix import _find_cmd, system, getoutput
28 28
29 29 from ._process_common import getoutputerror
30 30 from IPython.utils import py3compat
31 31
32 32 #-----------------------------------------------------------------------------
33 33 # Code
34 34 #-----------------------------------------------------------------------------
35 35
36 36
37 37 class FindCmdError(Exception):
38 38 pass
39 39
40 40
41 41 def find_cmd(cmd):
42 42 """Find absolute path to executable cmd in a cross platform manner.
43 43
44 44 This function tries to determine the full path to a command line program
45 45 using `which` on Unix/Linux/OS X and `win32api` on Windows. Most of the
46 46 time it will use the version that is first on the users `PATH`. If
47 47 cmd is `python` return `sys.executable`.
48 48
49 49 Warning, don't use this to find IPython command line programs as there
50 50 is a risk you will find the wrong one. Instead find those using the
51 51 following code and looking for the application itself::
52 52
53 53 from IPython.utils.path import get_ipython_module_path
54 54 from IPython.utils.process import pycmd2argv
55 55 argv = pycmd2argv(get_ipython_module_path('IPython.frontend.terminal.ipapp'))
56 56
57 57 Parameters
58 58 ----------
59 59 cmd : str
60 60 The command line program to look for.
61 61 """
62 62 if cmd == 'python':
63 63 return os.path.abspath(sys.executable)
64 64 try:
65 65 path = _find_cmd(cmd).rstrip()
66 66 except OSError:
67 67 raise FindCmdError('command could not be found: %s' % cmd)
68 68 # which returns empty if not found
69 if path == '':
69 if path == b'':
70 70 raise FindCmdError('command could not be found: %s' % cmd)
71 71 return os.path.abspath(path)
72 72
73 73
74 74 def pycmd2argv(cmd):
75 75 r"""Take the path of a python command and return a list (argv-style).
76 76
77 77 This only works on Python based command line programs and will find the
78 78 location of the ``python`` executable using ``sys.executable`` to make
79 79 sure the right version is used.
80 80
81 81 For a given path ``cmd``, this returns [cmd] if cmd's extension is .exe,
82 82 .com or .bat, and [, cmd] otherwise.
83 83
84 84 Parameters
85 85 ----------
86 86 cmd : string
87 87 The path of the command.
88 88
89 89 Returns
90 90 -------
91 91 argv-style list.
92 92 """
93 93 ext = os.path.splitext(cmd)[1]
94 94 if ext in ['.exe', '.com', '.bat']:
95 95 return [cmd]
96 96 else:
97 97 if sys.platform == 'win32':
98 98 # The -u option here turns on unbuffered output, which is required
99 99 # on Win32 to prevent wierd conflict and problems with Twisted.
100 100 # Also, use sys.executable to make sure we are picking up the
101 101 # right python exe.
102 102 return [sys.executable, '-u', cmd]
103 103 else:
104 104 return [sys.executable, cmd]
105 105
106 106
107 107 def arg_split(s, posix=False):
108 108 """Split a command line's arguments in a shell-like manner.
109 109
110 110 This is a modified version of the standard library's shlex.split()
111 111 function, but with a default of posix=False for splitting, so that quotes
112 112 in inputs are respected."""
113 113
114 114 # Unfortunately, python's shlex module is buggy with unicode input:
115 115 # http://bugs.python.org/issue1170
116 116 # At least encoding the input when it's unicode seems to help, but there
117 117 # may be more problems lurking. Apparently this is fixed in python3.
118 118 is_unicode = False
119 119 if (not py3compat.PY3) and isinstance(s, unicode):
120 120 is_unicode = True
121 121 s = s.encode('utf-8')
122 122 lex = shlex.shlex(s, posix=posix)
123 123 lex.whitespace_split = True
124 124 tokens = list(lex)
125 125 if is_unicode:
126 126 # Convert the tokens back to unicode.
127 127 tokens = [x.decode('utf-8') for x in tokens]
128 128 return tokens
129 129
130 130
131 131 def abbrev_cwd():
132 132 """ Return abbreviated version of cwd, e.g. d:mydir """
133 133 cwd = os.getcwdu().replace('\\','/')
134 134 drivepart = ''
135 135 tail = cwd
136 136 if sys.platform == 'win32':
137 137 if len(cwd) < 4:
138 138 return cwd
139 139 drivepart,tail = os.path.splitdrive(cwd)
140 140
141 141
142 142 parts = tail.split('/')
143 143 if len(parts) > 2:
144 144 tail = '/'.join(parts[-2:])
145 145
146 146 return (drivepart + (
147 147 cwd == '/' and '/' or tail))
@@ -1,443 +1,450 b''
1 1 # encoding: utf-8
2 2 """Tests for IPython.utils.path.py"""
3 3
4 4 #-----------------------------------------------------------------------------
5 5 # Copyright (C) 2008 The IPython Development Team
6 6 #
7 7 # Distributed under the terms of the BSD License. The full license is in
8 8 # the file COPYING, distributed as part of this software.
9 9 #-----------------------------------------------------------------------------
10 10
11 11 #-----------------------------------------------------------------------------
12 12 # Imports
13 13 #-----------------------------------------------------------------------------
14 14
15 15 from __future__ import with_statement
16 16
17 17 import os
18 18 import shutil
19 19 import sys
20 20 import tempfile
21 21 import StringIO
22 22
23 23 from os.path import join, abspath, split
24 24
25 25 import nose.tools as nt
26 26
27 27 from nose import with_setup
28 28
29 29 import IPython
30 30 from IPython.testing import decorators as dec
31 31 from IPython.testing.decorators import skip_if_not_win32, skip_win32
32 32 from IPython.testing.tools import make_tempfile
33 33 from IPython.utils import path, io
34 from IPython.utils import py3compat
34 35
35 36 # Platform-dependent imports
36 37 try:
37 38 import _winreg as wreg
38 39 except ImportError:
39 40 #Fake _winreg module on none windows platforms
40 import new
41 sys.modules["_winreg"] = new.module("_winreg")
41 import types
42 wr_name = "winreg" if py3compat.PY3 else "_winreg"
43 sys.modules[wr_name] = types.ModuleType(wr_name)
42 44 import _winreg as wreg
43 45 #Add entries that needs to be stubbed by the testing code
44 46 (wreg.OpenKey, wreg.QueryValueEx,) = (None, None)
47
48 try:
49 reload
50 except NameError: # Python 3
51 from imp import reload
45 52
46 53 #-----------------------------------------------------------------------------
47 54 # Globals
48 55 #-----------------------------------------------------------------------------
49 56 env = os.environ
50 57 TEST_FILE_PATH = split(abspath(__file__))[0]
51 58 TMP_TEST_DIR = tempfile.mkdtemp()
52 59 HOME_TEST_DIR = join(TMP_TEST_DIR, "home_test_dir")
53 60 XDG_TEST_DIR = join(HOME_TEST_DIR, "xdg_test_dir")
54 61 IP_TEST_DIR = join(HOME_TEST_DIR,'.ipython')
55 62 #
56 63 # Setup/teardown functions/decorators
57 64 #
58 65
59 66 def setup():
60 67 """Setup testenvironment for the module:
61 68
62 69 - Adds dummy home dir tree
63 70 """
64 71 # Do not mask exceptions here. In particular, catching WindowsError is a
65 72 # problem because that exception is only defined on Windows...
66 73 os.makedirs(IP_TEST_DIR)
67 74 os.makedirs(os.path.join(XDG_TEST_DIR, 'ipython'))
68 75
69 76
70 77 def teardown():
71 78 """Teardown testenvironment for the module:
72 79
73 80 - Remove dummy home dir tree
74 81 """
75 82 # Note: we remove the parent test dir, which is the root of all test
76 83 # subdirs we may have created. Use shutil instead of os.removedirs, so
77 84 # that non-empty directories are all recursively removed.
78 85 shutil.rmtree(TMP_TEST_DIR)
79 86
80 87
81 88 def setup_environment():
82 89 """Setup testenvironment for some functions that are tested
83 90 in this module. In particular this functions stores attributes
84 91 and other things that we need to stub in some test functions.
85 92 This needs to be done on a function level and not module level because
86 93 each testfunction needs a pristine environment.
87 94 """
88 95 global oldstuff, platformstuff
89 96 oldstuff = (env.copy(), os.name, path.get_home_dir, IPython.__file__, os.getcwd())
90 97
91 98 if os.name == 'nt':
92 99 platformstuff = (wreg.OpenKey, wreg.QueryValueEx,)
93 100
94 101
95 102 def teardown_environment():
96 103 """Restore things that were remebered by the setup_environment function
97 104 """
98 105 (oldenv, os.name, path.get_home_dir, IPython.__file__, old_wd) = oldstuff
99 106 os.chdir(old_wd)
100 107 reload(path)
101 108
102 109 for key in env.keys():
103 110 if key not in oldenv:
104 111 del env[key]
105 112 env.update(oldenv)
106 113 if hasattr(sys, 'frozen'):
107 114 del sys.frozen
108 115 if os.name == 'nt':
109 116 (wreg.OpenKey, wreg.QueryValueEx,) = platformstuff
110 117
111 118 # Build decorator that uses the setup_environment/setup_environment
112 119 with_environment = with_setup(setup_environment, teardown_environment)
113 120
114 121
115 122 @skip_if_not_win32
116 123 @with_environment
117 124 def test_get_home_dir_1():
118 125 """Testcase for py2exe logic, un-compressed lib
119 126 """
120 127 sys.frozen = True
121 128
122 129 #fake filename for IPython.__init__
123 130 IPython.__file__ = abspath(join(HOME_TEST_DIR, "Lib/IPython/__init__.py"))
124 131
125 132 home_dir = path.get_home_dir()
126 133 nt.assert_equal(home_dir, abspath(HOME_TEST_DIR))
127 134
128 135
129 136 @skip_if_not_win32
130 137 @with_environment
131 138 def test_get_home_dir_2():
132 139 """Testcase for py2exe logic, compressed lib
133 140 """
134 141 sys.frozen = True
135 142 #fake filename for IPython.__init__
136 143 IPython.__file__ = abspath(join(HOME_TEST_DIR, "Library.zip/IPython/__init__.py")).lower()
137 144
138 145 home_dir = path.get_home_dir()
139 146 nt.assert_equal(home_dir, abspath(HOME_TEST_DIR).lower())
140 147
141 148
142 149 @with_environment
143 150 @skip_win32
144 151 def test_get_home_dir_3():
145 152 """Testcase $HOME is set, then use its value as home directory."""
146 153 env["HOME"] = HOME_TEST_DIR
147 154 home_dir = path.get_home_dir()
148 155 nt.assert_equal(home_dir, env["HOME"])
149 156
150 157
151 158 @with_environment
152 159 @skip_win32
153 160 def test_get_home_dir_4():
154 161 """Testcase $HOME is not set, os=='posix'.
155 162 This should fail with HomeDirError"""
156 163
157 164 os.name = 'posix'
158 165 if 'HOME' in env: del env['HOME']
159 166 nt.assert_raises(path.HomeDirError, path.get_home_dir)
160 167
161 168
162 169 @skip_if_not_win32
163 170 @with_environment
164 171 def test_get_home_dir_5():
165 172 """Using HOMEDRIVE + HOMEPATH, os=='nt'.
166 173
167 174 HOMESHARE is missing.
168 175 """
169 176
170 177 os.name = 'nt'
171 178 env.pop('HOMESHARE', None)
172 179 env['HOMEDRIVE'], env['HOMEPATH'] = os.path.splitdrive(HOME_TEST_DIR)
173 180 home_dir = path.get_home_dir()
174 181 nt.assert_equal(home_dir, abspath(HOME_TEST_DIR))
175 182
176 183
177 184 @skip_if_not_win32
178 185 @with_environment
179 186 def test_get_home_dir_6():
180 187 """Using USERPROFILE, os=='nt'.
181 188
182 189 HOMESHARE, HOMEDRIVE, HOMEPATH are missing.
183 190 """
184 191
185 192 os.name = 'nt'
186 193 env.pop('HOMESHARE', None)
187 194 env.pop('HOMEDRIVE', None)
188 195 env.pop('HOMEPATH', None)
189 196 env["USERPROFILE"] = abspath(HOME_TEST_DIR)
190 197 home_dir = path.get_home_dir()
191 198 nt.assert_equal(home_dir, abspath(HOME_TEST_DIR))
192 199
193 200
194 201 @skip_if_not_win32
195 202 @with_environment
196 203 def test_get_home_dir_7():
197 204 """Using HOMESHARE, os=='nt'."""
198 205
199 206 os.name = 'nt'
200 207 env["HOMESHARE"] = abspath(HOME_TEST_DIR)
201 208 home_dir = path.get_home_dir()
202 209 nt.assert_equal(home_dir, abspath(HOME_TEST_DIR))
203 210
204 211
205 212 # Should we stub wreg fully so we can run the test on all platforms?
206 213 @skip_if_not_win32
207 214 @with_environment
208 215 def test_get_home_dir_8():
209 216 """Using registry hack for 'My Documents', os=='nt'
210 217
211 218 HOMESHARE, HOMEDRIVE, HOMEPATH, USERPROFILE and others are missing.
212 219 """
213 220 os.name = 'nt'
214 221 # Remove from stub environment all keys that may be set
215 222 for key in ['HOME', 'HOMESHARE', 'HOMEDRIVE', 'HOMEPATH', 'USERPROFILE']:
216 223 env.pop(key, None)
217 224
218 225 #Stub windows registry functions
219 226 def OpenKey(x, y):
220 227 class key:
221 228 def Close(self):
222 229 pass
223 230 return key()
224 231 def QueryValueEx(x, y):
225 232 return [abspath(HOME_TEST_DIR)]
226 233
227 234 wreg.OpenKey = OpenKey
228 235 wreg.QueryValueEx = QueryValueEx
229 236
230 237 home_dir = path.get_home_dir()
231 238 nt.assert_equal(home_dir, abspath(HOME_TEST_DIR))
232 239
233 240
234 241 @with_environment
235 242 def test_get_ipython_dir_1():
236 243 """test_get_ipython_dir_1, Testcase to see if we can call get_ipython_dir without Exceptions."""
237 244 env_ipdir = os.path.join("someplace", ".ipython")
238 245 path._writable_dir = lambda path: True
239 246 env['IPYTHON_DIR'] = env_ipdir
240 247 ipdir = path.get_ipython_dir()
241 248 nt.assert_equal(ipdir, env_ipdir)
242 249
243 250
244 251 @with_environment
245 252 def test_get_ipython_dir_2():
246 253 """test_get_ipython_dir_2, Testcase to see if we can call get_ipython_dir without Exceptions."""
247 254 path.get_home_dir = lambda : "someplace"
248 255 path.get_xdg_dir = lambda : None
249 256 path._writable_dir = lambda path: True
250 257 os.name = "posix"
251 258 env.pop('IPYTHON_DIR', None)
252 259 env.pop('IPYTHONDIR', None)
253 260 env.pop('XDG_CONFIG_HOME', None)
254 261 ipdir = path.get_ipython_dir()
255 262 nt.assert_equal(ipdir, os.path.join("someplace", ".ipython"))
256 263
257 264 @with_environment
258 265 def test_get_ipython_dir_3():
259 266 """test_get_ipython_dir_3, use XDG if defined, and .ipython doesn't exist."""
260 267 path.get_home_dir = lambda : "someplace"
261 268 path._writable_dir = lambda path: True
262 269 os.name = "posix"
263 270 env.pop('IPYTHON_DIR', None)
264 271 env.pop('IPYTHONDIR', None)
265 272 env['XDG_CONFIG_HOME'] = XDG_TEST_DIR
266 273 ipdir = path.get_ipython_dir()
267 274 nt.assert_equal(ipdir, os.path.join(XDG_TEST_DIR, "ipython"))
268 275
269 276 @with_environment
270 277 def test_get_ipython_dir_4():
271 278 """test_get_ipython_dir_4, use XDG if both exist."""
272 279 path.get_home_dir = lambda : HOME_TEST_DIR
273 280 os.name = "posix"
274 281 env.pop('IPYTHON_DIR', None)
275 282 env.pop('IPYTHONDIR', None)
276 283 env['XDG_CONFIG_HOME'] = XDG_TEST_DIR
277 284 xdg_ipdir = os.path.join(XDG_TEST_DIR, "ipython")
278 285 ipdir = path.get_ipython_dir()
279 286 nt.assert_equal(ipdir, xdg_ipdir)
280 287
281 288 @with_environment
282 289 def test_get_ipython_dir_5():
283 290 """test_get_ipython_dir_5, use .ipython if exists and XDG defined, but doesn't exist."""
284 291 path.get_home_dir = lambda : HOME_TEST_DIR
285 292 os.name = "posix"
286 293 env.pop('IPYTHON_DIR', None)
287 294 env.pop('IPYTHONDIR', None)
288 295 env['XDG_CONFIG_HOME'] = XDG_TEST_DIR
289 296 os.rmdir(os.path.join(XDG_TEST_DIR, 'ipython'))
290 297 ipdir = path.get_ipython_dir()
291 298 nt.assert_equal(ipdir, IP_TEST_DIR)
292 299
293 300 @with_environment
294 301 def test_get_ipython_dir_6():
295 302 """test_get_ipython_dir_6, use XDG if defined and neither exist."""
296 303 xdg = os.path.join(HOME_TEST_DIR, 'somexdg')
297 304 os.mkdir(xdg)
298 305 shutil.rmtree(os.path.join(HOME_TEST_DIR, '.ipython'))
299 306 path.get_home_dir = lambda : HOME_TEST_DIR
300 307 path.get_xdg_dir = lambda : xdg
301 308 os.name = "posix"
302 309 env.pop('IPYTHON_DIR', None)
303 310 env.pop('IPYTHONDIR', None)
304 311 env.pop('XDG_CONFIG_HOME', None)
305 312 xdg_ipdir = os.path.join(xdg, "ipython")
306 313 ipdir = path.get_ipython_dir()
307 314 nt.assert_equal(ipdir, xdg_ipdir)
308 315
309 316 @with_environment
310 317 def test_get_ipython_dir_7():
311 318 """test_get_ipython_dir_7, test home directory expansion on IPYTHON_DIR"""
312 319 path._writable_dir = lambda path: True
313 320 home_dir = os.path.expanduser('~')
314 321 env['IPYTHON_DIR'] = os.path.join('~', 'somewhere')
315 322 ipdir = path.get_ipython_dir()
316 323 nt.assert_equal(ipdir, os.path.join(home_dir, 'somewhere'))
317 324
318 325
319 326 @with_environment
320 327 def test_get_xdg_dir_1():
321 328 """test_get_xdg_dir_1, check xdg_dir"""
322 329 reload(path)
323 330 path._writable_dir = lambda path: True
324 331 path.get_home_dir = lambda : 'somewhere'
325 332 os.name = "posix"
326 333 env.pop('IPYTHON_DIR', None)
327 334 env.pop('IPYTHONDIR', None)
328 335 env.pop('XDG_CONFIG_HOME', None)
329 336
330 337 nt.assert_equal(path.get_xdg_dir(), os.path.join('somewhere', '.config'))
331 338
332 339
333 340 @with_environment
334 341 def test_get_xdg_dir_1():
335 342 """test_get_xdg_dir_1, check nonexistant xdg_dir"""
336 343 reload(path)
337 344 path.get_home_dir = lambda : HOME_TEST_DIR
338 345 os.name = "posix"
339 346 env.pop('IPYTHON_DIR', None)
340 347 env.pop('IPYTHONDIR', None)
341 348 env.pop('XDG_CONFIG_HOME', None)
342 349 nt.assert_equal(path.get_xdg_dir(), None)
343 350
344 351 @with_environment
345 352 def test_get_xdg_dir_2():
346 353 """test_get_xdg_dir_2, check xdg_dir default to ~/.config"""
347 354 reload(path)
348 355 path.get_home_dir = lambda : HOME_TEST_DIR
349 356 os.name = "posix"
350 357 env.pop('IPYTHON_DIR', None)
351 358 env.pop('IPYTHONDIR', None)
352 359 env.pop('XDG_CONFIG_HOME', None)
353 360 cfgdir=os.path.join(path.get_home_dir(), '.config')
354 361 os.makedirs(cfgdir)
355 362
356 363 nt.assert_equal(path.get_xdg_dir(), cfgdir)
357 364
358 365 def test_filefind():
359 366 """Various tests for filefind"""
360 367 f = tempfile.NamedTemporaryFile()
361 368 # print 'fname:',f.name
362 369 alt_dirs = path.get_ipython_dir()
363 370 t = path.filefind(f.name, alt_dirs)
364 371 # print 'found:',t
365 372
366 373
367 374 def test_get_ipython_package_dir():
368 375 ipdir = path.get_ipython_package_dir()
369 376 nt.assert_true(os.path.isdir(ipdir))
370 377
371 378
372 379 def test_get_ipython_module_path():
373 380 ipapp_path = path.get_ipython_module_path('IPython.frontend.terminal.ipapp')
374 381 nt.assert_true(os.path.isfile(ipapp_path))
375 382
376 383
377 384 @dec.skip_if_not_win32
378 385 def test_get_long_path_name_win32():
379 386 p = path.get_long_path_name('c:\\docume~1')
380 387 nt.assert_equals(p,u'c:\\Documents and Settings')
381 388
382 389
383 390 @dec.skip_win32
384 391 def test_get_long_path_name():
385 392 p = path.get_long_path_name('/usr/local')
386 393 nt.assert_equals(p,'/usr/local')
387 394
388 395 @dec.skip_win32 # can't create not-user-writable dir on win
389 396 @with_environment
390 397 def test_not_writable_ipdir():
391 398 tmpdir = tempfile.mkdtemp()
392 399 os.name = "posix"
393 400 env.pop('IPYTHON_DIR', None)
394 401 env.pop('IPYTHONDIR', None)
395 402 env.pop('XDG_CONFIG_HOME', None)
396 403 env['HOME'] = tmpdir
397 404 ipdir = os.path.join(tmpdir, '.ipython')
398 405 os.mkdir(ipdir)
399 406 os.chmod(ipdir, 600)
400 407 stderr = io.stderr
401 408 pipe = StringIO.StringIO()
402 409 io.stderr = pipe
403 410 ipdir = path.get_ipython_dir()
404 411 io.stderr.flush()
405 412 io.stderr = stderr
406 413 nt.assert_true('WARNING' in pipe.getvalue())
407 414 env.pop('IPYTHON_DIR', None)
408 415
409 416 def test_unquote_filename():
410 417 for win32 in (True, False):
411 418 nt.assert_equals(path.unquote_filename('foo.py', win32=win32), 'foo.py')
412 419 nt.assert_equals(path.unquote_filename('foo bar.py', win32=win32), 'foo bar.py')
413 420 nt.assert_equals(path.unquote_filename('"foo.py"', win32=True), 'foo.py')
414 421 nt.assert_equals(path.unquote_filename('"foo bar.py"', win32=True), 'foo bar.py')
415 422 nt.assert_equals(path.unquote_filename("'foo.py'", win32=True), 'foo.py')
416 423 nt.assert_equals(path.unquote_filename("'foo bar.py'", win32=True), 'foo bar.py')
417 424 nt.assert_equals(path.unquote_filename('"foo.py"', win32=False), '"foo.py"')
418 425 nt.assert_equals(path.unquote_filename('"foo bar.py"', win32=False), '"foo bar.py"')
419 426 nt.assert_equals(path.unquote_filename("'foo.py'", win32=False), "'foo.py'")
420 427 nt.assert_equals(path.unquote_filename("'foo bar.py'", win32=False), "'foo bar.py'")
421 428
422 429 @with_environment
423 430 def test_get_py_filename():
424 431 os.chdir(TMP_TEST_DIR)
425 432 for win32 in (True, False):
426 433 with make_tempfile('foo.py'):
427 434 nt.assert_equals(path.get_py_filename('foo.py', force_win32=win32), 'foo.py')
428 435 nt.assert_equals(path.get_py_filename('foo', force_win32=win32), 'foo.py')
429 436 with make_tempfile('foo'):
430 437 nt.assert_equals(path.get_py_filename('foo', force_win32=win32), 'foo')
431 438 nt.assert_raises(IOError, path.get_py_filename, 'foo.py', force_win32=win32)
432 439 nt.assert_raises(IOError, path.get_py_filename, 'foo', force_win32=win32)
433 440 nt.assert_raises(IOError, path.get_py_filename, 'foo.py', force_win32=win32)
434 441 true_fn = 'foo with spaces.py'
435 442 with make_tempfile(true_fn):
436 443 nt.assert_equals(path.get_py_filename('foo with spaces', force_win32=win32), true_fn)
437 444 nt.assert_equals(path.get_py_filename('foo with spaces.py', force_win32=win32), true_fn)
438 445 if win32:
439 446 nt.assert_equals(path.get_py_filename('"foo with spaces.py"', force_win32=True), true_fn)
440 447 nt.assert_equals(path.get_py_filename("'foo with spaces.py'", force_win32=True), true_fn)
441 448 else:
442 449 nt.assert_raises(IOError, path.get_py_filename, '"foo with spaces.py"', force_win32=False)
443 450 nt.assert_raises(IOError, path.get_py_filename, "'foo with spaces.py'", force_win32=False)
@@ -1,98 +1,98 b''
1 1 # encoding: utf-8
2 2 """
3 3 Tests for platutils.py
4 4 """
5 5
6 6 #-----------------------------------------------------------------------------
7 7 # Copyright (C) 2008-2009 The IPython Development Team
8 8 #
9 9 # Distributed under the terms of the BSD License. The full license is in
10 10 # the file COPYING, distributed as part of this software.
11 11 #-----------------------------------------------------------------------------
12 12
13 13 #-----------------------------------------------------------------------------
14 14 # Imports
15 15 #-----------------------------------------------------------------------------
16 16
17 17 import sys
18 18 from unittest import TestCase
19 19
20 20 import nose.tools as nt
21 21
22 22 from IPython.utils.process import (find_cmd, FindCmdError, arg_split,
23 23 system, getoutput, getoutputerror)
24 24 from IPython.testing import decorators as dec
25 25 from IPython.testing import tools as tt
26 26
27 27 #-----------------------------------------------------------------------------
28 28 # Tests
29 29 #-----------------------------------------------------------------------------
30 30
31 31 def test_find_cmd_python():
32 32 """Make sure we find sys.exectable for python."""
33 33 nt.assert_equals(find_cmd('python'), sys.executable)
34 34
35 35
36 36 @dec.skip_win32
37 37 def test_find_cmd_ls():
38 38 """Make sure we can find the full path to ls."""
39 39 path = find_cmd('ls')
40 nt.assert_true(path.endswith('ls'))
40 nt.assert_true(path.endswith(b'ls'))
41 41
42 42
43 43 def has_pywin32():
44 44 try:
45 45 import win32api
46 46 except ImportError:
47 47 return False
48 48 return True
49 49
50 50
51 51 @dec.onlyif(has_pywin32, "This test requires win32api to run")
52 52 def test_find_cmd_pythonw():
53 53 """Try to find pythonw on Windows."""
54 54 path = find_cmd('pythonw')
55 55 nt.assert_true(path.endswith('pythonw.exe'))
56 56
57 57
58 58 @dec.onlyif(lambda : sys.platform != 'win32' or has_pywin32(),
59 59 "This test runs on posix or in win32 with win32api installed")
60 60 def test_find_cmd_fail():
61 61 """Make sure that FindCmdError is raised if we can't find the cmd."""
62 62 nt.assert_raises(FindCmdError,find_cmd,'asdfasdf')
63 63
64 64
65 65 def test_arg_split():
66 66 """Ensure that argument lines are correctly split like in a shell."""
67 67 tests = [['hi', ['hi']],
68 68 [u'hi', [u'hi']],
69 69 ['hello there', ['hello', 'there']],
70 70 [u'h\N{LATIN SMALL LETTER A WITH CARON}llo', [u'h\N{LATIN SMALL LETTER A WITH CARON}llo']],
71 71 ['something "with quotes"', ['something', '"with quotes"']],
72 72 ]
73 73 for argstr, argv in tests:
74 74 nt.assert_equal(arg_split(argstr), argv)
75 75
76 76
77 77 class SubProcessTestCase(TestCase, tt.TempFileMixin):
78 78 def setUp(self):
79 79 """Make a valid python temp file."""
80 80 lines = ["from __future__ import print_function",
81 81 "import sys",
82 82 "print('on stdout', end='', file=sys.stdout)",
83 83 "print('on stderr', end='', file=sys.stderr)",
84 84 "sys.stdout.flush()",
85 85 "sys.stderr.flush()"]
86 86 self.mktmp('\n'.join(lines))
87 87
88 88 def test_system(self):
89 89 system('python "%s"' % self.fname)
90 90
91 91 def test_getoutput(self):
92 92 out = getoutput('python "%s"' % self.fname)
93 93 self.assertEquals(out, 'on stdout')
94 94
95 95 def test_getoutput(self):
96 96 out, err = getoutputerror('python "%s"' % self.fname)
97 97 self.assertEquals(out, 'on stdout')
98 98 self.assertEquals(err, 'on stderr')
@@ -1,715 +1,715 b''
1 1 # encoding: utf-8
2 2 """
3 3 Utilities for working with strings and text.
4 4 """
5 5
6 6 #-----------------------------------------------------------------------------
7 7 # Copyright (C) 2008-2009 The IPython Development Team
8 8 #
9 9 # Distributed under the terms of the BSD License. The full license is in
10 10 # the file COPYING, distributed as part of this software.
11 11 #-----------------------------------------------------------------------------
12 12
13 13 #-----------------------------------------------------------------------------
14 14 # Imports
15 15 #-----------------------------------------------------------------------------
16 16
17 17 import __main__
18 18
19 19 import os
20 20 import re
21 21 import shutil
22 22 import textwrap
23 23 from string import Formatter
24 24
25 25 from IPython.external.path import path
26 26 from IPython.utils import py3compat
27 27 from IPython.utils.io import nlprint
28 28 from IPython.utils.data import flatten
29 29
30 30 #-----------------------------------------------------------------------------
31 31 # Code
32 32 #-----------------------------------------------------------------------------
33 33
34 34
35 35 def unquote_ends(istr):
36 36 """Remove a single pair of quotes from the endpoints of a string."""
37 37
38 38 if not istr:
39 39 return istr
40 40 if (istr[0]=="'" and istr[-1]=="'") or \
41 41 (istr[0]=='"' and istr[-1]=='"'):
42 42 return istr[1:-1]
43 43 else:
44 44 return istr
45 45
46 46
47 47 class LSString(str):
48 48 """String derivative with a special access attributes.
49 49
50 50 These are normal strings, but with the special attributes:
51 51
52 52 .l (or .list) : value as list (split on newlines).
53 53 .n (or .nlstr): original value (the string itself).
54 54 .s (or .spstr): value as whitespace-separated string.
55 55 .p (or .paths): list of path objects
56 56
57 57 Any values which require transformations are computed only once and
58 58 cached.
59 59
60 60 Such strings are very useful to efficiently interact with the shell, which
61 61 typically only understands whitespace-separated options for commands."""
62 62
63 63 def get_list(self):
64 64 try:
65 65 return self.__list
66 66 except AttributeError:
67 67 self.__list = self.split('\n')
68 68 return self.__list
69 69
70 70 l = list = property(get_list)
71 71
72 72 def get_spstr(self):
73 73 try:
74 74 return self.__spstr
75 75 except AttributeError:
76 76 self.__spstr = self.replace('\n',' ')
77 77 return self.__spstr
78 78
79 79 s = spstr = property(get_spstr)
80 80
81 81 def get_nlstr(self):
82 82 return self
83 83
84 84 n = nlstr = property(get_nlstr)
85 85
86 86 def get_paths(self):
87 87 try:
88 88 return self.__paths
89 89 except AttributeError:
90 90 self.__paths = [path(p) for p in self.split('\n') if os.path.exists(p)]
91 91 return self.__paths
92 92
93 93 p = paths = property(get_paths)
94 94
95 95 # FIXME: We need to reimplement type specific displayhook and then add this
96 96 # back as a custom printer. This should also be moved outside utils into the
97 97 # core.
98 98
99 99 # def print_lsstring(arg):
100 100 # """ Prettier (non-repr-like) and more informative printer for LSString """
101 101 # print "LSString (.p, .n, .l, .s available). Value:"
102 102 # print arg
103 103 #
104 104 #
105 105 # print_lsstring = result_display.when_type(LSString)(print_lsstring)
106 106
107 107
108 108 class SList(list):
109 109 """List derivative with a special access attributes.
110 110
111 111 These are normal lists, but with the special attributes:
112 112
113 113 .l (or .list) : value as list (the list itself).
114 114 .n (or .nlstr): value as a string, joined on newlines.
115 115 .s (or .spstr): value as a string, joined on spaces.
116 116 .p (or .paths): list of path objects
117 117
118 118 Any values which require transformations are computed only once and
119 119 cached."""
120 120
121 121 def get_list(self):
122 122 return self
123 123
124 124 l = list = property(get_list)
125 125
126 126 def get_spstr(self):
127 127 try:
128 128 return self.__spstr
129 129 except AttributeError:
130 130 self.__spstr = ' '.join(self)
131 131 return self.__spstr
132 132
133 133 s = spstr = property(get_spstr)
134 134
135 135 def get_nlstr(self):
136 136 try:
137 137 return self.__nlstr
138 138 except AttributeError:
139 139 self.__nlstr = '\n'.join(self)
140 140 return self.__nlstr
141 141
142 142 n = nlstr = property(get_nlstr)
143 143
144 144 def get_paths(self):
145 145 try:
146 146 return self.__paths
147 147 except AttributeError:
148 148 self.__paths = [path(p) for p in self if os.path.exists(p)]
149 149 return self.__paths
150 150
151 151 p = paths = property(get_paths)
152 152
153 153 def grep(self, pattern, prune = False, field = None):
154 154 """ Return all strings matching 'pattern' (a regex or callable)
155 155
156 156 This is case-insensitive. If prune is true, return all items
157 157 NOT matching the pattern.
158 158
159 159 If field is specified, the match must occur in the specified
160 160 whitespace-separated field.
161 161
162 162 Examples::
163 163
164 164 a.grep( lambda x: x.startswith('C') )
165 165 a.grep('Cha.*log', prune=1)
166 166 a.grep('chm', field=-1)
167 167 """
168 168
169 169 def match_target(s):
170 170 if field is None:
171 171 return s
172 172 parts = s.split()
173 173 try:
174 174 tgt = parts[field]
175 175 return tgt
176 176 except IndexError:
177 177 return ""
178 178
179 179 if isinstance(pattern, basestring):
180 180 pred = lambda x : re.search(pattern, x, re.IGNORECASE)
181 181 else:
182 182 pred = pattern
183 183 if not prune:
184 184 return SList([el for el in self if pred(match_target(el))])
185 185 else:
186 186 return SList([el for el in self if not pred(match_target(el))])
187 187
188 188 def fields(self, *fields):
189 189 """ Collect whitespace-separated fields from string list
190 190
191 191 Allows quick awk-like usage of string lists.
192 192
193 193 Example data (in var a, created by 'a = !ls -l')::
194 194 -rwxrwxrwx 1 ville None 18 Dec 14 2006 ChangeLog
195 195 drwxrwxrwx+ 6 ville None 0 Oct 24 18:05 IPython
196 196
197 197 a.fields(0) is ['-rwxrwxrwx', 'drwxrwxrwx+']
198 198 a.fields(1,0) is ['1 -rwxrwxrwx', '6 drwxrwxrwx+']
199 199 (note the joining by space).
200 200 a.fields(-1) is ['ChangeLog', 'IPython']
201 201
202 202 IndexErrors are ignored.
203 203
204 204 Without args, fields() just split()'s the strings.
205 205 """
206 206 if len(fields) == 0:
207 207 return [el.split() for el in self]
208 208
209 209 res = SList()
210 210 for el in [f.split() for f in self]:
211 211 lineparts = []
212 212
213 213 for fd in fields:
214 214 try:
215 215 lineparts.append(el[fd])
216 216 except IndexError:
217 217 pass
218 218 if lineparts:
219 219 res.append(" ".join(lineparts))
220 220
221 221 return res
222 222
223 223 def sort(self,field= None, nums = False):
224 224 """ sort by specified fields (see fields())
225 225
226 226 Example::
227 227 a.sort(1, nums = True)
228 228
229 229 Sorts a by second field, in numerical order (so that 21 > 3)
230 230
231 231 """
232 232
233 233 #decorate, sort, undecorate
234 234 if field is not None:
235 235 dsu = [[SList([line]).fields(field), line] for line in self]
236 236 else:
237 237 dsu = [[line, line] for line in self]
238 238 if nums:
239 239 for i in range(len(dsu)):
240 240 numstr = "".join([ch for ch in dsu[i][0] if ch.isdigit()])
241 241 try:
242 242 n = int(numstr)
243 243 except ValueError:
244 244 n = 0;
245 245 dsu[i][0] = n
246 246
247 247
248 248 dsu.sort()
249 249 return SList([t[1] for t in dsu])
250 250
251 251
252 252 # FIXME: We need to reimplement type specific displayhook and then add this
253 253 # back as a custom printer. This should also be moved outside utils into the
254 254 # core.
255 255
256 256 # def print_slist(arg):
257 257 # """ Prettier (non-repr-like) and more informative printer for SList """
258 258 # print "SList (.p, .n, .l, .s, .grep(), .fields(), sort() available):"
259 259 # if hasattr(arg, 'hideonce') and arg.hideonce:
260 260 # arg.hideonce = False
261 261 # return
262 262 #
263 263 # nlprint(arg)
264 264 #
265 265 # print_slist = result_display.when_type(SList)(print_slist)
266 266
267 267
268 268 def esc_quotes(strng):
269 269 """Return the input string with single and double quotes escaped out"""
270 270
271 271 return strng.replace('"','\\"').replace("'","\\'")
272 272
273 273
274 274 def make_quoted_expr(s):
275 275 """Return string s in appropriate quotes, using raw string if possible.
276 276
277 277 XXX - example removed because it caused encoding errors in documentation
278 278 generation. We need a new example that doesn't contain invalid chars.
279 279
280 280 Note the use of raw string and padding at the end to allow trailing
281 281 backslash.
282 282 """
283 283
284 284 tail = ''
285 285 tailpadding = ''
286 286 raw = ''
287 287 ucode = '' if py3compat.PY3 else 'u'
288 288 if "\\" in s:
289 289 raw = 'r'
290 290 if s.endswith('\\'):
291 291 tail = '[:-1]'
292 292 tailpadding = '_'
293 293 if '"' not in s:
294 294 quote = '"'
295 295 elif "'" not in s:
296 296 quote = "'"
297 297 elif '"""' not in s and not s.endswith('"'):
298 298 quote = '"""'
299 299 elif "'''" not in s and not s.endswith("'"):
300 300 quote = "'''"
301 301 else:
302 302 # give up, backslash-escaped string will do
303 303 return '"%s"' % esc_quotes(s)
304 304 res = ucode + raw + quote + s + tailpadding + quote + tail
305 305 return res
306 306
307 307
308 308 def qw(words,flat=0,sep=None,maxsplit=-1):
309 309 """Similar to Perl's qw() operator, but with some more options.
310 310
311 311 qw(words,flat=0,sep=' ',maxsplit=-1) -> words.split(sep,maxsplit)
312 312
313 313 words can also be a list itself, and with flat=1, the output will be
314 314 recursively flattened.
315 315
316 316 Examples:
317 317
318 318 >>> qw('1 2')
319 319 ['1', '2']
320 320
321 321 >>> qw(['a b','1 2',['m n','p q']])
322 322 [['a', 'b'], ['1', '2'], [['m', 'n'], ['p', 'q']]]
323 323
324 324 >>> qw(['a b','1 2',['m n','p q']],flat=1)
325 325 ['a', 'b', '1', '2', 'm', 'n', 'p', 'q']
326 326 """
327 327
328 328 if isinstance(words, basestring):
329 329 return [word.strip() for word in words.split(sep,maxsplit)
330 330 if word and not word.isspace() ]
331 331 if flat:
332 332 return flatten(map(qw,words,[1]*len(words)))
333 333 return map(qw,words)
334 334
335 335
336 336 def qwflat(words,sep=None,maxsplit=-1):
337 337 """Calls qw(words) in flat mode. It's just a convenient shorthand."""
338 338 return qw(words,1,sep,maxsplit)
339 339
340 340
341 341 def qw_lol(indata):
342 342 """qw_lol('a b') -> [['a','b']],
343 343 otherwise it's just a call to qw().
344 344
345 345 We need this to make sure the modules_some keys *always* end up as a
346 346 list of lists."""
347 347
348 348 if isinstance(indata, basestring):
349 349 return [qw(indata)]
350 350 else:
351 351 return qw(indata)
352 352
353 353
354 354 def grep(pat,list,case=1):
355 355 """Simple minded grep-like function.
356 356 grep(pat,list) returns occurrences of pat in list, None on failure.
357 357
358 358 It only does simple string matching, with no support for regexps. Use the
359 359 option case=0 for case-insensitive matching."""
360 360
361 361 # This is pretty crude. At least it should implement copying only references
362 362 # to the original data in case it's big. Now it copies the data for output.
363 363 out=[]
364 364 if case:
365 365 for term in list:
366 366 if term.find(pat)>-1: out.append(term)
367 367 else:
368 368 lpat=pat.lower()
369 369 for term in list:
370 370 if term.lower().find(lpat)>-1: out.append(term)
371 371
372 372 if len(out): return out
373 373 else: return None
374 374
375 375
376 376 def dgrep(pat,*opts):
377 377 """Return grep() on dir()+dir(__builtins__).
378 378
379 379 A very common use of grep() when working interactively."""
380 380
381 381 return grep(pat,dir(__main__)+dir(__main__.__builtins__),*opts)
382 382
383 383
384 384 def idgrep(pat):
385 385 """Case-insensitive dgrep()"""
386 386
387 387 return dgrep(pat,0)
388 388
389 389
390 390 def igrep(pat,list):
391 391 """Synonym for case-insensitive grep."""
392 392
393 393 return grep(pat,list,case=0)
394 394
395 395
396 396 def indent(instr,nspaces=4, ntabs=0, flatten=False):
397 397 """Indent a string a given number of spaces or tabstops.
398 398
399 399 indent(str,nspaces=4,ntabs=0) -> indent str by ntabs+nspaces.
400 400
401 401 Parameters
402 402 ----------
403 403
404 404 instr : basestring
405 405 The string to be indented.
406 406 nspaces : int (default: 4)
407 407 The number of spaces to be indented.
408 408 ntabs : int (default: 0)
409 409 The number of tabs to be indented.
410 410 flatten : bool (default: False)
411 411 Whether to scrub existing indentation. If True, all lines will be
412 412 aligned to the same indentation. If False, existing indentation will
413 413 be strictly increased.
414 414
415 415 Returns
416 416 -------
417 417
418 418 str|unicode : string indented by ntabs and nspaces.
419 419
420 420 """
421 421 if instr is None:
422 422 return
423 423 ind = '\t'*ntabs+' '*nspaces
424 424 if flatten:
425 425 pat = re.compile(r'^\s*', re.MULTILINE)
426 426 else:
427 427 pat = re.compile(r'^', re.MULTILINE)
428 428 outstr = re.sub(pat, ind, instr)
429 429 if outstr.endswith(os.linesep+ind):
430 430 return outstr[:-len(ind)]
431 431 else:
432 432 return outstr
433 433
434 434 def native_line_ends(filename,backup=1):
435 435 """Convert (in-place) a file to line-ends native to the current OS.
436 436
437 437 If the optional backup argument is given as false, no backup of the
438 438 original file is left. """
439 439
440 440 backup_suffixes = {'posix':'~','dos':'.bak','nt':'.bak','mac':'.bak'}
441 441
442 442 bak_filename = filename + backup_suffixes[os.name]
443 443
444 444 original = open(filename).read()
445 445 shutil.copy2(filename,bak_filename)
446 446 try:
447 447 new = open(filename,'wb')
448 448 new.write(os.linesep.join(original.splitlines()))
449 449 new.write(os.linesep) # ALWAYS put an eol at the end of the file
450 450 new.close()
451 451 except:
452 452 os.rename(bak_filename,filename)
453 453 if not backup:
454 454 try:
455 455 os.remove(bak_filename)
456 456 except:
457 457 pass
458 458
459 459
460 460 def list_strings(arg):
461 461 """Always return a list of strings, given a string or list of strings
462 462 as input.
463 463
464 464 :Examples:
465 465
466 466 In [7]: list_strings('A single string')
467 467 Out[7]: ['A single string']
468 468
469 469 In [8]: list_strings(['A single string in a list'])
470 470 Out[8]: ['A single string in a list']
471 471
472 472 In [9]: list_strings(['A','list','of','strings'])
473 473 Out[9]: ['A', 'list', 'of', 'strings']
474 474 """
475 475
476 476 if isinstance(arg,basestring): return [arg]
477 477 else: return arg
478 478
479 479
480 480 def marquee(txt='',width=78,mark='*'):
481 481 """Return the input string centered in a 'marquee'.
482 482
483 483 :Examples:
484 484
485 485 In [16]: marquee('A test',40)
486 486 Out[16]: '**************** A test ****************'
487 487
488 488 In [17]: marquee('A test',40,'-')
489 489 Out[17]: '---------------- A test ----------------'
490 490
491 491 In [18]: marquee('A test',40,' ')
492 492 Out[18]: ' A test '
493 493
494 494 """
495 495 if not txt:
496 496 return (mark*width)[:width]
497 497 nmark = (width-len(txt)-2)//len(mark)//2
498 498 if nmark < 0: nmark =0
499 499 marks = mark*nmark
500 500 return '%s %s %s' % (marks,txt,marks)
501 501
502 502
503 503 ini_spaces_re = re.compile(r'^(\s+)')
504 504
505 505 def num_ini_spaces(strng):
506 506 """Return the number of initial spaces in a string"""
507 507
508 508 ini_spaces = ini_spaces_re.match(strng)
509 509 if ini_spaces:
510 510 return ini_spaces.end()
511 511 else:
512 512 return 0
513 513
514 514
515 515 def format_screen(strng):
516 516 """Format a string for screen printing.
517 517
518 518 This removes some latex-type format codes."""
519 519 # Paragraph continue
520 520 par_re = re.compile(r'\\$',re.MULTILINE)
521 521 strng = par_re.sub('',strng)
522 522 return strng
523 523
524 524 def dedent(text):
525 525 """Equivalent of textwrap.dedent that ignores unindented first line.
526 526
527 527 This means it will still dedent strings like:
528 528 '''foo
529 529 is a bar
530 530 '''
531 531
532 532 For use in wrap_paragraphs.
533 533 """
534 534
535 535 if text.startswith('\n'):
536 536 # text starts with blank line, don't ignore the first line
537 537 return textwrap.dedent(text)
538 538
539 539 # split first line
540 540 splits = text.split('\n',1)
541 541 if len(splits) == 1:
542 542 # only one line
543 543 return textwrap.dedent(text)
544 544
545 545 first, rest = splits
546 546 # dedent everything but the first line
547 547 rest = textwrap.dedent(rest)
548 548 return '\n'.join([first, rest])
549 549
550 550 def wrap_paragraphs(text, ncols=80):
551 551 """Wrap multiple paragraphs to fit a specified width.
552 552
553 553 This is equivalent to textwrap.wrap, but with support for multiple
554 554 paragraphs, as separated by empty lines.
555 555
556 556 Returns
557 557 -------
558 558
559 559 list of complete paragraphs, wrapped to fill `ncols` columns.
560 560 """
561 561 paragraph_re = re.compile(r'\n(\s*\n)+', re.MULTILINE)
562 562 text = dedent(text).strip()
563 563 paragraphs = paragraph_re.split(text)[::2] # every other entry is space
564 564 out_ps = []
565 565 indent_re = re.compile(r'\n\s+', re.MULTILINE)
566 566 for p in paragraphs:
567 567 # presume indentation that survives dedent is meaningful formatting,
568 568 # so don't fill unless text is flush.
569 569 if indent_re.search(p) is None:
570 570 # wrap paragraph
571 571 p = textwrap.fill(p, ncols)
572 572 out_ps.append(p)
573 573 return out_ps
574 574
575 575
576 576
577 577 class EvalFormatter(Formatter):
578 578 """A String Formatter that allows evaluation of simple expressions.
579 579
580 580 Any time a format key is not found in the kwargs,
581 581 it will be tried as an expression in the kwargs namespace.
582 582
583 583 This is to be used in templating cases, such as the parallel batch
584 584 script templates, where simple arithmetic on arguments is useful.
585 585
586 586 Examples
587 587 --------
588 588
589 589 In [1]: f = EvalFormatter()
590 In [2]: f.format('{n/4}', n=8)
590 In [2]: f.format('{n//4}', n=8)
591 591 Out[2]: '2'
592 592
593 In [3]: f.format('{range(3)}')
593 In [3]: f.format('{list(range(3))}')
594 594 Out[3]: '[0, 1, 2]'
595 595
596 596 In [4]: f.format('{3*2}')
597 597 Out[4]: '6'
598 598 """
599 599
600 600 # should we allow slicing by disabling the format_spec feature?
601 601 allow_slicing = True
602 602
603 603 # copied from Formatter._vformat with minor changes to allow eval
604 604 # and replace the format_spec code with slicing
605 605 def _vformat(self, format_string, args, kwargs, used_args, recursion_depth):
606 606 if recursion_depth < 0:
607 607 raise ValueError('Max string recursion exceeded')
608 608 result = []
609 609 for literal_text, field_name, format_spec, conversion in \
610 610 self.parse(format_string):
611 611
612 612 # output the literal text
613 613 if literal_text:
614 614 result.append(literal_text)
615 615
616 616 # if there's a field, output it
617 617 if field_name is not None:
618 618 # this is some markup, find the object and do
619 619 # the formatting
620 620
621 621 if self.allow_slicing and format_spec:
622 622 # override format spec, to allow slicing:
623 623 field_name = ':'.join([field_name, format_spec])
624 624 format_spec = ''
625 625
626 626 # eval the contents of the field for the object
627 627 # to be formatted
628 628 obj = eval(field_name, kwargs)
629 629
630 630 # do any conversion on the resulting object
631 631 obj = self.convert_field(obj, conversion)
632 632
633 633 # expand the format spec, if needed
634 634 format_spec = self._vformat(format_spec, args, kwargs,
635 635 used_args, recursion_depth-1)
636 636
637 637 # format the object and append to the result
638 638 result.append(self.format_field(obj, format_spec))
639 639
640 640 return ''.join(result)
641 641
642 642
643 643 def columnize(items, separator=' ', displaywidth=80):
644 644 """ Transform a list of strings into a single string with columns.
645 645
646 646 Parameters
647 647 ----------
648 648 items : sequence of strings
649 649 The strings to process.
650 650
651 651 separator : str, optional [default is two spaces]
652 652 The string that separates columns.
653 653
654 654 displaywidth : int, optional [default is 80]
655 655 Width of the display in number of characters.
656 656
657 657 Returns
658 658 -------
659 659 The formatted string.
660 660 """
661 661 # Note: this code is adapted from columnize 0.3.2.
662 662 # See http://code.google.com/p/pycolumnize/
663 663
664 664 # Some degenerate cases.
665 665 size = len(items)
666 666 if size == 0:
667 667 return '\n'
668 668 elif size == 1:
669 669 return '%s\n' % items[0]
670 670
671 671 # Special case: if any item is longer than the maximum width, there's no
672 672 # point in triggering the logic below...
673 673 item_len = map(len, items) # save these, we can reuse them below
674 674 longest = max(item_len)
675 675 if longest >= displaywidth:
676 676 return '\n'.join(items+[''])
677 677
678 678 # Try every row count from 1 upwards
679 679 array_index = lambda nrows, row, col: nrows*col + row
680 680 for nrows in range(1, size):
681 681 ncols = (size + nrows - 1) // nrows
682 682 colwidths = []
683 683 totwidth = -len(separator)
684 684 for col in range(ncols):
685 685 # Get max column width for this column
686 686 colwidth = 0
687 687 for row in range(nrows):
688 688 i = array_index(nrows, row, col)
689 689 if i >= size: break
690 690 x, len_x = items[i], item_len[i]
691 691 colwidth = max(colwidth, len_x)
692 692 colwidths.append(colwidth)
693 693 totwidth += colwidth + len(separator)
694 694 if totwidth > displaywidth:
695 695 break
696 696 if totwidth <= displaywidth:
697 697 break
698 698
699 699 # The smallest number of rows computed and the max widths for each
700 700 # column has been obtained. Now we just have to format each of the rows.
701 701 string = ''
702 702 for row in range(nrows):
703 703 texts = []
704 704 for col in range(ncols):
705 705 i = row + nrows*col
706 706 if i >= size:
707 707 texts.append('')
708 708 else:
709 709 texts.append(items[i])
710 710 while texts and not texts[-1]:
711 711 del texts[-1]
712 712 for col in range(len(texts)):
713 713 texts[col] = texts[col].ljust(colwidths[col])
714 714 string += '%s\n' % separator.join(texts)
715 715 return string
@@ -1,1392 +1,1394 b''
1 1 # encoding: utf-8
2 2 """
3 3 A lightweight Traits like module.
4 4
5 5 This is designed to provide a lightweight, simple, pure Python version of
6 6 many of the capabilities of enthought.traits. This includes:
7 7
8 8 * Validation
9 9 * Type specification with defaults
10 10 * Static and dynamic notification
11 11 * Basic predefined types
12 12 * An API that is similar to enthought.traits
13 13
14 14 We don't support:
15 15
16 16 * Delegation
17 17 * Automatic GUI generation
18 18 * A full set of trait types. Most importantly, we don't provide container
19 19 traits (list, dict, tuple) that can trigger notifications if their
20 20 contents change.
21 21 * API compatibility with enthought.traits
22 22
23 23 There are also some important difference in our design:
24 24
25 25 * enthought.traits does not validate default values. We do.
26 26
27 27 We choose to create this module because we need these capabilities, but
28 28 we need them to be pure Python so they work in all Python implementations,
29 29 including Jython and IronPython.
30 30
31 31 Authors:
32 32
33 33 * Brian Granger
34 34 * Enthought, Inc. Some of the code in this file comes from enthought.traits
35 35 and is licensed under the BSD license. Also, many of the ideas also come
36 36 from enthought.traits even though our implementation is very different.
37 37 """
38 38
39 39 #-----------------------------------------------------------------------------
40 40 # Copyright (C) 2008-2009 The IPython Development Team
41 41 #
42 42 # Distributed under the terms of the BSD License. The full license is in
43 43 # the file COPYING, distributed as part of this software.
44 44 #-----------------------------------------------------------------------------
45 45
46 46 #-----------------------------------------------------------------------------
47 47 # Imports
48 48 #-----------------------------------------------------------------------------
49 49
50 50
51 51 import inspect
52 52 import re
53 53 import sys
54 54 import types
55 55 from types import FunctionType
56 56 try:
57 57 from types import ClassType, InstanceType
58 58 ClassTypes = (ClassType, type)
59 59 except:
60 60 ClassTypes = (type,)
61 61
62 62 from .importstring import import_item
63 63 from IPython.utils import py3compat
64 64
65 65 SequenceTypes = (list, tuple, set, frozenset)
66 66
67 67 #-----------------------------------------------------------------------------
68 68 # Basic classes
69 69 #-----------------------------------------------------------------------------
70 70
71 71
72 72 class NoDefaultSpecified ( object ): pass
73 73 NoDefaultSpecified = NoDefaultSpecified()
74 74
75 75
76 76 class Undefined ( object ): pass
77 77 Undefined = Undefined()
78 78
79 79 class TraitError(Exception):
80 80 pass
81 81
82 82 #-----------------------------------------------------------------------------
83 83 # Utilities
84 84 #-----------------------------------------------------------------------------
85 85
86 86
87 87 def class_of ( object ):
88 88 """ Returns a string containing the class name of an object with the
89 89 correct indefinite article ('a' or 'an') preceding it (e.g., 'an Image',
90 90 'a PlotValue').
91 91 """
92 92 if isinstance( object, basestring ):
93 93 return add_article( object )
94 94
95 95 return add_article( object.__class__.__name__ )
96 96
97 97
98 98 def add_article ( name ):
99 99 """ Returns a string containing the correct indefinite article ('a' or 'an')
100 100 prefixed to the specified string.
101 101 """
102 102 if name[:1].lower() in 'aeiou':
103 103 return 'an ' + name
104 104
105 105 return 'a ' + name
106 106
107 107
108 108 def repr_type(obj):
109 109 """ Return a string representation of a value and its type for readable
110 110 error messages.
111 111 """
112 112 the_type = type(obj)
113 113 if (not py3compat.PY3) and the_type is InstanceType:
114 114 # Old-style class.
115 115 the_type = obj.__class__
116 116 msg = '%r %r' % (obj, the_type)
117 117 return msg
118 118
119 119
120 120 def parse_notifier_name(name):
121 121 """Convert the name argument to a list of names.
122 122
123 123 Examples
124 124 --------
125 125
126 126 >>> parse_notifier_name('a')
127 127 ['a']
128 128 >>> parse_notifier_name(['a','b'])
129 129 ['a', 'b']
130 130 >>> parse_notifier_name(None)
131 131 ['anytrait']
132 132 """
133 133 if isinstance(name, str):
134 134 return [name]
135 135 elif name is None:
136 136 return ['anytrait']
137 137 elif isinstance(name, (list, tuple)):
138 138 for n in name:
139 139 assert isinstance(n, str), "names must be strings"
140 140 return name
141 141
142 142
143 143 class _SimpleTest:
144 144 def __init__ ( self, value ): self.value = value
145 145 def __call__ ( self, test ):
146 146 return test == self.value
147 147 def __repr__(self):
148 148 return "<SimpleTest(%r)" % self.value
149 149 def __str__(self):
150 150 return self.__repr__()
151 151
152 152
153 153 def getmembers(object, predicate=None):
154 154 """A safe version of inspect.getmembers that handles missing attributes.
155 155
156 156 This is useful when there are descriptor based attributes that for
157 157 some reason raise AttributeError even though they exist. This happens
158 158 in zope.inteface with the __provides__ attribute.
159 159 """
160 160 results = []
161 161 for key in dir(object):
162 162 try:
163 163 value = getattr(object, key)
164 164 except AttributeError:
165 165 pass
166 166 else:
167 167 if not predicate or predicate(value):
168 168 results.append((key, value))
169 169 results.sort()
170 170 return results
171 171
172 172
173 173 #-----------------------------------------------------------------------------
174 174 # Base TraitType for all traits
175 175 #-----------------------------------------------------------------------------
176 176
177 177
178 178 class TraitType(object):
179 179 """A base class for all trait descriptors.
180 180
181 181 Notes
182 182 -----
183 183 Our implementation of traits is based on Python's descriptor
184 184 prototol. This class is the base class for all such descriptors. The
185 185 only magic we use is a custom metaclass for the main :class:`HasTraits`
186 186 class that does the following:
187 187
188 188 1. Sets the :attr:`name` attribute of every :class:`TraitType`
189 189 instance in the class dict to the name of the attribute.
190 190 2. Sets the :attr:`this_class` attribute of every :class:`TraitType`
191 191 instance in the class dict to the *class* that declared the trait.
192 192 This is used by the :class:`This` trait to allow subclasses to
193 193 accept superclasses for :class:`This` values.
194 194 """
195 195
196 196
197 197 metadata = {}
198 198 default_value = Undefined
199 199 info_text = 'any value'
200 200
201 201 def __init__(self, default_value=NoDefaultSpecified, **metadata):
202 202 """Create a TraitType.
203 203 """
204 204 if default_value is not NoDefaultSpecified:
205 205 self.default_value = default_value
206 206
207 207 if len(metadata) > 0:
208 208 if len(self.metadata) > 0:
209 209 self._metadata = self.metadata.copy()
210 210 self._metadata.update(metadata)
211 211 else:
212 212 self._metadata = metadata
213 213 else:
214 214 self._metadata = self.metadata
215 215
216 216 self.init()
217 217
218 218 def init(self):
219 219 pass
220 220
221 221 def get_default_value(self):
222 222 """Create a new instance of the default value."""
223 223 return self.default_value
224 224
225 225 def instance_init(self, obj):
226 226 """This is called by :meth:`HasTraits.__new__` to finish init'ing.
227 227
228 228 Some stages of initialization must be delayed until the parent
229 229 :class:`HasTraits` instance has been created. This method is
230 230 called in :meth:`HasTraits.__new__` after the instance has been
231 231 created.
232 232
233 233 This method trigger the creation and validation of default values
234 234 and also things like the resolution of str given class names in
235 235 :class:`Type` and :class`Instance`.
236 236
237 237 Parameters
238 238 ----------
239 239 obj : :class:`HasTraits` instance
240 240 The parent :class:`HasTraits` instance that has just been
241 241 created.
242 242 """
243 243 self.set_default_value(obj)
244 244
245 245 def set_default_value(self, obj):
246 246 """Set the default value on a per instance basis.
247 247
248 248 This method is called by :meth:`instance_init` to create and
249 249 validate the default value. The creation and validation of
250 250 default values must be delayed until the parent :class:`HasTraits`
251 251 class has been instantiated.
252 252 """
253 253 # Check for a deferred initializer defined in the same class as the
254 254 # trait declaration or above.
255 255 mro = type(obj).mro()
256 256 meth_name = '_%s_default' % self.name
257 257 for cls in mro[:mro.index(self.this_class)+1]:
258 258 if meth_name in cls.__dict__:
259 259 break
260 260 else:
261 261 # We didn't find one. Do static initialization.
262 262 dv = self.get_default_value()
263 263 newdv = self._validate(obj, dv)
264 264 obj._trait_values[self.name] = newdv
265 265 return
266 266 # Complete the dynamic initialization.
267 267 obj._trait_dyn_inits[self.name] = cls.__dict__[meth_name]
268 268
269 269 def __get__(self, obj, cls=None):
270 270 """Get the value of the trait by self.name for the instance.
271 271
272 272 Default values are instantiated when :meth:`HasTraits.__new__`
273 273 is called. Thus by the time this method gets called either the
274 274 default value or a user defined value (they called :meth:`__set__`)
275 275 is in the :class:`HasTraits` instance.
276 276 """
277 277 if obj is None:
278 278 return self
279 279 else:
280 280 try:
281 281 value = obj._trait_values[self.name]
282 282 except KeyError:
283 283 # Check for a dynamic initializer.
284 284 if self.name in obj._trait_dyn_inits:
285 285 value = obj._trait_dyn_inits[self.name](obj)
286 286 # FIXME: Do we really validate here?
287 287 value = self._validate(obj, value)
288 288 obj._trait_values[self.name] = value
289 289 return value
290 290 else:
291 291 raise TraitError('Unexpected error in TraitType: '
292 292 'both default value and dynamic initializer are '
293 293 'absent.')
294 294 except Exception:
295 295 # HasTraits should call set_default_value to populate
296 296 # this. So this should never be reached.
297 297 raise TraitError('Unexpected error in TraitType: '
298 298 'default value not set properly')
299 299 else:
300 300 return value
301 301
302 302 def __set__(self, obj, value):
303 303 new_value = self._validate(obj, value)
304 304 old_value = self.__get__(obj)
305 305 if old_value != new_value:
306 306 obj._trait_values[self.name] = new_value
307 307 obj._notify_trait(self.name, old_value, new_value)
308 308
309 309 def _validate(self, obj, value):
310 310 if hasattr(self, 'validate'):
311 311 return self.validate(obj, value)
312 312 elif hasattr(self, 'is_valid_for'):
313 313 valid = self.is_valid_for(value)
314 314 if valid:
315 315 return value
316 316 else:
317 317 raise TraitError('invalid value for type: %r' % value)
318 318 elif hasattr(self, 'value_for'):
319 319 return self.value_for(value)
320 320 else:
321 321 return value
322 322
323 323 def info(self):
324 324 return self.info_text
325 325
326 326 def error(self, obj, value):
327 327 if obj is not None:
328 328 e = "The '%s' trait of %s instance must be %s, but a value of %s was specified." \
329 329 % (self.name, class_of(obj),
330 330 self.info(), repr_type(value))
331 331 else:
332 332 e = "The '%s' trait must be %s, but a value of %r was specified." \
333 333 % (self.name, self.info(), repr_type(value))
334 334 raise TraitError(e)
335 335
336 336 def get_metadata(self, key):
337 337 return getattr(self, '_metadata', {}).get(key, None)
338 338
339 339 def set_metadata(self, key, value):
340 340 getattr(self, '_metadata', {})[key] = value
341 341
342 342
343 343 #-----------------------------------------------------------------------------
344 344 # The HasTraits implementation
345 345 #-----------------------------------------------------------------------------
346 346
347 347
348 348 class MetaHasTraits(type):
349 349 """A metaclass for HasTraits.
350 350
351 351 This metaclass makes sure that any TraitType class attributes are
352 352 instantiated and sets their name attribute.
353 353 """
354 354
355 355 def __new__(mcls, name, bases, classdict):
356 356 """Create the HasTraits class.
357 357
358 358 This instantiates all TraitTypes in the class dict and sets their
359 359 :attr:`name` attribute.
360 360 """
361 361 # print "MetaHasTraitlets (mcls, name): ", mcls, name
362 362 # print "MetaHasTraitlets (bases): ", bases
363 363 # print "MetaHasTraitlets (classdict): ", classdict
364 364 for k,v in classdict.iteritems():
365 365 if isinstance(v, TraitType):
366 366 v.name = k
367 367 elif inspect.isclass(v):
368 368 if issubclass(v, TraitType):
369 369 vinst = v()
370 370 vinst.name = k
371 371 classdict[k] = vinst
372 372 return super(MetaHasTraits, mcls).__new__(mcls, name, bases, classdict)
373 373
374 374 def __init__(cls, name, bases, classdict):
375 375 """Finish initializing the HasTraits class.
376 376
377 377 This sets the :attr:`this_class` attribute of each TraitType in the
378 378 class dict to the newly created class ``cls``.
379 379 """
380 380 for k, v in classdict.iteritems():
381 381 if isinstance(v, TraitType):
382 382 v.this_class = cls
383 383 super(MetaHasTraits, cls).__init__(name, bases, classdict)
384 384
385 385 class HasTraits(object):
386 386
387 387 __metaclass__ = MetaHasTraits
388 388
389 389 def __new__(cls, **kw):
390 390 # This is needed because in Python 2.6 object.__new__ only accepts
391 391 # the cls argument.
392 392 new_meth = super(HasTraits, cls).__new__
393 393 if new_meth is object.__new__:
394 394 inst = new_meth(cls)
395 395 else:
396 396 inst = new_meth(cls, **kw)
397 397 inst._trait_values = {}
398 398 inst._trait_notifiers = {}
399 399 inst._trait_dyn_inits = {}
400 400 # Here we tell all the TraitType instances to set their default
401 401 # values on the instance.
402 402 for key in dir(cls):
403 403 # Some descriptors raise AttributeError like zope.interface's
404 404 # __provides__ attributes even though they exist. This causes
405 405 # AttributeErrors even though they are listed in dir(cls).
406 406 try:
407 407 value = getattr(cls, key)
408 408 except AttributeError:
409 409 pass
410 410 else:
411 411 if isinstance(value, TraitType):
412 412 value.instance_init(inst)
413 413
414 414 return inst
415 415
416 416 def __init__(self, **kw):
417 417 # Allow trait values to be set using keyword arguments.
418 418 # We need to use setattr for this to trigger validation and
419 419 # notifications.
420 420 for key, value in kw.iteritems():
421 421 setattr(self, key, value)
422 422
423 423 def _notify_trait(self, name, old_value, new_value):
424 424
425 425 # First dynamic ones
426 426 callables = self._trait_notifiers.get(name,[])
427 427 more_callables = self._trait_notifiers.get('anytrait',[])
428 428 callables.extend(more_callables)
429 429
430 430 # Now static ones
431 431 try:
432 432 cb = getattr(self, '_%s_changed' % name)
433 433 except:
434 434 pass
435 435 else:
436 436 callables.append(cb)
437 437
438 438 # Call them all now
439 439 for c in callables:
440 440 # Traits catches and logs errors here. I allow them to raise
441 441 if callable(c):
442 442 argspec = inspect.getargspec(c)
443 443 nargs = len(argspec[0])
444 444 # Bound methods have an additional 'self' argument
445 445 # I don't know how to treat unbound methods, but they
446 446 # can't really be used for callbacks.
447 447 if isinstance(c, types.MethodType):
448 448 offset = -1
449 449 else:
450 450 offset = 0
451 451 if nargs + offset == 0:
452 452 c()
453 453 elif nargs + offset == 1:
454 454 c(name)
455 455 elif nargs + offset == 2:
456 456 c(name, new_value)
457 457 elif nargs + offset == 3:
458 458 c(name, old_value, new_value)
459 459 else:
460 460 raise TraitError('a trait changed callback '
461 461 'must have 0-3 arguments.')
462 462 else:
463 463 raise TraitError('a trait changed callback '
464 464 'must be callable.')
465 465
466 466
467 467 def _add_notifiers(self, handler, name):
468 468 if not self._trait_notifiers.has_key(name):
469 469 nlist = []
470 470 self._trait_notifiers[name] = nlist
471 471 else:
472 472 nlist = self._trait_notifiers[name]
473 473 if handler not in nlist:
474 474 nlist.append(handler)
475 475
476 476 def _remove_notifiers(self, handler, name):
477 477 if self._trait_notifiers.has_key(name):
478 478 nlist = self._trait_notifiers[name]
479 479 try:
480 480 index = nlist.index(handler)
481 481 except ValueError:
482 482 pass
483 483 else:
484 484 del nlist[index]
485 485
486 486 def on_trait_change(self, handler, name=None, remove=False):
487 487 """Setup a handler to be called when a trait changes.
488 488
489 489 This is used to setup dynamic notifications of trait changes.
490 490
491 491 Static handlers can be created by creating methods on a HasTraits
492 492 subclass with the naming convention '_[traitname]_changed'. Thus,
493 493 to create static handler for the trait 'a', create the method
494 494 _a_changed(self, name, old, new) (fewer arguments can be used, see
495 495 below).
496 496
497 497 Parameters
498 498 ----------
499 499 handler : callable
500 500 A callable that is called when a trait changes. Its
501 501 signature can be handler(), handler(name), handler(name, new)
502 502 or handler(name, old, new).
503 503 name : list, str, None
504 504 If None, the handler will apply to all traits. If a list
505 505 of str, handler will apply to all names in the list. If a
506 506 str, the handler will apply just to that name.
507 507 remove : bool
508 508 If False (the default), then install the handler. If True
509 509 then unintall it.
510 510 """
511 511 if remove:
512 512 names = parse_notifier_name(name)
513 513 for n in names:
514 514 self._remove_notifiers(handler, n)
515 515 else:
516 516 names = parse_notifier_name(name)
517 517 for n in names:
518 518 self._add_notifiers(handler, n)
519 519
520 520 @classmethod
521 521 def class_trait_names(cls, **metadata):
522 522 """Get a list of all the names of this classes traits.
523 523
524 524 This method is just like the :meth:`trait_names` method, but is unbound.
525 525 """
526 526 return cls.class_traits(**metadata).keys()
527 527
528 528 @classmethod
529 529 def class_traits(cls, **metadata):
530 530 """Get a list of all the traits of this class.
531 531
532 532 This method is just like the :meth:`traits` method, but is unbound.
533 533
534 534 The TraitTypes returned don't know anything about the values
535 535 that the various HasTrait's instances are holding.
536 536
537 537 This follows the same algorithm as traits does and does not allow
538 538 for any simple way of specifying merely that a metadata name
539 539 exists, but has any value. This is because get_metadata returns
540 540 None if a metadata key doesn't exist.
541 541 """
542 542 traits = dict([memb for memb in getmembers(cls) if \
543 543 isinstance(memb[1], TraitType)])
544 544
545 545 if len(metadata) == 0:
546 546 return traits
547 547
548 548 for meta_name, meta_eval in metadata.items():
549 549 if type(meta_eval) is not FunctionType:
550 550 metadata[meta_name] = _SimpleTest(meta_eval)
551 551
552 552 result = {}
553 553 for name, trait in traits.items():
554 554 for meta_name, meta_eval in metadata.items():
555 555 if not meta_eval(trait.get_metadata(meta_name)):
556 556 break
557 557 else:
558 558 result[name] = trait
559 559
560 560 return result
561 561
562 562 def trait_names(self, **metadata):
563 563 """Get a list of all the names of this classes traits."""
564 564 return self.traits(**metadata).keys()
565 565
566 566 def traits(self, **metadata):
567 567 """Get a list of all the traits of this class.
568 568
569 569 The TraitTypes returned don't know anything about the values
570 570 that the various HasTrait's instances are holding.
571 571
572 572 This follows the same algorithm as traits does and does not allow
573 573 for any simple way of specifying merely that a metadata name
574 574 exists, but has any value. This is because get_metadata returns
575 575 None if a metadata key doesn't exist.
576 576 """
577 577 traits = dict([memb for memb in getmembers(self.__class__) if \
578 578 isinstance(memb[1], TraitType)])
579 579
580 580 if len(metadata) == 0:
581 581 return traits
582 582
583 583 for meta_name, meta_eval in metadata.items():
584 584 if type(meta_eval) is not FunctionType:
585 585 metadata[meta_name] = _SimpleTest(meta_eval)
586 586
587 587 result = {}
588 588 for name, trait in traits.items():
589 589 for meta_name, meta_eval in metadata.items():
590 590 if not meta_eval(trait.get_metadata(meta_name)):
591 591 break
592 592 else:
593 593 result[name] = trait
594 594
595 595 return result
596 596
597 597 def trait_metadata(self, traitname, key):
598 598 """Get metadata values for trait by key."""
599 599 try:
600 600 trait = getattr(self.__class__, traitname)
601 601 except AttributeError:
602 602 raise TraitError("Class %s does not have a trait named %s" %
603 603 (self.__class__.__name__, traitname))
604 604 else:
605 605 return trait.get_metadata(key)
606 606
607 607 #-----------------------------------------------------------------------------
608 608 # Actual TraitTypes implementations/subclasses
609 609 #-----------------------------------------------------------------------------
610 610
611 611 #-----------------------------------------------------------------------------
612 612 # TraitTypes subclasses for handling classes and instances of classes
613 613 #-----------------------------------------------------------------------------
614 614
615 615
616 616 class ClassBasedTraitType(TraitType):
617 617 """A trait with error reporting for Type, Instance and This."""
618 618
619 619 def error(self, obj, value):
620 620 kind = type(value)
621 621 if (not py3compat.PY3) and kind is InstanceType:
622 622 msg = 'class %s' % value.__class__.__name__
623 623 else:
624 624 msg = '%s (i.e. %s)' % ( str( kind )[1:-1], repr( value ) )
625 625
626 626 if obj is not None:
627 627 e = "The '%s' trait of %s instance must be %s, but a value of %s was specified." \
628 628 % (self.name, class_of(obj),
629 629 self.info(), msg)
630 630 else:
631 631 e = "The '%s' trait must be %s, but a value of %r was specified." \
632 632 % (self.name, self.info(), msg)
633 633
634 634 raise TraitError(e)
635 635
636 636
637 637 class Type(ClassBasedTraitType):
638 638 """A trait whose value must be a subclass of a specified class."""
639 639
640 640 def __init__ (self, default_value=None, klass=None, allow_none=True, **metadata ):
641 641 """Construct a Type trait
642 642
643 643 A Type trait specifies that its values must be subclasses of
644 644 a particular class.
645 645
646 646 If only ``default_value`` is given, it is used for the ``klass`` as
647 647 well.
648 648
649 649 Parameters
650 650 ----------
651 651 default_value : class, str or None
652 652 The default value must be a subclass of klass. If an str,
653 653 the str must be a fully specified class name, like 'foo.bar.Bah'.
654 654 The string is resolved into real class, when the parent
655 655 :class:`HasTraits` class is instantiated.
656 656 klass : class, str, None
657 657 Values of this trait must be a subclass of klass. The klass
658 658 may be specified in a string like: 'foo.bar.MyClass'.
659 659 The string is resolved into real class, when the parent
660 660 :class:`HasTraits` class is instantiated.
661 661 allow_none : boolean
662 662 Indicates whether None is allowed as an assignable value. Even if
663 663 ``False``, the default value may be ``None``.
664 664 """
665 665 if default_value is None:
666 666 if klass is None:
667 667 klass = object
668 668 elif klass is None:
669 669 klass = default_value
670 670
671 671 if not (inspect.isclass(klass) or isinstance(klass, basestring)):
672 672 raise TraitError("A Type trait must specify a class.")
673 673
674 674 self.klass = klass
675 675 self._allow_none = allow_none
676 676
677 677 super(Type, self).__init__(default_value, **metadata)
678 678
679 679 def validate(self, obj, value):
680 680 """Validates that the value is a valid object instance."""
681 681 try:
682 682 if issubclass(value, self.klass):
683 683 return value
684 684 except:
685 685 if (value is None) and (self._allow_none):
686 686 return value
687 687
688 688 self.error(obj, value)
689 689
690 690 def info(self):
691 691 """ Returns a description of the trait."""
692 692 if isinstance(self.klass, basestring):
693 693 klass = self.klass
694 694 else:
695 695 klass = self.klass.__name__
696 696 result = 'a subclass of ' + klass
697 697 if self._allow_none:
698 698 return result + ' or None'
699 699 return result
700 700
701 701 def instance_init(self, obj):
702 702 self._resolve_classes()
703 703 super(Type, self).instance_init(obj)
704 704
705 705 def _resolve_classes(self):
706 706 if isinstance(self.klass, basestring):
707 707 self.klass = import_item(self.klass)
708 708 if isinstance(self.default_value, basestring):
709 709 self.default_value = import_item(self.default_value)
710 710
711 711 def get_default_value(self):
712 712 return self.default_value
713 713
714 714
715 715 class DefaultValueGenerator(object):
716 716 """A class for generating new default value instances."""
717 717
718 718 def __init__(self, *args, **kw):
719 719 self.args = args
720 720 self.kw = kw
721 721
722 722 def generate(self, klass):
723 723 return klass(*self.args, **self.kw)
724 724
725 725
726 726 class Instance(ClassBasedTraitType):
727 727 """A trait whose value must be an instance of a specified class.
728 728
729 729 The value can also be an instance of a subclass of the specified class.
730 730 """
731 731
732 732 def __init__(self, klass=None, args=None, kw=None,
733 733 allow_none=True, **metadata ):
734 734 """Construct an Instance trait.
735 735
736 736 This trait allows values that are instances of a particular
737 737 class or its sublclasses. Our implementation is quite different
738 738 from that of enthough.traits as we don't allow instances to be used
739 739 for klass and we handle the ``args`` and ``kw`` arguments differently.
740 740
741 741 Parameters
742 742 ----------
743 743 klass : class, str
744 744 The class that forms the basis for the trait. Class names
745 745 can also be specified as strings, like 'foo.bar.Bar'.
746 746 args : tuple
747 747 Positional arguments for generating the default value.
748 748 kw : dict
749 749 Keyword arguments for generating the default value.
750 750 allow_none : bool
751 751 Indicates whether None is allowed as a value.
752 752
753 753 Default Value
754 754 -------------
755 755 If both ``args`` and ``kw`` are None, then the default value is None.
756 756 If ``args`` is a tuple and ``kw`` is a dict, then the default is
757 757 created as ``klass(*args, **kw)``. If either ``args`` or ``kw`` is
758 758 not (but not both), None is replace by ``()`` or ``{}``.
759 759 """
760 760
761 761 self._allow_none = allow_none
762 762
763 763 if (klass is None) or (not (inspect.isclass(klass) or isinstance(klass, basestring))):
764 764 raise TraitError('The klass argument must be a class'
765 765 ' you gave: %r' % klass)
766 766 self.klass = klass
767 767
768 768 # self.klass is a class, so handle default_value
769 769 if args is None and kw is None:
770 770 default_value = None
771 771 else:
772 772 if args is None:
773 773 # kw is not None
774 774 args = ()
775 775 elif kw is None:
776 776 # args is not None
777 777 kw = {}
778 778
779 779 if not isinstance(kw, dict):
780 780 raise TraitError("The 'kw' argument must be a dict or None.")
781 781 if not isinstance(args, tuple):
782 782 raise TraitError("The 'args' argument must be a tuple or None.")
783 783
784 784 default_value = DefaultValueGenerator(*args, **kw)
785 785
786 786 super(Instance, self).__init__(default_value, **metadata)
787 787
788 788 def validate(self, obj, value):
789 789 if value is None:
790 790 if self._allow_none:
791 791 return value
792 792 self.error(obj, value)
793 793
794 794 if isinstance(value, self.klass):
795 795 return value
796 796 else:
797 797 self.error(obj, value)
798 798
799 799 def info(self):
800 800 if isinstance(self.klass, basestring):
801 801 klass = self.klass
802 802 else:
803 803 klass = self.klass.__name__
804 804 result = class_of(klass)
805 805 if self._allow_none:
806 806 return result + ' or None'
807 807
808 808 return result
809 809
810 810 def instance_init(self, obj):
811 811 self._resolve_classes()
812 812 super(Instance, self).instance_init(obj)
813 813
814 814 def _resolve_classes(self):
815 815 if isinstance(self.klass, basestring):
816 816 self.klass = import_item(self.klass)
817 817
818 818 def get_default_value(self):
819 819 """Instantiate a default value instance.
820 820
821 821 This is called when the containing HasTraits classes'
822 822 :meth:`__new__` method is called to ensure that a unique instance
823 823 is created for each HasTraits instance.
824 824 """
825 825 dv = self.default_value
826 826 if isinstance(dv, DefaultValueGenerator):
827 827 return dv.generate(self.klass)
828 828 else:
829 829 return dv
830 830
831 831
832 832 class This(ClassBasedTraitType):
833 833 """A trait for instances of the class containing this trait.
834 834
835 835 Because how how and when class bodies are executed, the ``This``
836 836 trait can only have a default value of None. This, and because we
837 837 always validate default values, ``allow_none`` is *always* true.
838 838 """
839 839
840 840 info_text = 'an instance of the same type as the receiver or None'
841 841
842 842 def __init__(self, **metadata):
843 843 super(This, self).__init__(None, **metadata)
844 844
845 845 def validate(self, obj, value):
846 846 # What if value is a superclass of obj.__class__? This is
847 847 # complicated if it was the superclass that defined the This
848 848 # trait.
849 849 if isinstance(value, self.this_class) or (value is None):
850 850 return value
851 851 else:
852 852 self.error(obj, value)
853 853
854 854
855 855 #-----------------------------------------------------------------------------
856 856 # Basic TraitTypes implementations/subclasses
857 857 #-----------------------------------------------------------------------------
858 858
859 859
860 860 class Any(TraitType):
861 861 default_value = None
862 862 info_text = 'any value'
863 863
864 864
865 865 class Int(TraitType):
866 866 """A integer trait."""
867 867
868 868 default_value = 0
869 869 info_text = 'an integer'
870 870
871 871 def validate(self, obj, value):
872 872 if isinstance(value, int):
873 873 return value
874 874 self.error(obj, value)
875 875
876 876 class CInt(Int):
877 877 """A casting version of the int trait."""
878 878
879 879 def validate(self, obj, value):
880 880 try:
881 881 return int(value)
882 882 except:
883 883 self.error(obj, value)
884 884
885 if not py3compat.PY3:
885 if py3compat.PY3:
886 Long, CLong = Int, CInt
887 else:
886 888 class Long(TraitType):
887 889 """A long integer trait."""
888 890
889 891 default_value = 0L
890 892 info_text = 'a long'
891 893
892 894 def validate(self, obj, value):
893 895 if isinstance(value, long):
894 896 return value
895 897 if isinstance(value, int):
896 898 return long(value)
897 899 self.error(obj, value)
898 900
899 901
900 902 class CLong(Long):
901 903 """A casting version of the long integer trait."""
902 904
903 905 def validate(self, obj, value):
904 906 try:
905 907 return long(value)
906 908 except:
907 909 self.error(obj, value)
908 910
909 911
910 912 class Float(TraitType):
911 913 """A float trait."""
912 914
913 915 default_value = 0.0
914 916 info_text = 'a float'
915 917
916 918 def validate(self, obj, value):
917 919 if isinstance(value, float):
918 920 return value
919 921 if isinstance(value, int):
920 922 return float(value)
921 923 self.error(obj, value)
922 924
923 925
924 926 class CFloat(Float):
925 927 """A casting version of the float trait."""
926 928
927 929 def validate(self, obj, value):
928 930 try:
929 931 return float(value)
930 932 except:
931 933 self.error(obj, value)
932 934
933 935 class Complex(TraitType):
934 936 """A trait for complex numbers."""
935 937
936 938 default_value = 0.0 + 0.0j
937 939 info_text = 'a complex number'
938 940
939 941 def validate(self, obj, value):
940 942 if isinstance(value, complex):
941 943 return value
942 944 if isinstance(value, (float, int)):
943 945 return complex(value)
944 946 self.error(obj, value)
945 947
946 948
947 949 class CComplex(Complex):
948 950 """A casting version of the complex number trait."""
949 951
950 952 def validate (self, obj, value):
951 953 try:
952 954 return complex(value)
953 955 except:
954 956 self.error(obj, value)
955 957
956 958 # We should always be explicit about whether we're using bytes or unicode, both
957 959 # for Python 3 conversion and for reliable unicode behaviour on Python 2. So
958 960 # we don't have a Str type.
959 961 class Bytes(TraitType):
960 962 """A trait for byte strings."""
961 963
962 964 default_value = ''
963 965 info_text = 'a string'
964 966
965 967 def validate(self, obj, value):
966 968 if isinstance(value, bytes):
967 969 return value
968 970 self.error(obj, value)
969 971
970 972
971 973 class CBytes(Bytes):
972 974 """A casting version of the byte string trait."""
973 975
974 976 def validate(self, obj, value):
975 977 try:
976 978 return bytes(value)
977 979 except:
978 980 self.error(obj, value)
979 981
980 982
981 983 class Unicode(TraitType):
982 984 """A trait for unicode strings."""
983 985
984 986 default_value = u''
985 987 info_text = 'a unicode string'
986 988
987 989 def validate(self, obj, value):
988 990 if isinstance(value, unicode):
989 991 return value
990 992 if isinstance(value, bytes):
991 993 return unicode(value)
992 994 self.error(obj, value)
993 995
994 996
995 997 class CUnicode(Unicode):
996 998 """A casting version of the unicode trait."""
997 999
998 1000 def validate(self, obj, value):
999 1001 try:
1000 1002 return unicode(value)
1001 1003 except:
1002 1004 self.error(obj, value)
1003 1005
1004 1006
1005 1007 class ObjectName(TraitType):
1006 1008 """A string holding a valid object name in this version of Python.
1007 1009
1008 1010 This does not check that the name exists in any scope."""
1009 1011 info_text = "a valid object identifier in Python"
1010 1012
1011 1013 if py3compat.PY3:
1012 1014 # Python 3:
1013 1015 coerce_str = staticmethod(lambda _,s: s)
1014 1016
1015 1017 else:
1016 1018 # Python 2:
1017 1019 def coerce_str(self, obj, value):
1018 1020 "In Python 2, coerce ascii-only unicode to str"
1019 1021 if isinstance(value, unicode):
1020 1022 try:
1021 1023 return str(value)
1022 1024 except UnicodeEncodeError:
1023 1025 self.error(obj, value)
1024 1026 return value
1025 1027
1026 1028 def validate(self, obj, value):
1027 1029 value = self.coerce_str(obj, value)
1028 1030
1029 1031 if isinstance(value, str) and py3compat.isidentifier(value):
1030 1032 return value
1031 1033 self.error(obj, value)
1032 1034
1033 1035 class DottedObjectName(ObjectName):
1034 1036 """A string holding a valid dotted object name in Python, such as A.b3._c"""
1035 1037 def validate(self, obj, value):
1036 1038 value = self.coerce_str(obj, value)
1037 1039
1038 1040 if isinstance(value, str) and py3compat.isidentifier(value, dotted=True):
1039 1041 return value
1040 1042 self.error(obj, value)
1041 1043
1042 1044
1043 1045 class Bool(TraitType):
1044 1046 """A boolean (True, False) trait."""
1045 1047
1046 1048 default_value = False
1047 1049 info_text = 'a boolean'
1048 1050
1049 1051 def validate(self, obj, value):
1050 1052 if isinstance(value, bool):
1051 1053 return value
1052 1054 self.error(obj, value)
1053 1055
1054 1056
1055 1057 class CBool(Bool):
1056 1058 """A casting version of the boolean trait."""
1057 1059
1058 1060 def validate(self, obj, value):
1059 1061 try:
1060 1062 return bool(value)
1061 1063 except:
1062 1064 self.error(obj, value)
1063 1065
1064 1066
1065 1067 class Enum(TraitType):
1066 1068 """An enum that whose value must be in a given sequence."""
1067 1069
1068 1070 def __init__(self, values, default_value=None, allow_none=True, **metadata):
1069 1071 self.values = values
1070 1072 self._allow_none = allow_none
1071 1073 super(Enum, self).__init__(default_value, **metadata)
1072 1074
1073 1075 def validate(self, obj, value):
1074 1076 if value is None:
1075 1077 if self._allow_none:
1076 1078 return value
1077 1079
1078 1080 if value in self.values:
1079 1081 return value
1080 1082 self.error(obj, value)
1081 1083
1082 1084 def info(self):
1083 1085 """ Returns a description of the trait."""
1084 1086 result = 'any of ' + repr(self.values)
1085 1087 if self._allow_none:
1086 1088 return result + ' or None'
1087 1089 return result
1088 1090
1089 1091 class CaselessStrEnum(Enum):
1090 1092 """An enum of strings that are caseless in validate."""
1091 1093
1092 1094 def validate(self, obj, value):
1093 1095 if value is None:
1094 1096 if self._allow_none:
1095 1097 return value
1096 1098
1097 1099 if not isinstance(value, basestring):
1098 1100 self.error(obj, value)
1099 1101
1100 1102 for v in self.values:
1101 1103 if v.lower() == value.lower():
1102 1104 return v
1103 1105 self.error(obj, value)
1104 1106
1105 1107 class Container(Instance):
1106 1108 """An instance of a container (list, set, etc.)
1107 1109
1108 1110 To be subclassed by overriding klass.
1109 1111 """
1110 1112 klass = None
1111 1113 _valid_defaults = SequenceTypes
1112 1114 _trait = None
1113 1115
1114 1116 def __init__(self, trait=None, default_value=None, allow_none=True,
1115 1117 **metadata):
1116 1118 """Create a container trait type from a list, set, or tuple.
1117 1119
1118 1120 The default value is created by doing ``List(default_value)``,
1119 1121 which creates a copy of the ``default_value``.
1120 1122
1121 1123 ``trait`` can be specified, which restricts the type of elements
1122 1124 in the container to that TraitType.
1123 1125
1124 1126 If only one arg is given and it is not a Trait, it is taken as
1125 1127 ``default_value``:
1126 1128
1127 1129 ``c = List([1,2,3])``
1128 1130
1129 1131 Parameters
1130 1132 ----------
1131 1133
1132 1134 trait : TraitType [ optional ]
1133 1135 the type for restricting the contents of the Container. If unspecified,
1134 1136 types are not checked.
1135 1137
1136 1138 default_value : SequenceType [ optional ]
1137 1139 The default value for the Trait. Must be list/tuple/set, and
1138 1140 will be cast to the container type.
1139 1141
1140 1142 allow_none : Bool [ default True ]
1141 1143 Whether to allow the value to be None
1142 1144
1143 1145 **metadata : any
1144 1146 further keys for extensions to the Trait (e.g. config)
1145 1147
1146 1148 """
1147 1149 istrait = lambda t: isinstance(t, type) and issubclass(t, TraitType)
1148 1150
1149 1151 # allow List([values]):
1150 1152 if default_value is None and not istrait(trait):
1151 1153 default_value = trait
1152 1154 trait = None
1153 1155
1154 1156 if default_value is None:
1155 1157 args = ()
1156 1158 elif isinstance(default_value, self._valid_defaults):
1157 1159 args = (default_value,)
1158 1160 else:
1159 1161 raise TypeError('default value of %s was %s' %(self.__class__.__name__, default_value))
1160 1162
1161 1163 if istrait(trait):
1162 1164 self._trait = trait()
1163 1165 self._trait.name = 'element'
1164 1166 elif trait is not None:
1165 1167 raise TypeError("`trait` must be a Trait or None, got %s"%repr_type(trait))
1166 1168
1167 1169 super(Container,self).__init__(klass=self.klass, args=args,
1168 1170 allow_none=allow_none, **metadata)
1169 1171
1170 1172 def element_error(self, obj, element, validator):
1171 1173 e = "Element of the '%s' trait of %s instance must be %s, but a value of %s was specified." \
1172 1174 % (self.name, class_of(obj), validator.info(), repr_type(element))
1173 1175 raise TraitError(e)
1174 1176
1175 1177 def validate(self, obj, value):
1176 1178 value = super(Container, self).validate(obj, value)
1177 1179 if value is None:
1178 1180 return value
1179 1181
1180 1182 value = self.validate_elements(obj, value)
1181 1183
1182 1184 return value
1183 1185
1184 1186 def validate_elements(self, obj, value):
1185 1187 validated = []
1186 1188 if self._trait is None or isinstance(self._trait, Any):
1187 1189 return value
1188 1190 for v in value:
1189 1191 try:
1190 1192 v = self._trait.validate(obj, v)
1191 1193 except TraitError:
1192 1194 self.element_error(obj, v, self._trait)
1193 1195 else:
1194 1196 validated.append(v)
1195 1197 return self.klass(validated)
1196 1198
1197 1199
1198 1200 class List(Container):
1199 1201 """An instance of a Python list."""
1200 1202 klass = list
1201 1203
1202 1204 def __init__(self, trait=None, default_value=None, minlen=0, maxlen=sys.maxint,
1203 1205 allow_none=True, **metadata):
1204 1206 """Create a List trait type from a list, set, or tuple.
1205 1207
1206 1208 The default value is created by doing ``List(default_value)``,
1207 1209 which creates a copy of the ``default_value``.
1208 1210
1209 1211 ``trait`` can be specified, which restricts the type of elements
1210 1212 in the container to that TraitType.
1211 1213
1212 1214 If only one arg is given and it is not a Trait, it is taken as
1213 1215 ``default_value``:
1214 1216
1215 1217 ``c = List([1,2,3])``
1216 1218
1217 1219 Parameters
1218 1220 ----------
1219 1221
1220 1222 trait : TraitType [ optional ]
1221 1223 the type for restricting the contents of the Container. If unspecified,
1222 1224 types are not checked.
1223 1225
1224 1226 default_value : SequenceType [ optional ]
1225 1227 The default value for the Trait. Must be list/tuple/set, and
1226 1228 will be cast to the container type.
1227 1229
1228 1230 minlen : Int [ default 0 ]
1229 1231 The minimum length of the input list
1230 1232
1231 1233 maxlen : Int [ default sys.maxint ]
1232 1234 The maximum length of the input list
1233 1235
1234 1236 allow_none : Bool [ default True ]
1235 1237 Whether to allow the value to be None
1236 1238
1237 1239 **metadata : any
1238 1240 further keys for extensions to the Trait (e.g. config)
1239 1241
1240 1242 """
1241 1243 self._minlen = minlen
1242 1244 self._maxlen = maxlen
1243 1245 super(List, self).__init__(trait=trait, default_value=default_value,
1244 1246 allow_none=allow_none, **metadata)
1245 1247
1246 1248 def length_error(self, obj, value):
1247 1249 e = "The '%s' trait of %s instance must be of length %i <= L <= %i, but a value of %s was specified." \
1248 1250 % (self.name, class_of(obj), self._minlen, self._maxlen, value)
1249 1251 raise TraitError(e)
1250 1252
1251 1253 def validate_elements(self, obj, value):
1252 1254 length = len(value)
1253 1255 if length < self._minlen or length > self._maxlen:
1254 1256 self.length_error(obj, value)
1255 1257
1256 1258 return super(List, self).validate_elements(obj, value)
1257 1259
1258 1260
1259 1261 class Set(Container):
1260 1262 """An instance of a Python set."""
1261 1263 klass = set
1262 1264
1263 1265 class Tuple(Container):
1264 1266 """An instance of a Python tuple."""
1265 1267 klass = tuple
1266 1268
1267 1269 def __init__(self, *traits, **metadata):
1268 1270 """Tuple(*traits, default_value=None, allow_none=True, **medatata)
1269 1271
1270 1272 Create a tuple from a list, set, or tuple.
1271 1273
1272 1274 Create a fixed-type tuple with Traits:
1273 1275
1274 1276 ``t = Tuple(Int, Str, CStr)``
1275 1277
1276 1278 would be length 3, with Int,Str,CStr for each element.
1277 1279
1278 1280 If only one arg is given and it is not a Trait, it is taken as
1279 1281 default_value:
1280 1282
1281 1283 ``t = Tuple((1,2,3))``
1282 1284
1283 1285 Otherwise, ``default_value`` *must* be specified by keyword.
1284 1286
1285 1287 Parameters
1286 1288 ----------
1287 1289
1288 1290 *traits : TraitTypes [ optional ]
1289 1291 the tsype for restricting the contents of the Tuple. If unspecified,
1290 1292 types are not checked. If specified, then each positional argument
1291 1293 corresponds to an element of the tuple. Tuples defined with traits
1292 1294 are of fixed length.
1293 1295
1294 1296 default_value : SequenceType [ optional ]
1295 1297 The default value for the Tuple. Must be list/tuple/set, and
1296 1298 will be cast to a tuple. If `traits` are specified, the
1297 1299 `default_value` must conform to the shape and type they specify.
1298 1300
1299 1301 allow_none : Bool [ default True ]
1300 1302 Whether to allow the value to be None
1301 1303
1302 1304 **metadata : any
1303 1305 further keys for extensions to the Trait (e.g. config)
1304 1306
1305 1307 """
1306 1308 default_value = metadata.pop('default_value', None)
1307 1309 allow_none = metadata.pop('allow_none', True)
1308 1310
1309 1311 istrait = lambda t: isinstance(t, type) and issubclass(t, TraitType)
1310 1312
1311 1313 # allow Tuple((values,)):
1312 1314 if len(traits) == 1 and default_value is None and not istrait(traits[0]):
1313 1315 default_value = traits[0]
1314 1316 traits = ()
1315 1317
1316 1318 if default_value is None:
1317 1319 args = ()
1318 1320 elif isinstance(default_value, self._valid_defaults):
1319 1321 args = (default_value,)
1320 1322 else:
1321 1323 raise TypeError('default value of %s was %s' %(self.__class__.__name__, default_value))
1322 1324
1323 1325 self._traits = []
1324 1326 for trait in traits:
1325 1327 t = trait()
1326 1328 t.name = 'element'
1327 1329 self._traits.append(t)
1328 1330
1329 1331 if self._traits and default_value is None:
1330 1332 # don't allow default to be an empty container if length is specified
1331 1333 args = None
1332 1334 super(Container,self).__init__(klass=self.klass, args=args,
1333 1335 allow_none=allow_none, **metadata)
1334 1336
1335 1337 def validate_elements(self, obj, value):
1336 1338 if not self._traits:
1337 1339 # nothing to validate
1338 1340 return value
1339 1341 if len(value) != len(self._traits):
1340 1342 e = "The '%s' trait of %s instance requires %i elements, but a value of %s was specified." \
1341 1343 % (self.name, class_of(obj), len(self._traits), repr_type(value))
1342 1344 raise TraitError(e)
1343 1345
1344 1346 validated = []
1345 1347 for t,v in zip(self._traits, value):
1346 1348 try:
1347 1349 v = t.validate(obj, v)
1348 1350 except TraitError:
1349 1351 self.element_error(obj, v, t)
1350 1352 else:
1351 1353 validated.append(v)
1352 1354 return tuple(validated)
1353 1355
1354 1356
1355 1357 class Dict(Instance):
1356 1358 """An instance of a Python dict."""
1357 1359
1358 1360 def __init__(self, default_value=None, allow_none=True, **metadata):
1359 1361 """Create a dict trait type from a dict.
1360 1362
1361 1363 The default value is created by doing ``dict(default_value)``,
1362 1364 which creates a copy of the ``default_value``.
1363 1365 """
1364 1366 if default_value is None:
1365 1367 args = ((),)
1366 1368 elif isinstance(default_value, dict):
1367 1369 args = (default_value,)
1368 1370 elif isinstance(default_value, SequenceTypes):
1369 1371 args = (default_value,)
1370 1372 else:
1371 1373 raise TypeError('default value of Dict was %s' % default_value)
1372 1374
1373 1375 super(Dict,self).__init__(klass=dict, args=args,
1374 1376 allow_none=allow_none, **metadata)
1375 1377
1376 1378 class TCPAddress(TraitType):
1377 1379 """A trait for an (ip, port) tuple.
1378 1380
1379 1381 This allows for both IPv4 IP addresses as well as hostnames.
1380 1382 """
1381 1383
1382 1384 default_value = ('127.0.0.1', 0)
1383 1385 info_text = 'an (ip, port) tuple'
1384 1386
1385 1387 def validate(self, obj, value):
1386 1388 if isinstance(value, tuple):
1387 1389 if len(value) == 2:
1388 1390 if isinstance(value[0], basestring) and isinstance(value[1], int):
1389 1391 port = value[1]
1390 1392 if port >= 0 and port <= 65535:
1391 1393 return value
1392 1394 self.error(obj, value)
General Comments 0
You need to be logged in to leave comments. Login now