##// END OF EJS Templates
Various Python 3 fixes in IPython.utils
Thomas Kluyver -
Show More
@@ -1,161 +1,164
1 """Utilities to manipulate JSON objects.
1 """Utilities to manipulate JSON objects.
2 """
2 """
3 #-----------------------------------------------------------------------------
3 #-----------------------------------------------------------------------------
4 # Copyright (C) 2010 The IPython Development Team
4 # Copyright (C) 2010 The IPython Development Team
5 #
5 #
6 # Distributed under the terms of the BSD License. The full license is in
6 # Distributed under the terms of the BSD License. The full license is in
7 # the file COPYING.txt, distributed as part of this software.
7 # the file COPYING.txt, distributed as part of this software.
8 #-----------------------------------------------------------------------------
8 #-----------------------------------------------------------------------------
9
9
10 #-----------------------------------------------------------------------------
10 #-----------------------------------------------------------------------------
11 # Imports
11 # Imports
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13 # stdlib
13 # stdlib
14 import re
14 import re
15 import sys
15 import sys
16 import types
16 import types
17 from datetime import datetime
17 from datetime import datetime
18
18
19 from IPython.utils import py3compat
20 next_attr_name = '__next__' if py3compat.PY3 else 'next'
21
19 #-----------------------------------------------------------------------------
22 #-----------------------------------------------------------------------------
20 # Globals and constants
23 # Globals and constants
21 #-----------------------------------------------------------------------------
24 #-----------------------------------------------------------------------------
22
25
23 # timestamp formats
26 # timestamp formats
24 ISO8601="%Y-%m-%dT%H:%M:%S.%f"
27 ISO8601="%Y-%m-%dT%H:%M:%S.%f"
25 ISO8601_PAT=re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d+$")
28 ISO8601_PAT=re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d+$")
26
29
27 #-----------------------------------------------------------------------------
30 #-----------------------------------------------------------------------------
28 # Classes and functions
31 # Classes and functions
29 #-----------------------------------------------------------------------------
32 #-----------------------------------------------------------------------------
30
33
31 def rekey(dikt):
34 def rekey(dikt):
32 """Rekey a dict that has been forced to use str keys where there should be
35 """Rekey a dict that has been forced to use str keys where there should be
33 ints by json."""
36 ints by json."""
34 for k in dikt.iterkeys():
37 for k in dikt.iterkeys():
35 if isinstance(k, basestring):
38 if isinstance(k, basestring):
36 ik=fk=None
39 ik=fk=None
37 try:
40 try:
38 ik = int(k)
41 ik = int(k)
39 except ValueError:
42 except ValueError:
40 try:
43 try:
41 fk = float(k)
44 fk = float(k)
42 except ValueError:
45 except ValueError:
43 continue
46 continue
44 if ik is not None:
47 if ik is not None:
45 nk = ik
48 nk = ik
46 else:
49 else:
47 nk = fk
50 nk = fk
48 if nk in dikt:
51 if nk in dikt:
49 raise KeyError("already have key %r"%nk)
52 raise KeyError("already have key %r"%nk)
50 dikt[nk] = dikt.pop(k)
53 dikt[nk] = dikt.pop(k)
51 return dikt
54 return dikt
52
55
53
56
54 def extract_dates(obj):
57 def extract_dates(obj):
55 """extract ISO8601 dates from unpacked JSON"""
58 """extract ISO8601 dates from unpacked JSON"""
56 if isinstance(obj, dict):
59 if isinstance(obj, dict):
57 obj = dict(obj) # don't clobber
60 obj = dict(obj) # don't clobber
58 for k,v in obj.iteritems():
61 for k,v in obj.iteritems():
59 obj[k] = extract_dates(v)
62 obj[k] = extract_dates(v)
60 elif isinstance(obj, (list, tuple)):
63 elif isinstance(obj, (list, tuple)):
61 obj = [ extract_dates(o) for o in obj ]
64 obj = [ extract_dates(o) for o in obj ]
62 elif isinstance(obj, basestring):
65 elif isinstance(obj, basestring):
63 if ISO8601_PAT.match(obj):
66 if ISO8601_PAT.match(obj):
64 obj = datetime.strptime(obj, ISO8601)
67 obj = datetime.strptime(obj, ISO8601)
65 return obj
68 return obj
66
69
67 def squash_dates(obj):
70 def squash_dates(obj):
68 """squash datetime objects into ISO8601 strings"""
71 """squash datetime objects into ISO8601 strings"""
69 if isinstance(obj, dict):
72 if isinstance(obj, dict):
70 obj = dict(obj) # don't clobber
73 obj = dict(obj) # don't clobber
71 for k,v in obj.iteritems():
74 for k,v in obj.iteritems():
72 obj[k] = squash_dates(v)
75 obj[k] = squash_dates(v)
73 elif isinstance(obj, (list, tuple)):
76 elif isinstance(obj, (list, tuple)):
74 obj = [ squash_dates(o) for o in obj ]
77 obj = [ squash_dates(o) for o in obj ]
75 elif isinstance(obj, datetime):
78 elif isinstance(obj, datetime):
76 obj = obj.strftime(ISO8601)
79 obj = obj.strftime(ISO8601)
77 return obj
80 return obj
78
81
79 def date_default(obj):
82 def date_default(obj):
80 """default function for packing datetime objects in JSON."""
83 """default function for packing datetime objects in JSON."""
81 if isinstance(obj, datetime):
84 if isinstance(obj, datetime):
82 return obj.strftime(ISO8601)
85 return obj.strftime(ISO8601)
83 else:
86 else:
84 raise TypeError("%r is not JSON serializable"%obj)
87 raise TypeError("%r is not JSON serializable"%obj)
85
88
86
89
87
90
88 def json_clean(obj):
91 def json_clean(obj):
89 """Clean an object to ensure it's safe to encode in JSON.
92 """Clean an object to ensure it's safe to encode in JSON.
90
93
91 Atomic, immutable objects are returned unmodified. Sets and tuples are
94 Atomic, immutable objects are returned unmodified. Sets and tuples are
92 converted to lists, lists are copied and dicts are also copied.
95 converted to lists, lists are copied and dicts are also copied.
93
96
94 Note: dicts whose keys could cause collisions upon encoding (such as a dict
97 Note: dicts whose keys could cause collisions upon encoding (such as a dict
95 with both the number 1 and the string '1' as keys) will cause a ValueError
98 with both the number 1 and the string '1' as keys) will cause a ValueError
96 to be raised.
99 to be raised.
97
100
98 Parameters
101 Parameters
99 ----------
102 ----------
100 obj : any python object
103 obj : any python object
101
104
102 Returns
105 Returns
103 -------
106 -------
104 out : object
107 out : object
105
108
106 A version of the input which will not cause an encoding error when
109 A version of the input which will not cause an encoding error when
107 encoded as JSON. Note that this function does not *encode* its inputs,
110 encoded as JSON. Note that this function does not *encode* its inputs,
108 it simply sanitizes it so that there will be no encoding errors later.
111 it simply sanitizes it so that there will be no encoding errors later.
109
112
110 Examples
113 Examples
111 --------
114 --------
112 >>> json_clean(4)
115 >>> json_clean(4)
113 4
116 4
114 >>> json_clean(range(10))
117 >>> json_clean(range(10))
115 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
118 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
116 >>> json_clean(dict(x=1, y=2))
119 >>> json_clean(dict(x=1, y=2))
117 {'y': 2, 'x': 1}
120 {'y': 2, 'x': 1}
118 >>> json_clean(dict(x=1, y=2, z=[1,2,3]))
121 >>> json_clean(dict(x=1, y=2, z=[1,2,3]))
119 {'y': 2, 'x': 1, 'z': [1, 2, 3]}
122 {'y': 2, 'x': 1, 'z': [1, 2, 3]}
120 >>> json_clean(True)
123 >>> json_clean(True)
121 True
124 True
122 """
125 """
123 # types that are 'atomic' and ok in json as-is. bool doesn't need to be
126 # types that are 'atomic' and ok in json as-is. bool doesn't need to be
124 # listed explicitly because bools pass as int instances
127 # listed explicitly because bools pass as int instances
125 atomic_ok = (unicode, int, float, types.NoneType)
128 atomic_ok = (unicode, int, float, types.NoneType)
126
129
127 # containers that we need to convert into lists
130 # containers that we need to convert into lists
128 container_to_list = (tuple, set, types.GeneratorType)
131 container_to_list = (tuple, set, types.GeneratorType)
129
132
130 if isinstance(obj, atomic_ok):
133 if isinstance(obj, atomic_ok):
131 return obj
134 return obj
132
135
133 if isinstance(obj, bytes):
136 if isinstance(obj, bytes):
134 return obj.decode(sys.getdefaultencoding(), 'replace')
137 return obj.decode(sys.getdefaultencoding(), 'replace')
135
138
136 if isinstance(obj, container_to_list) or (
139 if isinstance(obj, container_to_list) or (
137 hasattr(obj, '__iter__') and hasattr(obj, 'next')):
140 hasattr(obj, '__iter__') and hasattr(obj, next_attr_name)):
138 obj = list(obj)
141 obj = list(obj)
139
142
140 if isinstance(obj, list):
143 if isinstance(obj, list):
141 return [json_clean(x) for x in obj]
144 return [json_clean(x) for x in obj]
142
145
143 if isinstance(obj, dict):
146 if isinstance(obj, dict):
144 # First, validate that the dict won't lose data in conversion due to
147 # First, validate that the dict won't lose data in conversion due to
145 # key collisions after stringification. This can happen with keys like
148 # key collisions after stringification. This can happen with keys like
146 # True and 'true' or 1 and '1', which collide in JSON.
149 # True and 'true' or 1 and '1', which collide in JSON.
147 nkeys = len(obj)
150 nkeys = len(obj)
148 nkeys_collapsed = len(set(map(str, obj)))
151 nkeys_collapsed = len(set(map(str, obj)))
149 if nkeys != nkeys_collapsed:
152 if nkeys != nkeys_collapsed:
150 raise ValueError('dict can not be safely converted to JSON: '
153 raise ValueError('dict can not be safely converted to JSON: '
151 'key collision would lead to dropped values')
154 'key collision would lead to dropped values')
152 # If all OK, proceed by making the new dict that will be json-safe
155 # If all OK, proceed by making the new dict that will be json-safe
153 out = {}
156 out = {}
154 for k,v in obj.iteritems():
157 for k,v in obj.iteritems():
155 out[str(k)] = json_clean(v)
158 out[str(k)] = json_clean(v)
156 return out
159 return out
157
160
158 # If we get here, we don't know how to handle the object, so we just get
161 # If we get here, we don't know how to handle the object, so we just get
159 # its repr and return that. This will catch lambdas, open sockets, class
162 # its repr and return that. This will catch lambdas, open sockets, class
160 # objects, and any other complicated contraption that json can't encode
163 # objects, and any other complicated contraption that json can't encode
161 return repr(obj)
164 return repr(obj)
@@ -1,147 +1,147
1 # encoding: utf-8
1 # encoding: utf-8
2 """
2 """
3 Utilities for working with external processes.
3 Utilities for working with external processes.
4 """
4 """
5
5
6 #-----------------------------------------------------------------------------
6 #-----------------------------------------------------------------------------
7 # Copyright (C) 2008-2009 The IPython Development Team
7 # Copyright (C) 2008-2009 The IPython Development Team
8 #
8 #
9 # Distributed under the terms of the BSD License. The full license is in
9 # Distributed under the terms of the BSD License. The full license is in
10 # the file COPYING, distributed as part of this software.
10 # the file COPYING, distributed as part of this software.
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12
12
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14 # Imports
14 # Imports
15 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
16 from __future__ import print_function
16 from __future__ import print_function
17
17
18 # Stdlib
18 # Stdlib
19 import os
19 import os
20 import sys
20 import sys
21 import shlex
21 import shlex
22
22
23 # Our own
23 # Our own
24 if sys.platform == 'win32':
24 if sys.platform == 'win32':
25 from ._process_win32 import _find_cmd, system, getoutput, AvoidUNCPath
25 from ._process_win32 import _find_cmd, system, getoutput, AvoidUNCPath
26 else:
26 else:
27 from ._process_posix import _find_cmd, system, getoutput
27 from ._process_posix import _find_cmd, system, getoutput
28
28
29 from ._process_common import getoutputerror
29 from ._process_common import getoutputerror
30 from IPython.utils import py3compat
30 from IPython.utils import py3compat
31
31
32 #-----------------------------------------------------------------------------
32 #-----------------------------------------------------------------------------
33 # Code
33 # Code
34 #-----------------------------------------------------------------------------
34 #-----------------------------------------------------------------------------
35
35
36
36
37 class FindCmdError(Exception):
37 class FindCmdError(Exception):
38 pass
38 pass
39
39
40
40
41 def find_cmd(cmd):
41 def find_cmd(cmd):
42 """Find absolute path to executable cmd in a cross platform manner.
42 """Find absolute path to executable cmd in a cross platform manner.
43
43
44 This function tries to determine the full path to a command line program
44 This function tries to determine the full path to a command line program
45 using `which` on Unix/Linux/OS X and `win32api` on Windows. Most of the
45 using `which` on Unix/Linux/OS X and `win32api` on Windows. Most of the
46 time it will use the version that is first on the users `PATH`. If
46 time it will use the version that is first on the users `PATH`. If
47 cmd is `python` return `sys.executable`.
47 cmd is `python` return `sys.executable`.
48
48
49 Warning, don't use this to find IPython command line programs as there
49 Warning, don't use this to find IPython command line programs as there
50 is a risk you will find the wrong one. Instead find those using the
50 is a risk you will find the wrong one. Instead find those using the
51 following code and looking for the application itself::
51 following code and looking for the application itself::
52
52
53 from IPython.utils.path import get_ipython_module_path
53 from IPython.utils.path import get_ipython_module_path
54 from IPython.utils.process import pycmd2argv
54 from IPython.utils.process import pycmd2argv
55 argv = pycmd2argv(get_ipython_module_path('IPython.frontend.terminal.ipapp'))
55 argv = pycmd2argv(get_ipython_module_path('IPython.frontend.terminal.ipapp'))
56
56
57 Parameters
57 Parameters
58 ----------
58 ----------
59 cmd : str
59 cmd : str
60 The command line program to look for.
60 The command line program to look for.
61 """
61 """
62 if cmd == 'python':
62 if cmd == 'python':
63 return os.path.abspath(sys.executable)
63 return os.path.abspath(sys.executable)
64 try:
64 try:
65 path = _find_cmd(cmd).rstrip()
65 path = _find_cmd(cmd).rstrip()
66 except OSError:
66 except OSError:
67 raise FindCmdError('command could not be found: %s' % cmd)
67 raise FindCmdError('command could not be found: %s' % cmd)
68 # which returns empty if not found
68 # which returns empty if not found
69 if path == '':
69 if path == b'':
70 raise FindCmdError('command could not be found: %s' % cmd)
70 raise FindCmdError('command could not be found: %s' % cmd)
71 return os.path.abspath(path)
71 return os.path.abspath(path)
72
72
73
73
74 def pycmd2argv(cmd):
74 def pycmd2argv(cmd):
75 r"""Take the path of a python command and return a list (argv-style).
75 r"""Take the path of a python command and return a list (argv-style).
76
76
77 This only works on Python based command line programs and will find the
77 This only works on Python based command line programs and will find the
78 location of the ``python`` executable using ``sys.executable`` to make
78 location of the ``python`` executable using ``sys.executable`` to make
79 sure the right version is used.
79 sure the right version is used.
80
80
81 For a given path ``cmd``, this returns [cmd] if cmd's extension is .exe,
81 For a given path ``cmd``, this returns [cmd] if cmd's extension is .exe,
82 .com or .bat, and [, cmd] otherwise.
82 .com or .bat, and [, cmd] otherwise.
83
83
84 Parameters
84 Parameters
85 ----------
85 ----------
86 cmd : string
86 cmd : string
87 The path of the command.
87 The path of the command.
88
88
89 Returns
89 Returns
90 -------
90 -------
91 argv-style list.
91 argv-style list.
92 """
92 """
93 ext = os.path.splitext(cmd)[1]
93 ext = os.path.splitext(cmd)[1]
94 if ext in ['.exe', '.com', '.bat']:
94 if ext in ['.exe', '.com', '.bat']:
95 return [cmd]
95 return [cmd]
96 else:
96 else:
97 if sys.platform == 'win32':
97 if sys.platform == 'win32':
98 # The -u option here turns on unbuffered output, which is required
98 # The -u option here turns on unbuffered output, which is required
99 # on Win32 to prevent wierd conflict and problems with Twisted.
99 # on Win32 to prevent wierd conflict and problems with Twisted.
100 # Also, use sys.executable to make sure we are picking up the
100 # Also, use sys.executable to make sure we are picking up the
101 # right python exe.
101 # right python exe.
102 return [sys.executable, '-u', cmd]
102 return [sys.executable, '-u', cmd]
103 else:
103 else:
104 return [sys.executable, cmd]
104 return [sys.executable, cmd]
105
105
106
106
107 def arg_split(s, posix=False):
107 def arg_split(s, posix=False):
108 """Split a command line's arguments in a shell-like manner.
108 """Split a command line's arguments in a shell-like manner.
109
109
110 This is a modified version of the standard library's shlex.split()
110 This is a modified version of the standard library's shlex.split()
111 function, but with a default of posix=False for splitting, so that quotes
111 function, but with a default of posix=False for splitting, so that quotes
112 in inputs are respected."""
112 in inputs are respected."""
113
113
114 # Unfortunately, python's shlex module is buggy with unicode input:
114 # Unfortunately, python's shlex module is buggy with unicode input:
115 # http://bugs.python.org/issue1170
115 # http://bugs.python.org/issue1170
116 # At least encoding the input when it's unicode seems to help, but there
116 # At least encoding the input when it's unicode seems to help, but there
117 # may be more problems lurking. Apparently this is fixed in python3.
117 # may be more problems lurking. Apparently this is fixed in python3.
118 is_unicode = False
118 is_unicode = False
119 if (not py3compat.PY3) and isinstance(s, unicode):
119 if (not py3compat.PY3) and isinstance(s, unicode):
120 is_unicode = True
120 is_unicode = True
121 s = s.encode('utf-8')
121 s = s.encode('utf-8')
122 lex = shlex.shlex(s, posix=posix)
122 lex = shlex.shlex(s, posix=posix)
123 lex.whitespace_split = True
123 lex.whitespace_split = True
124 tokens = list(lex)
124 tokens = list(lex)
125 if is_unicode:
125 if is_unicode:
126 # Convert the tokens back to unicode.
126 # Convert the tokens back to unicode.
127 tokens = [x.decode('utf-8') for x in tokens]
127 tokens = [x.decode('utf-8') for x in tokens]
128 return tokens
128 return tokens
129
129
130
130
131 def abbrev_cwd():
131 def abbrev_cwd():
132 """ Return abbreviated version of cwd, e.g. d:mydir """
132 """ Return abbreviated version of cwd, e.g. d:mydir """
133 cwd = os.getcwdu().replace('\\','/')
133 cwd = os.getcwdu().replace('\\','/')
134 drivepart = ''
134 drivepart = ''
135 tail = cwd
135 tail = cwd
136 if sys.platform == 'win32':
136 if sys.platform == 'win32':
137 if len(cwd) < 4:
137 if len(cwd) < 4:
138 return cwd
138 return cwd
139 drivepart,tail = os.path.splitdrive(cwd)
139 drivepart,tail = os.path.splitdrive(cwd)
140
140
141
141
142 parts = tail.split('/')
142 parts = tail.split('/')
143 if len(parts) > 2:
143 if len(parts) > 2:
144 tail = '/'.join(parts[-2:])
144 tail = '/'.join(parts[-2:])
145
145
146 return (drivepart + (
146 return (drivepart + (
147 cwd == '/' and '/' or tail))
147 cwd == '/' and '/' or tail))
@@ -1,443 +1,450
1 # encoding: utf-8
1 # encoding: utf-8
2 """Tests for IPython.utils.path.py"""
2 """Tests for IPython.utils.path.py"""
3
3
4 #-----------------------------------------------------------------------------
4 #-----------------------------------------------------------------------------
5 # Copyright (C) 2008 The IPython Development Team
5 # Copyright (C) 2008 The IPython Development Team
6 #
6 #
7 # Distributed under the terms of the BSD License. The full license is in
7 # Distributed under the terms of the BSD License. The full license is in
8 # the file COPYING, distributed as part of this software.
8 # the file COPYING, distributed as part of this software.
9 #-----------------------------------------------------------------------------
9 #-----------------------------------------------------------------------------
10
10
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12 # Imports
12 # Imports
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14
14
15 from __future__ import with_statement
15 from __future__ import with_statement
16
16
17 import os
17 import os
18 import shutil
18 import shutil
19 import sys
19 import sys
20 import tempfile
20 import tempfile
21 import StringIO
21 import StringIO
22
22
23 from os.path import join, abspath, split
23 from os.path import join, abspath, split
24
24
25 import nose.tools as nt
25 import nose.tools as nt
26
26
27 from nose import with_setup
27 from nose import with_setup
28
28
29 import IPython
29 import IPython
30 from IPython.testing import decorators as dec
30 from IPython.testing import decorators as dec
31 from IPython.testing.decorators import skip_if_not_win32, skip_win32
31 from IPython.testing.decorators import skip_if_not_win32, skip_win32
32 from IPython.testing.tools import make_tempfile
32 from IPython.testing.tools import make_tempfile
33 from IPython.utils import path, io
33 from IPython.utils import path, io
34 from IPython.utils import py3compat
34
35
35 # Platform-dependent imports
36 # Platform-dependent imports
36 try:
37 try:
37 import _winreg as wreg
38 import _winreg as wreg
38 except ImportError:
39 except ImportError:
39 #Fake _winreg module on none windows platforms
40 #Fake _winreg module on none windows platforms
40 import new
41 import types
41 sys.modules["_winreg"] = new.module("_winreg")
42 wr_name = "winreg" if py3compat.PY3 else "_winreg"
43 sys.modules[wr_name] = types.ModuleType(wr_name)
42 import _winreg as wreg
44 import _winreg as wreg
43 #Add entries that needs to be stubbed by the testing code
45 #Add entries that needs to be stubbed by the testing code
44 (wreg.OpenKey, wreg.QueryValueEx,) = (None, None)
46 (wreg.OpenKey, wreg.QueryValueEx,) = (None, None)
45
47
48 try:
49 reload
50 except NameError: # Python 3
51 from imp import reload
52
46 #-----------------------------------------------------------------------------
53 #-----------------------------------------------------------------------------
47 # Globals
54 # Globals
48 #-----------------------------------------------------------------------------
55 #-----------------------------------------------------------------------------
49 env = os.environ
56 env = os.environ
50 TEST_FILE_PATH = split(abspath(__file__))[0]
57 TEST_FILE_PATH = split(abspath(__file__))[0]
51 TMP_TEST_DIR = tempfile.mkdtemp()
58 TMP_TEST_DIR = tempfile.mkdtemp()
52 HOME_TEST_DIR = join(TMP_TEST_DIR, "home_test_dir")
59 HOME_TEST_DIR = join(TMP_TEST_DIR, "home_test_dir")
53 XDG_TEST_DIR = join(HOME_TEST_DIR, "xdg_test_dir")
60 XDG_TEST_DIR = join(HOME_TEST_DIR, "xdg_test_dir")
54 IP_TEST_DIR = join(HOME_TEST_DIR,'.ipython')
61 IP_TEST_DIR = join(HOME_TEST_DIR,'.ipython')
55 #
62 #
56 # Setup/teardown functions/decorators
63 # Setup/teardown functions/decorators
57 #
64 #
58
65
59 def setup():
66 def setup():
60 """Setup testenvironment for the module:
67 """Setup testenvironment for the module:
61
68
62 - Adds dummy home dir tree
69 - Adds dummy home dir tree
63 """
70 """
64 # Do not mask exceptions here. In particular, catching WindowsError is a
71 # Do not mask exceptions here. In particular, catching WindowsError is a
65 # problem because that exception is only defined on Windows...
72 # problem because that exception is only defined on Windows...
66 os.makedirs(IP_TEST_DIR)
73 os.makedirs(IP_TEST_DIR)
67 os.makedirs(os.path.join(XDG_TEST_DIR, 'ipython'))
74 os.makedirs(os.path.join(XDG_TEST_DIR, 'ipython'))
68
75
69
76
70 def teardown():
77 def teardown():
71 """Teardown testenvironment for the module:
78 """Teardown testenvironment for the module:
72
79
73 - Remove dummy home dir tree
80 - Remove dummy home dir tree
74 """
81 """
75 # Note: we remove the parent test dir, which is the root of all test
82 # Note: we remove the parent test dir, which is the root of all test
76 # subdirs we may have created. Use shutil instead of os.removedirs, so
83 # subdirs we may have created. Use shutil instead of os.removedirs, so
77 # that non-empty directories are all recursively removed.
84 # that non-empty directories are all recursively removed.
78 shutil.rmtree(TMP_TEST_DIR)
85 shutil.rmtree(TMP_TEST_DIR)
79
86
80
87
81 def setup_environment():
88 def setup_environment():
82 """Setup testenvironment for some functions that are tested
89 """Setup testenvironment for some functions that are tested
83 in this module. In particular this functions stores attributes
90 in this module. In particular this functions stores attributes
84 and other things that we need to stub in some test functions.
91 and other things that we need to stub in some test functions.
85 This needs to be done on a function level and not module level because
92 This needs to be done on a function level and not module level because
86 each testfunction needs a pristine environment.
93 each testfunction needs a pristine environment.
87 """
94 """
88 global oldstuff, platformstuff
95 global oldstuff, platformstuff
89 oldstuff = (env.copy(), os.name, path.get_home_dir, IPython.__file__, os.getcwd())
96 oldstuff = (env.copy(), os.name, path.get_home_dir, IPython.__file__, os.getcwd())
90
97
91 if os.name == 'nt':
98 if os.name == 'nt':
92 platformstuff = (wreg.OpenKey, wreg.QueryValueEx,)
99 platformstuff = (wreg.OpenKey, wreg.QueryValueEx,)
93
100
94
101
95 def teardown_environment():
102 def teardown_environment():
96 """Restore things that were remebered by the setup_environment function
103 """Restore things that were remebered by the setup_environment function
97 """
104 """
98 (oldenv, os.name, path.get_home_dir, IPython.__file__, old_wd) = oldstuff
105 (oldenv, os.name, path.get_home_dir, IPython.__file__, old_wd) = oldstuff
99 os.chdir(old_wd)
106 os.chdir(old_wd)
100 reload(path)
107 reload(path)
101
108
102 for key in env.keys():
109 for key in env.keys():
103 if key not in oldenv:
110 if key not in oldenv:
104 del env[key]
111 del env[key]
105 env.update(oldenv)
112 env.update(oldenv)
106 if hasattr(sys, 'frozen'):
113 if hasattr(sys, 'frozen'):
107 del sys.frozen
114 del sys.frozen
108 if os.name == 'nt':
115 if os.name == 'nt':
109 (wreg.OpenKey, wreg.QueryValueEx,) = platformstuff
116 (wreg.OpenKey, wreg.QueryValueEx,) = platformstuff
110
117
111 # Build decorator that uses the setup_environment/setup_environment
118 # Build decorator that uses the setup_environment/setup_environment
112 with_environment = with_setup(setup_environment, teardown_environment)
119 with_environment = with_setup(setup_environment, teardown_environment)
113
120
114
121
115 @skip_if_not_win32
122 @skip_if_not_win32
116 @with_environment
123 @with_environment
117 def test_get_home_dir_1():
124 def test_get_home_dir_1():
118 """Testcase for py2exe logic, un-compressed lib
125 """Testcase for py2exe logic, un-compressed lib
119 """
126 """
120 sys.frozen = True
127 sys.frozen = True
121
128
122 #fake filename for IPython.__init__
129 #fake filename for IPython.__init__
123 IPython.__file__ = abspath(join(HOME_TEST_DIR, "Lib/IPython/__init__.py"))
130 IPython.__file__ = abspath(join(HOME_TEST_DIR, "Lib/IPython/__init__.py"))
124
131
125 home_dir = path.get_home_dir()
132 home_dir = path.get_home_dir()
126 nt.assert_equal(home_dir, abspath(HOME_TEST_DIR))
133 nt.assert_equal(home_dir, abspath(HOME_TEST_DIR))
127
134
128
135
129 @skip_if_not_win32
136 @skip_if_not_win32
130 @with_environment
137 @with_environment
131 def test_get_home_dir_2():
138 def test_get_home_dir_2():
132 """Testcase for py2exe logic, compressed lib
139 """Testcase for py2exe logic, compressed lib
133 """
140 """
134 sys.frozen = True
141 sys.frozen = True
135 #fake filename for IPython.__init__
142 #fake filename for IPython.__init__
136 IPython.__file__ = abspath(join(HOME_TEST_DIR, "Library.zip/IPython/__init__.py")).lower()
143 IPython.__file__ = abspath(join(HOME_TEST_DIR, "Library.zip/IPython/__init__.py")).lower()
137
144
138 home_dir = path.get_home_dir()
145 home_dir = path.get_home_dir()
139 nt.assert_equal(home_dir, abspath(HOME_TEST_DIR).lower())
146 nt.assert_equal(home_dir, abspath(HOME_TEST_DIR).lower())
140
147
141
148
142 @with_environment
149 @with_environment
143 @skip_win32
150 @skip_win32
144 def test_get_home_dir_3():
151 def test_get_home_dir_3():
145 """Testcase $HOME is set, then use its value as home directory."""
152 """Testcase $HOME is set, then use its value as home directory."""
146 env["HOME"] = HOME_TEST_DIR
153 env["HOME"] = HOME_TEST_DIR
147 home_dir = path.get_home_dir()
154 home_dir = path.get_home_dir()
148 nt.assert_equal(home_dir, env["HOME"])
155 nt.assert_equal(home_dir, env["HOME"])
149
156
150
157
151 @with_environment
158 @with_environment
152 @skip_win32
159 @skip_win32
153 def test_get_home_dir_4():
160 def test_get_home_dir_4():
154 """Testcase $HOME is not set, os=='posix'.
161 """Testcase $HOME is not set, os=='posix'.
155 This should fail with HomeDirError"""
162 This should fail with HomeDirError"""
156
163
157 os.name = 'posix'
164 os.name = 'posix'
158 if 'HOME' in env: del env['HOME']
165 if 'HOME' in env: del env['HOME']
159 nt.assert_raises(path.HomeDirError, path.get_home_dir)
166 nt.assert_raises(path.HomeDirError, path.get_home_dir)
160
167
161
168
162 @skip_if_not_win32
169 @skip_if_not_win32
163 @with_environment
170 @with_environment
164 def test_get_home_dir_5():
171 def test_get_home_dir_5():
165 """Using HOMEDRIVE + HOMEPATH, os=='nt'.
172 """Using HOMEDRIVE + HOMEPATH, os=='nt'.
166
173
167 HOMESHARE is missing.
174 HOMESHARE is missing.
168 """
175 """
169
176
170 os.name = 'nt'
177 os.name = 'nt'
171 env.pop('HOMESHARE', None)
178 env.pop('HOMESHARE', None)
172 env['HOMEDRIVE'], env['HOMEPATH'] = os.path.splitdrive(HOME_TEST_DIR)
179 env['HOMEDRIVE'], env['HOMEPATH'] = os.path.splitdrive(HOME_TEST_DIR)
173 home_dir = path.get_home_dir()
180 home_dir = path.get_home_dir()
174 nt.assert_equal(home_dir, abspath(HOME_TEST_DIR))
181 nt.assert_equal(home_dir, abspath(HOME_TEST_DIR))
175
182
176
183
177 @skip_if_not_win32
184 @skip_if_not_win32
178 @with_environment
185 @with_environment
179 def test_get_home_dir_6():
186 def test_get_home_dir_6():
180 """Using USERPROFILE, os=='nt'.
187 """Using USERPROFILE, os=='nt'.
181
188
182 HOMESHARE, HOMEDRIVE, HOMEPATH are missing.
189 HOMESHARE, HOMEDRIVE, HOMEPATH are missing.
183 """
190 """
184
191
185 os.name = 'nt'
192 os.name = 'nt'
186 env.pop('HOMESHARE', None)
193 env.pop('HOMESHARE', None)
187 env.pop('HOMEDRIVE', None)
194 env.pop('HOMEDRIVE', None)
188 env.pop('HOMEPATH', None)
195 env.pop('HOMEPATH', None)
189 env["USERPROFILE"] = abspath(HOME_TEST_DIR)
196 env["USERPROFILE"] = abspath(HOME_TEST_DIR)
190 home_dir = path.get_home_dir()
197 home_dir = path.get_home_dir()
191 nt.assert_equal(home_dir, abspath(HOME_TEST_DIR))
198 nt.assert_equal(home_dir, abspath(HOME_TEST_DIR))
192
199
193
200
194 @skip_if_not_win32
201 @skip_if_not_win32
195 @with_environment
202 @with_environment
196 def test_get_home_dir_7():
203 def test_get_home_dir_7():
197 """Using HOMESHARE, os=='nt'."""
204 """Using HOMESHARE, os=='nt'."""
198
205
199 os.name = 'nt'
206 os.name = 'nt'
200 env["HOMESHARE"] = abspath(HOME_TEST_DIR)
207 env["HOMESHARE"] = abspath(HOME_TEST_DIR)
201 home_dir = path.get_home_dir()
208 home_dir = path.get_home_dir()
202 nt.assert_equal(home_dir, abspath(HOME_TEST_DIR))
209 nt.assert_equal(home_dir, abspath(HOME_TEST_DIR))
203
210
204
211
205 # Should we stub wreg fully so we can run the test on all platforms?
212 # Should we stub wreg fully so we can run the test on all platforms?
206 @skip_if_not_win32
213 @skip_if_not_win32
207 @with_environment
214 @with_environment
208 def test_get_home_dir_8():
215 def test_get_home_dir_8():
209 """Using registry hack for 'My Documents', os=='nt'
216 """Using registry hack for 'My Documents', os=='nt'
210
217
211 HOMESHARE, HOMEDRIVE, HOMEPATH, USERPROFILE and others are missing.
218 HOMESHARE, HOMEDRIVE, HOMEPATH, USERPROFILE and others are missing.
212 """
219 """
213 os.name = 'nt'
220 os.name = 'nt'
214 # Remove from stub environment all keys that may be set
221 # Remove from stub environment all keys that may be set
215 for key in ['HOME', 'HOMESHARE', 'HOMEDRIVE', 'HOMEPATH', 'USERPROFILE']:
222 for key in ['HOME', 'HOMESHARE', 'HOMEDRIVE', 'HOMEPATH', 'USERPROFILE']:
216 env.pop(key, None)
223 env.pop(key, None)
217
224
218 #Stub windows registry functions
225 #Stub windows registry functions
219 def OpenKey(x, y):
226 def OpenKey(x, y):
220 class key:
227 class key:
221 def Close(self):
228 def Close(self):
222 pass
229 pass
223 return key()
230 return key()
224 def QueryValueEx(x, y):
231 def QueryValueEx(x, y):
225 return [abspath(HOME_TEST_DIR)]
232 return [abspath(HOME_TEST_DIR)]
226
233
227 wreg.OpenKey = OpenKey
234 wreg.OpenKey = OpenKey
228 wreg.QueryValueEx = QueryValueEx
235 wreg.QueryValueEx = QueryValueEx
229
236
230 home_dir = path.get_home_dir()
237 home_dir = path.get_home_dir()
231 nt.assert_equal(home_dir, abspath(HOME_TEST_DIR))
238 nt.assert_equal(home_dir, abspath(HOME_TEST_DIR))
232
239
233
240
234 @with_environment
241 @with_environment
235 def test_get_ipython_dir_1():
242 def test_get_ipython_dir_1():
236 """test_get_ipython_dir_1, Testcase to see if we can call get_ipython_dir without Exceptions."""
243 """test_get_ipython_dir_1, Testcase to see if we can call get_ipython_dir without Exceptions."""
237 env_ipdir = os.path.join("someplace", ".ipython")
244 env_ipdir = os.path.join("someplace", ".ipython")
238 path._writable_dir = lambda path: True
245 path._writable_dir = lambda path: True
239 env['IPYTHON_DIR'] = env_ipdir
246 env['IPYTHON_DIR'] = env_ipdir
240 ipdir = path.get_ipython_dir()
247 ipdir = path.get_ipython_dir()
241 nt.assert_equal(ipdir, env_ipdir)
248 nt.assert_equal(ipdir, env_ipdir)
242
249
243
250
244 @with_environment
251 @with_environment
245 def test_get_ipython_dir_2():
252 def test_get_ipython_dir_2():
246 """test_get_ipython_dir_2, Testcase to see if we can call get_ipython_dir without Exceptions."""
253 """test_get_ipython_dir_2, Testcase to see if we can call get_ipython_dir without Exceptions."""
247 path.get_home_dir = lambda : "someplace"
254 path.get_home_dir = lambda : "someplace"
248 path.get_xdg_dir = lambda : None
255 path.get_xdg_dir = lambda : None
249 path._writable_dir = lambda path: True
256 path._writable_dir = lambda path: True
250 os.name = "posix"
257 os.name = "posix"
251 env.pop('IPYTHON_DIR', None)
258 env.pop('IPYTHON_DIR', None)
252 env.pop('IPYTHONDIR', None)
259 env.pop('IPYTHONDIR', None)
253 env.pop('XDG_CONFIG_HOME', None)
260 env.pop('XDG_CONFIG_HOME', None)
254 ipdir = path.get_ipython_dir()
261 ipdir = path.get_ipython_dir()
255 nt.assert_equal(ipdir, os.path.join("someplace", ".ipython"))
262 nt.assert_equal(ipdir, os.path.join("someplace", ".ipython"))
256
263
257 @with_environment
264 @with_environment
258 def test_get_ipython_dir_3():
265 def test_get_ipython_dir_3():
259 """test_get_ipython_dir_3, use XDG if defined, and .ipython doesn't exist."""
266 """test_get_ipython_dir_3, use XDG if defined, and .ipython doesn't exist."""
260 path.get_home_dir = lambda : "someplace"
267 path.get_home_dir = lambda : "someplace"
261 path._writable_dir = lambda path: True
268 path._writable_dir = lambda path: True
262 os.name = "posix"
269 os.name = "posix"
263 env.pop('IPYTHON_DIR', None)
270 env.pop('IPYTHON_DIR', None)
264 env.pop('IPYTHONDIR', None)
271 env.pop('IPYTHONDIR', None)
265 env['XDG_CONFIG_HOME'] = XDG_TEST_DIR
272 env['XDG_CONFIG_HOME'] = XDG_TEST_DIR
266 ipdir = path.get_ipython_dir()
273 ipdir = path.get_ipython_dir()
267 nt.assert_equal(ipdir, os.path.join(XDG_TEST_DIR, "ipython"))
274 nt.assert_equal(ipdir, os.path.join(XDG_TEST_DIR, "ipython"))
268
275
269 @with_environment
276 @with_environment
270 def test_get_ipython_dir_4():
277 def test_get_ipython_dir_4():
271 """test_get_ipython_dir_4, use XDG if both exist."""
278 """test_get_ipython_dir_4, use XDG if both exist."""
272 path.get_home_dir = lambda : HOME_TEST_DIR
279 path.get_home_dir = lambda : HOME_TEST_DIR
273 os.name = "posix"
280 os.name = "posix"
274 env.pop('IPYTHON_DIR', None)
281 env.pop('IPYTHON_DIR', None)
275 env.pop('IPYTHONDIR', None)
282 env.pop('IPYTHONDIR', None)
276 env['XDG_CONFIG_HOME'] = XDG_TEST_DIR
283 env['XDG_CONFIG_HOME'] = XDG_TEST_DIR
277 xdg_ipdir = os.path.join(XDG_TEST_DIR, "ipython")
284 xdg_ipdir = os.path.join(XDG_TEST_DIR, "ipython")
278 ipdir = path.get_ipython_dir()
285 ipdir = path.get_ipython_dir()
279 nt.assert_equal(ipdir, xdg_ipdir)
286 nt.assert_equal(ipdir, xdg_ipdir)
280
287
281 @with_environment
288 @with_environment
282 def test_get_ipython_dir_5():
289 def test_get_ipython_dir_5():
283 """test_get_ipython_dir_5, use .ipython if exists and XDG defined, but doesn't exist."""
290 """test_get_ipython_dir_5, use .ipython if exists and XDG defined, but doesn't exist."""
284 path.get_home_dir = lambda : HOME_TEST_DIR
291 path.get_home_dir = lambda : HOME_TEST_DIR
285 os.name = "posix"
292 os.name = "posix"
286 env.pop('IPYTHON_DIR', None)
293 env.pop('IPYTHON_DIR', None)
287 env.pop('IPYTHONDIR', None)
294 env.pop('IPYTHONDIR', None)
288 env['XDG_CONFIG_HOME'] = XDG_TEST_DIR
295 env['XDG_CONFIG_HOME'] = XDG_TEST_DIR
289 os.rmdir(os.path.join(XDG_TEST_DIR, 'ipython'))
296 os.rmdir(os.path.join(XDG_TEST_DIR, 'ipython'))
290 ipdir = path.get_ipython_dir()
297 ipdir = path.get_ipython_dir()
291 nt.assert_equal(ipdir, IP_TEST_DIR)
298 nt.assert_equal(ipdir, IP_TEST_DIR)
292
299
293 @with_environment
300 @with_environment
294 def test_get_ipython_dir_6():
301 def test_get_ipython_dir_6():
295 """test_get_ipython_dir_6, use XDG if defined and neither exist."""
302 """test_get_ipython_dir_6, use XDG if defined and neither exist."""
296 xdg = os.path.join(HOME_TEST_DIR, 'somexdg')
303 xdg = os.path.join(HOME_TEST_DIR, 'somexdg')
297 os.mkdir(xdg)
304 os.mkdir(xdg)
298 shutil.rmtree(os.path.join(HOME_TEST_DIR, '.ipython'))
305 shutil.rmtree(os.path.join(HOME_TEST_DIR, '.ipython'))
299 path.get_home_dir = lambda : HOME_TEST_DIR
306 path.get_home_dir = lambda : HOME_TEST_DIR
300 path.get_xdg_dir = lambda : xdg
307 path.get_xdg_dir = lambda : xdg
301 os.name = "posix"
308 os.name = "posix"
302 env.pop('IPYTHON_DIR', None)
309 env.pop('IPYTHON_DIR', None)
303 env.pop('IPYTHONDIR', None)
310 env.pop('IPYTHONDIR', None)
304 env.pop('XDG_CONFIG_HOME', None)
311 env.pop('XDG_CONFIG_HOME', None)
305 xdg_ipdir = os.path.join(xdg, "ipython")
312 xdg_ipdir = os.path.join(xdg, "ipython")
306 ipdir = path.get_ipython_dir()
313 ipdir = path.get_ipython_dir()
307 nt.assert_equal(ipdir, xdg_ipdir)
314 nt.assert_equal(ipdir, xdg_ipdir)
308
315
309 @with_environment
316 @with_environment
310 def test_get_ipython_dir_7():
317 def test_get_ipython_dir_7():
311 """test_get_ipython_dir_7, test home directory expansion on IPYTHON_DIR"""
318 """test_get_ipython_dir_7, test home directory expansion on IPYTHON_DIR"""
312 path._writable_dir = lambda path: True
319 path._writable_dir = lambda path: True
313 home_dir = os.path.expanduser('~')
320 home_dir = os.path.expanduser('~')
314 env['IPYTHON_DIR'] = os.path.join('~', 'somewhere')
321 env['IPYTHON_DIR'] = os.path.join('~', 'somewhere')
315 ipdir = path.get_ipython_dir()
322 ipdir = path.get_ipython_dir()
316 nt.assert_equal(ipdir, os.path.join(home_dir, 'somewhere'))
323 nt.assert_equal(ipdir, os.path.join(home_dir, 'somewhere'))
317
324
318
325
319 @with_environment
326 @with_environment
320 def test_get_xdg_dir_1():
327 def test_get_xdg_dir_1():
321 """test_get_xdg_dir_1, check xdg_dir"""
328 """test_get_xdg_dir_1, check xdg_dir"""
322 reload(path)
329 reload(path)
323 path._writable_dir = lambda path: True
330 path._writable_dir = lambda path: True
324 path.get_home_dir = lambda : 'somewhere'
331 path.get_home_dir = lambda : 'somewhere'
325 os.name = "posix"
332 os.name = "posix"
326 env.pop('IPYTHON_DIR', None)
333 env.pop('IPYTHON_DIR', None)
327 env.pop('IPYTHONDIR', None)
334 env.pop('IPYTHONDIR', None)
328 env.pop('XDG_CONFIG_HOME', None)
335 env.pop('XDG_CONFIG_HOME', None)
329
336
330 nt.assert_equal(path.get_xdg_dir(), os.path.join('somewhere', '.config'))
337 nt.assert_equal(path.get_xdg_dir(), os.path.join('somewhere', '.config'))
331
338
332
339
333 @with_environment
340 @with_environment
334 def test_get_xdg_dir_1():
341 def test_get_xdg_dir_1():
335 """test_get_xdg_dir_1, check nonexistant xdg_dir"""
342 """test_get_xdg_dir_1, check nonexistant xdg_dir"""
336 reload(path)
343 reload(path)
337 path.get_home_dir = lambda : HOME_TEST_DIR
344 path.get_home_dir = lambda : HOME_TEST_DIR
338 os.name = "posix"
345 os.name = "posix"
339 env.pop('IPYTHON_DIR', None)
346 env.pop('IPYTHON_DIR', None)
340 env.pop('IPYTHONDIR', None)
347 env.pop('IPYTHONDIR', None)
341 env.pop('XDG_CONFIG_HOME', None)
348 env.pop('XDG_CONFIG_HOME', None)
342 nt.assert_equal(path.get_xdg_dir(), None)
349 nt.assert_equal(path.get_xdg_dir(), None)
343
350
344 @with_environment
351 @with_environment
345 def test_get_xdg_dir_2():
352 def test_get_xdg_dir_2():
346 """test_get_xdg_dir_2, check xdg_dir default to ~/.config"""
353 """test_get_xdg_dir_2, check xdg_dir default to ~/.config"""
347 reload(path)
354 reload(path)
348 path.get_home_dir = lambda : HOME_TEST_DIR
355 path.get_home_dir = lambda : HOME_TEST_DIR
349 os.name = "posix"
356 os.name = "posix"
350 env.pop('IPYTHON_DIR', None)
357 env.pop('IPYTHON_DIR', None)
351 env.pop('IPYTHONDIR', None)
358 env.pop('IPYTHONDIR', None)
352 env.pop('XDG_CONFIG_HOME', None)
359 env.pop('XDG_CONFIG_HOME', None)
353 cfgdir=os.path.join(path.get_home_dir(), '.config')
360 cfgdir=os.path.join(path.get_home_dir(), '.config')
354 os.makedirs(cfgdir)
361 os.makedirs(cfgdir)
355
362
356 nt.assert_equal(path.get_xdg_dir(), cfgdir)
363 nt.assert_equal(path.get_xdg_dir(), cfgdir)
357
364
358 def test_filefind():
365 def test_filefind():
359 """Various tests for filefind"""
366 """Various tests for filefind"""
360 f = tempfile.NamedTemporaryFile()
367 f = tempfile.NamedTemporaryFile()
361 # print 'fname:',f.name
368 # print 'fname:',f.name
362 alt_dirs = path.get_ipython_dir()
369 alt_dirs = path.get_ipython_dir()
363 t = path.filefind(f.name, alt_dirs)
370 t = path.filefind(f.name, alt_dirs)
364 # print 'found:',t
371 # print 'found:',t
365
372
366
373
367 def test_get_ipython_package_dir():
374 def test_get_ipython_package_dir():
368 ipdir = path.get_ipython_package_dir()
375 ipdir = path.get_ipython_package_dir()
369 nt.assert_true(os.path.isdir(ipdir))
376 nt.assert_true(os.path.isdir(ipdir))
370
377
371
378
372 def test_get_ipython_module_path():
379 def test_get_ipython_module_path():
373 ipapp_path = path.get_ipython_module_path('IPython.frontend.terminal.ipapp')
380 ipapp_path = path.get_ipython_module_path('IPython.frontend.terminal.ipapp')
374 nt.assert_true(os.path.isfile(ipapp_path))
381 nt.assert_true(os.path.isfile(ipapp_path))
375
382
376
383
377 @dec.skip_if_not_win32
384 @dec.skip_if_not_win32
378 def test_get_long_path_name_win32():
385 def test_get_long_path_name_win32():
379 p = path.get_long_path_name('c:\\docume~1')
386 p = path.get_long_path_name('c:\\docume~1')
380 nt.assert_equals(p,u'c:\\Documents and Settings')
387 nt.assert_equals(p,u'c:\\Documents and Settings')
381
388
382
389
383 @dec.skip_win32
390 @dec.skip_win32
384 def test_get_long_path_name():
391 def test_get_long_path_name():
385 p = path.get_long_path_name('/usr/local')
392 p = path.get_long_path_name('/usr/local')
386 nt.assert_equals(p,'/usr/local')
393 nt.assert_equals(p,'/usr/local')
387
394
388 @dec.skip_win32 # can't create not-user-writable dir on win
395 @dec.skip_win32 # can't create not-user-writable dir on win
389 @with_environment
396 @with_environment
390 def test_not_writable_ipdir():
397 def test_not_writable_ipdir():
391 tmpdir = tempfile.mkdtemp()
398 tmpdir = tempfile.mkdtemp()
392 os.name = "posix"
399 os.name = "posix"
393 env.pop('IPYTHON_DIR', None)
400 env.pop('IPYTHON_DIR', None)
394 env.pop('IPYTHONDIR', None)
401 env.pop('IPYTHONDIR', None)
395 env.pop('XDG_CONFIG_HOME', None)
402 env.pop('XDG_CONFIG_HOME', None)
396 env['HOME'] = tmpdir
403 env['HOME'] = tmpdir
397 ipdir = os.path.join(tmpdir, '.ipython')
404 ipdir = os.path.join(tmpdir, '.ipython')
398 os.mkdir(ipdir)
405 os.mkdir(ipdir)
399 os.chmod(ipdir, 600)
406 os.chmod(ipdir, 600)
400 stderr = io.stderr
407 stderr = io.stderr
401 pipe = StringIO.StringIO()
408 pipe = StringIO.StringIO()
402 io.stderr = pipe
409 io.stderr = pipe
403 ipdir = path.get_ipython_dir()
410 ipdir = path.get_ipython_dir()
404 io.stderr.flush()
411 io.stderr.flush()
405 io.stderr = stderr
412 io.stderr = stderr
406 nt.assert_true('WARNING' in pipe.getvalue())
413 nt.assert_true('WARNING' in pipe.getvalue())
407 env.pop('IPYTHON_DIR', None)
414 env.pop('IPYTHON_DIR', None)
408
415
409 def test_unquote_filename():
416 def test_unquote_filename():
410 for win32 in (True, False):
417 for win32 in (True, False):
411 nt.assert_equals(path.unquote_filename('foo.py', win32=win32), 'foo.py')
418 nt.assert_equals(path.unquote_filename('foo.py', win32=win32), 'foo.py')
412 nt.assert_equals(path.unquote_filename('foo bar.py', win32=win32), 'foo bar.py')
419 nt.assert_equals(path.unquote_filename('foo bar.py', win32=win32), 'foo bar.py')
413 nt.assert_equals(path.unquote_filename('"foo.py"', win32=True), 'foo.py')
420 nt.assert_equals(path.unquote_filename('"foo.py"', win32=True), 'foo.py')
414 nt.assert_equals(path.unquote_filename('"foo bar.py"', win32=True), 'foo bar.py')
421 nt.assert_equals(path.unquote_filename('"foo bar.py"', win32=True), 'foo bar.py')
415 nt.assert_equals(path.unquote_filename("'foo.py'", win32=True), 'foo.py')
422 nt.assert_equals(path.unquote_filename("'foo.py'", win32=True), 'foo.py')
416 nt.assert_equals(path.unquote_filename("'foo bar.py'", win32=True), 'foo bar.py')
423 nt.assert_equals(path.unquote_filename("'foo bar.py'", win32=True), 'foo bar.py')
417 nt.assert_equals(path.unquote_filename('"foo.py"', win32=False), '"foo.py"')
424 nt.assert_equals(path.unquote_filename('"foo.py"', win32=False), '"foo.py"')
418 nt.assert_equals(path.unquote_filename('"foo bar.py"', win32=False), '"foo bar.py"')
425 nt.assert_equals(path.unquote_filename('"foo bar.py"', win32=False), '"foo bar.py"')
419 nt.assert_equals(path.unquote_filename("'foo.py'", win32=False), "'foo.py'")
426 nt.assert_equals(path.unquote_filename("'foo.py'", win32=False), "'foo.py'")
420 nt.assert_equals(path.unquote_filename("'foo bar.py'", win32=False), "'foo bar.py'")
427 nt.assert_equals(path.unquote_filename("'foo bar.py'", win32=False), "'foo bar.py'")
421
428
422 @with_environment
429 @with_environment
423 def test_get_py_filename():
430 def test_get_py_filename():
424 os.chdir(TMP_TEST_DIR)
431 os.chdir(TMP_TEST_DIR)
425 for win32 in (True, False):
432 for win32 in (True, False):
426 with make_tempfile('foo.py'):
433 with make_tempfile('foo.py'):
427 nt.assert_equals(path.get_py_filename('foo.py', force_win32=win32), 'foo.py')
434 nt.assert_equals(path.get_py_filename('foo.py', force_win32=win32), 'foo.py')
428 nt.assert_equals(path.get_py_filename('foo', force_win32=win32), 'foo.py')
435 nt.assert_equals(path.get_py_filename('foo', force_win32=win32), 'foo.py')
429 with make_tempfile('foo'):
436 with make_tempfile('foo'):
430 nt.assert_equals(path.get_py_filename('foo', force_win32=win32), 'foo')
437 nt.assert_equals(path.get_py_filename('foo', force_win32=win32), 'foo')
431 nt.assert_raises(IOError, path.get_py_filename, 'foo.py', force_win32=win32)
438 nt.assert_raises(IOError, path.get_py_filename, 'foo.py', force_win32=win32)
432 nt.assert_raises(IOError, path.get_py_filename, 'foo', force_win32=win32)
439 nt.assert_raises(IOError, path.get_py_filename, 'foo', force_win32=win32)
433 nt.assert_raises(IOError, path.get_py_filename, 'foo.py', force_win32=win32)
440 nt.assert_raises(IOError, path.get_py_filename, 'foo.py', force_win32=win32)
434 true_fn = 'foo with spaces.py'
441 true_fn = 'foo with spaces.py'
435 with make_tempfile(true_fn):
442 with make_tempfile(true_fn):
436 nt.assert_equals(path.get_py_filename('foo with spaces', force_win32=win32), true_fn)
443 nt.assert_equals(path.get_py_filename('foo with spaces', force_win32=win32), true_fn)
437 nt.assert_equals(path.get_py_filename('foo with spaces.py', force_win32=win32), true_fn)
444 nt.assert_equals(path.get_py_filename('foo with spaces.py', force_win32=win32), true_fn)
438 if win32:
445 if win32:
439 nt.assert_equals(path.get_py_filename('"foo with spaces.py"', force_win32=True), true_fn)
446 nt.assert_equals(path.get_py_filename('"foo with spaces.py"', force_win32=True), true_fn)
440 nt.assert_equals(path.get_py_filename("'foo with spaces.py'", force_win32=True), true_fn)
447 nt.assert_equals(path.get_py_filename("'foo with spaces.py'", force_win32=True), true_fn)
441 else:
448 else:
442 nt.assert_raises(IOError, path.get_py_filename, '"foo with spaces.py"', force_win32=False)
449 nt.assert_raises(IOError, path.get_py_filename, '"foo with spaces.py"', force_win32=False)
443 nt.assert_raises(IOError, path.get_py_filename, "'foo with spaces.py'", force_win32=False)
450 nt.assert_raises(IOError, path.get_py_filename, "'foo with spaces.py'", force_win32=False)
@@ -1,98 +1,98
1 # encoding: utf-8
1 # encoding: utf-8
2 """
2 """
3 Tests for platutils.py
3 Tests for platutils.py
4 """
4 """
5
5
6 #-----------------------------------------------------------------------------
6 #-----------------------------------------------------------------------------
7 # Copyright (C) 2008-2009 The IPython Development Team
7 # Copyright (C) 2008-2009 The IPython Development Team
8 #
8 #
9 # Distributed under the terms of the BSD License. The full license is in
9 # Distributed under the terms of the BSD License. The full license is in
10 # the file COPYING, distributed as part of this software.
10 # the file COPYING, distributed as part of this software.
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12
12
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14 # Imports
14 # Imports
15 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
16
16
17 import sys
17 import sys
18 from unittest import TestCase
18 from unittest import TestCase
19
19
20 import nose.tools as nt
20 import nose.tools as nt
21
21
22 from IPython.utils.process import (find_cmd, FindCmdError, arg_split,
22 from IPython.utils.process import (find_cmd, FindCmdError, arg_split,
23 system, getoutput, getoutputerror)
23 system, getoutput, getoutputerror)
24 from IPython.testing import decorators as dec
24 from IPython.testing import decorators as dec
25 from IPython.testing import tools as tt
25 from IPython.testing import tools as tt
26
26
27 #-----------------------------------------------------------------------------
27 #-----------------------------------------------------------------------------
28 # Tests
28 # Tests
29 #-----------------------------------------------------------------------------
29 #-----------------------------------------------------------------------------
30
30
31 def test_find_cmd_python():
31 def test_find_cmd_python():
32 """Make sure we find sys.exectable for python."""
32 """Make sure we find sys.exectable for python."""
33 nt.assert_equals(find_cmd('python'), sys.executable)
33 nt.assert_equals(find_cmd('python'), sys.executable)
34
34
35
35
36 @dec.skip_win32
36 @dec.skip_win32
37 def test_find_cmd_ls():
37 def test_find_cmd_ls():
38 """Make sure we can find the full path to ls."""
38 """Make sure we can find the full path to ls."""
39 path = find_cmd('ls')
39 path = find_cmd('ls')
40 nt.assert_true(path.endswith('ls'))
40 nt.assert_true(path.endswith(b'ls'))
41
41
42
42
43 def has_pywin32():
43 def has_pywin32():
44 try:
44 try:
45 import win32api
45 import win32api
46 except ImportError:
46 except ImportError:
47 return False
47 return False
48 return True
48 return True
49
49
50
50
51 @dec.onlyif(has_pywin32, "This test requires win32api to run")
51 @dec.onlyif(has_pywin32, "This test requires win32api to run")
52 def test_find_cmd_pythonw():
52 def test_find_cmd_pythonw():
53 """Try to find pythonw on Windows."""
53 """Try to find pythonw on Windows."""
54 path = find_cmd('pythonw')
54 path = find_cmd('pythonw')
55 nt.assert_true(path.endswith('pythonw.exe'))
55 nt.assert_true(path.endswith('pythonw.exe'))
56
56
57
57
58 @dec.onlyif(lambda : sys.platform != 'win32' or has_pywin32(),
58 @dec.onlyif(lambda : sys.platform != 'win32' or has_pywin32(),
59 "This test runs on posix or in win32 with win32api installed")
59 "This test runs on posix or in win32 with win32api installed")
60 def test_find_cmd_fail():
60 def test_find_cmd_fail():
61 """Make sure that FindCmdError is raised if we can't find the cmd."""
61 """Make sure that FindCmdError is raised if we can't find the cmd."""
62 nt.assert_raises(FindCmdError,find_cmd,'asdfasdf')
62 nt.assert_raises(FindCmdError,find_cmd,'asdfasdf')
63
63
64
64
65 def test_arg_split():
65 def test_arg_split():
66 """Ensure that argument lines are correctly split like in a shell."""
66 """Ensure that argument lines are correctly split like in a shell."""
67 tests = [['hi', ['hi']],
67 tests = [['hi', ['hi']],
68 [u'hi', [u'hi']],
68 [u'hi', [u'hi']],
69 ['hello there', ['hello', 'there']],
69 ['hello there', ['hello', 'there']],
70 [u'h\N{LATIN SMALL LETTER A WITH CARON}llo', [u'h\N{LATIN SMALL LETTER A WITH CARON}llo']],
70 [u'h\N{LATIN SMALL LETTER A WITH CARON}llo', [u'h\N{LATIN SMALL LETTER A WITH CARON}llo']],
71 ['something "with quotes"', ['something', '"with quotes"']],
71 ['something "with quotes"', ['something', '"with quotes"']],
72 ]
72 ]
73 for argstr, argv in tests:
73 for argstr, argv in tests:
74 nt.assert_equal(arg_split(argstr), argv)
74 nt.assert_equal(arg_split(argstr), argv)
75
75
76
76
77 class SubProcessTestCase(TestCase, tt.TempFileMixin):
77 class SubProcessTestCase(TestCase, tt.TempFileMixin):
78 def setUp(self):
78 def setUp(self):
79 """Make a valid python temp file."""
79 """Make a valid python temp file."""
80 lines = ["from __future__ import print_function",
80 lines = ["from __future__ import print_function",
81 "import sys",
81 "import sys",
82 "print('on stdout', end='', file=sys.stdout)",
82 "print('on stdout', end='', file=sys.stdout)",
83 "print('on stderr', end='', file=sys.stderr)",
83 "print('on stderr', end='', file=sys.stderr)",
84 "sys.stdout.flush()",
84 "sys.stdout.flush()",
85 "sys.stderr.flush()"]
85 "sys.stderr.flush()"]
86 self.mktmp('\n'.join(lines))
86 self.mktmp('\n'.join(lines))
87
87
88 def test_system(self):
88 def test_system(self):
89 system('python "%s"' % self.fname)
89 system('python "%s"' % self.fname)
90
90
91 def test_getoutput(self):
91 def test_getoutput(self):
92 out = getoutput('python "%s"' % self.fname)
92 out = getoutput('python "%s"' % self.fname)
93 self.assertEquals(out, 'on stdout')
93 self.assertEquals(out, 'on stdout')
94
94
95 def test_getoutput(self):
95 def test_getoutput(self):
96 out, err = getoutputerror('python "%s"' % self.fname)
96 out, err = getoutputerror('python "%s"' % self.fname)
97 self.assertEquals(out, 'on stdout')
97 self.assertEquals(out, 'on stdout')
98 self.assertEquals(err, 'on stderr')
98 self.assertEquals(err, 'on stderr')
@@ -1,715 +1,715
1 # encoding: utf-8
1 # encoding: utf-8
2 """
2 """
3 Utilities for working with strings and text.
3 Utilities for working with strings and text.
4 """
4 """
5
5
6 #-----------------------------------------------------------------------------
6 #-----------------------------------------------------------------------------
7 # Copyright (C) 2008-2009 The IPython Development Team
7 # Copyright (C) 2008-2009 The IPython Development Team
8 #
8 #
9 # Distributed under the terms of the BSD License. The full license is in
9 # Distributed under the terms of the BSD License. The full license is in
10 # the file COPYING, distributed as part of this software.
10 # the file COPYING, distributed as part of this software.
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12
12
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14 # Imports
14 # Imports
15 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
16
16
17 import __main__
17 import __main__
18
18
19 import os
19 import os
20 import re
20 import re
21 import shutil
21 import shutil
22 import textwrap
22 import textwrap
23 from string import Formatter
23 from string import Formatter
24
24
25 from IPython.external.path import path
25 from IPython.external.path import path
26 from IPython.utils import py3compat
26 from IPython.utils import py3compat
27 from IPython.utils.io import nlprint
27 from IPython.utils.io import nlprint
28 from IPython.utils.data import flatten
28 from IPython.utils.data import flatten
29
29
30 #-----------------------------------------------------------------------------
30 #-----------------------------------------------------------------------------
31 # Code
31 # Code
32 #-----------------------------------------------------------------------------
32 #-----------------------------------------------------------------------------
33
33
34
34
35 def unquote_ends(istr):
35 def unquote_ends(istr):
36 """Remove a single pair of quotes from the endpoints of a string."""
36 """Remove a single pair of quotes from the endpoints of a string."""
37
37
38 if not istr:
38 if not istr:
39 return istr
39 return istr
40 if (istr[0]=="'" and istr[-1]=="'") or \
40 if (istr[0]=="'" and istr[-1]=="'") or \
41 (istr[0]=='"' and istr[-1]=='"'):
41 (istr[0]=='"' and istr[-1]=='"'):
42 return istr[1:-1]
42 return istr[1:-1]
43 else:
43 else:
44 return istr
44 return istr
45
45
46
46
47 class LSString(str):
47 class LSString(str):
48 """String derivative with a special access attributes.
48 """String derivative with a special access attributes.
49
49
50 These are normal strings, but with the special attributes:
50 These are normal strings, but with the special attributes:
51
51
52 .l (or .list) : value as list (split on newlines).
52 .l (or .list) : value as list (split on newlines).
53 .n (or .nlstr): original value (the string itself).
53 .n (or .nlstr): original value (the string itself).
54 .s (or .spstr): value as whitespace-separated string.
54 .s (or .spstr): value as whitespace-separated string.
55 .p (or .paths): list of path objects
55 .p (or .paths): list of path objects
56
56
57 Any values which require transformations are computed only once and
57 Any values which require transformations are computed only once and
58 cached.
58 cached.
59
59
60 Such strings are very useful to efficiently interact with the shell, which
60 Such strings are very useful to efficiently interact with the shell, which
61 typically only understands whitespace-separated options for commands."""
61 typically only understands whitespace-separated options for commands."""
62
62
63 def get_list(self):
63 def get_list(self):
64 try:
64 try:
65 return self.__list
65 return self.__list
66 except AttributeError:
66 except AttributeError:
67 self.__list = self.split('\n')
67 self.__list = self.split('\n')
68 return self.__list
68 return self.__list
69
69
70 l = list = property(get_list)
70 l = list = property(get_list)
71
71
72 def get_spstr(self):
72 def get_spstr(self):
73 try:
73 try:
74 return self.__spstr
74 return self.__spstr
75 except AttributeError:
75 except AttributeError:
76 self.__spstr = self.replace('\n',' ')
76 self.__spstr = self.replace('\n',' ')
77 return self.__spstr
77 return self.__spstr
78
78
79 s = spstr = property(get_spstr)
79 s = spstr = property(get_spstr)
80
80
81 def get_nlstr(self):
81 def get_nlstr(self):
82 return self
82 return self
83
83
84 n = nlstr = property(get_nlstr)
84 n = nlstr = property(get_nlstr)
85
85
86 def get_paths(self):
86 def get_paths(self):
87 try:
87 try:
88 return self.__paths
88 return self.__paths
89 except AttributeError:
89 except AttributeError:
90 self.__paths = [path(p) for p in self.split('\n') if os.path.exists(p)]
90 self.__paths = [path(p) for p in self.split('\n') if os.path.exists(p)]
91 return self.__paths
91 return self.__paths
92
92
93 p = paths = property(get_paths)
93 p = paths = property(get_paths)
94
94
95 # FIXME: We need to reimplement type specific displayhook and then add this
95 # FIXME: We need to reimplement type specific displayhook and then add this
96 # back as a custom printer. This should also be moved outside utils into the
96 # back as a custom printer. This should also be moved outside utils into the
97 # core.
97 # core.
98
98
99 # def print_lsstring(arg):
99 # def print_lsstring(arg):
100 # """ Prettier (non-repr-like) and more informative printer for LSString """
100 # """ Prettier (non-repr-like) and more informative printer for LSString """
101 # print "LSString (.p, .n, .l, .s available). Value:"
101 # print "LSString (.p, .n, .l, .s available). Value:"
102 # print arg
102 # print arg
103 #
103 #
104 #
104 #
105 # print_lsstring = result_display.when_type(LSString)(print_lsstring)
105 # print_lsstring = result_display.when_type(LSString)(print_lsstring)
106
106
107
107
108 class SList(list):
108 class SList(list):
109 """List derivative with a special access attributes.
109 """List derivative with a special access attributes.
110
110
111 These are normal lists, but with the special attributes:
111 These are normal lists, but with the special attributes:
112
112
113 .l (or .list) : value as list (the list itself).
113 .l (or .list) : value as list (the list itself).
114 .n (or .nlstr): value as a string, joined on newlines.
114 .n (or .nlstr): value as a string, joined on newlines.
115 .s (or .spstr): value as a string, joined on spaces.
115 .s (or .spstr): value as a string, joined on spaces.
116 .p (or .paths): list of path objects
116 .p (or .paths): list of path objects
117
117
118 Any values which require transformations are computed only once and
118 Any values which require transformations are computed only once and
119 cached."""
119 cached."""
120
120
121 def get_list(self):
121 def get_list(self):
122 return self
122 return self
123
123
124 l = list = property(get_list)
124 l = list = property(get_list)
125
125
126 def get_spstr(self):
126 def get_spstr(self):
127 try:
127 try:
128 return self.__spstr
128 return self.__spstr
129 except AttributeError:
129 except AttributeError:
130 self.__spstr = ' '.join(self)
130 self.__spstr = ' '.join(self)
131 return self.__spstr
131 return self.__spstr
132
132
133 s = spstr = property(get_spstr)
133 s = spstr = property(get_spstr)
134
134
135 def get_nlstr(self):
135 def get_nlstr(self):
136 try:
136 try:
137 return self.__nlstr
137 return self.__nlstr
138 except AttributeError:
138 except AttributeError:
139 self.__nlstr = '\n'.join(self)
139 self.__nlstr = '\n'.join(self)
140 return self.__nlstr
140 return self.__nlstr
141
141
142 n = nlstr = property(get_nlstr)
142 n = nlstr = property(get_nlstr)
143
143
144 def get_paths(self):
144 def get_paths(self):
145 try:
145 try:
146 return self.__paths
146 return self.__paths
147 except AttributeError:
147 except AttributeError:
148 self.__paths = [path(p) for p in self if os.path.exists(p)]
148 self.__paths = [path(p) for p in self if os.path.exists(p)]
149 return self.__paths
149 return self.__paths
150
150
151 p = paths = property(get_paths)
151 p = paths = property(get_paths)
152
152
153 def grep(self, pattern, prune = False, field = None):
153 def grep(self, pattern, prune = False, field = None):
154 """ Return all strings matching 'pattern' (a regex or callable)
154 """ Return all strings matching 'pattern' (a regex or callable)
155
155
156 This is case-insensitive. If prune is true, return all items
156 This is case-insensitive. If prune is true, return all items
157 NOT matching the pattern.
157 NOT matching the pattern.
158
158
159 If field is specified, the match must occur in the specified
159 If field is specified, the match must occur in the specified
160 whitespace-separated field.
160 whitespace-separated field.
161
161
162 Examples::
162 Examples::
163
163
164 a.grep( lambda x: x.startswith('C') )
164 a.grep( lambda x: x.startswith('C') )
165 a.grep('Cha.*log', prune=1)
165 a.grep('Cha.*log', prune=1)
166 a.grep('chm', field=-1)
166 a.grep('chm', field=-1)
167 """
167 """
168
168
169 def match_target(s):
169 def match_target(s):
170 if field is None:
170 if field is None:
171 return s
171 return s
172 parts = s.split()
172 parts = s.split()
173 try:
173 try:
174 tgt = parts[field]
174 tgt = parts[field]
175 return tgt
175 return tgt
176 except IndexError:
176 except IndexError:
177 return ""
177 return ""
178
178
179 if isinstance(pattern, basestring):
179 if isinstance(pattern, basestring):
180 pred = lambda x : re.search(pattern, x, re.IGNORECASE)
180 pred = lambda x : re.search(pattern, x, re.IGNORECASE)
181 else:
181 else:
182 pred = pattern
182 pred = pattern
183 if not prune:
183 if not prune:
184 return SList([el for el in self if pred(match_target(el))])
184 return SList([el for el in self if pred(match_target(el))])
185 else:
185 else:
186 return SList([el for el in self if not pred(match_target(el))])
186 return SList([el for el in self if not pred(match_target(el))])
187
187
188 def fields(self, *fields):
188 def fields(self, *fields):
189 """ Collect whitespace-separated fields from string list
189 """ Collect whitespace-separated fields from string list
190
190
191 Allows quick awk-like usage of string lists.
191 Allows quick awk-like usage of string lists.
192
192
193 Example data (in var a, created by 'a = !ls -l')::
193 Example data (in var a, created by 'a = !ls -l')::
194 -rwxrwxrwx 1 ville None 18 Dec 14 2006 ChangeLog
194 -rwxrwxrwx 1 ville None 18 Dec 14 2006 ChangeLog
195 drwxrwxrwx+ 6 ville None 0 Oct 24 18:05 IPython
195 drwxrwxrwx+ 6 ville None 0 Oct 24 18:05 IPython
196
196
197 a.fields(0) is ['-rwxrwxrwx', 'drwxrwxrwx+']
197 a.fields(0) is ['-rwxrwxrwx', 'drwxrwxrwx+']
198 a.fields(1,0) is ['1 -rwxrwxrwx', '6 drwxrwxrwx+']
198 a.fields(1,0) is ['1 -rwxrwxrwx', '6 drwxrwxrwx+']
199 (note the joining by space).
199 (note the joining by space).
200 a.fields(-1) is ['ChangeLog', 'IPython']
200 a.fields(-1) is ['ChangeLog', 'IPython']
201
201
202 IndexErrors are ignored.
202 IndexErrors are ignored.
203
203
204 Without args, fields() just split()'s the strings.
204 Without args, fields() just split()'s the strings.
205 """
205 """
206 if len(fields) == 0:
206 if len(fields) == 0:
207 return [el.split() for el in self]
207 return [el.split() for el in self]
208
208
209 res = SList()
209 res = SList()
210 for el in [f.split() for f in self]:
210 for el in [f.split() for f in self]:
211 lineparts = []
211 lineparts = []
212
212
213 for fd in fields:
213 for fd in fields:
214 try:
214 try:
215 lineparts.append(el[fd])
215 lineparts.append(el[fd])
216 except IndexError:
216 except IndexError:
217 pass
217 pass
218 if lineparts:
218 if lineparts:
219 res.append(" ".join(lineparts))
219 res.append(" ".join(lineparts))
220
220
221 return res
221 return res
222
222
223 def sort(self,field= None, nums = False):
223 def sort(self,field= None, nums = False):
224 """ sort by specified fields (see fields())
224 """ sort by specified fields (see fields())
225
225
226 Example::
226 Example::
227 a.sort(1, nums = True)
227 a.sort(1, nums = True)
228
228
229 Sorts a by second field, in numerical order (so that 21 > 3)
229 Sorts a by second field, in numerical order (so that 21 > 3)
230
230
231 """
231 """
232
232
233 #decorate, sort, undecorate
233 #decorate, sort, undecorate
234 if field is not None:
234 if field is not None:
235 dsu = [[SList([line]).fields(field), line] for line in self]
235 dsu = [[SList([line]).fields(field), line] for line in self]
236 else:
236 else:
237 dsu = [[line, line] for line in self]
237 dsu = [[line, line] for line in self]
238 if nums:
238 if nums:
239 for i in range(len(dsu)):
239 for i in range(len(dsu)):
240 numstr = "".join([ch for ch in dsu[i][0] if ch.isdigit()])
240 numstr = "".join([ch for ch in dsu[i][0] if ch.isdigit()])
241 try:
241 try:
242 n = int(numstr)
242 n = int(numstr)
243 except ValueError:
243 except ValueError:
244 n = 0;
244 n = 0;
245 dsu[i][0] = n
245 dsu[i][0] = n
246
246
247
247
248 dsu.sort()
248 dsu.sort()
249 return SList([t[1] for t in dsu])
249 return SList([t[1] for t in dsu])
250
250
251
251
252 # FIXME: We need to reimplement type specific displayhook and then add this
252 # FIXME: We need to reimplement type specific displayhook and then add this
253 # back as a custom printer. This should also be moved outside utils into the
253 # back as a custom printer. This should also be moved outside utils into the
254 # core.
254 # core.
255
255
256 # def print_slist(arg):
256 # def print_slist(arg):
257 # """ Prettier (non-repr-like) and more informative printer for SList """
257 # """ Prettier (non-repr-like) and more informative printer for SList """
258 # print "SList (.p, .n, .l, .s, .grep(), .fields(), sort() available):"
258 # print "SList (.p, .n, .l, .s, .grep(), .fields(), sort() available):"
259 # if hasattr(arg, 'hideonce') and arg.hideonce:
259 # if hasattr(arg, 'hideonce') and arg.hideonce:
260 # arg.hideonce = False
260 # arg.hideonce = False
261 # return
261 # return
262 #
262 #
263 # nlprint(arg)
263 # nlprint(arg)
264 #
264 #
265 # print_slist = result_display.when_type(SList)(print_slist)
265 # print_slist = result_display.when_type(SList)(print_slist)
266
266
267
267
268 def esc_quotes(strng):
268 def esc_quotes(strng):
269 """Return the input string with single and double quotes escaped out"""
269 """Return the input string with single and double quotes escaped out"""
270
270
271 return strng.replace('"','\\"').replace("'","\\'")
271 return strng.replace('"','\\"').replace("'","\\'")
272
272
273
273
274 def make_quoted_expr(s):
274 def make_quoted_expr(s):
275 """Return string s in appropriate quotes, using raw string if possible.
275 """Return string s in appropriate quotes, using raw string if possible.
276
276
277 XXX - example removed because it caused encoding errors in documentation
277 XXX - example removed because it caused encoding errors in documentation
278 generation. We need a new example that doesn't contain invalid chars.
278 generation. We need a new example that doesn't contain invalid chars.
279
279
280 Note the use of raw string and padding at the end to allow trailing
280 Note the use of raw string and padding at the end to allow trailing
281 backslash.
281 backslash.
282 """
282 """
283
283
284 tail = ''
284 tail = ''
285 tailpadding = ''
285 tailpadding = ''
286 raw = ''
286 raw = ''
287 ucode = '' if py3compat.PY3 else 'u'
287 ucode = '' if py3compat.PY3 else 'u'
288 if "\\" in s:
288 if "\\" in s:
289 raw = 'r'
289 raw = 'r'
290 if s.endswith('\\'):
290 if s.endswith('\\'):
291 tail = '[:-1]'
291 tail = '[:-1]'
292 tailpadding = '_'
292 tailpadding = '_'
293 if '"' not in s:
293 if '"' not in s:
294 quote = '"'
294 quote = '"'
295 elif "'" not in s:
295 elif "'" not in s:
296 quote = "'"
296 quote = "'"
297 elif '"""' not in s and not s.endswith('"'):
297 elif '"""' not in s and not s.endswith('"'):
298 quote = '"""'
298 quote = '"""'
299 elif "'''" not in s and not s.endswith("'"):
299 elif "'''" not in s and not s.endswith("'"):
300 quote = "'''"
300 quote = "'''"
301 else:
301 else:
302 # give up, backslash-escaped string will do
302 # give up, backslash-escaped string will do
303 return '"%s"' % esc_quotes(s)
303 return '"%s"' % esc_quotes(s)
304 res = ucode + raw + quote + s + tailpadding + quote + tail
304 res = ucode + raw + quote + s + tailpadding + quote + tail
305 return res
305 return res
306
306
307
307
308 def qw(words,flat=0,sep=None,maxsplit=-1):
308 def qw(words,flat=0,sep=None,maxsplit=-1):
309 """Similar to Perl's qw() operator, but with some more options.
309 """Similar to Perl's qw() operator, but with some more options.
310
310
311 qw(words,flat=0,sep=' ',maxsplit=-1) -> words.split(sep,maxsplit)
311 qw(words,flat=0,sep=' ',maxsplit=-1) -> words.split(sep,maxsplit)
312
312
313 words can also be a list itself, and with flat=1, the output will be
313 words can also be a list itself, and with flat=1, the output will be
314 recursively flattened.
314 recursively flattened.
315
315
316 Examples:
316 Examples:
317
317
318 >>> qw('1 2')
318 >>> qw('1 2')
319 ['1', '2']
319 ['1', '2']
320
320
321 >>> qw(['a b','1 2',['m n','p q']])
321 >>> qw(['a b','1 2',['m n','p q']])
322 [['a', 'b'], ['1', '2'], [['m', 'n'], ['p', 'q']]]
322 [['a', 'b'], ['1', '2'], [['m', 'n'], ['p', 'q']]]
323
323
324 >>> qw(['a b','1 2',['m n','p q']],flat=1)
324 >>> qw(['a b','1 2',['m n','p q']],flat=1)
325 ['a', 'b', '1', '2', 'm', 'n', 'p', 'q']
325 ['a', 'b', '1', '2', 'm', 'n', 'p', 'q']
326 """
326 """
327
327
328 if isinstance(words, basestring):
328 if isinstance(words, basestring):
329 return [word.strip() for word in words.split(sep,maxsplit)
329 return [word.strip() for word in words.split(sep,maxsplit)
330 if word and not word.isspace() ]
330 if word and not word.isspace() ]
331 if flat:
331 if flat:
332 return flatten(map(qw,words,[1]*len(words)))
332 return flatten(map(qw,words,[1]*len(words)))
333 return map(qw,words)
333 return map(qw,words)
334
334
335
335
336 def qwflat(words,sep=None,maxsplit=-1):
336 def qwflat(words,sep=None,maxsplit=-1):
337 """Calls qw(words) in flat mode. It's just a convenient shorthand."""
337 """Calls qw(words) in flat mode. It's just a convenient shorthand."""
338 return qw(words,1,sep,maxsplit)
338 return qw(words,1,sep,maxsplit)
339
339
340
340
341 def qw_lol(indata):
341 def qw_lol(indata):
342 """qw_lol('a b') -> [['a','b']],
342 """qw_lol('a b') -> [['a','b']],
343 otherwise it's just a call to qw().
343 otherwise it's just a call to qw().
344
344
345 We need this to make sure the modules_some keys *always* end up as a
345 We need this to make sure the modules_some keys *always* end up as a
346 list of lists."""
346 list of lists."""
347
347
348 if isinstance(indata, basestring):
348 if isinstance(indata, basestring):
349 return [qw(indata)]
349 return [qw(indata)]
350 else:
350 else:
351 return qw(indata)
351 return qw(indata)
352
352
353
353
354 def grep(pat,list,case=1):
354 def grep(pat,list,case=1):
355 """Simple minded grep-like function.
355 """Simple minded grep-like function.
356 grep(pat,list) returns occurrences of pat in list, None on failure.
356 grep(pat,list) returns occurrences of pat in list, None on failure.
357
357
358 It only does simple string matching, with no support for regexps. Use the
358 It only does simple string matching, with no support for regexps. Use the
359 option case=0 for case-insensitive matching."""
359 option case=0 for case-insensitive matching."""
360
360
361 # This is pretty crude. At least it should implement copying only references
361 # This is pretty crude. At least it should implement copying only references
362 # to the original data in case it's big. Now it copies the data for output.
362 # to the original data in case it's big. Now it copies the data for output.
363 out=[]
363 out=[]
364 if case:
364 if case:
365 for term in list:
365 for term in list:
366 if term.find(pat)>-1: out.append(term)
366 if term.find(pat)>-1: out.append(term)
367 else:
367 else:
368 lpat=pat.lower()
368 lpat=pat.lower()
369 for term in list:
369 for term in list:
370 if term.lower().find(lpat)>-1: out.append(term)
370 if term.lower().find(lpat)>-1: out.append(term)
371
371
372 if len(out): return out
372 if len(out): return out
373 else: return None
373 else: return None
374
374
375
375
376 def dgrep(pat,*opts):
376 def dgrep(pat,*opts):
377 """Return grep() on dir()+dir(__builtins__).
377 """Return grep() on dir()+dir(__builtins__).
378
378
379 A very common use of grep() when working interactively."""
379 A very common use of grep() when working interactively."""
380
380
381 return grep(pat,dir(__main__)+dir(__main__.__builtins__),*opts)
381 return grep(pat,dir(__main__)+dir(__main__.__builtins__),*opts)
382
382
383
383
384 def idgrep(pat):
384 def idgrep(pat):
385 """Case-insensitive dgrep()"""
385 """Case-insensitive dgrep()"""
386
386
387 return dgrep(pat,0)
387 return dgrep(pat,0)
388
388
389
389
390 def igrep(pat,list):
390 def igrep(pat,list):
391 """Synonym for case-insensitive grep."""
391 """Synonym for case-insensitive grep."""
392
392
393 return grep(pat,list,case=0)
393 return grep(pat,list,case=0)
394
394
395
395
396 def indent(instr,nspaces=4, ntabs=0, flatten=False):
396 def indent(instr,nspaces=4, ntabs=0, flatten=False):
397 """Indent a string a given number of spaces or tabstops.
397 """Indent a string a given number of spaces or tabstops.
398
398
399 indent(str,nspaces=4,ntabs=0) -> indent str by ntabs+nspaces.
399 indent(str,nspaces=4,ntabs=0) -> indent str by ntabs+nspaces.
400
400
401 Parameters
401 Parameters
402 ----------
402 ----------
403
403
404 instr : basestring
404 instr : basestring
405 The string to be indented.
405 The string to be indented.
406 nspaces : int (default: 4)
406 nspaces : int (default: 4)
407 The number of spaces to be indented.
407 The number of spaces to be indented.
408 ntabs : int (default: 0)
408 ntabs : int (default: 0)
409 The number of tabs to be indented.
409 The number of tabs to be indented.
410 flatten : bool (default: False)
410 flatten : bool (default: False)
411 Whether to scrub existing indentation. If True, all lines will be
411 Whether to scrub existing indentation. If True, all lines will be
412 aligned to the same indentation. If False, existing indentation will
412 aligned to the same indentation. If False, existing indentation will
413 be strictly increased.
413 be strictly increased.
414
414
415 Returns
415 Returns
416 -------
416 -------
417
417
418 str|unicode : string indented by ntabs and nspaces.
418 str|unicode : string indented by ntabs and nspaces.
419
419
420 """
420 """
421 if instr is None:
421 if instr is None:
422 return
422 return
423 ind = '\t'*ntabs+' '*nspaces
423 ind = '\t'*ntabs+' '*nspaces
424 if flatten:
424 if flatten:
425 pat = re.compile(r'^\s*', re.MULTILINE)
425 pat = re.compile(r'^\s*', re.MULTILINE)
426 else:
426 else:
427 pat = re.compile(r'^', re.MULTILINE)
427 pat = re.compile(r'^', re.MULTILINE)
428 outstr = re.sub(pat, ind, instr)
428 outstr = re.sub(pat, ind, instr)
429 if outstr.endswith(os.linesep+ind):
429 if outstr.endswith(os.linesep+ind):
430 return outstr[:-len(ind)]
430 return outstr[:-len(ind)]
431 else:
431 else:
432 return outstr
432 return outstr
433
433
434 def native_line_ends(filename,backup=1):
434 def native_line_ends(filename,backup=1):
435 """Convert (in-place) a file to line-ends native to the current OS.
435 """Convert (in-place) a file to line-ends native to the current OS.
436
436
437 If the optional backup argument is given as false, no backup of the
437 If the optional backup argument is given as false, no backup of the
438 original file is left. """
438 original file is left. """
439
439
440 backup_suffixes = {'posix':'~','dos':'.bak','nt':'.bak','mac':'.bak'}
440 backup_suffixes = {'posix':'~','dos':'.bak','nt':'.bak','mac':'.bak'}
441
441
442 bak_filename = filename + backup_suffixes[os.name]
442 bak_filename = filename + backup_suffixes[os.name]
443
443
444 original = open(filename).read()
444 original = open(filename).read()
445 shutil.copy2(filename,bak_filename)
445 shutil.copy2(filename,bak_filename)
446 try:
446 try:
447 new = open(filename,'wb')
447 new = open(filename,'wb')
448 new.write(os.linesep.join(original.splitlines()))
448 new.write(os.linesep.join(original.splitlines()))
449 new.write(os.linesep) # ALWAYS put an eol at the end of the file
449 new.write(os.linesep) # ALWAYS put an eol at the end of the file
450 new.close()
450 new.close()
451 except:
451 except:
452 os.rename(bak_filename,filename)
452 os.rename(bak_filename,filename)
453 if not backup:
453 if not backup:
454 try:
454 try:
455 os.remove(bak_filename)
455 os.remove(bak_filename)
456 except:
456 except:
457 pass
457 pass
458
458
459
459
460 def list_strings(arg):
460 def list_strings(arg):
461 """Always return a list of strings, given a string or list of strings
461 """Always return a list of strings, given a string or list of strings
462 as input.
462 as input.
463
463
464 :Examples:
464 :Examples:
465
465
466 In [7]: list_strings('A single string')
466 In [7]: list_strings('A single string')
467 Out[7]: ['A single string']
467 Out[7]: ['A single string']
468
468
469 In [8]: list_strings(['A single string in a list'])
469 In [8]: list_strings(['A single string in a list'])
470 Out[8]: ['A single string in a list']
470 Out[8]: ['A single string in a list']
471
471
472 In [9]: list_strings(['A','list','of','strings'])
472 In [9]: list_strings(['A','list','of','strings'])
473 Out[9]: ['A', 'list', 'of', 'strings']
473 Out[9]: ['A', 'list', 'of', 'strings']
474 """
474 """
475
475
476 if isinstance(arg,basestring): return [arg]
476 if isinstance(arg,basestring): return [arg]
477 else: return arg
477 else: return arg
478
478
479
479
480 def marquee(txt='',width=78,mark='*'):
480 def marquee(txt='',width=78,mark='*'):
481 """Return the input string centered in a 'marquee'.
481 """Return the input string centered in a 'marquee'.
482
482
483 :Examples:
483 :Examples:
484
484
485 In [16]: marquee('A test',40)
485 In [16]: marquee('A test',40)
486 Out[16]: '**************** A test ****************'
486 Out[16]: '**************** A test ****************'
487
487
488 In [17]: marquee('A test',40,'-')
488 In [17]: marquee('A test',40,'-')
489 Out[17]: '---------------- A test ----------------'
489 Out[17]: '---------------- A test ----------------'
490
490
491 In [18]: marquee('A test',40,' ')
491 In [18]: marquee('A test',40,' ')
492 Out[18]: ' A test '
492 Out[18]: ' A test '
493
493
494 """
494 """
495 if not txt:
495 if not txt:
496 return (mark*width)[:width]
496 return (mark*width)[:width]
497 nmark = (width-len(txt)-2)//len(mark)//2
497 nmark = (width-len(txt)-2)//len(mark)//2
498 if nmark < 0: nmark =0
498 if nmark < 0: nmark =0
499 marks = mark*nmark
499 marks = mark*nmark
500 return '%s %s %s' % (marks,txt,marks)
500 return '%s %s %s' % (marks,txt,marks)
501
501
502
502
503 ini_spaces_re = re.compile(r'^(\s+)')
503 ini_spaces_re = re.compile(r'^(\s+)')
504
504
505 def num_ini_spaces(strng):
505 def num_ini_spaces(strng):
506 """Return the number of initial spaces in a string"""
506 """Return the number of initial spaces in a string"""
507
507
508 ini_spaces = ini_spaces_re.match(strng)
508 ini_spaces = ini_spaces_re.match(strng)
509 if ini_spaces:
509 if ini_spaces:
510 return ini_spaces.end()
510 return ini_spaces.end()
511 else:
511 else:
512 return 0
512 return 0
513
513
514
514
515 def format_screen(strng):
515 def format_screen(strng):
516 """Format a string for screen printing.
516 """Format a string for screen printing.
517
517
518 This removes some latex-type format codes."""
518 This removes some latex-type format codes."""
519 # Paragraph continue
519 # Paragraph continue
520 par_re = re.compile(r'\\$',re.MULTILINE)
520 par_re = re.compile(r'\\$',re.MULTILINE)
521 strng = par_re.sub('',strng)
521 strng = par_re.sub('',strng)
522 return strng
522 return strng
523
523
524 def dedent(text):
524 def dedent(text):
525 """Equivalent of textwrap.dedent that ignores unindented first line.
525 """Equivalent of textwrap.dedent that ignores unindented first line.
526
526
527 This means it will still dedent strings like:
527 This means it will still dedent strings like:
528 '''foo
528 '''foo
529 is a bar
529 is a bar
530 '''
530 '''
531
531
532 For use in wrap_paragraphs.
532 For use in wrap_paragraphs.
533 """
533 """
534
534
535 if text.startswith('\n'):
535 if text.startswith('\n'):
536 # text starts with blank line, don't ignore the first line
536 # text starts with blank line, don't ignore the first line
537 return textwrap.dedent(text)
537 return textwrap.dedent(text)
538
538
539 # split first line
539 # split first line
540 splits = text.split('\n',1)
540 splits = text.split('\n',1)
541 if len(splits) == 1:
541 if len(splits) == 1:
542 # only one line
542 # only one line
543 return textwrap.dedent(text)
543 return textwrap.dedent(text)
544
544
545 first, rest = splits
545 first, rest = splits
546 # dedent everything but the first line
546 # dedent everything but the first line
547 rest = textwrap.dedent(rest)
547 rest = textwrap.dedent(rest)
548 return '\n'.join([first, rest])
548 return '\n'.join([first, rest])
549
549
550 def wrap_paragraphs(text, ncols=80):
550 def wrap_paragraphs(text, ncols=80):
551 """Wrap multiple paragraphs to fit a specified width.
551 """Wrap multiple paragraphs to fit a specified width.
552
552
553 This is equivalent to textwrap.wrap, but with support for multiple
553 This is equivalent to textwrap.wrap, but with support for multiple
554 paragraphs, as separated by empty lines.
554 paragraphs, as separated by empty lines.
555
555
556 Returns
556 Returns
557 -------
557 -------
558
558
559 list of complete paragraphs, wrapped to fill `ncols` columns.
559 list of complete paragraphs, wrapped to fill `ncols` columns.
560 """
560 """
561 paragraph_re = re.compile(r'\n(\s*\n)+', re.MULTILINE)
561 paragraph_re = re.compile(r'\n(\s*\n)+', re.MULTILINE)
562 text = dedent(text).strip()
562 text = dedent(text).strip()
563 paragraphs = paragraph_re.split(text)[::2] # every other entry is space
563 paragraphs = paragraph_re.split(text)[::2] # every other entry is space
564 out_ps = []
564 out_ps = []
565 indent_re = re.compile(r'\n\s+', re.MULTILINE)
565 indent_re = re.compile(r'\n\s+', re.MULTILINE)
566 for p in paragraphs:
566 for p in paragraphs:
567 # presume indentation that survives dedent is meaningful formatting,
567 # presume indentation that survives dedent is meaningful formatting,
568 # so don't fill unless text is flush.
568 # so don't fill unless text is flush.
569 if indent_re.search(p) is None:
569 if indent_re.search(p) is None:
570 # wrap paragraph
570 # wrap paragraph
571 p = textwrap.fill(p, ncols)
571 p = textwrap.fill(p, ncols)
572 out_ps.append(p)
572 out_ps.append(p)
573 return out_ps
573 return out_ps
574
574
575
575
576
576
577 class EvalFormatter(Formatter):
577 class EvalFormatter(Formatter):
578 """A String Formatter that allows evaluation of simple expressions.
578 """A String Formatter that allows evaluation of simple expressions.
579
579
580 Any time a format key is not found in the kwargs,
580 Any time a format key is not found in the kwargs,
581 it will be tried as an expression in the kwargs namespace.
581 it will be tried as an expression in the kwargs namespace.
582
582
583 This is to be used in templating cases, such as the parallel batch
583 This is to be used in templating cases, such as the parallel batch
584 script templates, where simple arithmetic on arguments is useful.
584 script templates, where simple arithmetic on arguments is useful.
585
585
586 Examples
586 Examples
587 --------
587 --------
588
588
589 In [1]: f = EvalFormatter()
589 In [1]: f = EvalFormatter()
590 In [2]: f.format('{n/4}', n=8)
590 In [2]: f.format('{n//4}', n=8)
591 Out[2]: '2'
591 Out[2]: '2'
592
592
593 In [3]: f.format('{range(3)}')
593 In [3]: f.format('{list(range(3))}')
594 Out[3]: '[0, 1, 2]'
594 Out[3]: '[0, 1, 2]'
595
595
596 In [4]: f.format('{3*2}')
596 In [4]: f.format('{3*2}')
597 Out[4]: '6'
597 Out[4]: '6'
598 """
598 """
599
599
600 # should we allow slicing by disabling the format_spec feature?
600 # should we allow slicing by disabling the format_spec feature?
601 allow_slicing = True
601 allow_slicing = True
602
602
603 # copied from Formatter._vformat with minor changes to allow eval
603 # copied from Formatter._vformat with minor changes to allow eval
604 # and replace the format_spec code with slicing
604 # and replace the format_spec code with slicing
605 def _vformat(self, format_string, args, kwargs, used_args, recursion_depth):
605 def _vformat(self, format_string, args, kwargs, used_args, recursion_depth):
606 if recursion_depth < 0:
606 if recursion_depth < 0:
607 raise ValueError('Max string recursion exceeded')
607 raise ValueError('Max string recursion exceeded')
608 result = []
608 result = []
609 for literal_text, field_name, format_spec, conversion in \
609 for literal_text, field_name, format_spec, conversion in \
610 self.parse(format_string):
610 self.parse(format_string):
611
611
612 # output the literal text
612 # output the literal text
613 if literal_text:
613 if literal_text:
614 result.append(literal_text)
614 result.append(literal_text)
615
615
616 # if there's a field, output it
616 # if there's a field, output it
617 if field_name is not None:
617 if field_name is not None:
618 # this is some markup, find the object and do
618 # this is some markup, find the object and do
619 # the formatting
619 # the formatting
620
620
621 if self.allow_slicing and format_spec:
621 if self.allow_slicing and format_spec:
622 # override format spec, to allow slicing:
622 # override format spec, to allow slicing:
623 field_name = ':'.join([field_name, format_spec])
623 field_name = ':'.join([field_name, format_spec])
624 format_spec = ''
624 format_spec = ''
625
625
626 # eval the contents of the field for the object
626 # eval the contents of the field for the object
627 # to be formatted
627 # to be formatted
628 obj = eval(field_name, kwargs)
628 obj = eval(field_name, kwargs)
629
629
630 # do any conversion on the resulting object
630 # do any conversion on the resulting object
631 obj = self.convert_field(obj, conversion)
631 obj = self.convert_field(obj, conversion)
632
632
633 # expand the format spec, if needed
633 # expand the format spec, if needed
634 format_spec = self._vformat(format_spec, args, kwargs,
634 format_spec = self._vformat(format_spec, args, kwargs,
635 used_args, recursion_depth-1)
635 used_args, recursion_depth-1)
636
636
637 # format the object and append to the result
637 # format the object and append to the result
638 result.append(self.format_field(obj, format_spec))
638 result.append(self.format_field(obj, format_spec))
639
639
640 return ''.join(result)
640 return ''.join(result)
641
641
642
642
643 def columnize(items, separator=' ', displaywidth=80):
643 def columnize(items, separator=' ', displaywidth=80):
644 """ Transform a list of strings into a single string with columns.
644 """ Transform a list of strings into a single string with columns.
645
645
646 Parameters
646 Parameters
647 ----------
647 ----------
648 items : sequence of strings
648 items : sequence of strings
649 The strings to process.
649 The strings to process.
650
650
651 separator : str, optional [default is two spaces]
651 separator : str, optional [default is two spaces]
652 The string that separates columns.
652 The string that separates columns.
653
653
654 displaywidth : int, optional [default is 80]
654 displaywidth : int, optional [default is 80]
655 Width of the display in number of characters.
655 Width of the display in number of characters.
656
656
657 Returns
657 Returns
658 -------
658 -------
659 The formatted string.
659 The formatted string.
660 """
660 """
661 # Note: this code is adapted from columnize 0.3.2.
661 # Note: this code is adapted from columnize 0.3.2.
662 # See http://code.google.com/p/pycolumnize/
662 # See http://code.google.com/p/pycolumnize/
663
663
664 # Some degenerate cases.
664 # Some degenerate cases.
665 size = len(items)
665 size = len(items)
666 if size == 0:
666 if size == 0:
667 return '\n'
667 return '\n'
668 elif size == 1:
668 elif size == 1:
669 return '%s\n' % items[0]
669 return '%s\n' % items[0]
670
670
671 # Special case: if any item is longer than the maximum width, there's no
671 # Special case: if any item is longer than the maximum width, there's no
672 # point in triggering the logic below...
672 # point in triggering the logic below...
673 item_len = map(len, items) # save these, we can reuse them below
673 item_len = map(len, items) # save these, we can reuse them below
674 longest = max(item_len)
674 longest = max(item_len)
675 if longest >= displaywidth:
675 if longest >= displaywidth:
676 return '\n'.join(items+[''])
676 return '\n'.join(items+[''])
677
677
678 # Try every row count from 1 upwards
678 # Try every row count from 1 upwards
679 array_index = lambda nrows, row, col: nrows*col + row
679 array_index = lambda nrows, row, col: nrows*col + row
680 for nrows in range(1, size):
680 for nrows in range(1, size):
681 ncols = (size + nrows - 1) // nrows
681 ncols = (size + nrows - 1) // nrows
682 colwidths = []
682 colwidths = []
683 totwidth = -len(separator)
683 totwidth = -len(separator)
684 for col in range(ncols):
684 for col in range(ncols):
685 # Get max column width for this column
685 # Get max column width for this column
686 colwidth = 0
686 colwidth = 0
687 for row in range(nrows):
687 for row in range(nrows):
688 i = array_index(nrows, row, col)
688 i = array_index(nrows, row, col)
689 if i >= size: break
689 if i >= size: break
690 x, len_x = items[i], item_len[i]
690 x, len_x = items[i], item_len[i]
691 colwidth = max(colwidth, len_x)
691 colwidth = max(colwidth, len_x)
692 colwidths.append(colwidth)
692 colwidths.append(colwidth)
693 totwidth += colwidth + len(separator)
693 totwidth += colwidth + len(separator)
694 if totwidth > displaywidth:
694 if totwidth > displaywidth:
695 break
695 break
696 if totwidth <= displaywidth:
696 if totwidth <= displaywidth:
697 break
697 break
698
698
699 # The smallest number of rows computed and the max widths for each
699 # The smallest number of rows computed and the max widths for each
700 # column has been obtained. Now we just have to format each of the rows.
700 # column has been obtained. Now we just have to format each of the rows.
701 string = ''
701 string = ''
702 for row in range(nrows):
702 for row in range(nrows):
703 texts = []
703 texts = []
704 for col in range(ncols):
704 for col in range(ncols):
705 i = row + nrows*col
705 i = row + nrows*col
706 if i >= size:
706 if i >= size:
707 texts.append('')
707 texts.append('')
708 else:
708 else:
709 texts.append(items[i])
709 texts.append(items[i])
710 while texts and not texts[-1]:
710 while texts and not texts[-1]:
711 del texts[-1]
711 del texts[-1]
712 for col in range(len(texts)):
712 for col in range(len(texts)):
713 texts[col] = texts[col].ljust(colwidths[col])
713 texts[col] = texts[col].ljust(colwidths[col])
714 string += '%s\n' % separator.join(texts)
714 string += '%s\n' % separator.join(texts)
715 return string
715 return string
@@ -1,1392 +1,1394
1 # encoding: utf-8
1 # encoding: utf-8
2 """
2 """
3 A lightweight Traits like module.
3 A lightweight Traits like module.
4
4
5 This is designed to provide a lightweight, simple, pure Python version of
5 This is designed to provide a lightweight, simple, pure Python version of
6 many of the capabilities of enthought.traits. This includes:
6 many of the capabilities of enthought.traits. This includes:
7
7
8 * Validation
8 * Validation
9 * Type specification with defaults
9 * Type specification with defaults
10 * Static and dynamic notification
10 * Static and dynamic notification
11 * Basic predefined types
11 * Basic predefined types
12 * An API that is similar to enthought.traits
12 * An API that is similar to enthought.traits
13
13
14 We don't support:
14 We don't support:
15
15
16 * Delegation
16 * Delegation
17 * Automatic GUI generation
17 * Automatic GUI generation
18 * A full set of trait types. Most importantly, we don't provide container
18 * A full set of trait types. Most importantly, we don't provide container
19 traits (list, dict, tuple) that can trigger notifications if their
19 traits (list, dict, tuple) that can trigger notifications if their
20 contents change.
20 contents change.
21 * API compatibility with enthought.traits
21 * API compatibility with enthought.traits
22
22
23 There are also some important difference in our design:
23 There are also some important difference in our design:
24
24
25 * enthought.traits does not validate default values. We do.
25 * enthought.traits does not validate default values. We do.
26
26
27 We choose to create this module because we need these capabilities, but
27 We choose to create this module because we need these capabilities, but
28 we need them to be pure Python so they work in all Python implementations,
28 we need them to be pure Python so they work in all Python implementations,
29 including Jython and IronPython.
29 including Jython and IronPython.
30
30
31 Authors:
31 Authors:
32
32
33 * Brian Granger
33 * Brian Granger
34 * Enthought, Inc. Some of the code in this file comes from enthought.traits
34 * Enthought, Inc. Some of the code in this file comes from enthought.traits
35 and is licensed under the BSD license. Also, many of the ideas also come
35 and is licensed under the BSD license. Also, many of the ideas also come
36 from enthought.traits even though our implementation is very different.
36 from enthought.traits even though our implementation is very different.
37 """
37 """
38
38
39 #-----------------------------------------------------------------------------
39 #-----------------------------------------------------------------------------
40 # Copyright (C) 2008-2009 The IPython Development Team
40 # Copyright (C) 2008-2009 The IPython Development Team
41 #
41 #
42 # Distributed under the terms of the BSD License. The full license is in
42 # Distributed under the terms of the BSD License. The full license is in
43 # the file COPYING, distributed as part of this software.
43 # the file COPYING, distributed as part of this software.
44 #-----------------------------------------------------------------------------
44 #-----------------------------------------------------------------------------
45
45
46 #-----------------------------------------------------------------------------
46 #-----------------------------------------------------------------------------
47 # Imports
47 # Imports
48 #-----------------------------------------------------------------------------
48 #-----------------------------------------------------------------------------
49
49
50
50
51 import inspect
51 import inspect
52 import re
52 import re
53 import sys
53 import sys
54 import types
54 import types
55 from types import FunctionType
55 from types import FunctionType
56 try:
56 try:
57 from types import ClassType, InstanceType
57 from types import ClassType, InstanceType
58 ClassTypes = (ClassType, type)
58 ClassTypes = (ClassType, type)
59 except:
59 except:
60 ClassTypes = (type,)
60 ClassTypes = (type,)
61
61
62 from .importstring import import_item
62 from .importstring import import_item
63 from IPython.utils import py3compat
63 from IPython.utils import py3compat
64
64
65 SequenceTypes = (list, tuple, set, frozenset)
65 SequenceTypes = (list, tuple, set, frozenset)
66
66
67 #-----------------------------------------------------------------------------
67 #-----------------------------------------------------------------------------
68 # Basic classes
68 # Basic classes
69 #-----------------------------------------------------------------------------
69 #-----------------------------------------------------------------------------
70
70
71
71
72 class NoDefaultSpecified ( object ): pass
72 class NoDefaultSpecified ( object ): pass
73 NoDefaultSpecified = NoDefaultSpecified()
73 NoDefaultSpecified = NoDefaultSpecified()
74
74
75
75
76 class Undefined ( object ): pass
76 class Undefined ( object ): pass
77 Undefined = Undefined()
77 Undefined = Undefined()
78
78
79 class TraitError(Exception):
79 class TraitError(Exception):
80 pass
80 pass
81
81
82 #-----------------------------------------------------------------------------
82 #-----------------------------------------------------------------------------
83 # Utilities
83 # Utilities
84 #-----------------------------------------------------------------------------
84 #-----------------------------------------------------------------------------
85
85
86
86
87 def class_of ( object ):
87 def class_of ( object ):
88 """ Returns a string containing the class name of an object with the
88 """ Returns a string containing the class name of an object with the
89 correct indefinite article ('a' or 'an') preceding it (e.g., 'an Image',
89 correct indefinite article ('a' or 'an') preceding it (e.g., 'an Image',
90 'a PlotValue').
90 'a PlotValue').
91 """
91 """
92 if isinstance( object, basestring ):
92 if isinstance( object, basestring ):
93 return add_article( object )
93 return add_article( object )
94
94
95 return add_article( object.__class__.__name__ )
95 return add_article( object.__class__.__name__ )
96
96
97
97
98 def add_article ( name ):
98 def add_article ( name ):
99 """ Returns a string containing the correct indefinite article ('a' or 'an')
99 """ Returns a string containing the correct indefinite article ('a' or 'an')
100 prefixed to the specified string.
100 prefixed to the specified string.
101 """
101 """
102 if name[:1].lower() in 'aeiou':
102 if name[:1].lower() in 'aeiou':
103 return 'an ' + name
103 return 'an ' + name
104
104
105 return 'a ' + name
105 return 'a ' + name
106
106
107
107
108 def repr_type(obj):
108 def repr_type(obj):
109 """ Return a string representation of a value and its type for readable
109 """ Return a string representation of a value and its type for readable
110 error messages.
110 error messages.
111 """
111 """
112 the_type = type(obj)
112 the_type = type(obj)
113 if (not py3compat.PY3) and the_type is InstanceType:
113 if (not py3compat.PY3) and the_type is InstanceType:
114 # Old-style class.
114 # Old-style class.
115 the_type = obj.__class__
115 the_type = obj.__class__
116 msg = '%r %r' % (obj, the_type)
116 msg = '%r %r' % (obj, the_type)
117 return msg
117 return msg
118
118
119
119
120 def parse_notifier_name(name):
120 def parse_notifier_name(name):
121 """Convert the name argument to a list of names.
121 """Convert the name argument to a list of names.
122
122
123 Examples
123 Examples
124 --------
124 --------
125
125
126 >>> parse_notifier_name('a')
126 >>> parse_notifier_name('a')
127 ['a']
127 ['a']
128 >>> parse_notifier_name(['a','b'])
128 >>> parse_notifier_name(['a','b'])
129 ['a', 'b']
129 ['a', 'b']
130 >>> parse_notifier_name(None)
130 >>> parse_notifier_name(None)
131 ['anytrait']
131 ['anytrait']
132 """
132 """
133 if isinstance(name, str):
133 if isinstance(name, str):
134 return [name]
134 return [name]
135 elif name is None:
135 elif name is None:
136 return ['anytrait']
136 return ['anytrait']
137 elif isinstance(name, (list, tuple)):
137 elif isinstance(name, (list, tuple)):
138 for n in name:
138 for n in name:
139 assert isinstance(n, str), "names must be strings"
139 assert isinstance(n, str), "names must be strings"
140 return name
140 return name
141
141
142
142
143 class _SimpleTest:
143 class _SimpleTest:
144 def __init__ ( self, value ): self.value = value
144 def __init__ ( self, value ): self.value = value
145 def __call__ ( self, test ):
145 def __call__ ( self, test ):
146 return test == self.value
146 return test == self.value
147 def __repr__(self):
147 def __repr__(self):
148 return "<SimpleTest(%r)" % self.value
148 return "<SimpleTest(%r)" % self.value
149 def __str__(self):
149 def __str__(self):
150 return self.__repr__()
150 return self.__repr__()
151
151
152
152
153 def getmembers(object, predicate=None):
153 def getmembers(object, predicate=None):
154 """A safe version of inspect.getmembers that handles missing attributes.
154 """A safe version of inspect.getmembers that handles missing attributes.
155
155
156 This is useful when there are descriptor based attributes that for
156 This is useful when there are descriptor based attributes that for
157 some reason raise AttributeError even though they exist. This happens
157 some reason raise AttributeError even though they exist. This happens
158 in zope.inteface with the __provides__ attribute.
158 in zope.inteface with the __provides__ attribute.
159 """
159 """
160 results = []
160 results = []
161 for key in dir(object):
161 for key in dir(object):
162 try:
162 try:
163 value = getattr(object, key)
163 value = getattr(object, key)
164 except AttributeError:
164 except AttributeError:
165 pass
165 pass
166 else:
166 else:
167 if not predicate or predicate(value):
167 if not predicate or predicate(value):
168 results.append((key, value))
168 results.append((key, value))
169 results.sort()
169 results.sort()
170 return results
170 return results
171
171
172
172
173 #-----------------------------------------------------------------------------
173 #-----------------------------------------------------------------------------
174 # Base TraitType for all traits
174 # Base TraitType for all traits
175 #-----------------------------------------------------------------------------
175 #-----------------------------------------------------------------------------
176
176
177
177
178 class TraitType(object):
178 class TraitType(object):
179 """A base class for all trait descriptors.
179 """A base class for all trait descriptors.
180
180
181 Notes
181 Notes
182 -----
182 -----
183 Our implementation of traits is based on Python's descriptor
183 Our implementation of traits is based on Python's descriptor
184 prototol. This class is the base class for all such descriptors. The
184 prototol. This class is the base class for all such descriptors. The
185 only magic we use is a custom metaclass for the main :class:`HasTraits`
185 only magic we use is a custom metaclass for the main :class:`HasTraits`
186 class that does the following:
186 class that does the following:
187
187
188 1. Sets the :attr:`name` attribute of every :class:`TraitType`
188 1. Sets the :attr:`name` attribute of every :class:`TraitType`
189 instance in the class dict to the name of the attribute.
189 instance in the class dict to the name of the attribute.
190 2. Sets the :attr:`this_class` attribute of every :class:`TraitType`
190 2. Sets the :attr:`this_class` attribute of every :class:`TraitType`
191 instance in the class dict to the *class* that declared the trait.
191 instance in the class dict to the *class* that declared the trait.
192 This is used by the :class:`This` trait to allow subclasses to
192 This is used by the :class:`This` trait to allow subclasses to
193 accept superclasses for :class:`This` values.
193 accept superclasses for :class:`This` values.
194 """
194 """
195
195
196
196
197 metadata = {}
197 metadata = {}
198 default_value = Undefined
198 default_value = Undefined
199 info_text = 'any value'
199 info_text = 'any value'
200
200
201 def __init__(self, default_value=NoDefaultSpecified, **metadata):
201 def __init__(self, default_value=NoDefaultSpecified, **metadata):
202 """Create a TraitType.
202 """Create a TraitType.
203 """
203 """
204 if default_value is not NoDefaultSpecified:
204 if default_value is not NoDefaultSpecified:
205 self.default_value = default_value
205 self.default_value = default_value
206
206
207 if len(metadata) > 0:
207 if len(metadata) > 0:
208 if len(self.metadata) > 0:
208 if len(self.metadata) > 0:
209 self._metadata = self.metadata.copy()
209 self._metadata = self.metadata.copy()
210 self._metadata.update(metadata)
210 self._metadata.update(metadata)
211 else:
211 else:
212 self._metadata = metadata
212 self._metadata = metadata
213 else:
213 else:
214 self._metadata = self.metadata
214 self._metadata = self.metadata
215
215
216 self.init()
216 self.init()
217
217
218 def init(self):
218 def init(self):
219 pass
219 pass
220
220
221 def get_default_value(self):
221 def get_default_value(self):
222 """Create a new instance of the default value."""
222 """Create a new instance of the default value."""
223 return self.default_value
223 return self.default_value
224
224
225 def instance_init(self, obj):
225 def instance_init(self, obj):
226 """This is called by :meth:`HasTraits.__new__` to finish init'ing.
226 """This is called by :meth:`HasTraits.__new__` to finish init'ing.
227
227
228 Some stages of initialization must be delayed until the parent
228 Some stages of initialization must be delayed until the parent
229 :class:`HasTraits` instance has been created. This method is
229 :class:`HasTraits` instance has been created. This method is
230 called in :meth:`HasTraits.__new__` after the instance has been
230 called in :meth:`HasTraits.__new__` after the instance has been
231 created.
231 created.
232
232
233 This method trigger the creation and validation of default values
233 This method trigger the creation and validation of default values
234 and also things like the resolution of str given class names in
234 and also things like the resolution of str given class names in
235 :class:`Type` and :class`Instance`.
235 :class:`Type` and :class`Instance`.
236
236
237 Parameters
237 Parameters
238 ----------
238 ----------
239 obj : :class:`HasTraits` instance
239 obj : :class:`HasTraits` instance
240 The parent :class:`HasTraits` instance that has just been
240 The parent :class:`HasTraits` instance that has just been
241 created.
241 created.
242 """
242 """
243 self.set_default_value(obj)
243 self.set_default_value(obj)
244
244
245 def set_default_value(self, obj):
245 def set_default_value(self, obj):
246 """Set the default value on a per instance basis.
246 """Set the default value on a per instance basis.
247
247
248 This method is called by :meth:`instance_init` to create and
248 This method is called by :meth:`instance_init` to create and
249 validate the default value. The creation and validation of
249 validate the default value. The creation and validation of
250 default values must be delayed until the parent :class:`HasTraits`
250 default values must be delayed until the parent :class:`HasTraits`
251 class has been instantiated.
251 class has been instantiated.
252 """
252 """
253 # Check for a deferred initializer defined in the same class as the
253 # Check for a deferred initializer defined in the same class as the
254 # trait declaration or above.
254 # trait declaration or above.
255 mro = type(obj).mro()
255 mro = type(obj).mro()
256 meth_name = '_%s_default' % self.name
256 meth_name = '_%s_default' % self.name
257 for cls in mro[:mro.index(self.this_class)+1]:
257 for cls in mro[:mro.index(self.this_class)+1]:
258 if meth_name in cls.__dict__:
258 if meth_name in cls.__dict__:
259 break
259 break
260 else:
260 else:
261 # We didn't find one. Do static initialization.
261 # We didn't find one. Do static initialization.
262 dv = self.get_default_value()
262 dv = self.get_default_value()
263 newdv = self._validate(obj, dv)
263 newdv = self._validate(obj, dv)
264 obj._trait_values[self.name] = newdv
264 obj._trait_values[self.name] = newdv
265 return
265 return
266 # Complete the dynamic initialization.
266 # Complete the dynamic initialization.
267 obj._trait_dyn_inits[self.name] = cls.__dict__[meth_name]
267 obj._trait_dyn_inits[self.name] = cls.__dict__[meth_name]
268
268
269 def __get__(self, obj, cls=None):
269 def __get__(self, obj, cls=None):
270 """Get the value of the trait by self.name for the instance.
270 """Get the value of the trait by self.name for the instance.
271
271
272 Default values are instantiated when :meth:`HasTraits.__new__`
272 Default values are instantiated when :meth:`HasTraits.__new__`
273 is called. Thus by the time this method gets called either the
273 is called. Thus by the time this method gets called either the
274 default value or a user defined value (they called :meth:`__set__`)
274 default value or a user defined value (they called :meth:`__set__`)
275 is in the :class:`HasTraits` instance.
275 is in the :class:`HasTraits` instance.
276 """
276 """
277 if obj is None:
277 if obj is None:
278 return self
278 return self
279 else:
279 else:
280 try:
280 try:
281 value = obj._trait_values[self.name]
281 value = obj._trait_values[self.name]
282 except KeyError:
282 except KeyError:
283 # Check for a dynamic initializer.
283 # Check for a dynamic initializer.
284 if self.name in obj._trait_dyn_inits:
284 if self.name in obj._trait_dyn_inits:
285 value = obj._trait_dyn_inits[self.name](obj)
285 value = obj._trait_dyn_inits[self.name](obj)
286 # FIXME: Do we really validate here?
286 # FIXME: Do we really validate here?
287 value = self._validate(obj, value)
287 value = self._validate(obj, value)
288 obj._trait_values[self.name] = value
288 obj._trait_values[self.name] = value
289 return value
289 return value
290 else:
290 else:
291 raise TraitError('Unexpected error in TraitType: '
291 raise TraitError('Unexpected error in TraitType: '
292 'both default value and dynamic initializer are '
292 'both default value and dynamic initializer are '
293 'absent.')
293 'absent.')
294 except Exception:
294 except Exception:
295 # HasTraits should call set_default_value to populate
295 # HasTraits should call set_default_value to populate
296 # this. So this should never be reached.
296 # this. So this should never be reached.
297 raise TraitError('Unexpected error in TraitType: '
297 raise TraitError('Unexpected error in TraitType: '
298 'default value not set properly')
298 'default value not set properly')
299 else:
299 else:
300 return value
300 return value
301
301
302 def __set__(self, obj, value):
302 def __set__(self, obj, value):
303 new_value = self._validate(obj, value)
303 new_value = self._validate(obj, value)
304 old_value = self.__get__(obj)
304 old_value = self.__get__(obj)
305 if old_value != new_value:
305 if old_value != new_value:
306 obj._trait_values[self.name] = new_value
306 obj._trait_values[self.name] = new_value
307 obj._notify_trait(self.name, old_value, new_value)
307 obj._notify_trait(self.name, old_value, new_value)
308
308
309 def _validate(self, obj, value):
309 def _validate(self, obj, value):
310 if hasattr(self, 'validate'):
310 if hasattr(self, 'validate'):
311 return self.validate(obj, value)
311 return self.validate(obj, value)
312 elif hasattr(self, 'is_valid_for'):
312 elif hasattr(self, 'is_valid_for'):
313 valid = self.is_valid_for(value)
313 valid = self.is_valid_for(value)
314 if valid:
314 if valid:
315 return value
315 return value
316 else:
316 else:
317 raise TraitError('invalid value for type: %r' % value)
317 raise TraitError('invalid value for type: %r' % value)
318 elif hasattr(self, 'value_for'):
318 elif hasattr(self, 'value_for'):
319 return self.value_for(value)
319 return self.value_for(value)
320 else:
320 else:
321 return value
321 return value
322
322
323 def info(self):
323 def info(self):
324 return self.info_text
324 return self.info_text
325
325
326 def error(self, obj, value):
326 def error(self, obj, value):
327 if obj is not None:
327 if obj is not None:
328 e = "The '%s' trait of %s instance must be %s, but a value of %s was specified." \
328 e = "The '%s' trait of %s instance must be %s, but a value of %s was specified." \
329 % (self.name, class_of(obj),
329 % (self.name, class_of(obj),
330 self.info(), repr_type(value))
330 self.info(), repr_type(value))
331 else:
331 else:
332 e = "The '%s' trait must be %s, but a value of %r was specified." \
332 e = "The '%s' trait must be %s, but a value of %r was specified." \
333 % (self.name, self.info(), repr_type(value))
333 % (self.name, self.info(), repr_type(value))
334 raise TraitError(e)
334 raise TraitError(e)
335
335
336 def get_metadata(self, key):
336 def get_metadata(self, key):
337 return getattr(self, '_metadata', {}).get(key, None)
337 return getattr(self, '_metadata', {}).get(key, None)
338
338
339 def set_metadata(self, key, value):
339 def set_metadata(self, key, value):
340 getattr(self, '_metadata', {})[key] = value
340 getattr(self, '_metadata', {})[key] = value
341
341
342
342
343 #-----------------------------------------------------------------------------
343 #-----------------------------------------------------------------------------
344 # The HasTraits implementation
344 # The HasTraits implementation
345 #-----------------------------------------------------------------------------
345 #-----------------------------------------------------------------------------
346
346
347
347
348 class MetaHasTraits(type):
348 class MetaHasTraits(type):
349 """A metaclass for HasTraits.
349 """A metaclass for HasTraits.
350
350
351 This metaclass makes sure that any TraitType class attributes are
351 This metaclass makes sure that any TraitType class attributes are
352 instantiated and sets their name attribute.
352 instantiated and sets their name attribute.
353 """
353 """
354
354
355 def __new__(mcls, name, bases, classdict):
355 def __new__(mcls, name, bases, classdict):
356 """Create the HasTraits class.
356 """Create the HasTraits class.
357
357
358 This instantiates all TraitTypes in the class dict and sets their
358 This instantiates all TraitTypes in the class dict and sets their
359 :attr:`name` attribute.
359 :attr:`name` attribute.
360 """
360 """
361 # print "MetaHasTraitlets (mcls, name): ", mcls, name
361 # print "MetaHasTraitlets (mcls, name): ", mcls, name
362 # print "MetaHasTraitlets (bases): ", bases
362 # print "MetaHasTraitlets (bases): ", bases
363 # print "MetaHasTraitlets (classdict): ", classdict
363 # print "MetaHasTraitlets (classdict): ", classdict
364 for k,v in classdict.iteritems():
364 for k,v in classdict.iteritems():
365 if isinstance(v, TraitType):
365 if isinstance(v, TraitType):
366 v.name = k
366 v.name = k
367 elif inspect.isclass(v):
367 elif inspect.isclass(v):
368 if issubclass(v, TraitType):
368 if issubclass(v, TraitType):
369 vinst = v()
369 vinst = v()
370 vinst.name = k
370 vinst.name = k
371 classdict[k] = vinst
371 classdict[k] = vinst
372 return super(MetaHasTraits, mcls).__new__(mcls, name, bases, classdict)
372 return super(MetaHasTraits, mcls).__new__(mcls, name, bases, classdict)
373
373
374 def __init__(cls, name, bases, classdict):
374 def __init__(cls, name, bases, classdict):
375 """Finish initializing the HasTraits class.
375 """Finish initializing the HasTraits class.
376
376
377 This sets the :attr:`this_class` attribute of each TraitType in the
377 This sets the :attr:`this_class` attribute of each TraitType in the
378 class dict to the newly created class ``cls``.
378 class dict to the newly created class ``cls``.
379 """
379 """
380 for k, v in classdict.iteritems():
380 for k, v in classdict.iteritems():
381 if isinstance(v, TraitType):
381 if isinstance(v, TraitType):
382 v.this_class = cls
382 v.this_class = cls
383 super(MetaHasTraits, cls).__init__(name, bases, classdict)
383 super(MetaHasTraits, cls).__init__(name, bases, classdict)
384
384
385 class HasTraits(object):
385 class HasTraits(object):
386
386
387 __metaclass__ = MetaHasTraits
387 __metaclass__ = MetaHasTraits
388
388
389 def __new__(cls, **kw):
389 def __new__(cls, **kw):
390 # This is needed because in Python 2.6 object.__new__ only accepts
390 # This is needed because in Python 2.6 object.__new__ only accepts
391 # the cls argument.
391 # the cls argument.
392 new_meth = super(HasTraits, cls).__new__
392 new_meth = super(HasTraits, cls).__new__
393 if new_meth is object.__new__:
393 if new_meth is object.__new__:
394 inst = new_meth(cls)
394 inst = new_meth(cls)
395 else:
395 else:
396 inst = new_meth(cls, **kw)
396 inst = new_meth(cls, **kw)
397 inst._trait_values = {}
397 inst._trait_values = {}
398 inst._trait_notifiers = {}
398 inst._trait_notifiers = {}
399 inst._trait_dyn_inits = {}
399 inst._trait_dyn_inits = {}
400 # Here we tell all the TraitType instances to set their default
400 # Here we tell all the TraitType instances to set their default
401 # values on the instance.
401 # values on the instance.
402 for key in dir(cls):
402 for key in dir(cls):
403 # Some descriptors raise AttributeError like zope.interface's
403 # Some descriptors raise AttributeError like zope.interface's
404 # __provides__ attributes even though they exist. This causes
404 # __provides__ attributes even though they exist. This causes
405 # AttributeErrors even though they are listed in dir(cls).
405 # AttributeErrors even though they are listed in dir(cls).
406 try:
406 try:
407 value = getattr(cls, key)
407 value = getattr(cls, key)
408 except AttributeError:
408 except AttributeError:
409 pass
409 pass
410 else:
410 else:
411 if isinstance(value, TraitType):
411 if isinstance(value, TraitType):
412 value.instance_init(inst)
412 value.instance_init(inst)
413
413
414 return inst
414 return inst
415
415
416 def __init__(self, **kw):
416 def __init__(self, **kw):
417 # Allow trait values to be set using keyword arguments.
417 # Allow trait values to be set using keyword arguments.
418 # We need to use setattr for this to trigger validation and
418 # We need to use setattr for this to trigger validation and
419 # notifications.
419 # notifications.
420 for key, value in kw.iteritems():
420 for key, value in kw.iteritems():
421 setattr(self, key, value)
421 setattr(self, key, value)
422
422
423 def _notify_trait(self, name, old_value, new_value):
423 def _notify_trait(self, name, old_value, new_value):
424
424
425 # First dynamic ones
425 # First dynamic ones
426 callables = self._trait_notifiers.get(name,[])
426 callables = self._trait_notifiers.get(name,[])
427 more_callables = self._trait_notifiers.get('anytrait',[])
427 more_callables = self._trait_notifiers.get('anytrait',[])
428 callables.extend(more_callables)
428 callables.extend(more_callables)
429
429
430 # Now static ones
430 # Now static ones
431 try:
431 try:
432 cb = getattr(self, '_%s_changed' % name)
432 cb = getattr(self, '_%s_changed' % name)
433 except:
433 except:
434 pass
434 pass
435 else:
435 else:
436 callables.append(cb)
436 callables.append(cb)
437
437
438 # Call them all now
438 # Call them all now
439 for c in callables:
439 for c in callables:
440 # Traits catches and logs errors here. I allow them to raise
440 # Traits catches and logs errors here. I allow them to raise
441 if callable(c):
441 if callable(c):
442 argspec = inspect.getargspec(c)
442 argspec = inspect.getargspec(c)
443 nargs = len(argspec[0])
443 nargs = len(argspec[0])
444 # Bound methods have an additional 'self' argument
444 # Bound methods have an additional 'self' argument
445 # I don't know how to treat unbound methods, but they
445 # I don't know how to treat unbound methods, but they
446 # can't really be used for callbacks.
446 # can't really be used for callbacks.
447 if isinstance(c, types.MethodType):
447 if isinstance(c, types.MethodType):
448 offset = -1
448 offset = -1
449 else:
449 else:
450 offset = 0
450 offset = 0
451 if nargs + offset == 0:
451 if nargs + offset == 0:
452 c()
452 c()
453 elif nargs + offset == 1:
453 elif nargs + offset == 1:
454 c(name)
454 c(name)
455 elif nargs + offset == 2:
455 elif nargs + offset == 2:
456 c(name, new_value)
456 c(name, new_value)
457 elif nargs + offset == 3:
457 elif nargs + offset == 3:
458 c(name, old_value, new_value)
458 c(name, old_value, new_value)
459 else:
459 else:
460 raise TraitError('a trait changed callback '
460 raise TraitError('a trait changed callback '
461 'must have 0-3 arguments.')
461 'must have 0-3 arguments.')
462 else:
462 else:
463 raise TraitError('a trait changed callback '
463 raise TraitError('a trait changed callback '
464 'must be callable.')
464 'must be callable.')
465
465
466
466
467 def _add_notifiers(self, handler, name):
467 def _add_notifiers(self, handler, name):
468 if not self._trait_notifiers.has_key(name):
468 if not self._trait_notifiers.has_key(name):
469 nlist = []
469 nlist = []
470 self._trait_notifiers[name] = nlist
470 self._trait_notifiers[name] = nlist
471 else:
471 else:
472 nlist = self._trait_notifiers[name]
472 nlist = self._trait_notifiers[name]
473 if handler not in nlist:
473 if handler not in nlist:
474 nlist.append(handler)
474 nlist.append(handler)
475
475
476 def _remove_notifiers(self, handler, name):
476 def _remove_notifiers(self, handler, name):
477 if self._trait_notifiers.has_key(name):
477 if self._trait_notifiers.has_key(name):
478 nlist = self._trait_notifiers[name]
478 nlist = self._trait_notifiers[name]
479 try:
479 try:
480 index = nlist.index(handler)
480 index = nlist.index(handler)
481 except ValueError:
481 except ValueError:
482 pass
482 pass
483 else:
483 else:
484 del nlist[index]
484 del nlist[index]
485
485
486 def on_trait_change(self, handler, name=None, remove=False):
486 def on_trait_change(self, handler, name=None, remove=False):
487 """Setup a handler to be called when a trait changes.
487 """Setup a handler to be called when a trait changes.
488
488
489 This is used to setup dynamic notifications of trait changes.
489 This is used to setup dynamic notifications of trait changes.
490
490
491 Static handlers can be created by creating methods on a HasTraits
491 Static handlers can be created by creating methods on a HasTraits
492 subclass with the naming convention '_[traitname]_changed'. Thus,
492 subclass with the naming convention '_[traitname]_changed'. Thus,
493 to create static handler for the trait 'a', create the method
493 to create static handler for the trait 'a', create the method
494 _a_changed(self, name, old, new) (fewer arguments can be used, see
494 _a_changed(self, name, old, new) (fewer arguments can be used, see
495 below).
495 below).
496
496
497 Parameters
497 Parameters
498 ----------
498 ----------
499 handler : callable
499 handler : callable
500 A callable that is called when a trait changes. Its
500 A callable that is called when a trait changes. Its
501 signature can be handler(), handler(name), handler(name, new)
501 signature can be handler(), handler(name), handler(name, new)
502 or handler(name, old, new).
502 or handler(name, old, new).
503 name : list, str, None
503 name : list, str, None
504 If None, the handler will apply to all traits. If a list
504 If None, the handler will apply to all traits. If a list
505 of str, handler will apply to all names in the list. If a
505 of str, handler will apply to all names in the list. If a
506 str, the handler will apply just to that name.
506 str, the handler will apply just to that name.
507 remove : bool
507 remove : bool
508 If False (the default), then install the handler. If True
508 If False (the default), then install the handler. If True
509 then unintall it.
509 then unintall it.
510 """
510 """
511 if remove:
511 if remove:
512 names = parse_notifier_name(name)
512 names = parse_notifier_name(name)
513 for n in names:
513 for n in names:
514 self._remove_notifiers(handler, n)
514 self._remove_notifiers(handler, n)
515 else:
515 else:
516 names = parse_notifier_name(name)
516 names = parse_notifier_name(name)
517 for n in names:
517 for n in names:
518 self._add_notifiers(handler, n)
518 self._add_notifiers(handler, n)
519
519
520 @classmethod
520 @classmethod
521 def class_trait_names(cls, **metadata):
521 def class_trait_names(cls, **metadata):
522 """Get a list of all the names of this classes traits.
522 """Get a list of all the names of this classes traits.
523
523
524 This method is just like the :meth:`trait_names` method, but is unbound.
524 This method is just like the :meth:`trait_names` method, but is unbound.
525 """
525 """
526 return cls.class_traits(**metadata).keys()
526 return cls.class_traits(**metadata).keys()
527
527
528 @classmethod
528 @classmethod
529 def class_traits(cls, **metadata):
529 def class_traits(cls, **metadata):
530 """Get a list of all the traits of this class.
530 """Get a list of all the traits of this class.
531
531
532 This method is just like the :meth:`traits` method, but is unbound.
532 This method is just like the :meth:`traits` method, but is unbound.
533
533
534 The TraitTypes returned don't know anything about the values
534 The TraitTypes returned don't know anything about the values
535 that the various HasTrait's instances are holding.
535 that the various HasTrait's instances are holding.
536
536
537 This follows the same algorithm as traits does and does not allow
537 This follows the same algorithm as traits does and does not allow
538 for any simple way of specifying merely that a metadata name
538 for any simple way of specifying merely that a metadata name
539 exists, but has any value. This is because get_metadata returns
539 exists, but has any value. This is because get_metadata returns
540 None if a metadata key doesn't exist.
540 None if a metadata key doesn't exist.
541 """
541 """
542 traits = dict([memb for memb in getmembers(cls) if \
542 traits = dict([memb for memb in getmembers(cls) if \
543 isinstance(memb[1], TraitType)])
543 isinstance(memb[1], TraitType)])
544
544
545 if len(metadata) == 0:
545 if len(metadata) == 0:
546 return traits
546 return traits
547
547
548 for meta_name, meta_eval in metadata.items():
548 for meta_name, meta_eval in metadata.items():
549 if type(meta_eval) is not FunctionType:
549 if type(meta_eval) is not FunctionType:
550 metadata[meta_name] = _SimpleTest(meta_eval)
550 metadata[meta_name] = _SimpleTest(meta_eval)
551
551
552 result = {}
552 result = {}
553 for name, trait in traits.items():
553 for name, trait in traits.items():
554 for meta_name, meta_eval in metadata.items():
554 for meta_name, meta_eval in metadata.items():
555 if not meta_eval(trait.get_metadata(meta_name)):
555 if not meta_eval(trait.get_metadata(meta_name)):
556 break
556 break
557 else:
557 else:
558 result[name] = trait
558 result[name] = trait
559
559
560 return result
560 return result
561
561
562 def trait_names(self, **metadata):
562 def trait_names(self, **metadata):
563 """Get a list of all the names of this classes traits."""
563 """Get a list of all the names of this classes traits."""
564 return self.traits(**metadata).keys()
564 return self.traits(**metadata).keys()
565
565
566 def traits(self, **metadata):
566 def traits(self, **metadata):
567 """Get a list of all the traits of this class.
567 """Get a list of all the traits of this class.
568
568
569 The TraitTypes returned don't know anything about the values
569 The TraitTypes returned don't know anything about the values
570 that the various HasTrait's instances are holding.
570 that the various HasTrait's instances are holding.
571
571
572 This follows the same algorithm as traits does and does not allow
572 This follows the same algorithm as traits does and does not allow
573 for any simple way of specifying merely that a metadata name
573 for any simple way of specifying merely that a metadata name
574 exists, but has any value. This is because get_metadata returns
574 exists, but has any value. This is because get_metadata returns
575 None if a metadata key doesn't exist.
575 None if a metadata key doesn't exist.
576 """
576 """
577 traits = dict([memb for memb in getmembers(self.__class__) if \
577 traits = dict([memb for memb in getmembers(self.__class__) if \
578 isinstance(memb[1], TraitType)])
578 isinstance(memb[1], TraitType)])
579
579
580 if len(metadata) == 0:
580 if len(metadata) == 0:
581 return traits
581 return traits
582
582
583 for meta_name, meta_eval in metadata.items():
583 for meta_name, meta_eval in metadata.items():
584 if type(meta_eval) is not FunctionType:
584 if type(meta_eval) is not FunctionType:
585 metadata[meta_name] = _SimpleTest(meta_eval)
585 metadata[meta_name] = _SimpleTest(meta_eval)
586
586
587 result = {}
587 result = {}
588 for name, trait in traits.items():
588 for name, trait in traits.items():
589 for meta_name, meta_eval in metadata.items():
589 for meta_name, meta_eval in metadata.items():
590 if not meta_eval(trait.get_metadata(meta_name)):
590 if not meta_eval(trait.get_metadata(meta_name)):
591 break
591 break
592 else:
592 else:
593 result[name] = trait
593 result[name] = trait
594
594
595 return result
595 return result
596
596
597 def trait_metadata(self, traitname, key):
597 def trait_metadata(self, traitname, key):
598 """Get metadata values for trait by key."""
598 """Get metadata values for trait by key."""
599 try:
599 try:
600 trait = getattr(self.__class__, traitname)
600 trait = getattr(self.__class__, traitname)
601 except AttributeError:
601 except AttributeError:
602 raise TraitError("Class %s does not have a trait named %s" %
602 raise TraitError("Class %s does not have a trait named %s" %
603 (self.__class__.__name__, traitname))
603 (self.__class__.__name__, traitname))
604 else:
604 else:
605 return trait.get_metadata(key)
605 return trait.get_metadata(key)
606
606
607 #-----------------------------------------------------------------------------
607 #-----------------------------------------------------------------------------
608 # Actual TraitTypes implementations/subclasses
608 # Actual TraitTypes implementations/subclasses
609 #-----------------------------------------------------------------------------
609 #-----------------------------------------------------------------------------
610
610
611 #-----------------------------------------------------------------------------
611 #-----------------------------------------------------------------------------
612 # TraitTypes subclasses for handling classes and instances of classes
612 # TraitTypes subclasses for handling classes and instances of classes
613 #-----------------------------------------------------------------------------
613 #-----------------------------------------------------------------------------
614
614
615
615
616 class ClassBasedTraitType(TraitType):
616 class ClassBasedTraitType(TraitType):
617 """A trait with error reporting for Type, Instance and This."""
617 """A trait with error reporting for Type, Instance and This."""
618
618
619 def error(self, obj, value):
619 def error(self, obj, value):
620 kind = type(value)
620 kind = type(value)
621 if (not py3compat.PY3) and kind is InstanceType:
621 if (not py3compat.PY3) and kind is InstanceType:
622 msg = 'class %s' % value.__class__.__name__
622 msg = 'class %s' % value.__class__.__name__
623 else:
623 else:
624 msg = '%s (i.e. %s)' % ( str( kind )[1:-1], repr( value ) )
624 msg = '%s (i.e. %s)' % ( str( kind )[1:-1], repr( value ) )
625
625
626 if obj is not None:
626 if obj is not None:
627 e = "The '%s' trait of %s instance must be %s, but a value of %s was specified." \
627 e = "The '%s' trait of %s instance must be %s, but a value of %s was specified." \
628 % (self.name, class_of(obj),
628 % (self.name, class_of(obj),
629 self.info(), msg)
629 self.info(), msg)
630 else:
630 else:
631 e = "The '%s' trait must be %s, but a value of %r was specified." \
631 e = "The '%s' trait must be %s, but a value of %r was specified." \
632 % (self.name, self.info(), msg)
632 % (self.name, self.info(), msg)
633
633
634 raise TraitError(e)
634 raise TraitError(e)
635
635
636
636
637 class Type(ClassBasedTraitType):
637 class Type(ClassBasedTraitType):
638 """A trait whose value must be a subclass of a specified class."""
638 """A trait whose value must be a subclass of a specified class."""
639
639
640 def __init__ (self, default_value=None, klass=None, allow_none=True, **metadata ):
640 def __init__ (self, default_value=None, klass=None, allow_none=True, **metadata ):
641 """Construct a Type trait
641 """Construct a Type trait
642
642
643 A Type trait specifies that its values must be subclasses of
643 A Type trait specifies that its values must be subclasses of
644 a particular class.
644 a particular class.
645
645
646 If only ``default_value`` is given, it is used for the ``klass`` as
646 If only ``default_value`` is given, it is used for the ``klass`` as
647 well.
647 well.
648
648
649 Parameters
649 Parameters
650 ----------
650 ----------
651 default_value : class, str or None
651 default_value : class, str or None
652 The default value must be a subclass of klass. If an str,
652 The default value must be a subclass of klass. If an str,
653 the str must be a fully specified class name, like 'foo.bar.Bah'.
653 the str must be a fully specified class name, like 'foo.bar.Bah'.
654 The string is resolved into real class, when the parent
654 The string is resolved into real class, when the parent
655 :class:`HasTraits` class is instantiated.
655 :class:`HasTraits` class is instantiated.
656 klass : class, str, None
656 klass : class, str, None
657 Values of this trait must be a subclass of klass. The klass
657 Values of this trait must be a subclass of klass. The klass
658 may be specified in a string like: 'foo.bar.MyClass'.
658 may be specified in a string like: 'foo.bar.MyClass'.
659 The string is resolved into real class, when the parent
659 The string is resolved into real class, when the parent
660 :class:`HasTraits` class is instantiated.
660 :class:`HasTraits` class is instantiated.
661 allow_none : boolean
661 allow_none : boolean
662 Indicates whether None is allowed as an assignable value. Even if
662 Indicates whether None is allowed as an assignable value. Even if
663 ``False``, the default value may be ``None``.
663 ``False``, the default value may be ``None``.
664 """
664 """
665 if default_value is None:
665 if default_value is None:
666 if klass is None:
666 if klass is None:
667 klass = object
667 klass = object
668 elif klass is None:
668 elif klass is None:
669 klass = default_value
669 klass = default_value
670
670
671 if not (inspect.isclass(klass) or isinstance(klass, basestring)):
671 if not (inspect.isclass(klass) or isinstance(klass, basestring)):
672 raise TraitError("A Type trait must specify a class.")
672 raise TraitError("A Type trait must specify a class.")
673
673
674 self.klass = klass
674 self.klass = klass
675 self._allow_none = allow_none
675 self._allow_none = allow_none
676
676
677 super(Type, self).__init__(default_value, **metadata)
677 super(Type, self).__init__(default_value, **metadata)
678
678
679 def validate(self, obj, value):
679 def validate(self, obj, value):
680 """Validates that the value is a valid object instance."""
680 """Validates that the value is a valid object instance."""
681 try:
681 try:
682 if issubclass(value, self.klass):
682 if issubclass(value, self.klass):
683 return value
683 return value
684 except:
684 except:
685 if (value is None) and (self._allow_none):
685 if (value is None) and (self._allow_none):
686 return value
686 return value
687
687
688 self.error(obj, value)
688 self.error(obj, value)
689
689
690 def info(self):
690 def info(self):
691 """ Returns a description of the trait."""
691 """ Returns a description of the trait."""
692 if isinstance(self.klass, basestring):
692 if isinstance(self.klass, basestring):
693 klass = self.klass
693 klass = self.klass
694 else:
694 else:
695 klass = self.klass.__name__
695 klass = self.klass.__name__
696 result = 'a subclass of ' + klass
696 result = 'a subclass of ' + klass
697 if self._allow_none:
697 if self._allow_none:
698 return result + ' or None'
698 return result + ' or None'
699 return result
699 return result
700
700
701 def instance_init(self, obj):
701 def instance_init(self, obj):
702 self._resolve_classes()
702 self._resolve_classes()
703 super(Type, self).instance_init(obj)
703 super(Type, self).instance_init(obj)
704
704
705 def _resolve_classes(self):
705 def _resolve_classes(self):
706 if isinstance(self.klass, basestring):
706 if isinstance(self.klass, basestring):
707 self.klass = import_item(self.klass)
707 self.klass = import_item(self.klass)
708 if isinstance(self.default_value, basestring):
708 if isinstance(self.default_value, basestring):
709 self.default_value = import_item(self.default_value)
709 self.default_value = import_item(self.default_value)
710
710
711 def get_default_value(self):
711 def get_default_value(self):
712 return self.default_value
712 return self.default_value
713
713
714
714
715 class DefaultValueGenerator(object):
715 class DefaultValueGenerator(object):
716 """A class for generating new default value instances."""
716 """A class for generating new default value instances."""
717
717
718 def __init__(self, *args, **kw):
718 def __init__(self, *args, **kw):
719 self.args = args
719 self.args = args
720 self.kw = kw
720 self.kw = kw
721
721
722 def generate(self, klass):
722 def generate(self, klass):
723 return klass(*self.args, **self.kw)
723 return klass(*self.args, **self.kw)
724
724
725
725
726 class Instance(ClassBasedTraitType):
726 class Instance(ClassBasedTraitType):
727 """A trait whose value must be an instance of a specified class.
727 """A trait whose value must be an instance of a specified class.
728
728
729 The value can also be an instance of a subclass of the specified class.
729 The value can also be an instance of a subclass of the specified class.
730 """
730 """
731
731
732 def __init__(self, klass=None, args=None, kw=None,
732 def __init__(self, klass=None, args=None, kw=None,
733 allow_none=True, **metadata ):
733 allow_none=True, **metadata ):
734 """Construct an Instance trait.
734 """Construct an Instance trait.
735
735
736 This trait allows values that are instances of a particular
736 This trait allows values that are instances of a particular
737 class or its sublclasses. Our implementation is quite different
737 class or its sublclasses. Our implementation is quite different
738 from that of enthough.traits as we don't allow instances to be used
738 from that of enthough.traits as we don't allow instances to be used
739 for klass and we handle the ``args`` and ``kw`` arguments differently.
739 for klass and we handle the ``args`` and ``kw`` arguments differently.
740
740
741 Parameters
741 Parameters
742 ----------
742 ----------
743 klass : class, str
743 klass : class, str
744 The class that forms the basis for the trait. Class names
744 The class that forms the basis for the trait. Class names
745 can also be specified as strings, like 'foo.bar.Bar'.
745 can also be specified as strings, like 'foo.bar.Bar'.
746 args : tuple
746 args : tuple
747 Positional arguments for generating the default value.
747 Positional arguments for generating the default value.
748 kw : dict
748 kw : dict
749 Keyword arguments for generating the default value.
749 Keyword arguments for generating the default value.
750 allow_none : bool
750 allow_none : bool
751 Indicates whether None is allowed as a value.
751 Indicates whether None is allowed as a value.
752
752
753 Default Value
753 Default Value
754 -------------
754 -------------
755 If both ``args`` and ``kw`` are None, then the default value is None.
755 If both ``args`` and ``kw`` are None, then the default value is None.
756 If ``args`` is a tuple and ``kw`` is a dict, then the default is
756 If ``args`` is a tuple and ``kw`` is a dict, then the default is
757 created as ``klass(*args, **kw)``. If either ``args`` or ``kw`` is
757 created as ``klass(*args, **kw)``. If either ``args`` or ``kw`` is
758 not (but not both), None is replace by ``()`` or ``{}``.
758 not (but not both), None is replace by ``()`` or ``{}``.
759 """
759 """
760
760
761 self._allow_none = allow_none
761 self._allow_none = allow_none
762
762
763 if (klass is None) or (not (inspect.isclass(klass) or isinstance(klass, basestring))):
763 if (klass is None) or (not (inspect.isclass(klass) or isinstance(klass, basestring))):
764 raise TraitError('The klass argument must be a class'
764 raise TraitError('The klass argument must be a class'
765 ' you gave: %r' % klass)
765 ' you gave: %r' % klass)
766 self.klass = klass
766 self.klass = klass
767
767
768 # self.klass is a class, so handle default_value
768 # self.klass is a class, so handle default_value
769 if args is None and kw is None:
769 if args is None and kw is None:
770 default_value = None
770 default_value = None
771 else:
771 else:
772 if args is None:
772 if args is None:
773 # kw is not None
773 # kw is not None
774 args = ()
774 args = ()
775 elif kw is None:
775 elif kw is None:
776 # args is not None
776 # args is not None
777 kw = {}
777 kw = {}
778
778
779 if not isinstance(kw, dict):
779 if not isinstance(kw, dict):
780 raise TraitError("The 'kw' argument must be a dict or None.")
780 raise TraitError("The 'kw' argument must be a dict or None.")
781 if not isinstance(args, tuple):
781 if not isinstance(args, tuple):
782 raise TraitError("The 'args' argument must be a tuple or None.")
782 raise TraitError("The 'args' argument must be a tuple or None.")
783
783
784 default_value = DefaultValueGenerator(*args, **kw)
784 default_value = DefaultValueGenerator(*args, **kw)
785
785
786 super(Instance, self).__init__(default_value, **metadata)
786 super(Instance, self).__init__(default_value, **metadata)
787
787
788 def validate(self, obj, value):
788 def validate(self, obj, value):
789 if value is None:
789 if value is None:
790 if self._allow_none:
790 if self._allow_none:
791 return value
791 return value
792 self.error(obj, value)
792 self.error(obj, value)
793
793
794 if isinstance(value, self.klass):
794 if isinstance(value, self.klass):
795 return value
795 return value
796 else:
796 else:
797 self.error(obj, value)
797 self.error(obj, value)
798
798
799 def info(self):
799 def info(self):
800 if isinstance(self.klass, basestring):
800 if isinstance(self.klass, basestring):
801 klass = self.klass
801 klass = self.klass
802 else:
802 else:
803 klass = self.klass.__name__
803 klass = self.klass.__name__
804 result = class_of(klass)
804 result = class_of(klass)
805 if self._allow_none:
805 if self._allow_none:
806 return result + ' or None'
806 return result + ' or None'
807
807
808 return result
808 return result
809
809
810 def instance_init(self, obj):
810 def instance_init(self, obj):
811 self._resolve_classes()
811 self._resolve_classes()
812 super(Instance, self).instance_init(obj)
812 super(Instance, self).instance_init(obj)
813
813
814 def _resolve_classes(self):
814 def _resolve_classes(self):
815 if isinstance(self.klass, basestring):
815 if isinstance(self.klass, basestring):
816 self.klass = import_item(self.klass)
816 self.klass = import_item(self.klass)
817
817
818 def get_default_value(self):
818 def get_default_value(self):
819 """Instantiate a default value instance.
819 """Instantiate a default value instance.
820
820
821 This is called when the containing HasTraits classes'
821 This is called when the containing HasTraits classes'
822 :meth:`__new__` method is called to ensure that a unique instance
822 :meth:`__new__` method is called to ensure that a unique instance
823 is created for each HasTraits instance.
823 is created for each HasTraits instance.
824 """
824 """
825 dv = self.default_value
825 dv = self.default_value
826 if isinstance(dv, DefaultValueGenerator):
826 if isinstance(dv, DefaultValueGenerator):
827 return dv.generate(self.klass)
827 return dv.generate(self.klass)
828 else:
828 else:
829 return dv
829 return dv
830
830
831
831
832 class This(ClassBasedTraitType):
832 class This(ClassBasedTraitType):
833 """A trait for instances of the class containing this trait.
833 """A trait for instances of the class containing this trait.
834
834
835 Because how how and when class bodies are executed, the ``This``
835 Because how how and when class bodies are executed, the ``This``
836 trait can only have a default value of None. This, and because we
836 trait can only have a default value of None. This, and because we
837 always validate default values, ``allow_none`` is *always* true.
837 always validate default values, ``allow_none`` is *always* true.
838 """
838 """
839
839
840 info_text = 'an instance of the same type as the receiver or None'
840 info_text = 'an instance of the same type as the receiver or None'
841
841
842 def __init__(self, **metadata):
842 def __init__(self, **metadata):
843 super(This, self).__init__(None, **metadata)
843 super(This, self).__init__(None, **metadata)
844
844
845 def validate(self, obj, value):
845 def validate(self, obj, value):
846 # What if value is a superclass of obj.__class__? This is
846 # What if value is a superclass of obj.__class__? This is
847 # complicated if it was the superclass that defined the This
847 # complicated if it was the superclass that defined the This
848 # trait.
848 # trait.
849 if isinstance(value, self.this_class) or (value is None):
849 if isinstance(value, self.this_class) or (value is None):
850 return value
850 return value
851 else:
851 else:
852 self.error(obj, value)
852 self.error(obj, value)
853
853
854
854
855 #-----------------------------------------------------------------------------
855 #-----------------------------------------------------------------------------
856 # Basic TraitTypes implementations/subclasses
856 # Basic TraitTypes implementations/subclasses
857 #-----------------------------------------------------------------------------
857 #-----------------------------------------------------------------------------
858
858
859
859
860 class Any(TraitType):
860 class Any(TraitType):
861 default_value = None
861 default_value = None
862 info_text = 'any value'
862 info_text = 'any value'
863
863
864
864
865 class Int(TraitType):
865 class Int(TraitType):
866 """A integer trait."""
866 """A integer trait."""
867
867
868 default_value = 0
868 default_value = 0
869 info_text = 'an integer'
869 info_text = 'an integer'
870
870
871 def validate(self, obj, value):
871 def validate(self, obj, value):
872 if isinstance(value, int):
872 if isinstance(value, int):
873 return value
873 return value
874 self.error(obj, value)
874 self.error(obj, value)
875
875
876 class CInt(Int):
876 class CInt(Int):
877 """A casting version of the int trait."""
877 """A casting version of the int trait."""
878
878
879 def validate(self, obj, value):
879 def validate(self, obj, value):
880 try:
880 try:
881 return int(value)
881 return int(value)
882 except:
882 except:
883 self.error(obj, value)
883 self.error(obj, value)
884
884
885 if not py3compat.PY3:
885 if py3compat.PY3:
886 Long, CLong = Int, CInt
887 else:
886 class Long(TraitType):
888 class Long(TraitType):
887 """A long integer trait."""
889 """A long integer trait."""
888
890
889 default_value = 0L
891 default_value = 0L
890 info_text = 'a long'
892 info_text = 'a long'
891
893
892 def validate(self, obj, value):
894 def validate(self, obj, value):
893 if isinstance(value, long):
895 if isinstance(value, long):
894 return value
896 return value
895 if isinstance(value, int):
897 if isinstance(value, int):
896 return long(value)
898 return long(value)
897 self.error(obj, value)
899 self.error(obj, value)
898
900
899
901
900 class CLong(Long):
902 class CLong(Long):
901 """A casting version of the long integer trait."""
903 """A casting version of the long integer trait."""
902
904
903 def validate(self, obj, value):
905 def validate(self, obj, value):
904 try:
906 try:
905 return long(value)
907 return long(value)
906 except:
908 except:
907 self.error(obj, value)
909 self.error(obj, value)
908
910
909
911
910 class Float(TraitType):
912 class Float(TraitType):
911 """A float trait."""
913 """A float trait."""
912
914
913 default_value = 0.0
915 default_value = 0.0
914 info_text = 'a float'
916 info_text = 'a float'
915
917
916 def validate(self, obj, value):
918 def validate(self, obj, value):
917 if isinstance(value, float):
919 if isinstance(value, float):
918 return value
920 return value
919 if isinstance(value, int):
921 if isinstance(value, int):
920 return float(value)
922 return float(value)
921 self.error(obj, value)
923 self.error(obj, value)
922
924
923
925
924 class CFloat(Float):
926 class CFloat(Float):
925 """A casting version of the float trait."""
927 """A casting version of the float trait."""
926
928
927 def validate(self, obj, value):
929 def validate(self, obj, value):
928 try:
930 try:
929 return float(value)
931 return float(value)
930 except:
932 except:
931 self.error(obj, value)
933 self.error(obj, value)
932
934
933 class Complex(TraitType):
935 class Complex(TraitType):
934 """A trait for complex numbers."""
936 """A trait for complex numbers."""
935
937
936 default_value = 0.0 + 0.0j
938 default_value = 0.0 + 0.0j
937 info_text = 'a complex number'
939 info_text = 'a complex number'
938
940
939 def validate(self, obj, value):
941 def validate(self, obj, value):
940 if isinstance(value, complex):
942 if isinstance(value, complex):
941 return value
943 return value
942 if isinstance(value, (float, int)):
944 if isinstance(value, (float, int)):
943 return complex(value)
945 return complex(value)
944 self.error(obj, value)
946 self.error(obj, value)
945
947
946
948
947 class CComplex(Complex):
949 class CComplex(Complex):
948 """A casting version of the complex number trait."""
950 """A casting version of the complex number trait."""
949
951
950 def validate (self, obj, value):
952 def validate (self, obj, value):
951 try:
953 try:
952 return complex(value)
954 return complex(value)
953 except:
955 except:
954 self.error(obj, value)
956 self.error(obj, value)
955
957
956 # We should always be explicit about whether we're using bytes or unicode, both
958 # We should always be explicit about whether we're using bytes or unicode, both
957 # for Python 3 conversion and for reliable unicode behaviour on Python 2. So
959 # for Python 3 conversion and for reliable unicode behaviour on Python 2. So
958 # we don't have a Str type.
960 # we don't have a Str type.
959 class Bytes(TraitType):
961 class Bytes(TraitType):
960 """A trait for byte strings."""
962 """A trait for byte strings."""
961
963
962 default_value = ''
964 default_value = ''
963 info_text = 'a string'
965 info_text = 'a string'
964
966
965 def validate(self, obj, value):
967 def validate(self, obj, value):
966 if isinstance(value, bytes):
968 if isinstance(value, bytes):
967 return value
969 return value
968 self.error(obj, value)
970 self.error(obj, value)
969
971
970
972
971 class CBytes(Bytes):
973 class CBytes(Bytes):
972 """A casting version of the byte string trait."""
974 """A casting version of the byte string trait."""
973
975
974 def validate(self, obj, value):
976 def validate(self, obj, value):
975 try:
977 try:
976 return bytes(value)
978 return bytes(value)
977 except:
979 except:
978 self.error(obj, value)
980 self.error(obj, value)
979
981
980
982
981 class Unicode(TraitType):
983 class Unicode(TraitType):
982 """A trait for unicode strings."""
984 """A trait for unicode strings."""
983
985
984 default_value = u''
986 default_value = u''
985 info_text = 'a unicode string'
987 info_text = 'a unicode string'
986
988
987 def validate(self, obj, value):
989 def validate(self, obj, value):
988 if isinstance(value, unicode):
990 if isinstance(value, unicode):
989 return value
991 return value
990 if isinstance(value, bytes):
992 if isinstance(value, bytes):
991 return unicode(value)
993 return unicode(value)
992 self.error(obj, value)
994 self.error(obj, value)
993
995
994
996
995 class CUnicode(Unicode):
997 class CUnicode(Unicode):
996 """A casting version of the unicode trait."""
998 """A casting version of the unicode trait."""
997
999
998 def validate(self, obj, value):
1000 def validate(self, obj, value):
999 try:
1001 try:
1000 return unicode(value)
1002 return unicode(value)
1001 except:
1003 except:
1002 self.error(obj, value)
1004 self.error(obj, value)
1003
1005
1004
1006
1005 class ObjectName(TraitType):
1007 class ObjectName(TraitType):
1006 """A string holding a valid object name in this version of Python.
1008 """A string holding a valid object name in this version of Python.
1007
1009
1008 This does not check that the name exists in any scope."""
1010 This does not check that the name exists in any scope."""
1009 info_text = "a valid object identifier in Python"
1011 info_text = "a valid object identifier in Python"
1010
1012
1011 if py3compat.PY3:
1013 if py3compat.PY3:
1012 # Python 3:
1014 # Python 3:
1013 coerce_str = staticmethod(lambda _,s: s)
1015 coerce_str = staticmethod(lambda _,s: s)
1014
1016
1015 else:
1017 else:
1016 # Python 2:
1018 # Python 2:
1017 def coerce_str(self, obj, value):
1019 def coerce_str(self, obj, value):
1018 "In Python 2, coerce ascii-only unicode to str"
1020 "In Python 2, coerce ascii-only unicode to str"
1019 if isinstance(value, unicode):
1021 if isinstance(value, unicode):
1020 try:
1022 try:
1021 return str(value)
1023 return str(value)
1022 except UnicodeEncodeError:
1024 except UnicodeEncodeError:
1023 self.error(obj, value)
1025 self.error(obj, value)
1024 return value
1026 return value
1025
1027
1026 def validate(self, obj, value):
1028 def validate(self, obj, value):
1027 value = self.coerce_str(obj, value)
1029 value = self.coerce_str(obj, value)
1028
1030
1029 if isinstance(value, str) and py3compat.isidentifier(value):
1031 if isinstance(value, str) and py3compat.isidentifier(value):
1030 return value
1032 return value
1031 self.error(obj, value)
1033 self.error(obj, value)
1032
1034
1033 class DottedObjectName(ObjectName):
1035 class DottedObjectName(ObjectName):
1034 """A string holding a valid dotted object name in Python, such as A.b3._c"""
1036 """A string holding a valid dotted object name in Python, such as A.b3._c"""
1035 def validate(self, obj, value):
1037 def validate(self, obj, value):
1036 value = self.coerce_str(obj, value)
1038 value = self.coerce_str(obj, value)
1037
1039
1038 if isinstance(value, str) and py3compat.isidentifier(value, dotted=True):
1040 if isinstance(value, str) and py3compat.isidentifier(value, dotted=True):
1039 return value
1041 return value
1040 self.error(obj, value)
1042 self.error(obj, value)
1041
1043
1042
1044
1043 class Bool(TraitType):
1045 class Bool(TraitType):
1044 """A boolean (True, False) trait."""
1046 """A boolean (True, False) trait."""
1045
1047
1046 default_value = False
1048 default_value = False
1047 info_text = 'a boolean'
1049 info_text = 'a boolean'
1048
1050
1049 def validate(self, obj, value):
1051 def validate(self, obj, value):
1050 if isinstance(value, bool):
1052 if isinstance(value, bool):
1051 return value
1053 return value
1052 self.error(obj, value)
1054 self.error(obj, value)
1053
1055
1054
1056
1055 class CBool(Bool):
1057 class CBool(Bool):
1056 """A casting version of the boolean trait."""
1058 """A casting version of the boolean trait."""
1057
1059
1058 def validate(self, obj, value):
1060 def validate(self, obj, value):
1059 try:
1061 try:
1060 return bool(value)
1062 return bool(value)
1061 except:
1063 except:
1062 self.error(obj, value)
1064 self.error(obj, value)
1063
1065
1064
1066
1065 class Enum(TraitType):
1067 class Enum(TraitType):
1066 """An enum that whose value must be in a given sequence."""
1068 """An enum that whose value must be in a given sequence."""
1067
1069
1068 def __init__(self, values, default_value=None, allow_none=True, **metadata):
1070 def __init__(self, values, default_value=None, allow_none=True, **metadata):
1069 self.values = values
1071 self.values = values
1070 self._allow_none = allow_none
1072 self._allow_none = allow_none
1071 super(Enum, self).__init__(default_value, **metadata)
1073 super(Enum, self).__init__(default_value, **metadata)
1072
1074
1073 def validate(self, obj, value):
1075 def validate(self, obj, value):
1074 if value is None:
1076 if value is None:
1075 if self._allow_none:
1077 if self._allow_none:
1076 return value
1078 return value
1077
1079
1078 if value in self.values:
1080 if value in self.values:
1079 return value
1081 return value
1080 self.error(obj, value)
1082 self.error(obj, value)
1081
1083
1082 def info(self):
1084 def info(self):
1083 """ Returns a description of the trait."""
1085 """ Returns a description of the trait."""
1084 result = 'any of ' + repr(self.values)
1086 result = 'any of ' + repr(self.values)
1085 if self._allow_none:
1087 if self._allow_none:
1086 return result + ' or None'
1088 return result + ' or None'
1087 return result
1089 return result
1088
1090
1089 class CaselessStrEnum(Enum):
1091 class CaselessStrEnum(Enum):
1090 """An enum of strings that are caseless in validate."""
1092 """An enum of strings that are caseless in validate."""
1091
1093
1092 def validate(self, obj, value):
1094 def validate(self, obj, value):
1093 if value is None:
1095 if value is None:
1094 if self._allow_none:
1096 if self._allow_none:
1095 return value
1097 return value
1096
1098
1097 if not isinstance(value, basestring):
1099 if not isinstance(value, basestring):
1098 self.error(obj, value)
1100 self.error(obj, value)
1099
1101
1100 for v in self.values:
1102 for v in self.values:
1101 if v.lower() == value.lower():
1103 if v.lower() == value.lower():
1102 return v
1104 return v
1103 self.error(obj, value)
1105 self.error(obj, value)
1104
1106
1105 class Container(Instance):
1107 class Container(Instance):
1106 """An instance of a container (list, set, etc.)
1108 """An instance of a container (list, set, etc.)
1107
1109
1108 To be subclassed by overriding klass.
1110 To be subclassed by overriding klass.
1109 """
1111 """
1110 klass = None
1112 klass = None
1111 _valid_defaults = SequenceTypes
1113 _valid_defaults = SequenceTypes
1112 _trait = None
1114 _trait = None
1113
1115
1114 def __init__(self, trait=None, default_value=None, allow_none=True,
1116 def __init__(self, trait=None, default_value=None, allow_none=True,
1115 **metadata):
1117 **metadata):
1116 """Create a container trait type from a list, set, or tuple.
1118 """Create a container trait type from a list, set, or tuple.
1117
1119
1118 The default value is created by doing ``List(default_value)``,
1120 The default value is created by doing ``List(default_value)``,
1119 which creates a copy of the ``default_value``.
1121 which creates a copy of the ``default_value``.
1120
1122
1121 ``trait`` can be specified, which restricts the type of elements
1123 ``trait`` can be specified, which restricts the type of elements
1122 in the container to that TraitType.
1124 in the container to that TraitType.
1123
1125
1124 If only one arg is given and it is not a Trait, it is taken as
1126 If only one arg is given and it is not a Trait, it is taken as
1125 ``default_value``:
1127 ``default_value``:
1126
1128
1127 ``c = List([1,2,3])``
1129 ``c = List([1,2,3])``
1128
1130
1129 Parameters
1131 Parameters
1130 ----------
1132 ----------
1131
1133
1132 trait : TraitType [ optional ]
1134 trait : TraitType [ optional ]
1133 the type for restricting the contents of the Container. If unspecified,
1135 the type for restricting the contents of the Container. If unspecified,
1134 types are not checked.
1136 types are not checked.
1135
1137
1136 default_value : SequenceType [ optional ]
1138 default_value : SequenceType [ optional ]
1137 The default value for the Trait. Must be list/tuple/set, and
1139 The default value for the Trait. Must be list/tuple/set, and
1138 will be cast to the container type.
1140 will be cast to the container type.
1139
1141
1140 allow_none : Bool [ default True ]
1142 allow_none : Bool [ default True ]
1141 Whether to allow the value to be None
1143 Whether to allow the value to be None
1142
1144
1143 **metadata : any
1145 **metadata : any
1144 further keys for extensions to the Trait (e.g. config)
1146 further keys for extensions to the Trait (e.g. config)
1145
1147
1146 """
1148 """
1147 istrait = lambda t: isinstance(t, type) and issubclass(t, TraitType)
1149 istrait = lambda t: isinstance(t, type) and issubclass(t, TraitType)
1148
1150
1149 # allow List([values]):
1151 # allow List([values]):
1150 if default_value is None and not istrait(trait):
1152 if default_value is None and not istrait(trait):
1151 default_value = trait
1153 default_value = trait
1152 trait = None
1154 trait = None
1153
1155
1154 if default_value is None:
1156 if default_value is None:
1155 args = ()
1157 args = ()
1156 elif isinstance(default_value, self._valid_defaults):
1158 elif isinstance(default_value, self._valid_defaults):
1157 args = (default_value,)
1159 args = (default_value,)
1158 else:
1160 else:
1159 raise TypeError('default value of %s was %s' %(self.__class__.__name__, default_value))
1161 raise TypeError('default value of %s was %s' %(self.__class__.__name__, default_value))
1160
1162
1161 if istrait(trait):
1163 if istrait(trait):
1162 self._trait = trait()
1164 self._trait = trait()
1163 self._trait.name = 'element'
1165 self._trait.name = 'element'
1164 elif trait is not None:
1166 elif trait is not None:
1165 raise TypeError("`trait` must be a Trait or None, got %s"%repr_type(trait))
1167 raise TypeError("`trait` must be a Trait or None, got %s"%repr_type(trait))
1166
1168
1167 super(Container,self).__init__(klass=self.klass, args=args,
1169 super(Container,self).__init__(klass=self.klass, args=args,
1168 allow_none=allow_none, **metadata)
1170 allow_none=allow_none, **metadata)
1169
1171
1170 def element_error(self, obj, element, validator):
1172 def element_error(self, obj, element, validator):
1171 e = "Element of the '%s' trait of %s instance must be %s, but a value of %s was specified." \
1173 e = "Element of the '%s' trait of %s instance must be %s, but a value of %s was specified." \
1172 % (self.name, class_of(obj), validator.info(), repr_type(element))
1174 % (self.name, class_of(obj), validator.info(), repr_type(element))
1173 raise TraitError(e)
1175 raise TraitError(e)
1174
1176
1175 def validate(self, obj, value):
1177 def validate(self, obj, value):
1176 value = super(Container, self).validate(obj, value)
1178 value = super(Container, self).validate(obj, value)
1177 if value is None:
1179 if value is None:
1178 return value
1180 return value
1179
1181
1180 value = self.validate_elements(obj, value)
1182 value = self.validate_elements(obj, value)
1181
1183
1182 return value
1184 return value
1183
1185
1184 def validate_elements(self, obj, value):
1186 def validate_elements(self, obj, value):
1185 validated = []
1187 validated = []
1186 if self._trait is None or isinstance(self._trait, Any):
1188 if self._trait is None or isinstance(self._trait, Any):
1187 return value
1189 return value
1188 for v in value:
1190 for v in value:
1189 try:
1191 try:
1190 v = self._trait.validate(obj, v)
1192 v = self._trait.validate(obj, v)
1191 except TraitError:
1193 except TraitError:
1192 self.element_error(obj, v, self._trait)
1194 self.element_error(obj, v, self._trait)
1193 else:
1195 else:
1194 validated.append(v)
1196 validated.append(v)
1195 return self.klass(validated)
1197 return self.klass(validated)
1196
1198
1197
1199
1198 class List(Container):
1200 class List(Container):
1199 """An instance of a Python list."""
1201 """An instance of a Python list."""
1200 klass = list
1202 klass = list
1201
1203
1202 def __init__(self, trait=None, default_value=None, minlen=0, maxlen=sys.maxint,
1204 def __init__(self, trait=None, default_value=None, minlen=0, maxlen=sys.maxint,
1203 allow_none=True, **metadata):
1205 allow_none=True, **metadata):
1204 """Create a List trait type from a list, set, or tuple.
1206 """Create a List trait type from a list, set, or tuple.
1205
1207
1206 The default value is created by doing ``List(default_value)``,
1208 The default value is created by doing ``List(default_value)``,
1207 which creates a copy of the ``default_value``.
1209 which creates a copy of the ``default_value``.
1208
1210
1209 ``trait`` can be specified, which restricts the type of elements
1211 ``trait`` can be specified, which restricts the type of elements
1210 in the container to that TraitType.
1212 in the container to that TraitType.
1211
1213
1212 If only one arg is given and it is not a Trait, it is taken as
1214 If only one arg is given and it is not a Trait, it is taken as
1213 ``default_value``:
1215 ``default_value``:
1214
1216
1215 ``c = List([1,2,3])``
1217 ``c = List([1,2,3])``
1216
1218
1217 Parameters
1219 Parameters
1218 ----------
1220 ----------
1219
1221
1220 trait : TraitType [ optional ]
1222 trait : TraitType [ optional ]
1221 the type for restricting the contents of the Container. If unspecified,
1223 the type for restricting the contents of the Container. If unspecified,
1222 types are not checked.
1224 types are not checked.
1223
1225
1224 default_value : SequenceType [ optional ]
1226 default_value : SequenceType [ optional ]
1225 The default value for the Trait. Must be list/tuple/set, and
1227 The default value for the Trait. Must be list/tuple/set, and
1226 will be cast to the container type.
1228 will be cast to the container type.
1227
1229
1228 minlen : Int [ default 0 ]
1230 minlen : Int [ default 0 ]
1229 The minimum length of the input list
1231 The minimum length of the input list
1230
1232
1231 maxlen : Int [ default sys.maxint ]
1233 maxlen : Int [ default sys.maxint ]
1232 The maximum length of the input list
1234 The maximum length of the input list
1233
1235
1234 allow_none : Bool [ default True ]
1236 allow_none : Bool [ default True ]
1235 Whether to allow the value to be None
1237 Whether to allow the value to be None
1236
1238
1237 **metadata : any
1239 **metadata : any
1238 further keys for extensions to the Trait (e.g. config)
1240 further keys for extensions to the Trait (e.g. config)
1239
1241
1240 """
1242 """
1241 self._minlen = minlen
1243 self._minlen = minlen
1242 self._maxlen = maxlen
1244 self._maxlen = maxlen
1243 super(List, self).__init__(trait=trait, default_value=default_value,
1245 super(List, self).__init__(trait=trait, default_value=default_value,
1244 allow_none=allow_none, **metadata)
1246 allow_none=allow_none, **metadata)
1245
1247
1246 def length_error(self, obj, value):
1248 def length_error(self, obj, value):
1247 e = "The '%s' trait of %s instance must be of length %i <= L <= %i, but a value of %s was specified." \
1249 e = "The '%s' trait of %s instance must be of length %i <= L <= %i, but a value of %s was specified." \
1248 % (self.name, class_of(obj), self._minlen, self._maxlen, value)
1250 % (self.name, class_of(obj), self._minlen, self._maxlen, value)
1249 raise TraitError(e)
1251 raise TraitError(e)
1250
1252
1251 def validate_elements(self, obj, value):
1253 def validate_elements(self, obj, value):
1252 length = len(value)
1254 length = len(value)
1253 if length < self._minlen or length > self._maxlen:
1255 if length < self._minlen or length > self._maxlen:
1254 self.length_error(obj, value)
1256 self.length_error(obj, value)
1255
1257
1256 return super(List, self).validate_elements(obj, value)
1258 return super(List, self).validate_elements(obj, value)
1257
1259
1258
1260
1259 class Set(Container):
1261 class Set(Container):
1260 """An instance of a Python set."""
1262 """An instance of a Python set."""
1261 klass = set
1263 klass = set
1262
1264
1263 class Tuple(Container):
1265 class Tuple(Container):
1264 """An instance of a Python tuple."""
1266 """An instance of a Python tuple."""
1265 klass = tuple
1267 klass = tuple
1266
1268
1267 def __init__(self, *traits, **metadata):
1269 def __init__(self, *traits, **metadata):
1268 """Tuple(*traits, default_value=None, allow_none=True, **medatata)
1270 """Tuple(*traits, default_value=None, allow_none=True, **medatata)
1269
1271
1270 Create a tuple from a list, set, or tuple.
1272 Create a tuple from a list, set, or tuple.
1271
1273
1272 Create a fixed-type tuple with Traits:
1274 Create a fixed-type tuple with Traits:
1273
1275
1274 ``t = Tuple(Int, Str, CStr)``
1276 ``t = Tuple(Int, Str, CStr)``
1275
1277
1276 would be length 3, with Int,Str,CStr for each element.
1278 would be length 3, with Int,Str,CStr for each element.
1277
1279
1278 If only one arg is given and it is not a Trait, it is taken as
1280 If only one arg is given and it is not a Trait, it is taken as
1279 default_value:
1281 default_value:
1280
1282
1281 ``t = Tuple((1,2,3))``
1283 ``t = Tuple((1,2,3))``
1282
1284
1283 Otherwise, ``default_value`` *must* be specified by keyword.
1285 Otherwise, ``default_value`` *must* be specified by keyword.
1284
1286
1285 Parameters
1287 Parameters
1286 ----------
1288 ----------
1287
1289
1288 *traits : TraitTypes [ optional ]
1290 *traits : TraitTypes [ optional ]
1289 the tsype for restricting the contents of the Tuple. If unspecified,
1291 the tsype for restricting the contents of the Tuple. If unspecified,
1290 types are not checked. If specified, then each positional argument
1292 types are not checked. If specified, then each positional argument
1291 corresponds to an element of the tuple. Tuples defined with traits
1293 corresponds to an element of the tuple. Tuples defined with traits
1292 are of fixed length.
1294 are of fixed length.
1293
1295
1294 default_value : SequenceType [ optional ]
1296 default_value : SequenceType [ optional ]
1295 The default value for the Tuple. Must be list/tuple/set, and
1297 The default value for the Tuple. Must be list/tuple/set, and
1296 will be cast to a tuple. If `traits` are specified, the
1298 will be cast to a tuple. If `traits` are specified, the
1297 `default_value` must conform to the shape and type they specify.
1299 `default_value` must conform to the shape and type they specify.
1298
1300
1299 allow_none : Bool [ default True ]
1301 allow_none : Bool [ default True ]
1300 Whether to allow the value to be None
1302 Whether to allow the value to be None
1301
1303
1302 **metadata : any
1304 **metadata : any
1303 further keys for extensions to the Trait (e.g. config)
1305 further keys for extensions to the Trait (e.g. config)
1304
1306
1305 """
1307 """
1306 default_value = metadata.pop('default_value', None)
1308 default_value = metadata.pop('default_value', None)
1307 allow_none = metadata.pop('allow_none', True)
1309 allow_none = metadata.pop('allow_none', True)
1308
1310
1309 istrait = lambda t: isinstance(t, type) and issubclass(t, TraitType)
1311 istrait = lambda t: isinstance(t, type) and issubclass(t, TraitType)
1310
1312
1311 # allow Tuple((values,)):
1313 # allow Tuple((values,)):
1312 if len(traits) == 1 and default_value is None and not istrait(traits[0]):
1314 if len(traits) == 1 and default_value is None and not istrait(traits[0]):
1313 default_value = traits[0]
1315 default_value = traits[0]
1314 traits = ()
1316 traits = ()
1315
1317
1316 if default_value is None:
1318 if default_value is None:
1317 args = ()
1319 args = ()
1318 elif isinstance(default_value, self._valid_defaults):
1320 elif isinstance(default_value, self._valid_defaults):
1319 args = (default_value,)
1321 args = (default_value,)
1320 else:
1322 else:
1321 raise TypeError('default value of %s was %s' %(self.__class__.__name__, default_value))
1323 raise TypeError('default value of %s was %s' %(self.__class__.__name__, default_value))
1322
1324
1323 self._traits = []
1325 self._traits = []
1324 for trait in traits:
1326 for trait in traits:
1325 t = trait()
1327 t = trait()
1326 t.name = 'element'
1328 t.name = 'element'
1327 self._traits.append(t)
1329 self._traits.append(t)
1328
1330
1329 if self._traits and default_value is None:
1331 if self._traits and default_value is None:
1330 # don't allow default to be an empty container if length is specified
1332 # don't allow default to be an empty container if length is specified
1331 args = None
1333 args = None
1332 super(Container,self).__init__(klass=self.klass, args=args,
1334 super(Container,self).__init__(klass=self.klass, args=args,
1333 allow_none=allow_none, **metadata)
1335 allow_none=allow_none, **metadata)
1334
1336
1335 def validate_elements(self, obj, value):
1337 def validate_elements(self, obj, value):
1336 if not self._traits:
1338 if not self._traits:
1337 # nothing to validate
1339 # nothing to validate
1338 return value
1340 return value
1339 if len(value) != len(self._traits):
1341 if len(value) != len(self._traits):
1340 e = "The '%s' trait of %s instance requires %i elements, but a value of %s was specified." \
1342 e = "The '%s' trait of %s instance requires %i elements, but a value of %s was specified." \
1341 % (self.name, class_of(obj), len(self._traits), repr_type(value))
1343 % (self.name, class_of(obj), len(self._traits), repr_type(value))
1342 raise TraitError(e)
1344 raise TraitError(e)
1343
1345
1344 validated = []
1346 validated = []
1345 for t,v in zip(self._traits, value):
1347 for t,v in zip(self._traits, value):
1346 try:
1348 try:
1347 v = t.validate(obj, v)
1349 v = t.validate(obj, v)
1348 except TraitError:
1350 except TraitError:
1349 self.element_error(obj, v, t)
1351 self.element_error(obj, v, t)
1350 else:
1352 else:
1351 validated.append(v)
1353 validated.append(v)
1352 return tuple(validated)
1354 return tuple(validated)
1353
1355
1354
1356
1355 class Dict(Instance):
1357 class Dict(Instance):
1356 """An instance of a Python dict."""
1358 """An instance of a Python dict."""
1357
1359
1358 def __init__(self, default_value=None, allow_none=True, **metadata):
1360 def __init__(self, default_value=None, allow_none=True, **metadata):
1359 """Create a dict trait type from a dict.
1361 """Create a dict trait type from a dict.
1360
1362
1361 The default value is created by doing ``dict(default_value)``,
1363 The default value is created by doing ``dict(default_value)``,
1362 which creates a copy of the ``default_value``.
1364 which creates a copy of the ``default_value``.
1363 """
1365 """
1364 if default_value is None:
1366 if default_value is None:
1365 args = ((),)
1367 args = ((),)
1366 elif isinstance(default_value, dict):
1368 elif isinstance(default_value, dict):
1367 args = (default_value,)
1369 args = (default_value,)
1368 elif isinstance(default_value, SequenceTypes):
1370 elif isinstance(default_value, SequenceTypes):
1369 args = (default_value,)
1371 args = (default_value,)
1370 else:
1372 else:
1371 raise TypeError('default value of Dict was %s' % default_value)
1373 raise TypeError('default value of Dict was %s' % default_value)
1372
1374
1373 super(Dict,self).__init__(klass=dict, args=args,
1375 super(Dict,self).__init__(klass=dict, args=args,
1374 allow_none=allow_none, **metadata)
1376 allow_none=allow_none, **metadata)
1375
1377
1376 class TCPAddress(TraitType):
1378 class TCPAddress(TraitType):
1377 """A trait for an (ip, port) tuple.
1379 """A trait for an (ip, port) tuple.
1378
1380
1379 This allows for both IPv4 IP addresses as well as hostnames.
1381 This allows for both IPv4 IP addresses as well as hostnames.
1380 """
1382 """
1381
1383
1382 default_value = ('127.0.0.1', 0)
1384 default_value = ('127.0.0.1', 0)
1383 info_text = 'an (ip, port) tuple'
1385 info_text = 'an (ip, port) tuple'
1384
1386
1385 def validate(self, obj, value):
1387 def validate(self, obj, value):
1386 if isinstance(value, tuple):
1388 if isinstance(value, tuple):
1387 if len(value) == 2:
1389 if len(value) == 2:
1388 if isinstance(value[0], basestring) and isinstance(value[1], int):
1390 if isinstance(value[0], basestring) and isinstance(value[1], int):
1389 port = value[1]
1391 port = value[1]
1390 if port >= 0 and port <= 65535:
1392 if port >= 0 and port <= 65535:
1391 return value
1393 return value
1392 self.error(obj, value)
1394 self.error(obj, value)
General Comments 0
You need to be logged in to leave comments. Login now