##// END OF EJS Templates
Fix tests in utils
Thomas Kluyver -
Show More
@@ -1,764 +1,764 b''
1 """Nose Plugin that supports IPython doctests.
1 """Nose Plugin that supports IPython doctests.
2
2
3 Limitations:
3 Limitations:
4
4
5 - When generating examples for use as doctests, make sure that you have
5 - When generating examples for use as doctests, make sure that you have
6 pretty-printing OFF. This can be done either by setting the
6 pretty-printing OFF. This can be done either by setting the
7 ``PlainTextFormatter.pprint`` option in your configuration file to False, or
7 ``PlainTextFormatter.pprint`` option in your configuration file to False, or
8 by interactively disabling it with %Pprint. This is required so that IPython
8 by interactively disabling it with %Pprint. This is required so that IPython
9 output matches that of normal Python, which is used by doctest for internal
9 output matches that of normal Python, which is used by doctest for internal
10 execution.
10 execution.
11
11
12 - Do not rely on specific prompt numbers for results (such as using
12 - Do not rely on specific prompt numbers for results (such as using
13 '_34==True', for example). For IPython tests run via an external process the
13 '_34==True', for example). For IPython tests run via an external process the
14 prompt numbers may be different, and IPython tests run as normal python code
14 prompt numbers may be different, and IPython tests run as normal python code
15 won't even have these special _NN variables set at all.
15 won't even have these special _NN variables set at all.
16 """
16 """
17
17
18 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
19 # Module imports
19 # Module imports
20
20
21 # From the standard library
21 # From the standard library
22 import doctest
22 import doctest
23 import inspect
23 import inspect
24 import logging
24 import logging
25 import os
25 import os
26 import re
26 import re
27 import sys
27 import sys
28 import traceback
28 import traceback
29 import unittest
29 import unittest
30
30
31 from inspect import getmodule
31 from inspect import getmodule
32
32
33 # We are overriding the default doctest runner, so we need to import a few
33 # We are overriding the default doctest runner, so we need to import a few
34 # things from doctest directly
34 # things from doctest directly
35 from doctest import (REPORTING_FLAGS, REPORT_ONLY_FIRST_FAILURE,
35 from doctest import (REPORTING_FLAGS, REPORT_ONLY_FIRST_FAILURE,
36 _unittest_reportflags, DocTestRunner,
36 _unittest_reportflags, DocTestRunner,
37 _extract_future_flags, pdb, _OutputRedirectingPdb,
37 _extract_future_flags, pdb, _OutputRedirectingPdb,
38 _exception_traceback,
38 _exception_traceback,
39 linecache)
39 linecache)
40
40
41 # Third-party modules
41 # Third-party modules
42 import nose.core
42 import nose.core
43
43
44 from nose.plugins import doctests, Plugin
44 from nose.plugins import doctests, Plugin
45 from nose.util import anyp, getpackage, test_address, resolve_name, tolist
45 from nose.util import anyp, getpackage, test_address, resolve_name, tolist
46
46
47 # Our own imports
47 # Our own imports
48 from IPython.utils.py3compat import builtin_mod, PY3
48 from IPython.utils.py3compat import builtin_mod, PY3
49
49
50 if PY3:
50 if PY3:
51 from io import StringIO
51 from io import StringIO
52 else:
52 else:
53 from StringIO import StringIO
53 from StringIO import StringIO
54
54
55 #-----------------------------------------------------------------------------
55 #-----------------------------------------------------------------------------
56 # Module globals and other constants
56 # Module globals and other constants
57 #-----------------------------------------------------------------------------
57 #-----------------------------------------------------------------------------
58
58
59 log = logging.getLogger(__name__)
59 log = logging.getLogger(__name__)
60
60
61
61
62 #-----------------------------------------------------------------------------
62 #-----------------------------------------------------------------------------
63 # Classes and functions
63 # Classes and functions
64 #-----------------------------------------------------------------------------
64 #-----------------------------------------------------------------------------
65
65
66 def is_extension_module(filename):
66 def is_extension_module(filename):
67 """Return whether the given filename is an extension module.
67 """Return whether the given filename is an extension module.
68
68
69 This simply checks that the extension is either .so or .pyd.
69 This simply checks that the extension is either .so or .pyd.
70 """
70 """
71 return os.path.splitext(filename)[1].lower() in ('.so','.pyd')
71 return os.path.splitext(filename)[1].lower() in ('.so','.pyd')
72
72
73
73
74 class DocTestSkip(object):
74 class DocTestSkip(object):
75 """Object wrapper for doctests to be skipped."""
75 """Object wrapper for doctests to be skipped."""
76
76
77 ds_skip = """Doctest to skip.
77 ds_skip = """Doctest to skip.
78 >>> 1 #doctest: +SKIP
78 >>> 1 #doctest: +SKIP
79 """
79 """
80
80
81 def __init__(self,obj):
81 def __init__(self,obj):
82 self.obj = obj
82 self.obj = obj
83
83
84 def __getattribute__(self,key):
84 def __getattribute__(self,key):
85 if key == '__doc__':
85 if key == '__doc__':
86 return DocTestSkip.ds_skip
86 return DocTestSkip.ds_skip
87 else:
87 else:
88 return getattr(object.__getattribute__(self,'obj'),key)
88 return getattr(object.__getattribute__(self,'obj'),key)
89
89
90 # Modified version of the one in the stdlib, that fixes a python bug (doctests
90 # Modified version of the one in the stdlib, that fixes a python bug (doctests
91 # not found in extension modules, http://bugs.python.org/issue3158)
91 # not found in extension modules, http://bugs.python.org/issue3158)
92 class DocTestFinder(doctest.DocTestFinder):
92 class DocTestFinder(doctest.DocTestFinder):
93
93
94 def _from_module(self, module, object):
94 def _from_module(self, module, object):
95 """
95 """
96 Return true if the given object is defined in the given
96 Return true if the given object is defined in the given
97 module.
97 module.
98 """
98 """
99 if module is None:
99 if module is None:
100 return True
100 return True
101 elif inspect.isfunction(object):
101 elif inspect.isfunction(object):
102 return module.__dict__ is object.__globals__
102 return module.__dict__ is object.__globals__
103 elif inspect.isbuiltin(object):
103 elif inspect.isbuiltin(object):
104 return module.__name__ == object.__module__
104 return module.__name__ == object.__module__
105 elif inspect.isclass(object):
105 elif inspect.isclass(object):
106 return module.__name__ == object.__module__
106 return module.__name__ == object.__module__
107 elif inspect.ismethod(object):
107 elif inspect.ismethod(object):
108 # This one may be a bug in cython that fails to correctly set the
108 # This one may be a bug in cython that fails to correctly set the
109 # __module__ attribute of methods, but since the same error is easy
109 # __module__ attribute of methods, but since the same error is easy
110 # to make by extension code writers, having this safety in place
110 # to make by extension code writers, having this safety in place
111 # isn't such a bad idea
111 # isn't such a bad idea
112 return module.__name__ == object.__self__.__class__.__module__
112 return module.__name__ == object.__self__.__class__.__module__
113 elif inspect.getmodule(object) is not None:
113 elif inspect.getmodule(object) is not None:
114 return module is inspect.getmodule(object)
114 return module is inspect.getmodule(object)
115 elif hasattr(object, '__module__'):
115 elif hasattr(object, '__module__'):
116 return module.__name__ == object.__module__
116 return module.__name__ == object.__module__
117 elif isinstance(object, property):
117 elif isinstance(object, property):
118 return True # [XX] no way not be sure.
118 return True # [XX] no way not be sure.
119 else:
119 else:
120 raise ValueError("object must be a class or function")
120 raise ValueError("object must be a class or function, got %r" % object)
121
121
122 def _find(self, tests, obj, name, module, source_lines, globs, seen):
122 def _find(self, tests, obj, name, module, source_lines, globs, seen):
123 """
123 """
124 Find tests for the given object and any contained objects, and
124 Find tests for the given object and any contained objects, and
125 add them to `tests`.
125 add them to `tests`.
126 """
126 """
127 #print '_find for:', obj, name, module # dbg
127 #print '_find for:', obj, name, module # dbg
128 if hasattr(obj,"skip_doctest"):
128 if hasattr(obj,"skip_doctest"):
129 #print 'SKIPPING DOCTEST FOR:',obj # dbg
129 #print 'SKIPPING DOCTEST FOR:',obj # dbg
130 obj = DocTestSkip(obj)
130 obj = DocTestSkip(obj)
131
131
132 doctest.DocTestFinder._find(self,tests, obj, name, module,
132 doctest.DocTestFinder._find(self,tests, obj, name, module,
133 source_lines, globs, seen)
133 source_lines, globs, seen)
134
134
135 # Below we re-run pieces of the above method with manual modifications,
135 # Below we re-run pieces of the above method with manual modifications,
136 # because the original code is buggy and fails to correctly identify
136 # because the original code is buggy and fails to correctly identify
137 # doctests in extension modules.
137 # doctests in extension modules.
138
138
139 # Local shorthands
139 # Local shorthands
140 from inspect import isroutine, isclass, ismodule
140 from inspect import isroutine, isclass, ismodule
141
141
142 # Look for tests in a module's contained objects.
142 # Look for tests in a module's contained objects.
143 if inspect.ismodule(obj) and self._recurse:
143 if inspect.ismodule(obj) and self._recurse:
144 for valname, val in obj.__dict__.items():
144 for valname, val in obj.__dict__.items():
145 valname1 = '%s.%s' % (name, valname)
145 valname1 = '%s.%s' % (name, valname)
146 if ( (isroutine(val) or isclass(val))
146 if ( (isroutine(val) or isclass(val))
147 and self._from_module(module, val) ):
147 and self._from_module(module, val) ):
148
148
149 self._find(tests, val, valname1, module, source_lines,
149 self._find(tests, val, valname1, module, source_lines,
150 globs, seen)
150 globs, seen)
151
151
152 # Look for tests in a class's contained objects.
152 # Look for tests in a class's contained objects.
153 if inspect.isclass(obj) and self._recurse:
153 if inspect.isclass(obj) and self._recurse:
154 #print 'RECURSE into class:',obj # dbg
154 #print 'RECURSE into class:',obj # dbg
155 for valname, val in obj.__dict__.items():
155 for valname, val in obj.__dict__.items():
156 # Special handling for staticmethod/classmethod.
156 # Special handling for staticmethod/classmethod.
157 if isinstance(val, staticmethod):
157 if isinstance(val, staticmethod):
158 val = getattr(obj, valname)
158 val = getattr(obj, valname)
159 if isinstance(val, classmethod):
159 if isinstance(val, classmethod):
160 val = getattr(obj, valname).__func__
160 val = getattr(obj, valname).__func__
161
161
162 # Recurse to methods, properties, and nested classes.
162 # Recurse to methods, properties, and nested classes.
163 if ((inspect.isfunction(val) or inspect.isclass(val) or
163 if ((inspect.isfunction(val) or inspect.isclass(val) or
164 inspect.ismethod(val) or
164 inspect.ismethod(val) or
165 isinstance(val, property)) and
165 isinstance(val, property)) and
166 self._from_module(module, val)):
166 self._from_module(module, val)):
167 valname = '%s.%s' % (name, valname)
167 valname = '%s.%s' % (name, valname)
168 self._find(tests, val, valname, module, source_lines,
168 self._find(tests, val, valname, module, source_lines,
169 globs, seen)
169 globs, seen)
170
170
171
171
172 class IPDoctestOutputChecker(doctest.OutputChecker):
172 class IPDoctestOutputChecker(doctest.OutputChecker):
173 """Second-chance checker with support for random tests.
173 """Second-chance checker with support for random tests.
174
174
175 If the default comparison doesn't pass, this checker looks in the expected
175 If the default comparison doesn't pass, this checker looks in the expected
176 output string for flags that tell us to ignore the output.
176 output string for flags that tell us to ignore the output.
177 """
177 """
178
178
179 random_re = re.compile(r'#\s*random\s+')
179 random_re = re.compile(r'#\s*random\s+')
180
180
181 def check_output(self, want, got, optionflags):
181 def check_output(self, want, got, optionflags):
182 """Check output, accepting special markers embedded in the output.
182 """Check output, accepting special markers embedded in the output.
183
183
184 If the output didn't pass the default validation but the special string
184 If the output didn't pass the default validation but the special string
185 '#random' is included, we accept it."""
185 '#random' is included, we accept it."""
186
186
187 # Let the original tester verify first, in case people have valid tests
187 # Let the original tester verify first, in case people have valid tests
188 # that happen to have a comment saying '#random' embedded in.
188 # that happen to have a comment saying '#random' embedded in.
189 ret = doctest.OutputChecker.check_output(self, want, got,
189 ret = doctest.OutputChecker.check_output(self, want, got,
190 optionflags)
190 optionflags)
191 if not ret and self.random_re.search(want):
191 if not ret and self.random_re.search(want):
192 #print >> sys.stderr, 'RANDOM OK:',want # dbg
192 #print >> sys.stderr, 'RANDOM OK:',want # dbg
193 return True
193 return True
194
194
195 return ret
195 return ret
196
196
197
197
198 class DocTestCase(doctests.DocTestCase):
198 class DocTestCase(doctests.DocTestCase):
199 """Proxy for DocTestCase: provides an address() method that
199 """Proxy for DocTestCase: provides an address() method that
200 returns the correct address for the doctest case. Otherwise
200 returns the correct address for the doctest case. Otherwise
201 acts as a proxy to the test case. To provide hints for address(),
201 acts as a proxy to the test case. To provide hints for address(),
202 an obj may also be passed -- this will be used as the test object
202 an obj may also be passed -- this will be used as the test object
203 for purposes of determining the test address, if it is provided.
203 for purposes of determining the test address, if it is provided.
204 """
204 """
205
205
206 # Note: this method was taken from numpy's nosetester module.
206 # Note: this method was taken from numpy's nosetester module.
207
207
208 # Subclass nose.plugins.doctests.DocTestCase to work around a bug in
208 # Subclass nose.plugins.doctests.DocTestCase to work around a bug in
209 # its constructor that blocks non-default arguments from being passed
209 # its constructor that blocks non-default arguments from being passed
210 # down into doctest.DocTestCase
210 # down into doctest.DocTestCase
211
211
212 def __init__(self, test, optionflags=0, setUp=None, tearDown=None,
212 def __init__(self, test, optionflags=0, setUp=None, tearDown=None,
213 checker=None, obj=None, result_var='_'):
213 checker=None, obj=None, result_var='_'):
214 self._result_var = result_var
214 self._result_var = result_var
215 doctests.DocTestCase.__init__(self, test,
215 doctests.DocTestCase.__init__(self, test,
216 optionflags=optionflags,
216 optionflags=optionflags,
217 setUp=setUp, tearDown=tearDown,
217 setUp=setUp, tearDown=tearDown,
218 checker=checker)
218 checker=checker)
219 # Now we must actually copy the original constructor from the stdlib
219 # Now we must actually copy the original constructor from the stdlib
220 # doctest class, because we can't call it directly and a bug in nose
220 # doctest class, because we can't call it directly and a bug in nose
221 # means it never gets passed the right arguments.
221 # means it never gets passed the right arguments.
222
222
223 self._dt_optionflags = optionflags
223 self._dt_optionflags = optionflags
224 self._dt_checker = checker
224 self._dt_checker = checker
225 self._dt_test = test
225 self._dt_test = test
226 self._dt_test_globs_ori = test.globs
226 self._dt_test_globs_ori = test.globs
227 self._dt_setUp = setUp
227 self._dt_setUp = setUp
228 self._dt_tearDown = tearDown
228 self._dt_tearDown = tearDown
229
229
230 # XXX - store this runner once in the object!
230 # XXX - store this runner once in the object!
231 runner = IPDocTestRunner(optionflags=optionflags,
231 runner = IPDocTestRunner(optionflags=optionflags,
232 checker=checker, verbose=False)
232 checker=checker, verbose=False)
233 self._dt_runner = runner
233 self._dt_runner = runner
234
234
235
235
236 # Each doctest should remember the directory it was loaded from, so
236 # Each doctest should remember the directory it was loaded from, so
237 # things like %run work without too many contortions
237 # things like %run work without too many contortions
238 self._ori_dir = os.path.dirname(test.filename)
238 self._ori_dir = os.path.dirname(test.filename)
239
239
240 # Modified runTest from the default stdlib
240 # Modified runTest from the default stdlib
241 def runTest(self):
241 def runTest(self):
242 test = self._dt_test
242 test = self._dt_test
243 runner = self._dt_runner
243 runner = self._dt_runner
244
244
245 old = sys.stdout
245 old = sys.stdout
246 new = StringIO()
246 new = StringIO()
247 optionflags = self._dt_optionflags
247 optionflags = self._dt_optionflags
248
248
249 if not (optionflags & REPORTING_FLAGS):
249 if not (optionflags & REPORTING_FLAGS):
250 # The option flags don't include any reporting flags,
250 # The option flags don't include any reporting flags,
251 # so add the default reporting flags
251 # so add the default reporting flags
252 optionflags |= _unittest_reportflags
252 optionflags |= _unittest_reportflags
253
253
254 try:
254 try:
255 # Save our current directory and switch out to the one where the
255 # Save our current directory and switch out to the one where the
256 # test was originally created, in case another doctest did a
256 # test was originally created, in case another doctest did a
257 # directory change. We'll restore this in the finally clause.
257 # directory change. We'll restore this in the finally clause.
258 curdir = os.getcwdu()
258 curdir = os.getcwdu()
259 #print 'runTest in dir:', self._ori_dir # dbg
259 #print 'runTest in dir:', self._ori_dir # dbg
260 os.chdir(self._ori_dir)
260 os.chdir(self._ori_dir)
261
261
262 runner.DIVIDER = "-"*70
262 runner.DIVIDER = "-"*70
263 failures, tries = runner.run(test,out=new.write,
263 failures, tries = runner.run(test,out=new.write,
264 clear_globs=False)
264 clear_globs=False)
265 finally:
265 finally:
266 sys.stdout = old
266 sys.stdout = old
267 os.chdir(curdir)
267 os.chdir(curdir)
268
268
269 if failures:
269 if failures:
270 raise self.failureException(self.format_failure(new.getvalue()))
270 raise self.failureException(self.format_failure(new.getvalue()))
271
271
272 def setUp(self):
272 def setUp(self):
273 """Modified test setup that syncs with ipython namespace"""
273 """Modified test setup that syncs with ipython namespace"""
274 #print "setUp test", self._dt_test.examples # dbg
274 #print "setUp test", self._dt_test.examples # dbg
275 if isinstance(self._dt_test.examples[0], IPExample):
275 if isinstance(self._dt_test.examples[0], IPExample):
276 # for IPython examples *only*, we swap the globals with the ipython
276 # for IPython examples *only*, we swap the globals with the ipython
277 # namespace, after updating it with the globals (which doctest
277 # namespace, after updating it with the globals (which doctest
278 # fills with the necessary info from the module being tested).
278 # fills with the necessary info from the module being tested).
279 self.user_ns_orig = {}
279 self.user_ns_orig = {}
280 self.user_ns_orig.update(_ip.user_ns)
280 self.user_ns_orig.update(_ip.user_ns)
281 _ip.user_ns.update(self._dt_test.globs)
281 _ip.user_ns.update(self._dt_test.globs)
282 # We must remove the _ key in the namespace, so that Python's
282 # We must remove the _ key in the namespace, so that Python's
283 # doctest code sets it naturally
283 # doctest code sets it naturally
284 _ip.user_ns.pop('_', None)
284 _ip.user_ns.pop('_', None)
285 _ip.user_ns['__builtins__'] = builtin_mod
285 _ip.user_ns['__builtins__'] = builtin_mod
286 self._dt_test.globs = _ip.user_ns
286 self._dt_test.globs = _ip.user_ns
287
287
288 super(DocTestCase, self).setUp()
288 super(DocTestCase, self).setUp()
289
289
290 def tearDown(self):
290 def tearDown(self):
291
291
292 # Undo the test.globs reassignment we made, so that the parent class
292 # Undo the test.globs reassignment we made, so that the parent class
293 # teardown doesn't destroy the ipython namespace
293 # teardown doesn't destroy the ipython namespace
294 if isinstance(self._dt_test.examples[0], IPExample):
294 if isinstance(self._dt_test.examples[0], IPExample):
295 self._dt_test.globs = self._dt_test_globs_ori
295 self._dt_test.globs = self._dt_test_globs_ori
296 _ip.user_ns.clear()
296 _ip.user_ns.clear()
297 _ip.user_ns.update(self.user_ns_orig)
297 _ip.user_ns.update(self.user_ns_orig)
298
298
299 # XXX - fperez: I am not sure if this is truly a bug in nose 0.11, but
299 # XXX - fperez: I am not sure if this is truly a bug in nose 0.11, but
300 # it does look like one to me: its tearDown method tries to run
300 # it does look like one to me: its tearDown method tries to run
301 #
301 #
302 # delattr(builtin_mod, self._result_var)
302 # delattr(builtin_mod, self._result_var)
303 #
303 #
304 # without checking that the attribute really is there; it implicitly
304 # without checking that the attribute really is there; it implicitly
305 # assumes it should have been set via displayhook. But if the
305 # assumes it should have been set via displayhook. But if the
306 # displayhook was never called, this doesn't necessarily happen. I
306 # displayhook was never called, this doesn't necessarily happen. I
307 # haven't been able to find a little self-contained example outside of
307 # haven't been able to find a little self-contained example outside of
308 # ipython that would show the problem so I can report it to the nose
308 # ipython that would show the problem so I can report it to the nose
309 # team, but it does happen a lot in our code.
309 # team, but it does happen a lot in our code.
310 #
310 #
311 # So here, we just protect as narrowly as possible by trapping an
311 # So here, we just protect as narrowly as possible by trapping an
312 # attribute error whose message would be the name of self._result_var,
312 # attribute error whose message would be the name of self._result_var,
313 # and letting any other error propagate.
313 # and letting any other error propagate.
314 try:
314 try:
315 super(DocTestCase, self).tearDown()
315 super(DocTestCase, self).tearDown()
316 except AttributeError as exc:
316 except AttributeError as exc:
317 if exc.args[0] != self._result_var:
317 if exc.args[0] != self._result_var:
318 raise
318 raise
319
319
320
320
321 # A simple subclassing of the original with a different class name, so we can
321 # A simple subclassing of the original with a different class name, so we can
322 # distinguish and treat differently IPython examples from pure python ones.
322 # distinguish and treat differently IPython examples from pure python ones.
323 class IPExample(doctest.Example): pass
323 class IPExample(doctest.Example): pass
324
324
325
325
326 class IPExternalExample(doctest.Example):
326 class IPExternalExample(doctest.Example):
327 """Doctest examples to be run in an external process."""
327 """Doctest examples to be run in an external process."""
328
328
329 def __init__(self, source, want, exc_msg=None, lineno=0, indent=0,
329 def __init__(self, source, want, exc_msg=None, lineno=0, indent=0,
330 options=None):
330 options=None):
331 # Parent constructor
331 # Parent constructor
332 doctest.Example.__init__(self,source,want,exc_msg,lineno,indent,options)
332 doctest.Example.__init__(self,source,want,exc_msg,lineno,indent,options)
333
333
334 # An EXTRA newline is needed to prevent pexpect hangs
334 # An EXTRA newline is needed to prevent pexpect hangs
335 self.source += '\n'
335 self.source += '\n'
336
336
337
337
338 class IPDocTestParser(doctest.DocTestParser):
338 class IPDocTestParser(doctest.DocTestParser):
339 """
339 """
340 A class used to parse strings containing doctest examples.
340 A class used to parse strings containing doctest examples.
341
341
342 Note: This is a version modified to properly recognize IPython input and
342 Note: This is a version modified to properly recognize IPython input and
343 convert any IPython examples into valid Python ones.
343 convert any IPython examples into valid Python ones.
344 """
344 """
345 # This regular expression is used to find doctest examples in a
345 # This regular expression is used to find doctest examples in a
346 # string. It defines three groups: `source` is the source code
346 # string. It defines three groups: `source` is the source code
347 # (including leading indentation and prompts); `indent` is the
347 # (including leading indentation and prompts); `indent` is the
348 # indentation of the first (PS1) line of the source code; and
348 # indentation of the first (PS1) line of the source code; and
349 # `want` is the expected output (including leading indentation).
349 # `want` is the expected output (including leading indentation).
350
350
351 # Classic Python prompts or default IPython ones
351 # Classic Python prompts or default IPython ones
352 _PS1_PY = r'>>>'
352 _PS1_PY = r'>>>'
353 _PS2_PY = r'\.\.\.'
353 _PS2_PY = r'\.\.\.'
354
354
355 _PS1_IP = r'In\ \[\d+\]:'
355 _PS1_IP = r'In\ \[\d+\]:'
356 _PS2_IP = r'\ \ \ \.\.\.+:'
356 _PS2_IP = r'\ \ \ \.\.\.+:'
357
357
358 _RE_TPL = r'''
358 _RE_TPL = r'''
359 # Source consists of a PS1 line followed by zero or more PS2 lines.
359 # Source consists of a PS1 line followed by zero or more PS2 lines.
360 (?P<source>
360 (?P<source>
361 (?:^(?P<indent> [ ]*) (?P<ps1> %s) .*) # PS1 line
361 (?:^(?P<indent> [ ]*) (?P<ps1> %s) .*) # PS1 line
362 (?:\n [ ]* (?P<ps2> %s) .*)*) # PS2 lines
362 (?:\n [ ]* (?P<ps2> %s) .*)*) # PS2 lines
363 \n? # a newline
363 \n? # a newline
364 # Want consists of any non-blank lines that do not start with PS1.
364 # Want consists of any non-blank lines that do not start with PS1.
365 (?P<want> (?:(?![ ]*$) # Not a blank line
365 (?P<want> (?:(?![ ]*$) # Not a blank line
366 (?![ ]*%s) # Not a line starting with PS1
366 (?![ ]*%s) # Not a line starting with PS1
367 (?![ ]*%s) # Not a line starting with PS2
367 (?![ ]*%s) # Not a line starting with PS2
368 .*$\n? # But any other line
368 .*$\n? # But any other line
369 )*)
369 )*)
370 '''
370 '''
371
371
372 _EXAMPLE_RE_PY = re.compile( _RE_TPL % (_PS1_PY,_PS2_PY,_PS1_PY,_PS2_PY),
372 _EXAMPLE_RE_PY = re.compile( _RE_TPL % (_PS1_PY,_PS2_PY,_PS1_PY,_PS2_PY),
373 re.MULTILINE | re.VERBOSE)
373 re.MULTILINE | re.VERBOSE)
374
374
375 _EXAMPLE_RE_IP = re.compile( _RE_TPL % (_PS1_IP,_PS2_IP,_PS1_IP,_PS2_IP),
375 _EXAMPLE_RE_IP = re.compile( _RE_TPL % (_PS1_IP,_PS2_IP,_PS1_IP,_PS2_IP),
376 re.MULTILINE | re.VERBOSE)
376 re.MULTILINE | re.VERBOSE)
377
377
378 # Mark a test as being fully random. In this case, we simply append the
378 # Mark a test as being fully random. In this case, we simply append the
379 # random marker ('#random') to each individual example's output. This way
379 # random marker ('#random') to each individual example's output. This way
380 # we don't need to modify any other code.
380 # we don't need to modify any other code.
381 _RANDOM_TEST = re.compile(r'#\s*all-random\s+')
381 _RANDOM_TEST = re.compile(r'#\s*all-random\s+')
382
382
383 # Mark tests to be executed in an external process - currently unsupported.
383 # Mark tests to be executed in an external process - currently unsupported.
384 _EXTERNAL_IP = re.compile(r'#\s*ipdoctest:\s*EXTERNAL')
384 _EXTERNAL_IP = re.compile(r'#\s*ipdoctest:\s*EXTERNAL')
385
385
386 def ip2py(self,source):
386 def ip2py(self,source):
387 """Convert input IPython source into valid Python."""
387 """Convert input IPython source into valid Python."""
388 block = _ip.input_transformer_manager.transform_cell(source)
388 block = _ip.input_transformer_manager.transform_cell(source)
389 if len(block.splitlines()) == 1:
389 if len(block.splitlines()) == 1:
390 return _ip.prefilter(block)
390 return _ip.prefilter(block)
391 else:
391 else:
392 return block
392 return block
393
393
394 def parse(self, string, name='<string>'):
394 def parse(self, string, name='<string>'):
395 """
395 """
396 Divide the given string into examples and intervening text,
396 Divide the given string into examples and intervening text,
397 and return them as a list of alternating Examples and strings.
397 and return them as a list of alternating Examples and strings.
398 Line numbers for the Examples are 0-based. The optional
398 Line numbers for the Examples are 0-based. The optional
399 argument `name` is a name identifying this string, and is only
399 argument `name` is a name identifying this string, and is only
400 used for error messages.
400 used for error messages.
401 """
401 """
402
402
403 #print 'Parse string:\n',string # dbg
403 #print 'Parse string:\n',string # dbg
404
404
405 string = string.expandtabs()
405 string = string.expandtabs()
406 # If all lines begin with the same indentation, then strip it.
406 # If all lines begin with the same indentation, then strip it.
407 min_indent = self._min_indent(string)
407 min_indent = self._min_indent(string)
408 if min_indent > 0:
408 if min_indent > 0:
409 string = '\n'.join([l[min_indent:] for l in string.split('\n')])
409 string = '\n'.join([l[min_indent:] for l in string.split('\n')])
410
410
411 output = []
411 output = []
412 charno, lineno = 0, 0
412 charno, lineno = 0, 0
413
413
414 # We make 'all random' tests by adding the '# random' mark to every
414 # We make 'all random' tests by adding the '# random' mark to every
415 # block of output in the test.
415 # block of output in the test.
416 if self._RANDOM_TEST.search(string):
416 if self._RANDOM_TEST.search(string):
417 random_marker = '\n# random'
417 random_marker = '\n# random'
418 else:
418 else:
419 random_marker = ''
419 random_marker = ''
420
420
421 # Whether to convert the input from ipython to python syntax
421 # Whether to convert the input from ipython to python syntax
422 ip2py = False
422 ip2py = False
423 # Find all doctest examples in the string. First, try them as Python
423 # Find all doctest examples in the string. First, try them as Python
424 # examples, then as IPython ones
424 # examples, then as IPython ones
425 terms = list(self._EXAMPLE_RE_PY.finditer(string))
425 terms = list(self._EXAMPLE_RE_PY.finditer(string))
426 if terms:
426 if terms:
427 # Normal Python example
427 # Normal Python example
428 #print '-'*70 # dbg
428 #print '-'*70 # dbg
429 #print 'PyExample, Source:\n',string # dbg
429 #print 'PyExample, Source:\n',string # dbg
430 #print '-'*70 # dbg
430 #print '-'*70 # dbg
431 Example = doctest.Example
431 Example = doctest.Example
432 else:
432 else:
433 # It's an ipython example. Note that IPExamples are run
433 # It's an ipython example. Note that IPExamples are run
434 # in-process, so their syntax must be turned into valid python.
434 # in-process, so their syntax must be turned into valid python.
435 # IPExternalExamples are run out-of-process (via pexpect) so they
435 # IPExternalExamples are run out-of-process (via pexpect) so they
436 # don't need any filtering (a real ipython will be executing them).
436 # don't need any filtering (a real ipython will be executing them).
437 terms = list(self._EXAMPLE_RE_IP.finditer(string))
437 terms = list(self._EXAMPLE_RE_IP.finditer(string))
438 if self._EXTERNAL_IP.search(string):
438 if self._EXTERNAL_IP.search(string):
439 #print '-'*70 # dbg
439 #print '-'*70 # dbg
440 #print 'IPExternalExample, Source:\n',string # dbg
440 #print 'IPExternalExample, Source:\n',string # dbg
441 #print '-'*70 # dbg
441 #print '-'*70 # dbg
442 Example = IPExternalExample
442 Example = IPExternalExample
443 else:
443 else:
444 #print '-'*70 # dbg
444 #print '-'*70 # dbg
445 #print 'IPExample, Source:\n',string # dbg
445 #print 'IPExample, Source:\n',string # dbg
446 #print '-'*70 # dbg
446 #print '-'*70 # dbg
447 Example = IPExample
447 Example = IPExample
448 ip2py = True
448 ip2py = True
449
449
450 for m in terms:
450 for m in terms:
451 # Add the pre-example text to `output`.
451 # Add the pre-example text to `output`.
452 output.append(string[charno:m.start()])
452 output.append(string[charno:m.start()])
453 # Update lineno (lines before this example)
453 # Update lineno (lines before this example)
454 lineno += string.count('\n', charno, m.start())
454 lineno += string.count('\n', charno, m.start())
455 # Extract info from the regexp match.
455 # Extract info from the regexp match.
456 (source, options, want, exc_msg) = \
456 (source, options, want, exc_msg) = \
457 self._parse_example(m, name, lineno,ip2py)
457 self._parse_example(m, name, lineno,ip2py)
458
458
459 # Append the random-output marker (it defaults to empty in most
459 # Append the random-output marker (it defaults to empty in most
460 # cases, it's only non-empty for 'all-random' tests):
460 # cases, it's only non-empty for 'all-random' tests):
461 want += random_marker
461 want += random_marker
462
462
463 if Example is IPExternalExample:
463 if Example is IPExternalExample:
464 options[doctest.NORMALIZE_WHITESPACE] = True
464 options[doctest.NORMALIZE_WHITESPACE] = True
465 want += '\n'
465 want += '\n'
466
466
467 # Create an Example, and add it to the list.
467 # Create an Example, and add it to the list.
468 if not self._IS_BLANK_OR_COMMENT(source):
468 if not self._IS_BLANK_OR_COMMENT(source):
469 output.append(Example(source, want, exc_msg,
469 output.append(Example(source, want, exc_msg,
470 lineno=lineno,
470 lineno=lineno,
471 indent=min_indent+len(m.group('indent')),
471 indent=min_indent+len(m.group('indent')),
472 options=options))
472 options=options))
473 # Update lineno (lines inside this example)
473 # Update lineno (lines inside this example)
474 lineno += string.count('\n', m.start(), m.end())
474 lineno += string.count('\n', m.start(), m.end())
475 # Update charno.
475 # Update charno.
476 charno = m.end()
476 charno = m.end()
477 # Add any remaining post-example text to `output`.
477 # Add any remaining post-example text to `output`.
478 output.append(string[charno:])
478 output.append(string[charno:])
479 return output
479 return output
480
480
481 def _parse_example(self, m, name, lineno,ip2py=False):
481 def _parse_example(self, m, name, lineno,ip2py=False):
482 """
482 """
483 Given a regular expression match from `_EXAMPLE_RE` (`m`),
483 Given a regular expression match from `_EXAMPLE_RE` (`m`),
484 return a pair `(source, want)`, where `source` is the matched
484 return a pair `(source, want)`, where `source` is the matched
485 example's source code (with prompts and indentation stripped);
485 example's source code (with prompts and indentation stripped);
486 and `want` is the example's expected output (with indentation
486 and `want` is the example's expected output (with indentation
487 stripped).
487 stripped).
488
488
489 `name` is the string's name, and `lineno` is the line number
489 `name` is the string's name, and `lineno` is the line number
490 where the example starts; both are used for error messages.
490 where the example starts; both are used for error messages.
491
491
492 Optional:
492 Optional:
493 `ip2py`: if true, filter the input via IPython to convert the syntax
493 `ip2py`: if true, filter the input via IPython to convert the syntax
494 into valid python.
494 into valid python.
495 """
495 """
496
496
497 # Get the example's indentation level.
497 # Get the example's indentation level.
498 indent = len(m.group('indent'))
498 indent = len(m.group('indent'))
499
499
500 # Divide source into lines; check that they're properly
500 # Divide source into lines; check that they're properly
501 # indented; and then strip their indentation & prompts.
501 # indented; and then strip their indentation & prompts.
502 source_lines = m.group('source').split('\n')
502 source_lines = m.group('source').split('\n')
503
503
504 # We're using variable-length input prompts
504 # We're using variable-length input prompts
505 ps1 = m.group('ps1')
505 ps1 = m.group('ps1')
506 ps2 = m.group('ps2')
506 ps2 = m.group('ps2')
507 ps1_len = len(ps1)
507 ps1_len = len(ps1)
508
508
509 self._check_prompt_blank(source_lines, indent, name, lineno,ps1_len)
509 self._check_prompt_blank(source_lines, indent, name, lineno,ps1_len)
510 if ps2:
510 if ps2:
511 self._check_prefix(source_lines[1:], ' '*indent + ps2, name, lineno)
511 self._check_prefix(source_lines[1:], ' '*indent + ps2, name, lineno)
512
512
513 source = '\n'.join([sl[indent+ps1_len+1:] for sl in source_lines])
513 source = '\n'.join([sl[indent+ps1_len+1:] for sl in source_lines])
514
514
515 if ip2py:
515 if ip2py:
516 # Convert source input from IPython into valid Python syntax
516 # Convert source input from IPython into valid Python syntax
517 source = self.ip2py(source)
517 source = self.ip2py(source)
518
518
519 # Divide want into lines; check that it's properly indented; and
519 # Divide want into lines; check that it's properly indented; and
520 # then strip the indentation. Spaces before the last newline should
520 # then strip the indentation. Spaces before the last newline should
521 # be preserved, so plain rstrip() isn't good enough.
521 # be preserved, so plain rstrip() isn't good enough.
522 want = m.group('want')
522 want = m.group('want')
523 want_lines = want.split('\n')
523 want_lines = want.split('\n')
524 if len(want_lines) > 1 and re.match(r' *$', want_lines[-1]):
524 if len(want_lines) > 1 and re.match(r' *$', want_lines[-1]):
525 del want_lines[-1] # forget final newline & spaces after it
525 del want_lines[-1] # forget final newline & spaces after it
526 self._check_prefix(want_lines, ' '*indent, name,
526 self._check_prefix(want_lines, ' '*indent, name,
527 lineno + len(source_lines))
527 lineno + len(source_lines))
528
528
529 # Remove ipython output prompt that might be present in the first line
529 # Remove ipython output prompt that might be present in the first line
530 want_lines[0] = re.sub(r'Out\[\d+\]: \s*?\n?','',want_lines[0])
530 want_lines[0] = re.sub(r'Out\[\d+\]: \s*?\n?','',want_lines[0])
531
531
532 want = '\n'.join([wl[indent:] for wl in want_lines])
532 want = '\n'.join([wl[indent:] for wl in want_lines])
533
533
534 # If `want` contains a traceback message, then extract it.
534 # If `want` contains a traceback message, then extract it.
535 m = self._EXCEPTION_RE.match(want)
535 m = self._EXCEPTION_RE.match(want)
536 if m:
536 if m:
537 exc_msg = m.group('msg')
537 exc_msg = m.group('msg')
538 else:
538 else:
539 exc_msg = None
539 exc_msg = None
540
540
541 # Extract options from the source.
541 # Extract options from the source.
542 options = self._find_options(source, name, lineno)
542 options = self._find_options(source, name, lineno)
543
543
544 return source, options, want, exc_msg
544 return source, options, want, exc_msg
545
545
546 def _check_prompt_blank(self, lines, indent, name, lineno, ps1_len):
546 def _check_prompt_blank(self, lines, indent, name, lineno, ps1_len):
547 """
547 """
548 Given the lines of a source string (including prompts and
548 Given the lines of a source string (including prompts and
549 leading indentation), check to make sure that every prompt is
549 leading indentation), check to make sure that every prompt is
550 followed by a space character. If any line is not followed by
550 followed by a space character. If any line is not followed by
551 a space character, then raise ValueError.
551 a space character, then raise ValueError.
552
552
553 Note: IPython-modified version which takes the input prompt length as a
553 Note: IPython-modified version which takes the input prompt length as a
554 parameter, so that prompts of variable length can be dealt with.
554 parameter, so that prompts of variable length can be dealt with.
555 """
555 """
556 space_idx = indent+ps1_len
556 space_idx = indent+ps1_len
557 min_len = space_idx+1
557 min_len = space_idx+1
558 for i, line in enumerate(lines):
558 for i, line in enumerate(lines):
559 if len(line) >= min_len and line[space_idx] != ' ':
559 if len(line) >= min_len and line[space_idx] != ' ':
560 raise ValueError('line %r of the docstring for %s '
560 raise ValueError('line %r of the docstring for %s '
561 'lacks blank after %s: %r' %
561 'lacks blank after %s: %r' %
562 (lineno+i+1, name,
562 (lineno+i+1, name,
563 line[indent:space_idx], line))
563 line[indent:space_idx], line))
564
564
565
565
566 SKIP = doctest.register_optionflag('SKIP')
566 SKIP = doctest.register_optionflag('SKIP')
567
567
568
568
569 class IPDocTestRunner(doctest.DocTestRunner,object):
569 class IPDocTestRunner(doctest.DocTestRunner,object):
570 """Test runner that synchronizes the IPython namespace with test globals.
570 """Test runner that synchronizes the IPython namespace with test globals.
571 """
571 """
572
572
573 def run(self, test, compileflags=None, out=None, clear_globs=True):
573 def run(self, test, compileflags=None, out=None, clear_globs=True):
574
574
575 # Hack: ipython needs access to the execution context of the example,
575 # Hack: ipython needs access to the execution context of the example,
576 # so that it can propagate user variables loaded by %run into
576 # so that it can propagate user variables loaded by %run into
577 # test.globs. We put them here into our modified %run as a function
577 # test.globs. We put them here into our modified %run as a function
578 # attribute. Our new %run will then only make the namespace update
578 # attribute. Our new %run will then only make the namespace update
579 # when called (rather than unconconditionally updating test.globs here
579 # when called (rather than unconconditionally updating test.globs here
580 # for all examples, most of which won't be calling %run anyway).
580 # for all examples, most of which won't be calling %run anyway).
581 #_ip._ipdoctest_test_globs = test.globs
581 #_ip._ipdoctest_test_globs = test.globs
582 #_ip._ipdoctest_test_filename = test.filename
582 #_ip._ipdoctest_test_filename = test.filename
583
583
584 test.globs.update(_ip.user_ns)
584 test.globs.update(_ip.user_ns)
585
585
586 return super(IPDocTestRunner,self).run(test,
586 return super(IPDocTestRunner,self).run(test,
587 compileflags,out,clear_globs)
587 compileflags,out,clear_globs)
588
588
589
589
590 class DocFileCase(doctest.DocFileCase):
590 class DocFileCase(doctest.DocFileCase):
591 """Overrides to provide filename
591 """Overrides to provide filename
592 """
592 """
593 def address(self):
593 def address(self):
594 return (self._dt_test.filename, None, None)
594 return (self._dt_test.filename, None, None)
595
595
596
596
597 class ExtensionDoctest(doctests.Doctest):
597 class ExtensionDoctest(doctests.Doctest):
598 """Nose Plugin that supports doctests in extension modules.
598 """Nose Plugin that supports doctests in extension modules.
599 """
599 """
600 name = 'extdoctest' # call nosetests with --with-extdoctest
600 name = 'extdoctest' # call nosetests with --with-extdoctest
601 enabled = True
601 enabled = True
602
602
603 def options(self, parser, env=os.environ):
603 def options(self, parser, env=os.environ):
604 Plugin.options(self, parser, env)
604 Plugin.options(self, parser, env)
605 parser.add_option('--doctest-tests', action='store_true',
605 parser.add_option('--doctest-tests', action='store_true',
606 dest='doctest_tests',
606 dest='doctest_tests',
607 default=env.get('NOSE_DOCTEST_TESTS',True),
607 default=env.get('NOSE_DOCTEST_TESTS',True),
608 help="Also look for doctests in test modules. "
608 help="Also look for doctests in test modules. "
609 "Note that classes, methods and functions should "
609 "Note that classes, methods and functions should "
610 "have either doctests or non-doctest tests, "
610 "have either doctests or non-doctest tests, "
611 "not both. [NOSE_DOCTEST_TESTS]")
611 "not both. [NOSE_DOCTEST_TESTS]")
612 parser.add_option('--doctest-extension', action="append",
612 parser.add_option('--doctest-extension', action="append",
613 dest="doctestExtension",
613 dest="doctestExtension",
614 help="Also look for doctests in files with "
614 help="Also look for doctests in files with "
615 "this extension [NOSE_DOCTEST_EXTENSION]")
615 "this extension [NOSE_DOCTEST_EXTENSION]")
616 # Set the default as a list, if given in env; otherwise
616 # Set the default as a list, if given in env; otherwise
617 # an additional value set on the command line will cause
617 # an additional value set on the command line will cause
618 # an error.
618 # an error.
619 env_setting = env.get('NOSE_DOCTEST_EXTENSION')
619 env_setting = env.get('NOSE_DOCTEST_EXTENSION')
620 if env_setting is not None:
620 if env_setting is not None:
621 parser.set_defaults(doctestExtension=tolist(env_setting))
621 parser.set_defaults(doctestExtension=tolist(env_setting))
622
622
623
623
624 def configure(self, options, config):
624 def configure(self, options, config):
625 Plugin.configure(self, options, config)
625 Plugin.configure(self, options, config)
626 # Pull standard doctest plugin out of config; we will do doctesting
626 # Pull standard doctest plugin out of config; we will do doctesting
627 config.plugins.plugins = [p for p in config.plugins.plugins
627 config.plugins.plugins = [p for p in config.plugins.plugins
628 if p.name != 'doctest']
628 if p.name != 'doctest']
629 self.doctest_tests = options.doctest_tests
629 self.doctest_tests = options.doctest_tests
630 self.extension = tolist(options.doctestExtension)
630 self.extension = tolist(options.doctestExtension)
631
631
632 self.parser = doctest.DocTestParser()
632 self.parser = doctest.DocTestParser()
633 self.finder = DocTestFinder()
633 self.finder = DocTestFinder()
634 self.checker = IPDoctestOutputChecker()
634 self.checker = IPDoctestOutputChecker()
635 self.globs = None
635 self.globs = None
636 self.extraglobs = None
636 self.extraglobs = None
637
637
638
638
639 def loadTestsFromExtensionModule(self,filename):
639 def loadTestsFromExtensionModule(self,filename):
640 bpath,mod = os.path.split(filename)
640 bpath,mod = os.path.split(filename)
641 modname = os.path.splitext(mod)[0]
641 modname = os.path.splitext(mod)[0]
642 try:
642 try:
643 sys.path.append(bpath)
643 sys.path.append(bpath)
644 module = __import__(modname)
644 module = __import__(modname)
645 tests = list(self.loadTestsFromModule(module))
645 tests = list(self.loadTestsFromModule(module))
646 finally:
646 finally:
647 sys.path.pop()
647 sys.path.pop()
648 return tests
648 return tests
649
649
650 # NOTE: the method below is almost a copy of the original one in nose, with
650 # NOTE: the method below is almost a copy of the original one in nose, with
651 # a few modifications to control output checking.
651 # a few modifications to control output checking.
652
652
653 def loadTestsFromModule(self, module):
653 def loadTestsFromModule(self, module):
654 #print '*** ipdoctest - lTM',module # dbg
654 #print '*** ipdoctest - lTM',module # dbg
655
655
656 if not self.matches(module.__name__):
656 if not self.matches(module.__name__):
657 log.debug("Doctest doesn't want module %s", module)
657 log.debug("Doctest doesn't want module %s", module)
658 return
658 return
659
659
660 tests = self.finder.find(module,globs=self.globs,
660 tests = self.finder.find(module,globs=self.globs,
661 extraglobs=self.extraglobs)
661 extraglobs=self.extraglobs)
662 if not tests:
662 if not tests:
663 return
663 return
664
664
665 # always use whitespace and ellipsis options
665 # always use whitespace and ellipsis options
666 optionflags = doctest.NORMALIZE_WHITESPACE | doctest.ELLIPSIS
666 optionflags = doctest.NORMALIZE_WHITESPACE | doctest.ELLIPSIS
667
667
668 tests.sort()
668 tests.sort()
669 module_file = module.__file__
669 module_file = module.__file__
670 if module_file[-4:] in ('.pyc', '.pyo'):
670 if module_file[-4:] in ('.pyc', '.pyo'):
671 module_file = module_file[:-1]
671 module_file = module_file[:-1]
672 for test in tests:
672 for test in tests:
673 if not test.examples:
673 if not test.examples:
674 continue
674 continue
675 if not test.filename:
675 if not test.filename:
676 test.filename = module_file
676 test.filename = module_file
677
677
678 yield DocTestCase(test,
678 yield DocTestCase(test,
679 optionflags=optionflags,
679 optionflags=optionflags,
680 checker=self.checker)
680 checker=self.checker)
681
681
682
682
683 def loadTestsFromFile(self, filename):
683 def loadTestsFromFile(self, filename):
684 #print "ipdoctest - from file", filename # dbg
684 #print "ipdoctest - from file", filename # dbg
685 if is_extension_module(filename):
685 if is_extension_module(filename):
686 for t in self.loadTestsFromExtensionModule(filename):
686 for t in self.loadTestsFromExtensionModule(filename):
687 yield t
687 yield t
688 else:
688 else:
689 if self.extension and anyp(filename.endswith, self.extension):
689 if self.extension and anyp(filename.endswith, self.extension):
690 name = os.path.basename(filename)
690 name = os.path.basename(filename)
691 dh = open(filename)
691 dh = open(filename)
692 try:
692 try:
693 doc = dh.read()
693 doc = dh.read()
694 finally:
694 finally:
695 dh.close()
695 dh.close()
696 test = self.parser.get_doctest(
696 test = self.parser.get_doctest(
697 doc, globs={'__file__': filename}, name=name,
697 doc, globs={'__file__': filename}, name=name,
698 filename=filename, lineno=0)
698 filename=filename, lineno=0)
699 if test.examples:
699 if test.examples:
700 #print 'FileCase:',test.examples # dbg
700 #print 'FileCase:',test.examples # dbg
701 yield DocFileCase(test)
701 yield DocFileCase(test)
702 else:
702 else:
703 yield False # no tests to load
703 yield False # no tests to load
704
704
705
705
706 class IPythonDoctest(ExtensionDoctest):
706 class IPythonDoctest(ExtensionDoctest):
707 """Nose Plugin that supports doctests in extension modules.
707 """Nose Plugin that supports doctests in extension modules.
708 """
708 """
709 name = 'ipdoctest' # call nosetests with --with-ipdoctest
709 name = 'ipdoctest' # call nosetests with --with-ipdoctest
710 enabled = True
710 enabled = True
711
711
712 def makeTest(self, obj, parent):
712 def makeTest(self, obj, parent):
713 """Look for doctests in the given object, which will be a
713 """Look for doctests in the given object, which will be a
714 function, method or class.
714 function, method or class.
715 """
715 """
716 #print 'Plugin analyzing:', obj, parent # dbg
716 #print 'Plugin analyzing:', obj, parent # dbg
717 # always use whitespace and ellipsis options
717 # always use whitespace and ellipsis options
718 optionflags = doctest.NORMALIZE_WHITESPACE | doctest.ELLIPSIS
718 optionflags = doctest.NORMALIZE_WHITESPACE | doctest.ELLIPSIS
719
719
720 doctests = self.finder.find(obj, module=getmodule(parent))
720 doctests = self.finder.find(obj, module=getmodule(parent))
721 if doctests:
721 if doctests:
722 for test in doctests:
722 for test in doctests:
723 if len(test.examples) == 0:
723 if len(test.examples) == 0:
724 continue
724 continue
725
725
726 yield DocTestCase(test, obj=obj,
726 yield DocTestCase(test, obj=obj,
727 optionflags=optionflags,
727 optionflags=optionflags,
728 checker=self.checker)
728 checker=self.checker)
729
729
730 def options(self, parser, env=os.environ):
730 def options(self, parser, env=os.environ):
731 #print "Options for nose plugin:", self.name # dbg
731 #print "Options for nose plugin:", self.name # dbg
732 Plugin.options(self, parser, env)
732 Plugin.options(self, parser, env)
733 parser.add_option('--ipdoctest-tests', action='store_true',
733 parser.add_option('--ipdoctest-tests', action='store_true',
734 dest='ipdoctest_tests',
734 dest='ipdoctest_tests',
735 default=env.get('NOSE_IPDOCTEST_TESTS',True),
735 default=env.get('NOSE_IPDOCTEST_TESTS',True),
736 help="Also look for doctests in test modules. "
736 help="Also look for doctests in test modules. "
737 "Note that classes, methods and functions should "
737 "Note that classes, methods and functions should "
738 "have either doctests or non-doctest tests, "
738 "have either doctests or non-doctest tests, "
739 "not both. [NOSE_IPDOCTEST_TESTS]")
739 "not both. [NOSE_IPDOCTEST_TESTS]")
740 parser.add_option('--ipdoctest-extension', action="append",
740 parser.add_option('--ipdoctest-extension', action="append",
741 dest="ipdoctest_extension",
741 dest="ipdoctest_extension",
742 help="Also look for doctests in files with "
742 help="Also look for doctests in files with "
743 "this extension [NOSE_IPDOCTEST_EXTENSION]")
743 "this extension [NOSE_IPDOCTEST_EXTENSION]")
744 # Set the default as a list, if given in env; otherwise
744 # Set the default as a list, if given in env; otherwise
745 # an additional value set on the command line will cause
745 # an additional value set on the command line will cause
746 # an error.
746 # an error.
747 env_setting = env.get('NOSE_IPDOCTEST_EXTENSION')
747 env_setting = env.get('NOSE_IPDOCTEST_EXTENSION')
748 if env_setting is not None:
748 if env_setting is not None:
749 parser.set_defaults(ipdoctest_extension=tolist(env_setting))
749 parser.set_defaults(ipdoctest_extension=tolist(env_setting))
750
750
751 def configure(self, options, config):
751 def configure(self, options, config):
752 #print "Configuring nose plugin:", self.name # dbg
752 #print "Configuring nose plugin:", self.name # dbg
753 Plugin.configure(self, options, config)
753 Plugin.configure(self, options, config)
754 # Pull standard doctest plugin out of config; we will do doctesting
754 # Pull standard doctest plugin out of config; we will do doctesting
755 config.plugins.plugins = [p for p in config.plugins.plugins
755 config.plugins.plugins = [p for p in config.plugins.plugins
756 if p.name != 'doctest']
756 if p.name != 'doctest']
757 self.doctest_tests = options.ipdoctest_tests
757 self.doctest_tests = options.ipdoctest_tests
758 self.extension = tolist(options.ipdoctest_extension)
758 self.extension = tolist(options.ipdoctest_extension)
759
759
760 self.parser = IPDocTestParser()
760 self.parser = IPDocTestParser()
761 self.finder = DocTestFinder(parser=self.parser)
761 self.finder = DocTestFinder(parser=self.parser)
762 self.checker = IPDoctestOutputChecker()
762 self.checker = IPDoctestOutputChecker()
763 self.globs = None
763 self.globs = None
764 self.extraglobs = None
764 self.extraglobs = None
@@ -1,95 +1,95 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """
2 """
3 Utilities for working with stack frames.
3 Utilities for working with stack frames.
4 """
4 """
5 from __future__ import print_function
5 from __future__ import print_function
6
6
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8 # Copyright (C) 2008-2011 The IPython Development Team
8 # Copyright (C) 2008-2011 The IPython Development Team
9 #
9 #
10 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
11 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13
13
14 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
15 # Imports
15 # Imports
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17
17
18 import sys
18 import sys
19 from IPython.utils import py3compat
19 from IPython.utils import py3compat
20
20
21 #-----------------------------------------------------------------------------
21 #-----------------------------------------------------------------------------
22 # Code
22 # Code
23 #-----------------------------------------------------------------------------
23 #-----------------------------------------------------------------------------
24
24
25 @py3compat.doctest_refactor_print
25 @py3compat.doctest_refactor_print
26 def extract_vars(*names,**kw):
26 def extract_vars(*names,**kw):
27 """Extract a set of variables by name from another frame.
27 """Extract a set of variables by name from another frame.
28
28
29 :Parameters:
29 :Parameters:
30 - `*names`: strings
30 - `*names`: strings
31 One or more variable names which will be extracted from the caller's
31 One or more variable names which will be extracted from the caller's
32 frame.
32 frame.
33
33
34 :Keywords:
34 :Keywords:
35 - `depth`: integer (0)
35 - `depth`: integer (0)
36 How many frames in the stack to walk when looking for your variables.
36 How many frames in the stack to walk when looking for your variables.
37
37
38
38
39 Examples:
39 Examples:
40
40
41 In [2]: def func(x):
41 In [2]: def func(x):
42 ...: y = 1
42 ...: y = 1
43 ...: print sorted(extract_vars('x','y').items())
43 ...: print(sorted(extract_vars('x','y').items()))
44 ...:
44 ...:
45
45
46 In [3]: func('hello')
46 In [3]: func('hello')
47 [('x', 'hello'), ('y', 1)]
47 [('x', 'hello'), ('y', 1)]
48 """
48 """
49
49
50 depth = kw.get('depth',0)
50 depth = kw.get('depth',0)
51
51
52 callerNS = sys._getframe(depth+1).f_locals
52 callerNS = sys._getframe(depth+1).f_locals
53 return dict((k,callerNS[k]) for k in names)
53 return dict((k,callerNS[k]) for k in names)
54
54
55
55
56 def extract_vars_above(*names):
56 def extract_vars_above(*names):
57 """Extract a set of variables by name from another frame.
57 """Extract a set of variables by name from another frame.
58
58
59 Similar to extractVars(), but with a specified depth of 1, so that names
59 Similar to extractVars(), but with a specified depth of 1, so that names
60 are exctracted exactly from above the caller.
60 are exctracted exactly from above the caller.
61
61
62 This is simply a convenience function so that the very common case (for us)
62 This is simply a convenience function so that the very common case (for us)
63 of skipping exactly 1 frame doesn't have to construct a special dict for
63 of skipping exactly 1 frame doesn't have to construct a special dict for
64 keyword passing."""
64 keyword passing."""
65
65
66 callerNS = sys._getframe(2).f_locals
66 callerNS = sys._getframe(2).f_locals
67 return dict((k,callerNS[k]) for k in names)
67 return dict((k,callerNS[k]) for k in names)
68
68
69
69
70 def debugx(expr,pre_msg=''):
70 def debugx(expr,pre_msg=''):
71 """Print the value of an expression from the caller's frame.
71 """Print the value of an expression from the caller's frame.
72
72
73 Takes an expression, evaluates it in the caller's frame and prints both
73 Takes an expression, evaluates it in the caller's frame and prints both
74 the given expression and the resulting value (as well as a debug mark
74 the given expression and the resulting value (as well as a debug mark
75 indicating the name of the calling function. The input must be of a form
75 indicating the name of the calling function. The input must be of a form
76 suitable for eval().
76 suitable for eval().
77
77
78 An optional message can be passed, which will be prepended to the printed
78 An optional message can be passed, which will be prepended to the printed
79 expr->value pair."""
79 expr->value pair."""
80
80
81 cf = sys._getframe(1)
81 cf = sys._getframe(1)
82 print('[DBG:%s] %s%s -> %r' % (cf.f_code.co_name,pre_msg,expr,
82 print('[DBG:%s] %s%s -> %r' % (cf.f_code.co_name,pre_msg,expr,
83 eval(expr,cf.f_globals,cf.f_locals)))
83 eval(expr,cf.f_globals,cf.f_locals)))
84
84
85
85
86 # deactivate it by uncommenting the following line, which makes it a no-op
86 # deactivate it by uncommenting the following line, which makes it a no-op
87 #def debugx(expr,pre_msg=''): pass
87 #def debugx(expr,pre_msg=''): pass
88
88
89 def extract_module_locals(depth=0):
89 def extract_module_locals(depth=0):
90 """Returns (module, locals) of the funciton `depth` frames away from the caller"""
90 """Returns (module, locals) of the funciton `depth` frames away from the caller"""
91 f = sys._getframe(depth + 1)
91 f = sys._getframe(depth + 1)
92 global_ns = f.f_globals
92 global_ns = f.f_globals
93 module = sys.modules[global_ns['__name__']]
93 module = sys.modules[global_ns['__name__']]
94 return (module, f.f_locals)
94 return (module, f.f_locals)
95
95
@@ -1,229 +1,229 b''
1 """Utilities to manipulate JSON objects.
1 """Utilities to manipulate JSON objects.
2 """
2 """
3 #-----------------------------------------------------------------------------
3 #-----------------------------------------------------------------------------
4 # Copyright (C) 2010-2011 The IPython Development Team
4 # Copyright (C) 2010-2011 The IPython Development Team
5 #
5 #
6 # Distributed under the terms of the BSD License. The full license is in
6 # Distributed under the terms of the BSD License. The full license is in
7 # the file COPYING.txt, distributed as part of this software.
7 # the file COPYING.txt, distributed as part of this software.
8 #-----------------------------------------------------------------------------
8 #-----------------------------------------------------------------------------
9
9
10 #-----------------------------------------------------------------------------
10 #-----------------------------------------------------------------------------
11 # Imports
11 # Imports
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13 # stdlib
13 # stdlib
14 import math
14 import math
15 import re
15 import re
16 import types
16 import types
17 from datetime import datetime
17 from datetime import datetime
18
18
19 try:
19 try:
20 # base64.encodestring is deprecated in Python 3.x
20 # base64.encodestring is deprecated in Python 3.x
21 from base64 import encodebytes
21 from base64 import encodebytes
22 except ImportError:
22 except ImportError:
23 # Python 2.x
23 # Python 2.x
24 from base64 import encodestring as encodebytes
24 from base64 import encodestring as encodebytes
25
25
26 from IPython.utils import py3compat
26 from IPython.utils import py3compat
27 from IPython.utils.py3compat import string_types, unicode_type, iteritems
27 from IPython.utils.py3compat import string_types, unicode_type, iteritems
28 from IPython.utils.encoding import DEFAULT_ENCODING
28 from IPython.utils.encoding import DEFAULT_ENCODING
29 next_attr_name = '__next__' if py3compat.PY3 else 'next'
29 next_attr_name = '__next__' if py3compat.PY3 else 'next'
30
30
31 #-----------------------------------------------------------------------------
31 #-----------------------------------------------------------------------------
32 # Globals and constants
32 # Globals and constants
33 #-----------------------------------------------------------------------------
33 #-----------------------------------------------------------------------------
34
34
35 # timestamp formats
35 # timestamp formats
36 ISO8601="%Y-%m-%dT%H:%M:%S.%f"
36 ISO8601="%Y-%m-%dT%H:%M:%S.%f"
37 ISO8601_PAT=re.compile(r"^(\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d+)Z?([\+\-]\d{2}:?\d{2})?$")
37 ISO8601_PAT=re.compile(r"^(\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d+)Z?([\+\-]\d{2}:?\d{2})?$")
38
38
39 #-----------------------------------------------------------------------------
39 #-----------------------------------------------------------------------------
40 # Classes and functions
40 # Classes and functions
41 #-----------------------------------------------------------------------------
41 #-----------------------------------------------------------------------------
42
42
43 def rekey(dikt):
43 def rekey(dikt):
44 """Rekey a dict that has been forced to use str keys where there should be
44 """Rekey a dict that has been forced to use str keys where there should be
45 ints by json."""
45 ints by json."""
46 for k in dikt:
46 for k in dikt:
47 if isinstance(k, string_types):
47 if isinstance(k, string_types):
48 ik=fk=None
48 ik=fk=None
49 try:
49 try:
50 ik = int(k)
50 ik = int(k)
51 except ValueError:
51 except ValueError:
52 try:
52 try:
53 fk = float(k)
53 fk = float(k)
54 except ValueError:
54 except ValueError:
55 continue
55 continue
56 if ik is not None:
56 if ik is not None:
57 nk = ik
57 nk = ik
58 else:
58 else:
59 nk = fk
59 nk = fk
60 if nk in dikt:
60 if nk in dikt:
61 raise KeyError("already have key %r"%nk)
61 raise KeyError("already have key %r"%nk)
62 dikt[nk] = dikt.pop(k)
62 dikt[nk] = dikt.pop(k)
63 return dikt
63 return dikt
64
64
65
65
66 def extract_dates(obj):
66 def extract_dates(obj):
67 """extract ISO8601 dates from unpacked JSON"""
67 """extract ISO8601 dates from unpacked JSON"""
68 if isinstance(obj, dict):
68 if isinstance(obj, dict):
69 obj = dict(obj) # don't clobber
69 obj = dict(obj) # don't clobber
70 for k,v in iteritems(obj):
70 for k,v in iteritems(obj):
71 obj[k] = extract_dates(v)
71 obj[k] = extract_dates(v)
72 elif isinstance(obj, (list, tuple)):
72 elif isinstance(obj, (list, tuple)):
73 obj = [ extract_dates(o) for o in obj ]
73 obj = [ extract_dates(o) for o in obj ]
74 elif isinstance(obj, string_types):
74 elif isinstance(obj, string_types):
75 m = ISO8601_PAT.match(obj)
75 m = ISO8601_PAT.match(obj)
76 if m:
76 if m:
77 # FIXME: add actual timezone support
77 # FIXME: add actual timezone support
78 # this just drops the timezone info
78 # this just drops the timezone info
79 notz = m.groups()[0]
79 notz = m.groups()[0]
80 obj = datetime.strptime(notz, ISO8601)
80 obj = datetime.strptime(notz, ISO8601)
81 return obj
81 return obj
82
82
83 def squash_dates(obj):
83 def squash_dates(obj):
84 """squash datetime objects into ISO8601 strings"""
84 """squash datetime objects into ISO8601 strings"""
85 if isinstance(obj, dict):
85 if isinstance(obj, dict):
86 obj = dict(obj) # don't clobber
86 obj = dict(obj) # don't clobber
87 for k,v in iteritems(obj):
87 for k,v in iteritems(obj):
88 obj[k] = squash_dates(v)
88 obj[k] = squash_dates(v)
89 elif isinstance(obj, (list, tuple)):
89 elif isinstance(obj, (list, tuple)):
90 obj = [ squash_dates(o) for o in obj ]
90 obj = [ squash_dates(o) for o in obj ]
91 elif isinstance(obj, datetime):
91 elif isinstance(obj, datetime):
92 obj = obj.isoformat()
92 obj = obj.isoformat()
93 return obj
93 return obj
94
94
95 def date_default(obj):
95 def date_default(obj):
96 """default function for packing datetime objects in JSON."""
96 """default function for packing datetime objects in JSON."""
97 if isinstance(obj, datetime):
97 if isinstance(obj, datetime):
98 return obj.isoformat()
98 return obj.isoformat()
99 else:
99 else:
100 raise TypeError("%r is not JSON serializable"%obj)
100 raise TypeError("%r is not JSON serializable"%obj)
101
101
102
102
103 # constants for identifying png/jpeg data
103 # constants for identifying png/jpeg data
104 PNG = b'\x89PNG\r\n\x1a\n'
104 PNG = b'\x89PNG\r\n\x1a\n'
105 # front of PNG base64-encoded
105 # front of PNG base64-encoded
106 PNG64 = b'iVBORw0KG'
106 PNG64 = b'iVBORw0KG'
107 JPEG = b'\xff\xd8'
107 JPEG = b'\xff\xd8'
108 # front of JPEG base64-encoded
108 # front of JPEG base64-encoded
109 JPEG64 = b'/9'
109 JPEG64 = b'/9'
110
110
111 def encode_images(format_dict):
111 def encode_images(format_dict):
112 """b64-encodes images in a displaypub format dict
112 """b64-encodes images in a displaypub format dict
113
113
114 Perhaps this should be handled in json_clean itself?
114 Perhaps this should be handled in json_clean itself?
115
115
116 Parameters
116 Parameters
117 ----------
117 ----------
118
118
119 format_dict : dict
119 format_dict : dict
120 A dictionary of display data keyed by mime-type
120 A dictionary of display data keyed by mime-type
121
121
122 Returns
122 Returns
123 -------
123 -------
124
124
125 format_dict : dict
125 format_dict : dict
126 A copy of the same dictionary,
126 A copy of the same dictionary,
127 but binary image data ('image/png' or 'image/jpeg')
127 but binary image data ('image/png' or 'image/jpeg')
128 is base64-encoded.
128 is base64-encoded.
129
129
130 """
130 """
131 encoded = format_dict.copy()
131 encoded = format_dict.copy()
132
132
133 pngdata = format_dict.get('image/png')
133 pngdata = format_dict.get('image/png')
134 if isinstance(pngdata, bytes):
134 if isinstance(pngdata, bytes):
135 # make sure we don't double-encode
135 # make sure we don't double-encode
136 if not pngdata.startswith(PNG64):
136 if not pngdata.startswith(PNG64):
137 pngdata = encodebytes(pngdata)
137 pngdata = encodebytes(pngdata)
138 encoded['image/png'] = pngdata.decode('ascii')
138 encoded['image/png'] = pngdata.decode('ascii')
139
139
140 jpegdata = format_dict.get('image/jpeg')
140 jpegdata = format_dict.get('image/jpeg')
141 if isinstance(jpegdata, bytes):
141 if isinstance(jpegdata, bytes):
142 # make sure we don't double-encode
142 # make sure we don't double-encode
143 if not jpegdata.startswith(JPEG64):
143 if not jpegdata.startswith(JPEG64):
144 jpegdata = encodebytes(jpegdata)
144 jpegdata = encodebytes(jpegdata)
145 encoded['image/jpeg'] = jpegdata.decode('ascii')
145 encoded['image/jpeg'] = jpegdata.decode('ascii')
146
146
147 return encoded
147 return encoded
148
148
149
149
150 def json_clean(obj):
150 def json_clean(obj):
151 """Clean an object to ensure it's safe to encode in JSON.
151 """Clean an object to ensure it's safe to encode in JSON.
152
152
153 Atomic, immutable objects are returned unmodified. Sets and tuples are
153 Atomic, immutable objects are returned unmodified. Sets and tuples are
154 converted to lists, lists are copied and dicts are also copied.
154 converted to lists, lists are copied and dicts are also copied.
155
155
156 Note: dicts whose keys could cause collisions upon encoding (such as a dict
156 Note: dicts whose keys could cause collisions upon encoding (such as a dict
157 with both the number 1 and the string '1' as keys) will cause a ValueError
157 with both the number 1 and the string '1' as keys) will cause a ValueError
158 to be raised.
158 to be raised.
159
159
160 Parameters
160 Parameters
161 ----------
161 ----------
162 obj : any python object
162 obj : any python object
163
163
164 Returns
164 Returns
165 -------
165 -------
166 out : object
166 out : object
167
167
168 A version of the input which will not cause an encoding error when
168 A version of the input which will not cause an encoding error when
169 encoded as JSON. Note that this function does not *encode* its inputs,
169 encoded as JSON. Note that this function does not *encode* its inputs,
170 it simply sanitizes it so that there will be no encoding errors later.
170 it simply sanitizes it so that there will be no encoding errors later.
171
171
172 Examples
172 Examples
173 --------
173 --------
174 >>> json_clean(4)
174 >>> json_clean(4)
175 4
175 4
176 >>> json_clean(range(10))
176 >>> json_clean(list(range(10)))
177 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
177 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
178 >>> sorted(json_clean(dict(x=1, y=2)).items())
178 >>> sorted(json_clean(dict(x=1, y=2)).items())
179 [('x', 1), ('y', 2)]
179 [('x', 1), ('y', 2)]
180 >>> sorted(json_clean(dict(x=1, y=2, z=[1,2,3])).items())
180 >>> sorted(json_clean(dict(x=1, y=2, z=[1,2,3])).items())
181 [('x', 1), ('y', 2), ('z', [1, 2, 3])]
181 [('x', 1), ('y', 2), ('z', [1, 2, 3])]
182 >>> json_clean(True)
182 >>> json_clean(True)
183 True
183 True
184 """
184 """
185 # types that are 'atomic' and ok in json as-is. bool doesn't need to be
185 # types that are 'atomic' and ok in json as-is. bool doesn't need to be
186 # listed explicitly because bools pass as int instances
186 # listed explicitly because bools pass as int instances
187 atomic_ok = (unicode_type, int, type(None))
187 atomic_ok = (unicode_type, int, type(None))
188
188
189 # containers that we need to convert into lists
189 # containers that we need to convert into lists
190 container_to_list = (tuple, set, types.GeneratorType)
190 container_to_list = (tuple, set, types.GeneratorType)
191
191
192 if isinstance(obj, float):
192 if isinstance(obj, float):
193 # cast out-of-range floats to their reprs
193 # cast out-of-range floats to their reprs
194 if math.isnan(obj) or math.isinf(obj):
194 if math.isnan(obj) or math.isinf(obj):
195 return repr(obj)
195 return repr(obj)
196 return obj
196 return obj
197
197
198 if isinstance(obj, atomic_ok):
198 if isinstance(obj, atomic_ok):
199 return obj
199 return obj
200
200
201 if isinstance(obj, bytes):
201 if isinstance(obj, bytes):
202 return obj.decode(DEFAULT_ENCODING, 'replace')
202 return obj.decode(DEFAULT_ENCODING, 'replace')
203
203
204 if isinstance(obj, container_to_list) or (
204 if isinstance(obj, container_to_list) or (
205 hasattr(obj, '__iter__') and hasattr(obj, next_attr_name)):
205 hasattr(obj, '__iter__') and hasattr(obj, next_attr_name)):
206 obj = list(obj)
206 obj = list(obj)
207
207
208 if isinstance(obj, list):
208 if isinstance(obj, list):
209 return [json_clean(x) for x in obj]
209 return [json_clean(x) for x in obj]
210
210
211 if isinstance(obj, dict):
211 if isinstance(obj, dict):
212 # First, validate that the dict won't lose data in conversion due to
212 # First, validate that the dict won't lose data in conversion due to
213 # key collisions after stringification. This can happen with keys like
213 # key collisions after stringification. This can happen with keys like
214 # True and 'true' or 1 and '1', which collide in JSON.
214 # True and 'true' or 1 and '1', which collide in JSON.
215 nkeys = len(obj)
215 nkeys = len(obj)
216 nkeys_collapsed = len(set(map(str, obj)))
216 nkeys_collapsed = len(set(map(str, obj)))
217 if nkeys != nkeys_collapsed:
217 if nkeys != nkeys_collapsed:
218 raise ValueError('dict can not be safely converted to JSON: '
218 raise ValueError('dict can not be safely converted to JSON: '
219 'key collision would lead to dropped values')
219 'key collision would lead to dropped values')
220 # If all OK, proceed by making the new dict that will be json-safe
220 # If all OK, proceed by making the new dict that will be json-safe
221 out = {}
221 out = {}
222 for k,v in iteritems(obj):
222 for k,v in iteritems(obj):
223 out[str(k)] = json_clean(v)
223 out[str(k)] = json_clean(v)
224 return out
224 return out
225
225
226 # If we get here, we don't know how to handle the object, so we just get
226 # If we get here, we don't know how to handle the object, so we just get
227 # its repr and return that. This will catch lambdas, open sockets, class
227 # its repr and return that. This will catch lambdas, open sockets, class
228 # objects, and any other complicated contraption that json can't encode
228 # objects, and any other complicated contraption that json can't encode
229 return repr(obj)
229 return repr(obj)
@@ -1,239 +1,239 b''
1 # coding: utf-8
1 # coding: utf-8
2 """Compatibility tricks for Python 3. Mainly to do with unicode."""
2 """Compatibility tricks for Python 3. Mainly to do with unicode."""
3 import functools
3 import functools
4 import sys
4 import sys
5 import re
5 import re
6 import types
6 import types
7
7
8 from .encoding import DEFAULT_ENCODING
8 from .encoding import DEFAULT_ENCODING
9
9
10 orig_open = open
10 orig_open = open
11
11
12 def no_code(x, encoding=None):
12 def no_code(x, encoding=None):
13 return x
13 return x
14
14
15 def decode(s, encoding=None):
15 def decode(s, encoding=None):
16 encoding = encoding or DEFAULT_ENCODING
16 encoding = encoding or DEFAULT_ENCODING
17 return s.decode(encoding, "replace")
17 return s.decode(encoding, "replace")
18
18
19 def encode(u, encoding=None):
19 def encode(u, encoding=None):
20 encoding = encoding or DEFAULT_ENCODING
20 encoding = encoding or DEFAULT_ENCODING
21 return u.encode(encoding, "replace")
21 return u.encode(encoding, "replace")
22
22
23
23
24 def cast_unicode(s, encoding=None):
24 def cast_unicode(s, encoding=None):
25 if isinstance(s, bytes):
25 if isinstance(s, bytes):
26 return decode(s, encoding)
26 return decode(s, encoding)
27 return s
27 return s
28
28
29 def cast_bytes(s, encoding=None):
29 def cast_bytes(s, encoding=None):
30 if not isinstance(s, bytes):
30 if not isinstance(s, bytes):
31 return encode(s, encoding)
31 return encode(s, encoding)
32 return s
32 return s
33
33
34 def _modify_str_or_docstring(str_change_func):
34 def _modify_str_or_docstring(str_change_func):
35 @functools.wraps(str_change_func)
35 @functools.wraps(str_change_func)
36 def wrapper(func_or_str):
36 def wrapper(func_or_str):
37 if isinstance(func_or_str, string_types):
37 if isinstance(func_or_str, string_types):
38 func = None
38 func = None
39 doc = func_or_str
39 doc = func_or_str
40 else:
40 else:
41 func = func_or_str
41 func = func_or_str
42 doc = func.__doc__
42 doc = func.__doc__
43
43
44 doc = str_change_func(doc)
44 doc = str_change_func(doc)
45
45
46 if func:
46 if func:
47 func.__doc__ = doc
47 func.__doc__ = doc
48 return func
48 return func
49 return doc
49 return doc
50 return wrapper
50 return wrapper
51
51
52 def safe_unicode(e):
52 def safe_unicode(e):
53 """unicode(e) with various fallbacks. Used for exceptions, which may not be
53 """unicode(e) with various fallbacks. Used for exceptions, which may not be
54 safe to call unicode() on.
54 safe to call unicode() on.
55 """
55 """
56 try:
56 try:
57 return unicode_type(e)
57 return unicode_type(e)
58 except UnicodeError:
58 except UnicodeError:
59 pass
59 pass
60
60
61 try:
61 try:
62 return str_to_unicode(str(e))
62 return str_to_unicode(str(e))
63 except UnicodeError:
63 except UnicodeError:
64 pass
64 pass
65
65
66 try:
66 try:
67 return str_to_unicode(repr(e))
67 return str_to_unicode(repr(e))
68 except UnicodeError:
68 except UnicodeError:
69 pass
69 pass
70
70
71 return u'Unrecoverably corrupt evalue'
71 return u'Unrecoverably corrupt evalue'
72
72
73 if sys.version_info[0] >= 3:
73 if sys.version_info[0] >= 3:
74 PY3 = True
74 PY3 = True
75
75
76 input = input
76 input = input
77 builtin_mod_name = "builtins"
77 builtin_mod_name = "builtins"
78 import builtins as builtin_mod
78 import builtins as builtin_mod
79
79
80 str_to_unicode = no_code
80 str_to_unicode = no_code
81 unicode_to_str = no_code
81 unicode_to_str = no_code
82 str_to_bytes = encode
82 str_to_bytes = encode
83 bytes_to_str = decode
83 bytes_to_str = decode
84 cast_bytes_py2 = no_code
84 cast_bytes_py2 = no_code
85
85
86 string_types = (str,)
86 string_types = (str,)
87 unicode_type = str
87 unicode_type = str
88
88
89 def isidentifier(s, dotted=False):
89 def isidentifier(s, dotted=False):
90 if dotted:
90 if dotted:
91 return all(isidentifier(a) for a in s.split("."))
91 return all(isidentifier(a) for a in s.split("."))
92 return s.isidentifier()
92 return s.isidentifier()
93
93
94 open = orig_open
94 open = orig_open
95 xrange = range
95 xrange = range
96 iteritems = dict.items
96 def iteritems(d): return iter(d.items())
97 itervalues = dict.values
97 def itervalues(d): return iter(d.values())
98
98
99 MethodType = types.MethodType
99 MethodType = types.MethodType
100
100
101 def execfile(fname, glob, loc=None):
101 def execfile(fname, glob, loc=None):
102 loc = loc if (loc is not None) else glob
102 loc = loc if (loc is not None) else glob
103 with open(fname, 'rb') as f:
103 with open(fname, 'rb') as f:
104 exec(compile(f.read(), fname, 'exec'), glob, loc)
104 exec(compile(f.read(), fname, 'exec'), glob, loc)
105
105
106 # Refactor print statements in doctests.
106 # Refactor print statements in doctests.
107 _print_statement_re = re.compile(r"\bprint (?P<expr>.*)$", re.MULTILINE)
107 _print_statement_re = re.compile(r"\bprint (?P<expr>.*)$", re.MULTILINE)
108 def _print_statement_sub(match):
108 def _print_statement_sub(match):
109 expr = match.groups('expr')
109 expr = match.groups('expr')
110 return "print(%s)" % expr
110 return "print(%s)" % expr
111
111
112 @_modify_str_or_docstring
112 @_modify_str_or_docstring
113 def doctest_refactor_print(doc):
113 def doctest_refactor_print(doc):
114 """Refactor 'print x' statements in a doctest to print(x) style. 2to3
114 """Refactor 'print x' statements in a doctest to print(x) style. 2to3
115 unfortunately doesn't pick up on our doctests.
115 unfortunately doesn't pick up on our doctests.
116
116
117 Can accept a string or a function, so it can be used as a decorator."""
117 Can accept a string or a function, so it can be used as a decorator."""
118 return _print_statement_re.sub(_print_statement_sub, doc)
118 return _print_statement_re.sub(_print_statement_sub, doc)
119
119
120 # Abstract u'abc' syntax:
120 # Abstract u'abc' syntax:
121 @_modify_str_or_docstring
121 @_modify_str_or_docstring
122 def u_format(s):
122 def u_format(s):
123 """"{u}'abc'" --> "'abc'" (Python 3)
123 """"{u}'abc'" --> "'abc'" (Python 3)
124
124
125 Accepts a string or a function, so it can be used as a decorator."""
125 Accepts a string or a function, so it can be used as a decorator."""
126 return s.format(u='')
126 return s.format(u='')
127
127
128 else:
128 else:
129 PY3 = False
129 PY3 = False
130
130
131 input = raw_input
131 input = raw_input
132 builtin_mod_name = "__builtin__"
132 builtin_mod_name = "__builtin__"
133 import __builtin__ as builtin_mod
133 import __builtin__ as builtin_mod
134
134
135 str_to_unicode = decode
135 str_to_unicode = decode
136 unicode_to_str = encode
136 unicode_to_str = encode
137 str_to_bytes = no_code
137 str_to_bytes = no_code
138 bytes_to_str = no_code
138 bytes_to_str = no_code
139 cast_bytes_py2 = cast_bytes
139 cast_bytes_py2 = cast_bytes
140
140
141 string_types = (str, unicode)
141 string_types = (str, unicode)
142 unicode_type = unicode
142 unicode_type = unicode
143
143
144 import re
144 import re
145 _name_re = re.compile(r"[a-zA-Z_][a-zA-Z0-9_]*$")
145 _name_re = re.compile(r"[a-zA-Z_][a-zA-Z0-9_]*$")
146 def isidentifier(s, dotted=False):
146 def isidentifier(s, dotted=False):
147 if dotted:
147 if dotted:
148 return all(isidentifier(a) for a in s.split("."))
148 return all(isidentifier(a) for a in s.split("."))
149 return bool(_name_re.match(s))
149 return bool(_name_re.match(s))
150
150
151 class open(object):
151 class open(object):
152 """Wrapper providing key part of Python 3 open() interface."""
152 """Wrapper providing key part of Python 3 open() interface."""
153 def __init__(self, fname, mode="r", encoding="utf-8"):
153 def __init__(self, fname, mode="r", encoding="utf-8"):
154 self.f = orig_open(fname, mode)
154 self.f = orig_open(fname, mode)
155 self.enc = encoding
155 self.enc = encoding
156
156
157 def write(self, s):
157 def write(self, s):
158 return self.f.write(s.encode(self.enc))
158 return self.f.write(s.encode(self.enc))
159
159
160 def read(self, size=-1):
160 def read(self, size=-1):
161 return self.f.read(size).decode(self.enc)
161 return self.f.read(size).decode(self.enc)
162
162
163 def close(self):
163 def close(self):
164 return self.f.close()
164 return self.f.close()
165
165
166 def __enter__(self):
166 def __enter__(self):
167 return self
167 return self
168
168
169 def __exit__(self, etype, value, traceback):
169 def __exit__(self, etype, value, traceback):
170 self.f.close()
170 self.f.close()
171
171
172 xrange = xrange
172 xrange = xrange
173 iteritems = dict.iteritems
173 def iteritems(d): return d.iteritems()
174 itervalues = dict.itervalues
174 def itervalues(d): return d.itervalues()
175
175
176 def MethodType(func, instance):
176 def MethodType(func, instance):
177 return types.MethodType(func, instance, type(instance))
177 return types.MethodType(func, instance, type(instance))
178
178
179 # don't override system execfile on 2.x:
179 # don't override system execfile on 2.x:
180 execfile = execfile
180 execfile = execfile
181
181
182 def doctest_refactor_print(func_or_str):
182 def doctest_refactor_print(func_or_str):
183 return func_or_str
183 return func_or_str
184
184
185
185
186 # Abstract u'abc' syntax:
186 # Abstract u'abc' syntax:
187 @_modify_str_or_docstring
187 @_modify_str_or_docstring
188 def u_format(s):
188 def u_format(s):
189 """"{u}'abc'" --> "u'abc'" (Python 2)
189 """"{u}'abc'" --> "u'abc'" (Python 2)
190
190
191 Accepts a string or a function, so it can be used as a decorator."""
191 Accepts a string or a function, so it can be used as a decorator."""
192 return s.format(u='u')
192 return s.format(u='u')
193
193
194 if sys.platform == 'win32':
194 if sys.platform == 'win32':
195 def execfile(fname, glob=None, loc=None):
195 def execfile(fname, glob=None, loc=None):
196 loc = loc if (loc is not None) else glob
196 loc = loc if (loc is not None) else glob
197 # The rstrip() is necessary b/c trailing whitespace in files will
197 # The rstrip() is necessary b/c trailing whitespace in files will
198 # cause an IndentationError in Python 2.6 (this was fixed in 2.7,
198 # cause an IndentationError in Python 2.6 (this was fixed in 2.7,
199 # but we still support 2.6). See issue 1027.
199 # but we still support 2.6). See issue 1027.
200 scripttext = builtin_mod.open(fname).read().rstrip() + '\n'
200 scripttext = builtin_mod.open(fname).read().rstrip() + '\n'
201 # compile converts unicode filename to str assuming
201 # compile converts unicode filename to str assuming
202 # ascii. Let's do the conversion before calling compile
202 # ascii. Let's do the conversion before calling compile
203 if isinstance(fname, unicode):
203 if isinstance(fname, unicode):
204 filename = unicode_to_str(fname)
204 filename = unicode_to_str(fname)
205 else:
205 else:
206 filename = fname
206 filename = fname
207 exec(compile(scripttext, filename, 'exec'), glob, loc)
207 exec(compile(scripttext, filename, 'exec'), glob, loc)
208 else:
208 else:
209 def execfile(fname, *where):
209 def execfile(fname, *where):
210 if isinstance(fname, unicode):
210 if isinstance(fname, unicode):
211 filename = fname.encode(sys.getfilesystemencoding())
211 filename = fname.encode(sys.getfilesystemencoding())
212 else:
212 else:
213 filename = fname
213 filename = fname
214 builtin_mod.execfile(filename, *where)
214 builtin_mod.execfile(filename, *where)
215
215
216 # Parts below taken from six:
216 # Parts below taken from six:
217 # Copyright (c) 2010-2013 Benjamin Peterson
217 # Copyright (c) 2010-2013 Benjamin Peterson
218 #
218 #
219 # Permission is hereby granted, free of charge, to any person obtaining a copy
219 # Permission is hereby granted, free of charge, to any person obtaining a copy
220 # of this software and associated documentation files (the "Software"), to deal
220 # of this software and associated documentation files (the "Software"), to deal
221 # in the Software without restriction, including without limitation the rights
221 # in the Software without restriction, including without limitation the rights
222 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
222 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
223 # copies of the Software, and to permit persons to whom the Software is
223 # copies of the Software, and to permit persons to whom the Software is
224 # furnished to do so, subject to the following conditions:
224 # furnished to do so, subject to the following conditions:
225 #
225 #
226 # The above copyright notice and this permission notice shall be included in all
226 # The above copyright notice and this permission notice shall be included in all
227 # copies or substantial portions of the Software.
227 # copies or substantial portions of the Software.
228 #
228 #
229 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
229 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
230 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
230 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
231 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
231 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
232 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
232 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
233 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
233 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
234 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
234 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
235 # SOFTWARE.
235 # SOFTWARE.
236
236
237 def with_metaclass(meta, *bases):
237 def with_metaclass(meta, *bases):
238 """Create a base class with a metaclass."""
238 """Create a base class with a metaclass."""
239 return meta("NewBase", bases, {})
239 return meta("NewBase", bases, {})
@@ -1,643 +1,643 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """Tests for IPython.utils.path.py"""
2 """Tests for IPython.utils.path.py"""
3
3
4 #-----------------------------------------------------------------------------
4 #-----------------------------------------------------------------------------
5 # Copyright (C) 2008-2011 The IPython Development Team
5 # Copyright (C) 2008-2011 The IPython Development Team
6 #
6 #
7 # Distributed under the terms of the BSD License. The full license is in
7 # Distributed under the terms of the BSD License. The full license is in
8 # the file COPYING, distributed as part of this software.
8 # the file COPYING, distributed as part of this software.
9 #-----------------------------------------------------------------------------
9 #-----------------------------------------------------------------------------
10
10
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12 # Imports
12 # Imports
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14
14
15 from __future__ import with_statement
15 from __future__ import with_statement
16
16
17 import os
17 import os
18 import shutil
18 import shutil
19 import sys
19 import sys
20 import tempfile
20 import tempfile
21 from contextlib import contextmanager
21 from contextlib import contextmanager
22
22
23 from os.path import join, abspath, split
23 from os.path import join, abspath, split
24
24
25 import nose.tools as nt
25 import nose.tools as nt
26
26
27 from nose import with_setup
27 from nose import with_setup
28
28
29 import IPython
29 import IPython
30 from IPython.testing import decorators as dec
30 from IPython.testing import decorators as dec
31 from IPython.testing.decorators import (skip_if_not_win32, skip_win32,
31 from IPython.testing.decorators import (skip_if_not_win32, skip_win32,
32 onlyif_unicode_paths,)
32 onlyif_unicode_paths,)
33 from IPython.testing.tools import make_tempfile, AssertPrints
33 from IPython.testing.tools import make_tempfile, AssertPrints
34 from IPython.utils import path
34 from IPython.utils import path
35 from IPython.utils import py3compat
35 from IPython.utils import py3compat
36 from IPython.utils.tempdir import TemporaryDirectory
36 from IPython.utils.tempdir import TemporaryDirectory
37
37
38 # Platform-dependent imports
38 # Platform-dependent imports
39 try:
39 try:
40 import winreg as wreg # Py 3
40 import winreg as wreg # Py 3
41 except ImportError:
41 except ImportError:
42 try:
42 try:
43 import _winreg as wreg # Py 2
43 import _winreg as wreg # Py 2
44 except ImportError:
44 except ImportError:
45 #Fake _winreg module on none windows platforms
45 #Fake _winreg module on none windows platforms
46 import types
46 import types
47 wr_name = "winreg" if py3compat.PY3 else "_winreg"
47 wr_name = "winreg" if py3compat.PY3 else "_winreg"
48 sys.modules[wr_name] = types.ModuleType(wr_name)
48 sys.modules[wr_name] = types.ModuleType(wr_name)
49 try:
49 try:
50 import winreg as wreg
50 import winreg as wreg
51 except ImportError:
51 except ImportError:
52 import _winreg as wreg
52 import _winreg as wreg
53 #Add entries that needs to be stubbed by the testing code
53 #Add entries that needs to be stubbed by the testing code
54 (wreg.OpenKey, wreg.QueryValueEx,) = (None, None)
54 (wreg.OpenKey, wreg.QueryValueEx,) = (None, None)
55
55
56 try:
56 try:
57 reload
57 reload
58 except NameError: # Python 3
58 except NameError: # Python 3
59 from imp import reload
59 from imp import reload
60
60
61 #-----------------------------------------------------------------------------
61 #-----------------------------------------------------------------------------
62 # Globals
62 # Globals
63 #-----------------------------------------------------------------------------
63 #-----------------------------------------------------------------------------
64 env = os.environ
64 env = os.environ
65 TEST_FILE_PATH = split(abspath(__file__))[0]
65 TEST_FILE_PATH = split(abspath(__file__))[0]
66 TMP_TEST_DIR = tempfile.mkdtemp()
66 TMP_TEST_DIR = tempfile.mkdtemp()
67 HOME_TEST_DIR = join(TMP_TEST_DIR, "home_test_dir")
67 HOME_TEST_DIR = join(TMP_TEST_DIR, "home_test_dir")
68 XDG_TEST_DIR = join(HOME_TEST_DIR, "xdg_test_dir")
68 XDG_TEST_DIR = join(HOME_TEST_DIR, "xdg_test_dir")
69 XDG_CACHE_DIR = join(HOME_TEST_DIR, "xdg_cache_dir")
69 XDG_CACHE_DIR = join(HOME_TEST_DIR, "xdg_cache_dir")
70 IP_TEST_DIR = join(HOME_TEST_DIR,'.ipython')
70 IP_TEST_DIR = join(HOME_TEST_DIR,'.ipython')
71 #
71 #
72 # Setup/teardown functions/decorators
72 # Setup/teardown functions/decorators
73 #
73 #
74
74
75 def setup():
75 def setup():
76 """Setup testenvironment for the module:
76 """Setup testenvironment for the module:
77
77
78 - Adds dummy home dir tree
78 - Adds dummy home dir tree
79 """
79 """
80 # Do not mask exceptions here. In particular, catching WindowsError is a
80 # Do not mask exceptions here. In particular, catching WindowsError is a
81 # problem because that exception is only defined on Windows...
81 # problem because that exception is only defined on Windows...
82 os.makedirs(IP_TEST_DIR)
82 os.makedirs(IP_TEST_DIR)
83 os.makedirs(os.path.join(XDG_TEST_DIR, 'ipython'))
83 os.makedirs(os.path.join(XDG_TEST_DIR, 'ipython'))
84 os.makedirs(os.path.join(XDG_CACHE_DIR, 'ipython'))
84 os.makedirs(os.path.join(XDG_CACHE_DIR, 'ipython'))
85
85
86
86
87 def teardown():
87 def teardown():
88 """Teardown testenvironment for the module:
88 """Teardown testenvironment for the module:
89
89
90 - Remove dummy home dir tree
90 - Remove dummy home dir tree
91 """
91 """
92 # Note: we remove the parent test dir, which is the root of all test
92 # Note: we remove the parent test dir, which is the root of all test
93 # subdirs we may have created. Use shutil instead of os.removedirs, so
93 # subdirs we may have created. Use shutil instead of os.removedirs, so
94 # that non-empty directories are all recursively removed.
94 # that non-empty directories are all recursively removed.
95 shutil.rmtree(TMP_TEST_DIR)
95 shutil.rmtree(TMP_TEST_DIR)
96
96
97
97
98 def setup_environment():
98 def setup_environment():
99 """Setup testenvironment for some functions that are tested
99 """Setup testenvironment for some functions that are tested
100 in this module. In particular this functions stores attributes
100 in this module. In particular this functions stores attributes
101 and other things that we need to stub in some test functions.
101 and other things that we need to stub in some test functions.
102 This needs to be done on a function level and not module level because
102 This needs to be done on a function level and not module level because
103 each testfunction needs a pristine environment.
103 each testfunction needs a pristine environment.
104 """
104 """
105 global oldstuff, platformstuff
105 global oldstuff, platformstuff
106 oldstuff = (env.copy(), os.name, sys.platform, path.get_home_dir, IPython.__file__, os.getcwd())
106 oldstuff = (env.copy(), os.name, sys.platform, path.get_home_dir, IPython.__file__, os.getcwd())
107
107
108 if os.name == 'nt':
108 if os.name == 'nt':
109 platformstuff = (wreg.OpenKey, wreg.QueryValueEx,)
109 platformstuff = (wreg.OpenKey, wreg.QueryValueEx,)
110
110
111
111
112 def teardown_environment():
112 def teardown_environment():
113 """Restore things that were remembered by the setup_environment function
113 """Restore things that were remembered by the setup_environment function
114 """
114 """
115 (oldenv, os.name, sys.platform, path.get_home_dir, IPython.__file__, old_wd) = oldstuff
115 (oldenv, os.name, sys.platform, path.get_home_dir, IPython.__file__, old_wd) = oldstuff
116 os.chdir(old_wd)
116 os.chdir(old_wd)
117 reload(path)
117 reload(path)
118
118
119 for key in env.keys():
119 for key in list(env):
120 if key not in oldenv:
120 if key not in oldenv:
121 del env[key]
121 del env[key]
122 env.update(oldenv)
122 env.update(oldenv)
123 if hasattr(sys, 'frozen'):
123 if hasattr(sys, 'frozen'):
124 del sys.frozen
124 del sys.frozen
125 if os.name == 'nt':
125 if os.name == 'nt':
126 (wreg.OpenKey, wreg.QueryValueEx,) = platformstuff
126 (wreg.OpenKey, wreg.QueryValueEx,) = platformstuff
127
127
128 # Build decorator that uses the setup_environment/setup_environment
128 # Build decorator that uses the setup_environment/setup_environment
129 with_environment = with_setup(setup_environment, teardown_environment)
129 with_environment = with_setup(setup_environment, teardown_environment)
130
130
131 @skip_if_not_win32
131 @skip_if_not_win32
132 @with_environment
132 @with_environment
133 def test_get_home_dir_1():
133 def test_get_home_dir_1():
134 """Testcase for py2exe logic, un-compressed lib
134 """Testcase for py2exe logic, un-compressed lib
135 """
135 """
136 unfrozen = path.get_home_dir()
136 unfrozen = path.get_home_dir()
137 sys.frozen = True
137 sys.frozen = True
138
138
139 #fake filename for IPython.__init__
139 #fake filename for IPython.__init__
140 IPython.__file__ = abspath(join(HOME_TEST_DIR, "Lib/IPython/__init__.py"))
140 IPython.__file__ = abspath(join(HOME_TEST_DIR, "Lib/IPython/__init__.py"))
141
141
142 home_dir = path.get_home_dir()
142 home_dir = path.get_home_dir()
143 nt.assert_equal(home_dir, unfrozen)
143 nt.assert_equal(home_dir, unfrozen)
144
144
145
145
146 @skip_if_not_win32
146 @skip_if_not_win32
147 @with_environment
147 @with_environment
148 def test_get_home_dir_2():
148 def test_get_home_dir_2():
149 """Testcase for py2exe logic, compressed lib
149 """Testcase for py2exe logic, compressed lib
150 """
150 """
151 unfrozen = path.get_home_dir()
151 unfrozen = path.get_home_dir()
152 sys.frozen = True
152 sys.frozen = True
153 #fake filename for IPython.__init__
153 #fake filename for IPython.__init__
154 IPython.__file__ = abspath(join(HOME_TEST_DIR, "Library.zip/IPython/__init__.py")).lower()
154 IPython.__file__ = abspath(join(HOME_TEST_DIR, "Library.zip/IPython/__init__.py")).lower()
155
155
156 home_dir = path.get_home_dir(True)
156 home_dir = path.get_home_dir(True)
157 nt.assert_equal(home_dir, unfrozen)
157 nt.assert_equal(home_dir, unfrozen)
158
158
159
159
160 @with_environment
160 @with_environment
161 def test_get_home_dir_3():
161 def test_get_home_dir_3():
162 """get_home_dir() uses $HOME if set"""
162 """get_home_dir() uses $HOME if set"""
163 env["HOME"] = HOME_TEST_DIR
163 env["HOME"] = HOME_TEST_DIR
164 home_dir = path.get_home_dir(True)
164 home_dir = path.get_home_dir(True)
165 # get_home_dir expands symlinks
165 # get_home_dir expands symlinks
166 nt.assert_equal(home_dir, os.path.realpath(env["HOME"]))
166 nt.assert_equal(home_dir, os.path.realpath(env["HOME"]))
167
167
168
168
169 @with_environment
169 @with_environment
170 def test_get_home_dir_4():
170 def test_get_home_dir_4():
171 """get_home_dir() still works if $HOME is not set"""
171 """get_home_dir() still works if $HOME is not set"""
172
172
173 if 'HOME' in env: del env['HOME']
173 if 'HOME' in env: del env['HOME']
174 # this should still succeed, but we don't care what the answer is
174 # this should still succeed, but we don't care what the answer is
175 home = path.get_home_dir(False)
175 home = path.get_home_dir(False)
176
176
177 @with_environment
177 @with_environment
178 def test_get_home_dir_5():
178 def test_get_home_dir_5():
179 """raise HomeDirError if $HOME is specified, but not a writable dir"""
179 """raise HomeDirError if $HOME is specified, but not a writable dir"""
180 env['HOME'] = abspath(HOME_TEST_DIR+'garbage')
180 env['HOME'] = abspath(HOME_TEST_DIR+'garbage')
181 # set os.name = posix, to prevent My Documents fallback on Windows
181 # set os.name = posix, to prevent My Documents fallback on Windows
182 os.name = 'posix'
182 os.name = 'posix'
183 nt.assert_raises(path.HomeDirError, path.get_home_dir, True)
183 nt.assert_raises(path.HomeDirError, path.get_home_dir, True)
184
184
185
185
186 # Should we stub wreg fully so we can run the test on all platforms?
186 # Should we stub wreg fully so we can run the test on all platforms?
187 @skip_if_not_win32
187 @skip_if_not_win32
188 @with_environment
188 @with_environment
189 def test_get_home_dir_8():
189 def test_get_home_dir_8():
190 """Using registry hack for 'My Documents', os=='nt'
190 """Using registry hack for 'My Documents', os=='nt'
191
191
192 HOMESHARE, HOMEDRIVE, HOMEPATH, USERPROFILE and others are missing.
192 HOMESHARE, HOMEDRIVE, HOMEPATH, USERPROFILE and others are missing.
193 """
193 """
194 os.name = 'nt'
194 os.name = 'nt'
195 # Remove from stub environment all keys that may be set
195 # Remove from stub environment all keys that may be set
196 for key in ['HOME', 'HOMESHARE', 'HOMEDRIVE', 'HOMEPATH', 'USERPROFILE']:
196 for key in ['HOME', 'HOMESHARE', 'HOMEDRIVE', 'HOMEPATH', 'USERPROFILE']:
197 env.pop(key, None)
197 env.pop(key, None)
198
198
199 #Stub windows registry functions
199 #Stub windows registry functions
200 def OpenKey(x, y):
200 def OpenKey(x, y):
201 class key:
201 class key:
202 def Close(self):
202 def Close(self):
203 pass
203 pass
204 return key()
204 return key()
205 def QueryValueEx(x, y):
205 def QueryValueEx(x, y):
206 return [abspath(HOME_TEST_DIR)]
206 return [abspath(HOME_TEST_DIR)]
207
207
208 wreg.OpenKey = OpenKey
208 wreg.OpenKey = OpenKey
209 wreg.QueryValueEx = QueryValueEx
209 wreg.QueryValueEx = QueryValueEx
210
210
211 home_dir = path.get_home_dir()
211 home_dir = path.get_home_dir()
212 nt.assert_equal(home_dir, abspath(HOME_TEST_DIR))
212 nt.assert_equal(home_dir, abspath(HOME_TEST_DIR))
213
213
214
214
215 @with_environment
215 @with_environment
216 def test_get_ipython_dir_1():
216 def test_get_ipython_dir_1():
217 """test_get_ipython_dir_1, Testcase to see if we can call get_ipython_dir without Exceptions."""
217 """test_get_ipython_dir_1, Testcase to see if we can call get_ipython_dir without Exceptions."""
218 env_ipdir = os.path.join("someplace", ".ipython")
218 env_ipdir = os.path.join("someplace", ".ipython")
219 path._writable_dir = lambda path: True
219 path._writable_dir = lambda path: True
220 env['IPYTHONDIR'] = env_ipdir
220 env['IPYTHONDIR'] = env_ipdir
221 ipdir = path.get_ipython_dir()
221 ipdir = path.get_ipython_dir()
222 nt.assert_equal(ipdir, env_ipdir)
222 nt.assert_equal(ipdir, env_ipdir)
223
223
224
224
225 @with_environment
225 @with_environment
226 def test_get_ipython_dir_2():
226 def test_get_ipython_dir_2():
227 """test_get_ipython_dir_2, Testcase to see if we can call get_ipython_dir without Exceptions."""
227 """test_get_ipython_dir_2, Testcase to see if we can call get_ipython_dir without Exceptions."""
228 path.get_home_dir = lambda : "someplace"
228 path.get_home_dir = lambda : "someplace"
229 path.get_xdg_dir = lambda : None
229 path.get_xdg_dir = lambda : None
230 path._writable_dir = lambda path: True
230 path._writable_dir = lambda path: True
231 os.name = "posix"
231 os.name = "posix"
232 env.pop('IPYTHON_DIR', None)
232 env.pop('IPYTHON_DIR', None)
233 env.pop('IPYTHONDIR', None)
233 env.pop('IPYTHONDIR', None)
234 env.pop('XDG_CONFIG_HOME', None)
234 env.pop('XDG_CONFIG_HOME', None)
235 ipdir = path.get_ipython_dir()
235 ipdir = path.get_ipython_dir()
236 nt.assert_equal(ipdir, os.path.join("someplace", ".ipython"))
236 nt.assert_equal(ipdir, os.path.join("someplace", ".ipython"))
237
237
238 @with_environment
238 @with_environment
239 def test_get_ipython_dir_3():
239 def test_get_ipython_dir_3():
240 """test_get_ipython_dir_3, use XDG if defined, and .ipython doesn't exist."""
240 """test_get_ipython_dir_3, use XDG if defined, and .ipython doesn't exist."""
241 path.get_home_dir = lambda : "someplace"
241 path.get_home_dir = lambda : "someplace"
242 path._writable_dir = lambda path: True
242 path._writable_dir = lambda path: True
243 os.name = "posix"
243 os.name = "posix"
244 env.pop('IPYTHON_DIR', None)
244 env.pop('IPYTHON_DIR', None)
245 env.pop('IPYTHONDIR', None)
245 env.pop('IPYTHONDIR', None)
246 env['XDG_CONFIG_HOME'] = XDG_TEST_DIR
246 env['XDG_CONFIG_HOME'] = XDG_TEST_DIR
247 ipdir = path.get_ipython_dir()
247 ipdir = path.get_ipython_dir()
248 if sys.platform == "darwin":
248 if sys.platform == "darwin":
249 expected = os.path.join("someplace", ".ipython")
249 expected = os.path.join("someplace", ".ipython")
250 else:
250 else:
251 expected = os.path.join(XDG_TEST_DIR, "ipython")
251 expected = os.path.join(XDG_TEST_DIR, "ipython")
252 nt.assert_equal(ipdir, expected)
252 nt.assert_equal(ipdir, expected)
253
253
254 @with_environment
254 @with_environment
255 def test_get_ipython_dir_4():
255 def test_get_ipython_dir_4():
256 """test_get_ipython_dir_4, use XDG if both exist."""
256 """test_get_ipython_dir_4, use XDG if both exist."""
257 path.get_home_dir = lambda : HOME_TEST_DIR
257 path.get_home_dir = lambda : HOME_TEST_DIR
258 os.name = "posix"
258 os.name = "posix"
259 env.pop('IPYTHON_DIR', None)
259 env.pop('IPYTHON_DIR', None)
260 env.pop('IPYTHONDIR', None)
260 env.pop('IPYTHONDIR', None)
261 env['XDG_CONFIG_HOME'] = XDG_TEST_DIR
261 env['XDG_CONFIG_HOME'] = XDG_TEST_DIR
262 ipdir = path.get_ipython_dir()
262 ipdir = path.get_ipython_dir()
263 if sys.platform == "darwin":
263 if sys.platform == "darwin":
264 expected = os.path.join(HOME_TEST_DIR, ".ipython")
264 expected = os.path.join(HOME_TEST_DIR, ".ipython")
265 else:
265 else:
266 expected = os.path.join(XDG_TEST_DIR, "ipython")
266 expected = os.path.join(XDG_TEST_DIR, "ipython")
267 nt.assert_equal(ipdir, expected)
267 nt.assert_equal(ipdir, expected)
268
268
269 @with_environment
269 @with_environment
270 def test_get_ipython_dir_5():
270 def test_get_ipython_dir_5():
271 """test_get_ipython_dir_5, use .ipython if exists and XDG defined, but doesn't exist."""
271 """test_get_ipython_dir_5, use .ipython if exists and XDG defined, but doesn't exist."""
272 path.get_home_dir = lambda : HOME_TEST_DIR
272 path.get_home_dir = lambda : HOME_TEST_DIR
273 os.name = "posix"
273 os.name = "posix"
274 env.pop('IPYTHON_DIR', None)
274 env.pop('IPYTHON_DIR', None)
275 env.pop('IPYTHONDIR', None)
275 env.pop('IPYTHONDIR', None)
276 env['XDG_CONFIG_HOME'] = XDG_TEST_DIR
276 env['XDG_CONFIG_HOME'] = XDG_TEST_DIR
277 os.rmdir(os.path.join(XDG_TEST_DIR, 'ipython'))
277 os.rmdir(os.path.join(XDG_TEST_DIR, 'ipython'))
278 ipdir = path.get_ipython_dir()
278 ipdir = path.get_ipython_dir()
279 nt.assert_equal(ipdir, IP_TEST_DIR)
279 nt.assert_equal(ipdir, IP_TEST_DIR)
280
280
281 @with_environment
281 @with_environment
282 def test_get_ipython_dir_6():
282 def test_get_ipython_dir_6():
283 """test_get_ipython_dir_6, use XDG if defined and neither exist."""
283 """test_get_ipython_dir_6, use XDG if defined and neither exist."""
284 xdg = os.path.join(HOME_TEST_DIR, 'somexdg')
284 xdg = os.path.join(HOME_TEST_DIR, 'somexdg')
285 os.mkdir(xdg)
285 os.mkdir(xdg)
286 shutil.rmtree(os.path.join(HOME_TEST_DIR, '.ipython'))
286 shutil.rmtree(os.path.join(HOME_TEST_DIR, '.ipython'))
287 path.get_home_dir = lambda : HOME_TEST_DIR
287 path.get_home_dir = lambda : HOME_TEST_DIR
288 path.get_xdg_dir = lambda : xdg
288 path.get_xdg_dir = lambda : xdg
289 os.name = "posix"
289 os.name = "posix"
290 env.pop('IPYTHON_DIR', None)
290 env.pop('IPYTHON_DIR', None)
291 env.pop('IPYTHONDIR', None)
291 env.pop('IPYTHONDIR', None)
292 env.pop('XDG_CONFIG_HOME', None)
292 env.pop('XDG_CONFIG_HOME', None)
293 xdg_ipdir = os.path.join(xdg, "ipython")
293 xdg_ipdir = os.path.join(xdg, "ipython")
294 ipdir = path.get_ipython_dir()
294 ipdir = path.get_ipython_dir()
295 nt.assert_equal(ipdir, xdg_ipdir)
295 nt.assert_equal(ipdir, xdg_ipdir)
296
296
297 @with_environment
297 @with_environment
298 def test_get_ipython_dir_7():
298 def test_get_ipython_dir_7():
299 """test_get_ipython_dir_7, test home directory expansion on IPYTHONDIR"""
299 """test_get_ipython_dir_7, test home directory expansion on IPYTHONDIR"""
300 path._writable_dir = lambda path: True
300 path._writable_dir = lambda path: True
301 home_dir = os.path.normpath(os.path.expanduser('~'))
301 home_dir = os.path.normpath(os.path.expanduser('~'))
302 env['IPYTHONDIR'] = os.path.join('~', 'somewhere')
302 env['IPYTHONDIR'] = os.path.join('~', 'somewhere')
303 ipdir = path.get_ipython_dir()
303 ipdir = path.get_ipython_dir()
304 nt.assert_equal(ipdir, os.path.join(home_dir, 'somewhere'))
304 nt.assert_equal(ipdir, os.path.join(home_dir, 'somewhere'))
305
305
306 @skip_win32
306 @skip_win32
307 @with_environment
307 @with_environment
308 def test_get_ipython_dir_8():
308 def test_get_ipython_dir_8():
309 """test_get_ipython_dir_8, test / home directory"""
309 """test_get_ipython_dir_8, test / home directory"""
310 old = path._writable_dir, path.get_xdg_dir
310 old = path._writable_dir, path.get_xdg_dir
311 try:
311 try:
312 path._writable_dir = lambda path: bool(path)
312 path._writable_dir = lambda path: bool(path)
313 path.get_xdg_dir = lambda: None
313 path.get_xdg_dir = lambda: None
314 env.pop('IPYTHON_DIR', None)
314 env.pop('IPYTHON_DIR', None)
315 env.pop('IPYTHONDIR', None)
315 env.pop('IPYTHONDIR', None)
316 env['HOME'] = '/'
316 env['HOME'] = '/'
317 nt.assert_equal(path.get_ipython_dir(), '/.ipython')
317 nt.assert_equal(path.get_ipython_dir(), '/.ipython')
318 finally:
318 finally:
319 path._writable_dir, path.get_xdg_dir = old
319 path._writable_dir, path.get_xdg_dir = old
320
320
321 @with_environment
321 @with_environment
322 def test_get_xdg_dir_0():
322 def test_get_xdg_dir_0():
323 """test_get_xdg_dir_0, check xdg_dir"""
323 """test_get_xdg_dir_0, check xdg_dir"""
324 reload(path)
324 reload(path)
325 path._writable_dir = lambda path: True
325 path._writable_dir = lambda path: True
326 path.get_home_dir = lambda : 'somewhere'
326 path.get_home_dir = lambda : 'somewhere'
327 os.name = "posix"
327 os.name = "posix"
328 sys.platform = "linux2"
328 sys.platform = "linux2"
329 env.pop('IPYTHON_DIR', None)
329 env.pop('IPYTHON_DIR', None)
330 env.pop('IPYTHONDIR', None)
330 env.pop('IPYTHONDIR', None)
331 env.pop('XDG_CONFIG_HOME', None)
331 env.pop('XDG_CONFIG_HOME', None)
332
332
333 nt.assert_equal(path.get_xdg_dir(), os.path.join('somewhere', '.config'))
333 nt.assert_equal(path.get_xdg_dir(), os.path.join('somewhere', '.config'))
334
334
335
335
336 @with_environment
336 @with_environment
337 def test_get_xdg_dir_1():
337 def test_get_xdg_dir_1():
338 """test_get_xdg_dir_1, check nonexistant xdg_dir"""
338 """test_get_xdg_dir_1, check nonexistant xdg_dir"""
339 reload(path)
339 reload(path)
340 path.get_home_dir = lambda : HOME_TEST_DIR
340 path.get_home_dir = lambda : HOME_TEST_DIR
341 os.name = "posix"
341 os.name = "posix"
342 sys.platform = "linux2"
342 sys.platform = "linux2"
343 env.pop('IPYTHON_DIR', None)
343 env.pop('IPYTHON_DIR', None)
344 env.pop('IPYTHONDIR', None)
344 env.pop('IPYTHONDIR', None)
345 env.pop('XDG_CONFIG_HOME', None)
345 env.pop('XDG_CONFIG_HOME', None)
346 nt.assert_equal(path.get_xdg_dir(), None)
346 nt.assert_equal(path.get_xdg_dir(), None)
347
347
348 @with_environment
348 @with_environment
349 def test_get_xdg_dir_2():
349 def test_get_xdg_dir_2():
350 """test_get_xdg_dir_2, check xdg_dir default to ~/.config"""
350 """test_get_xdg_dir_2, check xdg_dir default to ~/.config"""
351 reload(path)
351 reload(path)
352 path.get_home_dir = lambda : HOME_TEST_DIR
352 path.get_home_dir = lambda : HOME_TEST_DIR
353 os.name = "posix"
353 os.name = "posix"
354 sys.platform = "linux2"
354 sys.platform = "linux2"
355 env.pop('IPYTHON_DIR', None)
355 env.pop('IPYTHON_DIR', None)
356 env.pop('IPYTHONDIR', None)
356 env.pop('IPYTHONDIR', None)
357 env.pop('XDG_CONFIG_HOME', None)
357 env.pop('XDG_CONFIG_HOME', None)
358 cfgdir=os.path.join(path.get_home_dir(), '.config')
358 cfgdir=os.path.join(path.get_home_dir(), '.config')
359 if not os.path.exists(cfgdir):
359 if not os.path.exists(cfgdir):
360 os.makedirs(cfgdir)
360 os.makedirs(cfgdir)
361
361
362 nt.assert_equal(path.get_xdg_dir(), cfgdir)
362 nt.assert_equal(path.get_xdg_dir(), cfgdir)
363
363
364 @with_environment
364 @with_environment
365 def test_get_xdg_dir_3():
365 def test_get_xdg_dir_3():
366 """test_get_xdg_dir_3, check xdg_dir not used on OS X"""
366 """test_get_xdg_dir_3, check xdg_dir not used on OS X"""
367 reload(path)
367 reload(path)
368 path.get_home_dir = lambda : HOME_TEST_DIR
368 path.get_home_dir = lambda : HOME_TEST_DIR
369 os.name = "posix"
369 os.name = "posix"
370 sys.platform = "darwin"
370 sys.platform = "darwin"
371 env.pop('IPYTHON_DIR', None)
371 env.pop('IPYTHON_DIR', None)
372 env.pop('IPYTHONDIR', None)
372 env.pop('IPYTHONDIR', None)
373 env.pop('XDG_CONFIG_HOME', None)
373 env.pop('XDG_CONFIG_HOME', None)
374 cfgdir=os.path.join(path.get_home_dir(), '.config')
374 cfgdir=os.path.join(path.get_home_dir(), '.config')
375 if not os.path.exists(cfgdir):
375 if not os.path.exists(cfgdir):
376 os.makedirs(cfgdir)
376 os.makedirs(cfgdir)
377
377
378 nt.assert_equal(path.get_xdg_dir(), None)
378 nt.assert_equal(path.get_xdg_dir(), None)
379
379
380 def test_filefind():
380 def test_filefind():
381 """Various tests for filefind"""
381 """Various tests for filefind"""
382 f = tempfile.NamedTemporaryFile()
382 f = tempfile.NamedTemporaryFile()
383 # print 'fname:',f.name
383 # print 'fname:',f.name
384 alt_dirs = path.get_ipython_dir()
384 alt_dirs = path.get_ipython_dir()
385 t = path.filefind(f.name, alt_dirs)
385 t = path.filefind(f.name, alt_dirs)
386 # print 'found:',t
386 # print 'found:',t
387
387
388 @with_environment
388 @with_environment
389 def test_get_ipython_cache_dir():
389 def test_get_ipython_cache_dir():
390 os.environ["HOME"] = HOME_TEST_DIR
390 os.environ["HOME"] = HOME_TEST_DIR
391 if os.name == 'posix' and sys.platform != 'darwin':
391 if os.name == 'posix' and sys.platform != 'darwin':
392 # test default
392 # test default
393 os.makedirs(os.path.join(HOME_TEST_DIR, ".cache"))
393 os.makedirs(os.path.join(HOME_TEST_DIR, ".cache"))
394 os.environ.pop("XDG_CACHE_HOME", None)
394 os.environ.pop("XDG_CACHE_HOME", None)
395 ipdir = path.get_ipython_cache_dir()
395 ipdir = path.get_ipython_cache_dir()
396 nt.assert_equal(os.path.join(HOME_TEST_DIR, ".cache", "ipython"),
396 nt.assert_equal(os.path.join(HOME_TEST_DIR, ".cache", "ipython"),
397 ipdir)
397 ipdir)
398 nt.assert_true(os.path.isdir(ipdir))
398 nt.assert_true(os.path.isdir(ipdir))
399
399
400 # test env override
400 # test env override
401 os.environ["XDG_CACHE_HOME"] = XDG_CACHE_DIR
401 os.environ["XDG_CACHE_HOME"] = XDG_CACHE_DIR
402 ipdir = path.get_ipython_cache_dir()
402 ipdir = path.get_ipython_cache_dir()
403 nt.assert_true(os.path.isdir(ipdir))
403 nt.assert_true(os.path.isdir(ipdir))
404 nt.assert_equal(ipdir, os.path.join(XDG_CACHE_DIR, "ipython"))
404 nt.assert_equal(ipdir, os.path.join(XDG_CACHE_DIR, "ipython"))
405 else:
405 else:
406 nt.assert_equal(path.get_ipython_cache_dir(),
406 nt.assert_equal(path.get_ipython_cache_dir(),
407 path.get_ipython_dir())
407 path.get_ipython_dir())
408
408
409 def test_get_ipython_package_dir():
409 def test_get_ipython_package_dir():
410 ipdir = path.get_ipython_package_dir()
410 ipdir = path.get_ipython_package_dir()
411 nt.assert_true(os.path.isdir(ipdir))
411 nt.assert_true(os.path.isdir(ipdir))
412
412
413
413
414 def test_get_ipython_module_path():
414 def test_get_ipython_module_path():
415 ipapp_path = path.get_ipython_module_path('IPython.terminal.ipapp')
415 ipapp_path = path.get_ipython_module_path('IPython.terminal.ipapp')
416 nt.assert_true(os.path.isfile(ipapp_path))
416 nt.assert_true(os.path.isfile(ipapp_path))
417
417
418
418
419 @dec.skip_if_not_win32
419 @dec.skip_if_not_win32
420 def test_get_long_path_name_win32():
420 def test_get_long_path_name_win32():
421 with TemporaryDirectory() as tmpdir:
421 with TemporaryDirectory() as tmpdir:
422
422
423 # Make a long path.
423 # Make a long path.
424 long_path = os.path.join(tmpdir, u'this is my long path name')
424 long_path = os.path.join(tmpdir, u'this is my long path name')
425 os.makedirs(long_path)
425 os.makedirs(long_path)
426
426
427 # Test to see if the short path evaluates correctly.
427 # Test to see if the short path evaluates correctly.
428 short_path = os.path.join(tmpdir, u'THISIS~1')
428 short_path = os.path.join(tmpdir, u'THISIS~1')
429 evaluated_path = path.get_long_path_name(short_path)
429 evaluated_path = path.get_long_path_name(short_path)
430 nt.assert_equal(evaluated_path.lower(), long_path.lower())
430 nt.assert_equal(evaluated_path.lower(), long_path.lower())
431
431
432
432
433 @dec.skip_win32
433 @dec.skip_win32
434 def test_get_long_path_name():
434 def test_get_long_path_name():
435 p = path.get_long_path_name('/usr/local')
435 p = path.get_long_path_name('/usr/local')
436 nt.assert_equal(p,'/usr/local')
436 nt.assert_equal(p,'/usr/local')
437
437
438 @dec.skip_win32 # can't create not-user-writable dir on win
438 @dec.skip_win32 # can't create not-user-writable dir on win
439 @with_environment
439 @with_environment
440 def test_not_writable_ipdir():
440 def test_not_writable_ipdir():
441 tmpdir = tempfile.mkdtemp()
441 tmpdir = tempfile.mkdtemp()
442 os.name = "posix"
442 os.name = "posix"
443 env.pop('IPYTHON_DIR', None)
443 env.pop('IPYTHON_DIR', None)
444 env.pop('IPYTHONDIR', None)
444 env.pop('IPYTHONDIR', None)
445 env.pop('XDG_CONFIG_HOME', None)
445 env.pop('XDG_CONFIG_HOME', None)
446 env['HOME'] = tmpdir
446 env['HOME'] = tmpdir
447 ipdir = os.path.join(tmpdir, '.ipython')
447 ipdir = os.path.join(tmpdir, '.ipython')
448 os.mkdir(ipdir)
448 os.mkdir(ipdir)
449 os.chmod(ipdir, 600)
449 os.chmod(ipdir, 600)
450 with AssertPrints('is not a writable location', channel='stderr'):
450 with AssertPrints('is not a writable location', channel='stderr'):
451 ipdir = path.get_ipython_dir()
451 ipdir = path.get_ipython_dir()
452 env.pop('IPYTHON_DIR', None)
452 env.pop('IPYTHON_DIR', None)
453
453
454 def test_unquote_filename():
454 def test_unquote_filename():
455 for win32 in (True, False):
455 for win32 in (True, False):
456 nt.assert_equal(path.unquote_filename('foo.py', win32=win32), 'foo.py')
456 nt.assert_equal(path.unquote_filename('foo.py', win32=win32), 'foo.py')
457 nt.assert_equal(path.unquote_filename('foo bar.py', win32=win32), 'foo bar.py')
457 nt.assert_equal(path.unquote_filename('foo bar.py', win32=win32), 'foo bar.py')
458 nt.assert_equal(path.unquote_filename('"foo.py"', win32=True), 'foo.py')
458 nt.assert_equal(path.unquote_filename('"foo.py"', win32=True), 'foo.py')
459 nt.assert_equal(path.unquote_filename('"foo bar.py"', win32=True), 'foo bar.py')
459 nt.assert_equal(path.unquote_filename('"foo bar.py"', win32=True), 'foo bar.py')
460 nt.assert_equal(path.unquote_filename("'foo.py'", win32=True), 'foo.py')
460 nt.assert_equal(path.unquote_filename("'foo.py'", win32=True), 'foo.py')
461 nt.assert_equal(path.unquote_filename("'foo bar.py'", win32=True), 'foo bar.py')
461 nt.assert_equal(path.unquote_filename("'foo bar.py'", win32=True), 'foo bar.py')
462 nt.assert_equal(path.unquote_filename('"foo.py"', win32=False), '"foo.py"')
462 nt.assert_equal(path.unquote_filename('"foo.py"', win32=False), '"foo.py"')
463 nt.assert_equal(path.unquote_filename('"foo bar.py"', win32=False), '"foo bar.py"')
463 nt.assert_equal(path.unquote_filename('"foo bar.py"', win32=False), '"foo bar.py"')
464 nt.assert_equal(path.unquote_filename("'foo.py'", win32=False), "'foo.py'")
464 nt.assert_equal(path.unquote_filename("'foo.py'", win32=False), "'foo.py'")
465 nt.assert_equal(path.unquote_filename("'foo bar.py'", win32=False), "'foo bar.py'")
465 nt.assert_equal(path.unquote_filename("'foo bar.py'", win32=False), "'foo bar.py'")
466
466
467 @with_environment
467 @with_environment
468 def test_get_py_filename():
468 def test_get_py_filename():
469 os.chdir(TMP_TEST_DIR)
469 os.chdir(TMP_TEST_DIR)
470 for win32 in (True, False):
470 for win32 in (True, False):
471 with make_tempfile('foo.py'):
471 with make_tempfile('foo.py'):
472 nt.assert_equal(path.get_py_filename('foo.py', force_win32=win32), 'foo.py')
472 nt.assert_equal(path.get_py_filename('foo.py', force_win32=win32), 'foo.py')
473 nt.assert_equal(path.get_py_filename('foo', force_win32=win32), 'foo.py')
473 nt.assert_equal(path.get_py_filename('foo', force_win32=win32), 'foo.py')
474 with make_tempfile('foo'):
474 with make_tempfile('foo'):
475 nt.assert_equal(path.get_py_filename('foo', force_win32=win32), 'foo')
475 nt.assert_equal(path.get_py_filename('foo', force_win32=win32), 'foo')
476 nt.assert_raises(IOError, path.get_py_filename, 'foo.py', force_win32=win32)
476 nt.assert_raises(IOError, path.get_py_filename, 'foo.py', force_win32=win32)
477 nt.assert_raises(IOError, path.get_py_filename, 'foo', force_win32=win32)
477 nt.assert_raises(IOError, path.get_py_filename, 'foo', force_win32=win32)
478 nt.assert_raises(IOError, path.get_py_filename, 'foo.py', force_win32=win32)
478 nt.assert_raises(IOError, path.get_py_filename, 'foo.py', force_win32=win32)
479 true_fn = 'foo with spaces.py'
479 true_fn = 'foo with spaces.py'
480 with make_tempfile(true_fn):
480 with make_tempfile(true_fn):
481 nt.assert_equal(path.get_py_filename('foo with spaces', force_win32=win32), true_fn)
481 nt.assert_equal(path.get_py_filename('foo with spaces', force_win32=win32), true_fn)
482 nt.assert_equal(path.get_py_filename('foo with spaces.py', force_win32=win32), true_fn)
482 nt.assert_equal(path.get_py_filename('foo with spaces.py', force_win32=win32), true_fn)
483 if win32:
483 if win32:
484 nt.assert_equal(path.get_py_filename('"foo with spaces.py"', force_win32=True), true_fn)
484 nt.assert_equal(path.get_py_filename('"foo with spaces.py"', force_win32=True), true_fn)
485 nt.assert_equal(path.get_py_filename("'foo with spaces.py'", force_win32=True), true_fn)
485 nt.assert_equal(path.get_py_filename("'foo with spaces.py'", force_win32=True), true_fn)
486 else:
486 else:
487 nt.assert_raises(IOError, path.get_py_filename, '"foo with spaces.py"', force_win32=False)
487 nt.assert_raises(IOError, path.get_py_filename, '"foo with spaces.py"', force_win32=False)
488 nt.assert_raises(IOError, path.get_py_filename, "'foo with spaces.py'", force_win32=False)
488 nt.assert_raises(IOError, path.get_py_filename, "'foo with spaces.py'", force_win32=False)
489
489
490 @onlyif_unicode_paths
490 @onlyif_unicode_paths
491 def test_unicode_in_filename():
491 def test_unicode_in_filename():
492 """When a file doesn't exist, the exception raised should be safe to call
492 """When a file doesn't exist, the exception raised should be safe to call
493 str() on - i.e. in Python 2 it must only have ASCII characters.
493 str() on - i.e. in Python 2 it must only have ASCII characters.
494
494
495 https://github.com/ipython/ipython/issues/875
495 https://github.com/ipython/ipython/issues/875
496 """
496 """
497 try:
497 try:
498 # these calls should not throw unicode encode exceptions
498 # these calls should not throw unicode encode exceptions
499 path.get_py_filename(u'fooéè.py', force_win32=False)
499 path.get_py_filename(u'fooéè.py', force_win32=False)
500 except IOError as ex:
500 except IOError as ex:
501 str(ex)
501 str(ex)
502
502
503
503
504 class TestShellGlob(object):
504 class TestShellGlob(object):
505
505
506 @classmethod
506 @classmethod
507 def setUpClass(cls):
507 def setUpClass(cls):
508 cls.filenames_start_with_a = map('a{0}'.format, range(3))
508 cls.filenames_start_with_a = ['a0', 'a1', 'a2']
509 cls.filenames_end_with_b = map('{0}b'.format, range(3))
509 cls.filenames_end_with_b = ['0b', '1b', '2b']
510 cls.filenames = cls.filenames_start_with_a + cls.filenames_end_with_b
510 cls.filenames = cls.filenames_start_with_a + cls.filenames_end_with_b
511 cls.tempdir = TemporaryDirectory()
511 cls.tempdir = TemporaryDirectory()
512 td = cls.tempdir.name
512 td = cls.tempdir.name
513
513
514 with cls.in_tempdir():
514 with cls.in_tempdir():
515 # Create empty files
515 # Create empty files
516 for fname in cls.filenames:
516 for fname in cls.filenames:
517 open(os.path.join(td, fname), 'w').close()
517 open(os.path.join(td, fname), 'w').close()
518
518
519 @classmethod
519 @classmethod
520 def tearDownClass(cls):
520 def tearDownClass(cls):
521 cls.tempdir.cleanup()
521 cls.tempdir.cleanup()
522
522
523 @classmethod
523 @classmethod
524 @contextmanager
524 @contextmanager
525 def in_tempdir(cls):
525 def in_tempdir(cls):
526 save = os.getcwdu()
526 save = os.getcwdu()
527 try:
527 try:
528 os.chdir(cls.tempdir.name)
528 os.chdir(cls.tempdir.name)
529 yield
529 yield
530 finally:
530 finally:
531 os.chdir(save)
531 os.chdir(save)
532
532
533 def check_match(self, patterns, matches):
533 def check_match(self, patterns, matches):
534 with self.in_tempdir():
534 with self.in_tempdir():
535 # glob returns unordered list. that's why sorted is required.
535 # glob returns unordered list. that's why sorted is required.
536 nt.assert_equals(sorted(path.shellglob(patterns)),
536 nt.assert_equals(sorted(path.shellglob(patterns)),
537 sorted(matches))
537 sorted(matches))
538
538
539 def common_cases(self):
539 def common_cases(self):
540 return [
540 return [
541 (['*'], self.filenames),
541 (['*'], self.filenames),
542 (['a*'], self.filenames_start_with_a),
542 (['a*'], self.filenames_start_with_a),
543 (['*c'], ['*c']),
543 (['*c'], ['*c']),
544 (['*', 'a*', '*b', '*c'], self.filenames
544 (['*', 'a*', '*b', '*c'], self.filenames
545 + self.filenames_start_with_a
545 + self.filenames_start_with_a
546 + self.filenames_end_with_b
546 + self.filenames_end_with_b
547 + ['*c']),
547 + ['*c']),
548 (['a[012]'], self.filenames_start_with_a),
548 (['a[012]'], self.filenames_start_with_a),
549 ]
549 ]
550
550
551 @skip_win32
551 @skip_win32
552 def test_match_posix(self):
552 def test_match_posix(self):
553 for (patterns, matches) in self.common_cases() + [
553 for (patterns, matches) in self.common_cases() + [
554 ([r'\*'], ['*']),
554 ([r'\*'], ['*']),
555 ([r'a\*', 'a*'], ['a*'] + self.filenames_start_with_a),
555 ([r'a\*', 'a*'], ['a*'] + self.filenames_start_with_a),
556 ([r'a\[012]'], ['a[012]']),
556 ([r'a\[012]'], ['a[012]']),
557 ]:
557 ]:
558 yield (self.check_match, patterns, matches)
558 yield (self.check_match, patterns, matches)
559
559
560 @skip_if_not_win32
560 @skip_if_not_win32
561 def test_match_windows(self):
561 def test_match_windows(self):
562 for (patterns, matches) in self.common_cases() + [
562 for (patterns, matches) in self.common_cases() + [
563 # In windows, backslash is interpreted as path
563 # In windows, backslash is interpreted as path
564 # separator. Therefore, you can't escape glob
564 # separator. Therefore, you can't escape glob
565 # using it.
565 # using it.
566 ([r'a\*', 'a*'], [r'a\*'] + self.filenames_start_with_a),
566 ([r'a\*', 'a*'], [r'a\*'] + self.filenames_start_with_a),
567 ([r'a\[012]'], [r'a\[012]']),
567 ([r'a\[012]'], [r'a\[012]']),
568 ]:
568 ]:
569 yield (self.check_match, patterns, matches)
569 yield (self.check_match, patterns, matches)
570
570
571
571
572 def test_unescape_glob():
572 def test_unescape_glob():
573 nt.assert_equals(path.unescape_glob(r'\*\[\!\]\?'), '*[!]?')
573 nt.assert_equals(path.unescape_glob(r'\*\[\!\]\?'), '*[!]?')
574 nt.assert_equals(path.unescape_glob(r'\\*'), r'\*')
574 nt.assert_equals(path.unescape_glob(r'\\*'), r'\*')
575 nt.assert_equals(path.unescape_glob(r'\\\*'), r'\*')
575 nt.assert_equals(path.unescape_glob(r'\\\*'), r'\*')
576 nt.assert_equals(path.unescape_glob(r'\\a'), r'\a')
576 nt.assert_equals(path.unescape_glob(r'\\a'), r'\a')
577 nt.assert_equals(path.unescape_glob(r'\a'), r'\a')
577 nt.assert_equals(path.unescape_glob(r'\a'), r'\a')
578
578
579
579
580 class TestLinkOrCopy(object):
580 class TestLinkOrCopy(object):
581 def setUp(self):
581 def setUp(self):
582 self.tempdir = TemporaryDirectory()
582 self.tempdir = TemporaryDirectory()
583 self.src = self.dst("src")
583 self.src = self.dst("src")
584 with open(self.src, "w") as f:
584 with open(self.src, "w") as f:
585 f.write("Hello, world!")
585 f.write("Hello, world!")
586
586
587 def tearDown(self):
587 def tearDown(self):
588 self.tempdir.cleanup()
588 self.tempdir.cleanup()
589
589
590 def dst(self, *args):
590 def dst(self, *args):
591 return os.path.join(self.tempdir.name, *args)
591 return os.path.join(self.tempdir.name, *args)
592
592
593 def assert_inode_not_equal(self, a, b):
593 def assert_inode_not_equal(self, a, b):
594 nt.assert_not_equals(os.stat(a).st_ino, os.stat(b).st_ino,
594 nt.assert_not_equals(os.stat(a).st_ino, os.stat(b).st_ino,
595 "%r and %r do reference the same indoes" %(a, b))
595 "%r and %r do reference the same indoes" %(a, b))
596
596
597 def assert_inode_equal(self, a, b):
597 def assert_inode_equal(self, a, b):
598 nt.assert_equals(os.stat(a).st_ino, os.stat(b).st_ino,
598 nt.assert_equals(os.stat(a).st_ino, os.stat(b).st_ino,
599 "%r and %r do not reference the same indoes" %(a, b))
599 "%r and %r do not reference the same indoes" %(a, b))
600
600
601 def assert_content_equal(self, a, b):
601 def assert_content_equal(self, a, b):
602 with open(a) as a_f:
602 with open(a) as a_f:
603 with open(b) as b_f:
603 with open(b) as b_f:
604 nt.assert_equals(a_f.read(), b_f.read())
604 nt.assert_equals(a_f.read(), b_f.read())
605
605
606 @skip_win32
606 @skip_win32
607 def test_link_successful(self):
607 def test_link_successful(self):
608 dst = self.dst("target")
608 dst = self.dst("target")
609 path.link_or_copy(self.src, dst)
609 path.link_or_copy(self.src, dst)
610 self.assert_inode_equal(self.src, dst)
610 self.assert_inode_equal(self.src, dst)
611
611
612 @skip_win32
612 @skip_win32
613 def test_link_into_dir(self):
613 def test_link_into_dir(self):
614 dst = self.dst("some_dir")
614 dst = self.dst("some_dir")
615 os.mkdir(dst)
615 os.mkdir(dst)
616 path.link_or_copy(self.src, dst)
616 path.link_or_copy(self.src, dst)
617 expected_dst = self.dst("some_dir", os.path.basename(self.src))
617 expected_dst = self.dst("some_dir", os.path.basename(self.src))
618 self.assert_inode_equal(self.src, expected_dst)
618 self.assert_inode_equal(self.src, expected_dst)
619
619
620 @skip_win32
620 @skip_win32
621 def test_target_exists(self):
621 def test_target_exists(self):
622 dst = self.dst("target")
622 dst = self.dst("target")
623 open(dst, "w").close()
623 open(dst, "w").close()
624 path.link_or_copy(self.src, dst)
624 path.link_or_copy(self.src, dst)
625 self.assert_inode_equal(self.src, dst)
625 self.assert_inode_equal(self.src, dst)
626
626
627 @skip_win32
627 @skip_win32
628 def test_no_link(self):
628 def test_no_link(self):
629 real_link = os.link
629 real_link = os.link
630 try:
630 try:
631 del os.link
631 del os.link
632 dst = self.dst("target")
632 dst = self.dst("target")
633 path.link_or_copy(self.src, dst)
633 path.link_or_copy(self.src, dst)
634 self.assert_content_equal(self.src, dst)
634 self.assert_content_equal(self.src, dst)
635 self.assert_inode_not_equal(self.src, dst)
635 self.assert_inode_not_equal(self.src, dst)
636 finally:
636 finally:
637 os.link = real_link
637 os.link = real_link
638
638
639 @skip_if_not_win32
639 @skip_if_not_win32
640 def test_windows(self):
640 def test_windows(self):
641 dst = self.dst("target")
641 dst = self.dst("target")
642 path.link_or_copy(self.src, dst)
642 path.link_or_copy(self.src, dst)
643 self.assert_content_equal(self.src, dst)
643 self.assert_content_equal(self.src, dst)
@@ -1,975 +1,975 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """
2 """
3 Tests for IPython.utils.traitlets.
3 Tests for IPython.utils.traitlets.
4
4
5 Authors:
5 Authors:
6
6
7 * Brian Granger
7 * Brian Granger
8 * Enthought, Inc. Some of the code in this file comes from enthought.traits
8 * Enthought, Inc. Some of the code in this file comes from enthought.traits
9 and is licensed under the BSD license. Also, many of the ideas also come
9 and is licensed under the BSD license. Also, many of the ideas also come
10 from enthought.traits even though our implementation is very different.
10 from enthought.traits even though our implementation is very different.
11 """
11 """
12
12
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14 # Copyright (C) 2008-2011 The IPython Development Team
14 # Copyright (C) 2008-2011 The IPython Development Team
15 #
15 #
16 # Distributed under the terms of the BSD License. The full license is in
16 # Distributed under the terms of the BSD License. The full license is in
17 # the file COPYING, distributed as part of this software.
17 # the file COPYING, distributed as part of this software.
18 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
19
19
20 #-----------------------------------------------------------------------------
20 #-----------------------------------------------------------------------------
21 # Imports
21 # Imports
22 #-----------------------------------------------------------------------------
22 #-----------------------------------------------------------------------------
23
23
24 import re
24 import re
25 import sys
25 import sys
26 from unittest import TestCase
26 from unittest import TestCase
27
27
28 import nose.tools as nt
28 import nose.tools as nt
29 from nose import SkipTest
29 from nose import SkipTest
30
30
31 from IPython.utils.traitlets import (
31 from IPython.utils.traitlets import (
32 HasTraits, MetaHasTraits, TraitType, Any, CBytes, Dict,
32 HasTraits, MetaHasTraits, TraitType, Any, CBytes, Dict,
33 Int, Long, Integer, Float, Complex, Bytes, Unicode, TraitError,
33 Int, Long, Integer, Float, Complex, Bytes, Unicode, TraitError,
34 Undefined, Type, This, Instance, TCPAddress, List, Tuple,
34 Undefined, Type, This, Instance, TCPAddress, List, Tuple,
35 ObjectName, DottedObjectName, CRegExp
35 ObjectName, DottedObjectName, CRegExp
36 )
36 )
37 from IPython.utils import py3compat
37 from IPython.utils import py3compat
38 from IPython.testing.decorators import skipif
38 from IPython.testing.decorators import skipif
39
39
40 #-----------------------------------------------------------------------------
40 #-----------------------------------------------------------------------------
41 # Helper classes for testing
41 # Helper classes for testing
42 #-----------------------------------------------------------------------------
42 #-----------------------------------------------------------------------------
43
43
44
44
45 class HasTraitsStub(HasTraits):
45 class HasTraitsStub(HasTraits):
46
46
47 def _notify_trait(self, name, old, new):
47 def _notify_trait(self, name, old, new):
48 self._notify_name = name
48 self._notify_name = name
49 self._notify_old = old
49 self._notify_old = old
50 self._notify_new = new
50 self._notify_new = new
51
51
52
52
53 #-----------------------------------------------------------------------------
53 #-----------------------------------------------------------------------------
54 # Test classes
54 # Test classes
55 #-----------------------------------------------------------------------------
55 #-----------------------------------------------------------------------------
56
56
57
57
58 class TestTraitType(TestCase):
58 class TestTraitType(TestCase):
59
59
60 def test_get_undefined(self):
60 def test_get_undefined(self):
61 class A(HasTraits):
61 class A(HasTraits):
62 a = TraitType
62 a = TraitType
63 a = A()
63 a = A()
64 self.assertEqual(a.a, Undefined)
64 self.assertEqual(a.a, Undefined)
65
65
66 def test_set(self):
66 def test_set(self):
67 class A(HasTraitsStub):
67 class A(HasTraitsStub):
68 a = TraitType
68 a = TraitType
69
69
70 a = A()
70 a = A()
71 a.a = 10
71 a.a = 10
72 self.assertEqual(a.a, 10)
72 self.assertEqual(a.a, 10)
73 self.assertEqual(a._notify_name, 'a')
73 self.assertEqual(a._notify_name, 'a')
74 self.assertEqual(a._notify_old, Undefined)
74 self.assertEqual(a._notify_old, Undefined)
75 self.assertEqual(a._notify_new, 10)
75 self.assertEqual(a._notify_new, 10)
76
76
77 def test_validate(self):
77 def test_validate(self):
78 class MyTT(TraitType):
78 class MyTT(TraitType):
79 def validate(self, inst, value):
79 def validate(self, inst, value):
80 return -1
80 return -1
81 class A(HasTraitsStub):
81 class A(HasTraitsStub):
82 tt = MyTT
82 tt = MyTT
83
83
84 a = A()
84 a = A()
85 a.tt = 10
85 a.tt = 10
86 self.assertEqual(a.tt, -1)
86 self.assertEqual(a.tt, -1)
87
87
88 def test_default_validate(self):
88 def test_default_validate(self):
89 class MyIntTT(TraitType):
89 class MyIntTT(TraitType):
90 def validate(self, obj, value):
90 def validate(self, obj, value):
91 if isinstance(value, int):
91 if isinstance(value, int):
92 return value
92 return value
93 self.error(obj, value)
93 self.error(obj, value)
94 class A(HasTraits):
94 class A(HasTraits):
95 tt = MyIntTT(10)
95 tt = MyIntTT(10)
96 a = A()
96 a = A()
97 self.assertEqual(a.tt, 10)
97 self.assertEqual(a.tt, 10)
98
98
99 # Defaults are validated when the HasTraits is instantiated
99 # Defaults are validated when the HasTraits is instantiated
100 class B(HasTraits):
100 class B(HasTraits):
101 tt = MyIntTT('bad default')
101 tt = MyIntTT('bad default')
102 self.assertRaises(TraitError, B)
102 self.assertRaises(TraitError, B)
103
103
104 def test_is_valid_for(self):
104 def test_is_valid_for(self):
105 class MyTT(TraitType):
105 class MyTT(TraitType):
106 def is_valid_for(self, value):
106 def is_valid_for(self, value):
107 return True
107 return True
108 class A(HasTraits):
108 class A(HasTraits):
109 tt = MyTT
109 tt = MyTT
110
110
111 a = A()
111 a = A()
112 a.tt = 10
112 a.tt = 10
113 self.assertEqual(a.tt, 10)
113 self.assertEqual(a.tt, 10)
114
114
115 def test_value_for(self):
115 def test_value_for(self):
116 class MyTT(TraitType):
116 class MyTT(TraitType):
117 def value_for(self, value):
117 def value_for(self, value):
118 return 20
118 return 20
119 class A(HasTraits):
119 class A(HasTraits):
120 tt = MyTT
120 tt = MyTT
121
121
122 a = A()
122 a = A()
123 a.tt = 10
123 a.tt = 10
124 self.assertEqual(a.tt, 20)
124 self.assertEqual(a.tt, 20)
125
125
126 def test_info(self):
126 def test_info(self):
127 class A(HasTraits):
127 class A(HasTraits):
128 tt = TraitType
128 tt = TraitType
129 a = A()
129 a = A()
130 self.assertEqual(A.tt.info(), 'any value')
130 self.assertEqual(A.tt.info(), 'any value')
131
131
132 def test_error(self):
132 def test_error(self):
133 class A(HasTraits):
133 class A(HasTraits):
134 tt = TraitType
134 tt = TraitType
135 a = A()
135 a = A()
136 self.assertRaises(TraitError, A.tt.error, a, 10)
136 self.assertRaises(TraitError, A.tt.error, a, 10)
137
137
138 def test_dynamic_initializer(self):
138 def test_dynamic_initializer(self):
139 class A(HasTraits):
139 class A(HasTraits):
140 x = Int(10)
140 x = Int(10)
141 def _x_default(self):
141 def _x_default(self):
142 return 11
142 return 11
143 class B(A):
143 class B(A):
144 x = Int(20)
144 x = Int(20)
145 class C(A):
145 class C(A):
146 def _x_default(self):
146 def _x_default(self):
147 return 21
147 return 21
148
148
149 a = A()
149 a = A()
150 self.assertEqual(a._trait_values, {})
150 self.assertEqual(a._trait_values, {})
151 self.assertEqual(a._trait_dyn_inits.keys(), ['x'])
151 self.assertEqual(list(a._trait_dyn_inits.keys()), ['x'])
152 self.assertEqual(a.x, 11)
152 self.assertEqual(a.x, 11)
153 self.assertEqual(a._trait_values, {'x': 11})
153 self.assertEqual(a._trait_values, {'x': 11})
154 b = B()
154 b = B()
155 self.assertEqual(b._trait_values, {'x': 20})
155 self.assertEqual(b._trait_values, {'x': 20})
156 self.assertEqual(a._trait_dyn_inits.keys(), ['x'])
156 self.assertEqual(list(a._trait_dyn_inits.keys()), ['x'])
157 self.assertEqual(b.x, 20)
157 self.assertEqual(b.x, 20)
158 c = C()
158 c = C()
159 self.assertEqual(c._trait_values, {})
159 self.assertEqual(c._trait_values, {})
160 self.assertEqual(a._trait_dyn_inits.keys(), ['x'])
160 self.assertEqual(list(a._trait_dyn_inits.keys()), ['x'])
161 self.assertEqual(c.x, 21)
161 self.assertEqual(c.x, 21)
162 self.assertEqual(c._trait_values, {'x': 21})
162 self.assertEqual(c._trait_values, {'x': 21})
163 # Ensure that the base class remains unmolested when the _default
163 # Ensure that the base class remains unmolested when the _default
164 # initializer gets overridden in a subclass.
164 # initializer gets overridden in a subclass.
165 a = A()
165 a = A()
166 c = C()
166 c = C()
167 self.assertEqual(a._trait_values, {})
167 self.assertEqual(a._trait_values, {})
168 self.assertEqual(a._trait_dyn_inits.keys(), ['x'])
168 self.assertEqual(list(a._trait_dyn_inits.keys()), ['x'])
169 self.assertEqual(a.x, 11)
169 self.assertEqual(a.x, 11)
170 self.assertEqual(a._trait_values, {'x': 11})
170 self.assertEqual(a._trait_values, {'x': 11})
171
171
172
172
173
173
174 class TestHasTraitsMeta(TestCase):
174 class TestHasTraitsMeta(TestCase):
175
175
176 def test_metaclass(self):
176 def test_metaclass(self):
177 self.assertEqual(type(HasTraits), MetaHasTraits)
177 self.assertEqual(type(HasTraits), MetaHasTraits)
178
178
179 class A(HasTraits):
179 class A(HasTraits):
180 a = Int
180 a = Int
181
181
182 a = A()
182 a = A()
183 self.assertEqual(type(a.__class__), MetaHasTraits)
183 self.assertEqual(type(a.__class__), MetaHasTraits)
184 self.assertEqual(a.a,0)
184 self.assertEqual(a.a,0)
185 a.a = 10
185 a.a = 10
186 self.assertEqual(a.a,10)
186 self.assertEqual(a.a,10)
187
187
188 class B(HasTraits):
188 class B(HasTraits):
189 b = Int()
189 b = Int()
190
190
191 b = B()
191 b = B()
192 self.assertEqual(b.b,0)
192 self.assertEqual(b.b,0)
193 b.b = 10
193 b.b = 10
194 self.assertEqual(b.b,10)
194 self.assertEqual(b.b,10)
195
195
196 class C(HasTraits):
196 class C(HasTraits):
197 c = Int(30)
197 c = Int(30)
198
198
199 c = C()
199 c = C()
200 self.assertEqual(c.c,30)
200 self.assertEqual(c.c,30)
201 c.c = 10
201 c.c = 10
202 self.assertEqual(c.c,10)
202 self.assertEqual(c.c,10)
203
203
204 def test_this_class(self):
204 def test_this_class(self):
205 class A(HasTraits):
205 class A(HasTraits):
206 t = This()
206 t = This()
207 tt = This()
207 tt = This()
208 class B(A):
208 class B(A):
209 tt = This()
209 tt = This()
210 ttt = This()
210 ttt = This()
211 self.assertEqual(A.t.this_class, A)
211 self.assertEqual(A.t.this_class, A)
212 self.assertEqual(B.t.this_class, A)
212 self.assertEqual(B.t.this_class, A)
213 self.assertEqual(B.tt.this_class, B)
213 self.assertEqual(B.tt.this_class, B)
214 self.assertEqual(B.ttt.this_class, B)
214 self.assertEqual(B.ttt.this_class, B)
215
215
216 class TestHasTraitsNotify(TestCase):
216 class TestHasTraitsNotify(TestCase):
217
217
218 def setUp(self):
218 def setUp(self):
219 self._notify1 = []
219 self._notify1 = []
220 self._notify2 = []
220 self._notify2 = []
221
221
222 def notify1(self, name, old, new):
222 def notify1(self, name, old, new):
223 self._notify1.append((name, old, new))
223 self._notify1.append((name, old, new))
224
224
225 def notify2(self, name, old, new):
225 def notify2(self, name, old, new):
226 self._notify2.append((name, old, new))
226 self._notify2.append((name, old, new))
227
227
228 def test_notify_all(self):
228 def test_notify_all(self):
229
229
230 class A(HasTraits):
230 class A(HasTraits):
231 a = Int
231 a = Int
232 b = Float
232 b = Float
233
233
234 a = A()
234 a = A()
235 a.on_trait_change(self.notify1)
235 a.on_trait_change(self.notify1)
236 a.a = 0
236 a.a = 0
237 self.assertEqual(len(self._notify1),0)
237 self.assertEqual(len(self._notify1),0)
238 a.b = 0.0
238 a.b = 0.0
239 self.assertEqual(len(self._notify1),0)
239 self.assertEqual(len(self._notify1),0)
240 a.a = 10
240 a.a = 10
241 self.assertTrue(('a',0,10) in self._notify1)
241 self.assertTrue(('a',0,10) in self._notify1)
242 a.b = 10.0
242 a.b = 10.0
243 self.assertTrue(('b',0.0,10.0) in self._notify1)
243 self.assertTrue(('b',0.0,10.0) in self._notify1)
244 self.assertRaises(TraitError,setattr,a,'a','bad string')
244 self.assertRaises(TraitError,setattr,a,'a','bad string')
245 self.assertRaises(TraitError,setattr,a,'b','bad string')
245 self.assertRaises(TraitError,setattr,a,'b','bad string')
246 self._notify1 = []
246 self._notify1 = []
247 a.on_trait_change(self.notify1,remove=True)
247 a.on_trait_change(self.notify1,remove=True)
248 a.a = 20
248 a.a = 20
249 a.b = 20.0
249 a.b = 20.0
250 self.assertEqual(len(self._notify1),0)
250 self.assertEqual(len(self._notify1),0)
251
251
252 def test_notify_one(self):
252 def test_notify_one(self):
253
253
254 class A(HasTraits):
254 class A(HasTraits):
255 a = Int
255 a = Int
256 b = Float
256 b = Float
257
257
258 a = A()
258 a = A()
259 a.on_trait_change(self.notify1, 'a')
259 a.on_trait_change(self.notify1, 'a')
260 a.a = 0
260 a.a = 0
261 self.assertEqual(len(self._notify1),0)
261 self.assertEqual(len(self._notify1),0)
262 a.a = 10
262 a.a = 10
263 self.assertTrue(('a',0,10) in self._notify1)
263 self.assertTrue(('a',0,10) in self._notify1)
264 self.assertRaises(TraitError,setattr,a,'a','bad string')
264 self.assertRaises(TraitError,setattr,a,'a','bad string')
265
265
266 def test_subclass(self):
266 def test_subclass(self):
267
267
268 class A(HasTraits):
268 class A(HasTraits):
269 a = Int
269 a = Int
270
270
271 class B(A):
271 class B(A):
272 b = Float
272 b = Float
273
273
274 b = B()
274 b = B()
275 self.assertEqual(b.a,0)
275 self.assertEqual(b.a,0)
276 self.assertEqual(b.b,0.0)
276 self.assertEqual(b.b,0.0)
277 b.a = 100
277 b.a = 100
278 b.b = 100.0
278 b.b = 100.0
279 self.assertEqual(b.a,100)
279 self.assertEqual(b.a,100)
280 self.assertEqual(b.b,100.0)
280 self.assertEqual(b.b,100.0)
281
281
282 def test_notify_subclass(self):
282 def test_notify_subclass(self):
283
283
284 class A(HasTraits):
284 class A(HasTraits):
285 a = Int
285 a = Int
286
286
287 class B(A):
287 class B(A):
288 b = Float
288 b = Float
289
289
290 b = B()
290 b = B()
291 b.on_trait_change(self.notify1, 'a')
291 b.on_trait_change(self.notify1, 'a')
292 b.on_trait_change(self.notify2, 'b')
292 b.on_trait_change(self.notify2, 'b')
293 b.a = 0
293 b.a = 0
294 b.b = 0.0
294 b.b = 0.0
295 self.assertEqual(len(self._notify1),0)
295 self.assertEqual(len(self._notify1),0)
296 self.assertEqual(len(self._notify2),0)
296 self.assertEqual(len(self._notify2),0)
297 b.a = 10
297 b.a = 10
298 b.b = 10.0
298 b.b = 10.0
299 self.assertTrue(('a',0,10) in self._notify1)
299 self.assertTrue(('a',0,10) in self._notify1)
300 self.assertTrue(('b',0.0,10.0) in self._notify2)
300 self.assertTrue(('b',0.0,10.0) in self._notify2)
301
301
302 def test_static_notify(self):
302 def test_static_notify(self):
303
303
304 class A(HasTraits):
304 class A(HasTraits):
305 a = Int
305 a = Int
306 _notify1 = []
306 _notify1 = []
307 def _a_changed(self, name, old, new):
307 def _a_changed(self, name, old, new):
308 self._notify1.append((name, old, new))
308 self._notify1.append((name, old, new))
309
309
310 a = A()
310 a = A()
311 a.a = 0
311 a.a = 0
312 # This is broken!!!
312 # This is broken!!!
313 self.assertEqual(len(a._notify1),0)
313 self.assertEqual(len(a._notify1),0)
314 a.a = 10
314 a.a = 10
315 self.assertTrue(('a',0,10) in a._notify1)
315 self.assertTrue(('a',0,10) in a._notify1)
316
316
317 class B(A):
317 class B(A):
318 b = Float
318 b = Float
319 _notify2 = []
319 _notify2 = []
320 def _b_changed(self, name, old, new):
320 def _b_changed(self, name, old, new):
321 self._notify2.append((name, old, new))
321 self._notify2.append((name, old, new))
322
322
323 b = B()
323 b = B()
324 b.a = 10
324 b.a = 10
325 b.b = 10.0
325 b.b = 10.0
326 self.assertTrue(('a',0,10) in b._notify1)
326 self.assertTrue(('a',0,10) in b._notify1)
327 self.assertTrue(('b',0.0,10.0) in b._notify2)
327 self.assertTrue(('b',0.0,10.0) in b._notify2)
328
328
329 def test_notify_args(self):
329 def test_notify_args(self):
330
330
331 def callback0():
331 def callback0():
332 self.cb = ()
332 self.cb = ()
333 def callback1(name):
333 def callback1(name):
334 self.cb = (name,)
334 self.cb = (name,)
335 def callback2(name, new):
335 def callback2(name, new):
336 self.cb = (name, new)
336 self.cb = (name, new)
337 def callback3(name, old, new):
337 def callback3(name, old, new):
338 self.cb = (name, old, new)
338 self.cb = (name, old, new)
339
339
340 class A(HasTraits):
340 class A(HasTraits):
341 a = Int
341 a = Int
342
342
343 a = A()
343 a = A()
344 a.on_trait_change(callback0, 'a')
344 a.on_trait_change(callback0, 'a')
345 a.a = 10
345 a.a = 10
346 self.assertEqual(self.cb,())
346 self.assertEqual(self.cb,())
347 a.on_trait_change(callback0, 'a', remove=True)
347 a.on_trait_change(callback0, 'a', remove=True)
348
348
349 a.on_trait_change(callback1, 'a')
349 a.on_trait_change(callback1, 'a')
350 a.a = 100
350 a.a = 100
351 self.assertEqual(self.cb,('a',))
351 self.assertEqual(self.cb,('a',))
352 a.on_trait_change(callback1, 'a', remove=True)
352 a.on_trait_change(callback1, 'a', remove=True)
353
353
354 a.on_trait_change(callback2, 'a')
354 a.on_trait_change(callback2, 'a')
355 a.a = 1000
355 a.a = 1000
356 self.assertEqual(self.cb,('a',1000))
356 self.assertEqual(self.cb,('a',1000))
357 a.on_trait_change(callback2, 'a', remove=True)
357 a.on_trait_change(callback2, 'a', remove=True)
358
358
359 a.on_trait_change(callback3, 'a')
359 a.on_trait_change(callback3, 'a')
360 a.a = 10000
360 a.a = 10000
361 self.assertEqual(self.cb,('a',1000,10000))
361 self.assertEqual(self.cb,('a',1000,10000))
362 a.on_trait_change(callback3, 'a', remove=True)
362 a.on_trait_change(callback3, 'a', remove=True)
363
363
364 self.assertEqual(len(a._trait_notifiers['a']),0)
364 self.assertEqual(len(a._trait_notifiers['a']),0)
365
365
366 def test_notify_only_once(self):
366 def test_notify_only_once(self):
367
367
368 class A(HasTraits):
368 class A(HasTraits):
369 listen_to = ['a']
369 listen_to = ['a']
370
370
371 a = Int(0)
371 a = Int(0)
372 b = 0
372 b = 0
373
373
374 def __init__(self, **kwargs):
374 def __init__(self, **kwargs):
375 super(A, self).__init__(**kwargs)
375 super(A, self).__init__(**kwargs)
376 self.on_trait_change(self.listener1, ['a'])
376 self.on_trait_change(self.listener1, ['a'])
377
377
378 def listener1(self, name, old, new):
378 def listener1(self, name, old, new):
379 self.b += 1
379 self.b += 1
380
380
381 class B(A):
381 class B(A):
382
382
383 c = 0
383 c = 0
384 d = 0
384 d = 0
385
385
386 def __init__(self, **kwargs):
386 def __init__(self, **kwargs):
387 super(B, self).__init__(**kwargs)
387 super(B, self).__init__(**kwargs)
388 self.on_trait_change(self.listener2)
388 self.on_trait_change(self.listener2)
389
389
390 def listener2(self, name, old, new):
390 def listener2(self, name, old, new):
391 self.c += 1
391 self.c += 1
392
392
393 def _a_changed(self, name, old, new):
393 def _a_changed(self, name, old, new):
394 self.d += 1
394 self.d += 1
395
395
396 b = B()
396 b = B()
397 b.a += 1
397 b.a += 1
398 self.assertEqual(b.b, b.c)
398 self.assertEqual(b.b, b.c)
399 self.assertEqual(b.b, b.d)
399 self.assertEqual(b.b, b.d)
400 b.a += 1
400 b.a += 1
401 self.assertEqual(b.b, b.c)
401 self.assertEqual(b.b, b.c)
402 self.assertEqual(b.b, b.d)
402 self.assertEqual(b.b, b.d)
403
403
404
404
405 class TestHasTraits(TestCase):
405 class TestHasTraits(TestCase):
406
406
407 def test_trait_names(self):
407 def test_trait_names(self):
408 class A(HasTraits):
408 class A(HasTraits):
409 i = Int
409 i = Int
410 f = Float
410 f = Float
411 a = A()
411 a = A()
412 self.assertEqual(sorted(a.trait_names()),['f','i'])
412 self.assertEqual(sorted(a.trait_names()),['f','i'])
413 self.assertEqual(sorted(A.class_trait_names()),['f','i'])
413 self.assertEqual(sorted(A.class_trait_names()),['f','i'])
414
414
415 def test_trait_metadata(self):
415 def test_trait_metadata(self):
416 class A(HasTraits):
416 class A(HasTraits):
417 i = Int(config_key='MY_VALUE')
417 i = Int(config_key='MY_VALUE')
418 a = A()
418 a = A()
419 self.assertEqual(a.trait_metadata('i','config_key'), 'MY_VALUE')
419 self.assertEqual(a.trait_metadata('i','config_key'), 'MY_VALUE')
420
420
421 def test_traits(self):
421 def test_traits(self):
422 class A(HasTraits):
422 class A(HasTraits):
423 i = Int
423 i = Int
424 f = Float
424 f = Float
425 a = A()
425 a = A()
426 self.assertEqual(a.traits(), dict(i=A.i, f=A.f))
426 self.assertEqual(a.traits(), dict(i=A.i, f=A.f))
427 self.assertEqual(A.class_traits(), dict(i=A.i, f=A.f))
427 self.assertEqual(A.class_traits(), dict(i=A.i, f=A.f))
428
428
429 def test_traits_metadata(self):
429 def test_traits_metadata(self):
430 class A(HasTraits):
430 class A(HasTraits):
431 i = Int(config_key='VALUE1', other_thing='VALUE2')
431 i = Int(config_key='VALUE1', other_thing='VALUE2')
432 f = Float(config_key='VALUE3', other_thing='VALUE2')
432 f = Float(config_key='VALUE3', other_thing='VALUE2')
433 j = Int(0)
433 j = Int(0)
434 a = A()
434 a = A()
435 self.assertEqual(a.traits(), dict(i=A.i, f=A.f, j=A.j))
435 self.assertEqual(a.traits(), dict(i=A.i, f=A.f, j=A.j))
436 traits = a.traits(config_key='VALUE1', other_thing='VALUE2')
436 traits = a.traits(config_key='VALUE1', other_thing='VALUE2')
437 self.assertEqual(traits, dict(i=A.i))
437 self.assertEqual(traits, dict(i=A.i))
438
438
439 # This passes, but it shouldn't because I am replicating a bug in
439 # This passes, but it shouldn't because I am replicating a bug in
440 # traits.
440 # traits.
441 traits = a.traits(config_key=lambda v: True)
441 traits = a.traits(config_key=lambda v: True)
442 self.assertEqual(traits, dict(i=A.i, f=A.f, j=A.j))
442 self.assertEqual(traits, dict(i=A.i, f=A.f, j=A.j))
443
443
444 def test_init(self):
444 def test_init(self):
445 class A(HasTraits):
445 class A(HasTraits):
446 i = Int()
446 i = Int()
447 x = Float()
447 x = Float()
448 a = A(i=1, x=10.0)
448 a = A(i=1, x=10.0)
449 self.assertEqual(a.i, 1)
449 self.assertEqual(a.i, 1)
450 self.assertEqual(a.x, 10.0)
450 self.assertEqual(a.x, 10.0)
451
451
452 def test_positional_args(self):
452 def test_positional_args(self):
453 class A(HasTraits):
453 class A(HasTraits):
454 i = Int(0)
454 i = Int(0)
455 def __init__(self, i):
455 def __init__(self, i):
456 super(A, self).__init__()
456 super(A, self).__init__()
457 self.i = i
457 self.i = i
458
458
459 a = A(5)
459 a = A(5)
460 self.assertEqual(a.i, 5)
460 self.assertEqual(a.i, 5)
461 # should raise TypeError if no positional arg given
461 # should raise TypeError if no positional arg given
462 self.assertRaises(TypeError, A)
462 self.assertRaises(TypeError, A)
463
463
464 #-----------------------------------------------------------------------------
464 #-----------------------------------------------------------------------------
465 # Tests for specific trait types
465 # Tests for specific trait types
466 #-----------------------------------------------------------------------------
466 #-----------------------------------------------------------------------------
467
467
468
468
469 class TestType(TestCase):
469 class TestType(TestCase):
470
470
471 def test_default(self):
471 def test_default(self):
472
472
473 class B(object): pass
473 class B(object): pass
474 class A(HasTraits):
474 class A(HasTraits):
475 klass = Type
475 klass = Type
476
476
477 a = A()
477 a = A()
478 self.assertEqual(a.klass, None)
478 self.assertEqual(a.klass, None)
479
479
480 a.klass = B
480 a.klass = B
481 self.assertEqual(a.klass, B)
481 self.assertEqual(a.klass, B)
482 self.assertRaises(TraitError, setattr, a, 'klass', 10)
482 self.assertRaises(TraitError, setattr, a, 'klass', 10)
483
483
484 def test_value(self):
484 def test_value(self):
485
485
486 class B(object): pass
486 class B(object): pass
487 class C(object): pass
487 class C(object): pass
488 class A(HasTraits):
488 class A(HasTraits):
489 klass = Type(B)
489 klass = Type(B)
490
490
491 a = A()
491 a = A()
492 self.assertEqual(a.klass, B)
492 self.assertEqual(a.klass, B)
493 self.assertRaises(TraitError, setattr, a, 'klass', C)
493 self.assertRaises(TraitError, setattr, a, 'klass', C)
494 self.assertRaises(TraitError, setattr, a, 'klass', object)
494 self.assertRaises(TraitError, setattr, a, 'klass', object)
495 a.klass = B
495 a.klass = B
496
496
497 def test_allow_none(self):
497 def test_allow_none(self):
498
498
499 class B(object): pass
499 class B(object): pass
500 class C(B): pass
500 class C(B): pass
501 class A(HasTraits):
501 class A(HasTraits):
502 klass = Type(B, allow_none=False)
502 klass = Type(B, allow_none=False)
503
503
504 a = A()
504 a = A()
505 self.assertEqual(a.klass, B)
505 self.assertEqual(a.klass, B)
506 self.assertRaises(TraitError, setattr, a, 'klass', None)
506 self.assertRaises(TraitError, setattr, a, 'klass', None)
507 a.klass = C
507 a.klass = C
508 self.assertEqual(a.klass, C)
508 self.assertEqual(a.klass, C)
509
509
510 def test_validate_klass(self):
510 def test_validate_klass(self):
511
511
512 class A(HasTraits):
512 class A(HasTraits):
513 klass = Type('no strings allowed')
513 klass = Type('no strings allowed')
514
514
515 self.assertRaises(ImportError, A)
515 self.assertRaises(ImportError, A)
516
516
517 class A(HasTraits):
517 class A(HasTraits):
518 klass = Type('rub.adub.Duck')
518 klass = Type('rub.adub.Duck')
519
519
520 self.assertRaises(ImportError, A)
520 self.assertRaises(ImportError, A)
521
521
522 def test_validate_default(self):
522 def test_validate_default(self):
523
523
524 class B(object): pass
524 class B(object): pass
525 class A(HasTraits):
525 class A(HasTraits):
526 klass = Type('bad default', B)
526 klass = Type('bad default', B)
527
527
528 self.assertRaises(ImportError, A)
528 self.assertRaises(ImportError, A)
529
529
530 class C(HasTraits):
530 class C(HasTraits):
531 klass = Type(None, B, allow_none=False)
531 klass = Type(None, B, allow_none=False)
532
532
533 self.assertRaises(TraitError, C)
533 self.assertRaises(TraitError, C)
534
534
535 def test_str_klass(self):
535 def test_str_klass(self):
536
536
537 class A(HasTraits):
537 class A(HasTraits):
538 klass = Type('IPython.utils.ipstruct.Struct')
538 klass = Type('IPython.utils.ipstruct.Struct')
539
539
540 from IPython.utils.ipstruct import Struct
540 from IPython.utils.ipstruct import Struct
541 a = A()
541 a = A()
542 a.klass = Struct
542 a.klass = Struct
543 self.assertEqual(a.klass, Struct)
543 self.assertEqual(a.klass, Struct)
544
544
545 self.assertRaises(TraitError, setattr, a, 'klass', 10)
545 self.assertRaises(TraitError, setattr, a, 'klass', 10)
546
546
547 class TestInstance(TestCase):
547 class TestInstance(TestCase):
548
548
549 def test_basic(self):
549 def test_basic(self):
550 class Foo(object): pass
550 class Foo(object): pass
551 class Bar(Foo): pass
551 class Bar(Foo): pass
552 class Bah(object): pass
552 class Bah(object): pass
553
553
554 class A(HasTraits):
554 class A(HasTraits):
555 inst = Instance(Foo)
555 inst = Instance(Foo)
556
556
557 a = A()
557 a = A()
558 self.assertTrue(a.inst is None)
558 self.assertTrue(a.inst is None)
559 a.inst = Foo()
559 a.inst = Foo()
560 self.assertTrue(isinstance(a.inst, Foo))
560 self.assertTrue(isinstance(a.inst, Foo))
561 a.inst = Bar()
561 a.inst = Bar()
562 self.assertTrue(isinstance(a.inst, Foo))
562 self.assertTrue(isinstance(a.inst, Foo))
563 self.assertRaises(TraitError, setattr, a, 'inst', Foo)
563 self.assertRaises(TraitError, setattr, a, 'inst', Foo)
564 self.assertRaises(TraitError, setattr, a, 'inst', Bar)
564 self.assertRaises(TraitError, setattr, a, 'inst', Bar)
565 self.assertRaises(TraitError, setattr, a, 'inst', Bah())
565 self.assertRaises(TraitError, setattr, a, 'inst', Bah())
566
566
567 def test_unique_default_value(self):
567 def test_unique_default_value(self):
568 class Foo(object): pass
568 class Foo(object): pass
569 class A(HasTraits):
569 class A(HasTraits):
570 inst = Instance(Foo,(),{})
570 inst = Instance(Foo,(),{})
571
571
572 a = A()
572 a = A()
573 b = A()
573 b = A()
574 self.assertTrue(a.inst is not b.inst)
574 self.assertTrue(a.inst is not b.inst)
575
575
576 def test_args_kw(self):
576 def test_args_kw(self):
577 class Foo(object):
577 class Foo(object):
578 def __init__(self, c): self.c = c
578 def __init__(self, c): self.c = c
579 class Bar(object): pass
579 class Bar(object): pass
580 class Bah(object):
580 class Bah(object):
581 def __init__(self, c, d):
581 def __init__(self, c, d):
582 self.c = c; self.d = d
582 self.c = c; self.d = d
583
583
584 class A(HasTraits):
584 class A(HasTraits):
585 inst = Instance(Foo, (10,))
585 inst = Instance(Foo, (10,))
586 a = A()
586 a = A()
587 self.assertEqual(a.inst.c, 10)
587 self.assertEqual(a.inst.c, 10)
588
588
589 class B(HasTraits):
589 class B(HasTraits):
590 inst = Instance(Bah, args=(10,), kw=dict(d=20))
590 inst = Instance(Bah, args=(10,), kw=dict(d=20))
591 b = B()
591 b = B()
592 self.assertEqual(b.inst.c, 10)
592 self.assertEqual(b.inst.c, 10)
593 self.assertEqual(b.inst.d, 20)
593 self.assertEqual(b.inst.d, 20)
594
594
595 class C(HasTraits):
595 class C(HasTraits):
596 inst = Instance(Foo)
596 inst = Instance(Foo)
597 c = C()
597 c = C()
598 self.assertTrue(c.inst is None)
598 self.assertTrue(c.inst is None)
599
599
600 def test_bad_default(self):
600 def test_bad_default(self):
601 class Foo(object): pass
601 class Foo(object): pass
602
602
603 class A(HasTraits):
603 class A(HasTraits):
604 inst = Instance(Foo, allow_none=False)
604 inst = Instance(Foo, allow_none=False)
605
605
606 self.assertRaises(TraitError, A)
606 self.assertRaises(TraitError, A)
607
607
608 def test_instance(self):
608 def test_instance(self):
609 class Foo(object): pass
609 class Foo(object): pass
610
610
611 def inner():
611 def inner():
612 class A(HasTraits):
612 class A(HasTraits):
613 inst = Instance(Foo())
613 inst = Instance(Foo())
614
614
615 self.assertRaises(TraitError, inner)
615 self.assertRaises(TraitError, inner)
616
616
617
617
618 class TestThis(TestCase):
618 class TestThis(TestCase):
619
619
620 def test_this_class(self):
620 def test_this_class(self):
621 class Foo(HasTraits):
621 class Foo(HasTraits):
622 this = This
622 this = This
623
623
624 f = Foo()
624 f = Foo()
625 self.assertEqual(f.this, None)
625 self.assertEqual(f.this, None)
626 g = Foo()
626 g = Foo()
627 f.this = g
627 f.this = g
628 self.assertEqual(f.this, g)
628 self.assertEqual(f.this, g)
629 self.assertRaises(TraitError, setattr, f, 'this', 10)
629 self.assertRaises(TraitError, setattr, f, 'this', 10)
630
630
631 def test_this_inst(self):
631 def test_this_inst(self):
632 class Foo(HasTraits):
632 class Foo(HasTraits):
633 this = This()
633 this = This()
634
634
635 f = Foo()
635 f = Foo()
636 f.this = Foo()
636 f.this = Foo()
637 self.assertTrue(isinstance(f.this, Foo))
637 self.assertTrue(isinstance(f.this, Foo))
638
638
639 def test_subclass(self):
639 def test_subclass(self):
640 class Foo(HasTraits):
640 class Foo(HasTraits):
641 t = This()
641 t = This()
642 class Bar(Foo):
642 class Bar(Foo):
643 pass
643 pass
644 f = Foo()
644 f = Foo()
645 b = Bar()
645 b = Bar()
646 f.t = b
646 f.t = b
647 b.t = f
647 b.t = f
648 self.assertEqual(f.t, b)
648 self.assertEqual(f.t, b)
649 self.assertEqual(b.t, f)
649 self.assertEqual(b.t, f)
650
650
651 def test_subclass_override(self):
651 def test_subclass_override(self):
652 class Foo(HasTraits):
652 class Foo(HasTraits):
653 t = This()
653 t = This()
654 class Bar(Foo):
654 class Bar(Foo):
655 t = This()
655 t = This()
656 f = Foo()
656 f = Foo()
657 b = Bar()
657 b = Bar()
658 f.t = b
658 f.t = b
659 self.assertEqual(f.t, b)
659 self.assertEqual(f.t, b)
660 self.assertRaises(TraitError, setattr, b, 't', f)
660 self.assertRaises(TraitError, setattr, b, 't', f)
661
661
662 class TraitTestBase(TestCase):
662 class TraitTestBase(TestCase):
663 """A best testing class for basic trait types."""
663 """A best testing class for basic trait types."""
664
664
665 def assign(self, value):
665 def assign(self, value):
666 self.obj.value = value
666 self.obj.value = value
667
667
668 def coerce(self, value):
668 def coerce(self, value):
669 return value
669 return value
670
670
671 def test_good_values(self):
671 def test_good_values(self):
672 if hasattr(self, '_good_values'):
672 if hasattr(self, '_good_values'):
673 for value in self._good_values:
673 for value in self._good_values:
674 self.assign(value)
674 self.assign(value)
675 self.assertEqual(self.obj.value, self.coerce(value))
675 self.assertEqual(self.obj.value, self.coerce(value))
676
676
677 def test_bad_values(self):
677 def test_bad_values(self):
678 if hasattr(self, '_bad_values'):
678 if hasattr(self, '_bad_values'):
679 for value in self._bad_values:
679 for value in self._bad_values:
680 try:
680 try:
681 self.assertRaises(TraitError, self.assign, value)
681 self.assertRaises(TraitError, self.assign, value)
682 except AssertionError:
682 except AssertionError:
683 assert False, value
683 assert False, value
684
684
685 def test_default_value(self):
685 def test_default_value(self):
686 if hasattr(self, '_default_value'):
686 if hasattr(self, '_default_value'):
687 self.assertEqual(self._default_value, self.obj.value)
687 self.assertEqual(self._default_value, self.obj.value)
688
688
689 def tearDown(self):
689 def tearDown(self):
690 # restore default value after tests, if set
690 # restore default value after tests, if set
691 if hasattr(self, '_default_value'):
691 if hasattr(self, '_default_value'):
692 self.obj.value = self._default_value
692 self.obj.value = self._default_value
693
693
694
694
695 class AnyTrait(HasTraits):
695 class AnyTrait(HasTraits):
696
696
697 value = Any
697 value = Any
698
698
699 class AnyTraitTest(TraitTestBase):
699 class AnyTraitTest(TraitTestBase):
700
700
701 obj = AnyTrait()
701 obj = AnyTrait()
702
702
703 _default_value = None
703 _default_value = None
704 _good_values = [10.0, 'ten', u'ten', [10], {'ten': 10},(10,), None, 1j]
704 _good_values = [10.0, 'ten', u'ten', [10], {'ten': 10},(10,), None, 1j]
705 _bad_values = []
705 _bad_values = []
706
706
707
707
708 class IntTrait(HasTraits):
708 class IntTrait(HasTraits):
709
709
710 value = Int(99)
710 value = Int(99)
711
711
712 class TestInt(TraitTestBase):
712 class TestInt(TraitTestBase):
713
713
714 obj = IntTrait()
714 obj = IntTrait()
715 _default_value = 99
715 _default_value = 99
716 _good_values = [10, -10]
716 _good_values = [10, -10]
717 _bad_values = ['ten', u'ten', [10], {'ten': 10},(10,), None, 1j,
717 _bad_values = ['ten', u'ten', [10], {'ten': 10},(10,), None, 1j,
718 10.1, -10.1, '10L', '-10L', '10.1', '-10.1', u'10L',
718 10.1, -10.1, '10L', '-10L', '10.1', '-10.1', u'10L',
719 u'-10L', u'10.1', u'-10.1', '10', '-10', u'10', u'-10']
719 u'-10L', u'10.1', u'-10.1', '10', '-10', u'10', u'-10']
720 if not py3compat.PY3:
720 if not py3compat.PY3:
721 _bad_values.extend([long(10), long(-10), 10*sys.maxint, -10*sys.maxint])
721 _bad_values.extend([long(10), long(-10), 10*sys.maxint, -10*sys.maxint])
722
722
723
723
724 class LongTrait(HasTraits):
724 class LongTrait(HasTraits):
725
725
726 value = Long(99 if py3compat.PY3 else long(99))
726 value = Long(99 if py3compat.PY3 else long(99))
727
727
728 class TestLong(TraitTestBase):
728 class TestLong(TraitTestBase):
729
729
730 obj = LongTrait()
730 obj = LongTrait()
731
731
732 _default_value = 99 if py3compat.PY3 else long(99)
732 _default_value = 99 if py3compat.PY3 else long(99)
733 _good_values = [10, -10]
733 _good_values = [10, -10]
734 _bad_values = ['ten', u'ten', [10], {'ten': 10},(10,),
734 _bad_values = ['ten', u'ten', [10], {'ten': 10},(10,),
735 None, 1j, 10.1, -10.1, '10', '-10', '10L', '-10L', '10.1',
735 None, 1j, 10.1, -10.1, '10', '-10', '10L', '-10L', '10.1',
736 '-10.1', u'10', u'-10', u'10L', u'-10L', u'10.1',
736 '-10.1', u'10', u'-10', u'10L', u'-10L', u'10.1',
737 u'-10.1']
737 u'-10.1']
738 if not py3compat.PY3:
738 if not py3compat.PY3:
739 # maxint undefined on py3, because int == long
739 # maxint undefined on py3, because int == long
740 _good_values.extend([long(10), long(-10), 10*sys.maxint, -10*sys.maxint])
740 _good_values.extend([long(10), long(-10), 10*sys.maxint, -10*sys.maxint])
741 _bad_values.extend([[long(10)], (long(10),)])
741 _bad_values.extend([[long(10)], (long(10),)])
742
742
743 @skipif(py3compat.PY3, "not relevant on py3")
743 @skipif(py3compat.PY3, "not relevant on py3")
744 def test_cast_small(self):
744 def test_cast_small(self):
745 """Long casts ints to long"""
745 """Long casts ints to long"""
746 self.obj.value = 10
746 self.obj.value = 10
747 self.assertEqual(type(self.obj.value), long)
747 self.assertEqual(type(self.obj.value), long)
748
748
749
749
750 class IntegerTrait(HasTraits):
750 class IntegerTrait(HasTraits):
751 value = Integer(1)
751 value = Integer(1)
752
752
753 class TestInteger(TestLong):
753 class TestInteger(TestLong):
754 obj = IntegerTrait()
754 obj = IntegerTrait()
755 _default_value = 1
755 _default_value = 1
756
756
757 def coerce(self, n):
757 def coerce(self, n):
758 return int(n)
758 return int(n)
759
759
760 @skipif(py3compat.PY3, "not relevant on py3")
760 @skipif(py3compat.PY3, "not relevant on py3")
761 def test_cast_small(self):
761 def test_cast_small(self):
762 """Integer casts small longs to int"""
762 """Integer casts small longs to int"""
763 if py3compat.PY3:
763 if py3compat.PY3:
764 raise SkipTest("not relevant on py3")
764 raise SkipTest("not relevant on py3")
765
765
766 self.obj.value = long(100)
766 self.obj.value = long(100)
767 self.assertEqual(type(self.obj.value), int)
767 self.assertEqual(type(self.obj.value), int)
768
768
769
769
770 class FloatTrait(HasTraits):
770 class FloatTrait(HasTraits):
771
771
772 value = Float(99.0)
772 value = Float(99.0)
773
773
774 class TestFloat(TraitTestBase):
774 class TestFloat(TraitTestBase):
775
775
776 obj = FloatTrait()
776 obj = FloatTrait()
777
777
778 _default_value = 99.0
778 _default_value = 99.0
779 _good_values = [10, -10, 10.1, -10.1]
779 _good_values = [10, -10, 10.1, -10.1]
780 _bad_values = ['ten', u'ten', [10], {'ten': 10},(10,), None,
780 _bad_values = ['ten', u'ten', [10], {'ten': 10},(10,), None,
781 1j, '10', '-10', '10L', '-10L', '10.1', '-10.1', u'10',
781 1j, '10', '-10', '10L', '-10L', '10.1', '-10.1', u'10',
782 u'-10', u'10L', u'-10L', u'10.1', u'-10.1']
782 u'-10', u'10L', u'-10L', u'10.1', u'-10.1']
783 if not py3compat.PY3:
783 if not py3compat.PY3:
784 _bad_values.extend([long(10), long(-10)])
784 _bad_values.extend([long(10), long(-10)])
785
785
786
786
787 class ComplexTrait(HasTraits):
787 class ComplexTrait(HasTraits):
788
788
789 value = Complex(99.0-99.0j)
789 value = Complex(99.0-99.0j)
790
790
791 class TestComplex(TraitTestBase):
791 class TestComplex(TraitTestBase):
792
792
793 obj = ComplexTrait()
793 obj = ComplexTrait()
794
794
795 _default_value = 99.0-99.0j
795 _default_value = 99.0-99.0j
796 _good_values = [10, -10, 10.1, -10.1, 10j, 10+10j, 10-10j,
796 _good_values = [10, -10, 10.1, -10.1, 10j, 10+10j, 10-10j,
797 10.1j, 10.1+10.1j, 10.1-10.1j]
797 10.1j, 10.1+10.1j, 10.1-10.1j]
798 _bad_values = [u'10L', u'-10L', 'ten', [10], {'ten': 10},(10,), None]
798 _bad_values = [u'10L', u'-10L', 'ten', [10], {'ten': 10},(10,), None]
799 if not py3compat.PY3:
799 if not py3compat.PY3:
800 _bad_values.extend([long(10), long(-10)])
800 _bad_values.extend([long(10), long(-10)])
801
801
802
802
803 class BytesTrait(HasTraits):
803 class BytesTrait(HasTraits):
804
804
805 value = Bytes(b'string')
805 value = Bytes(b'string')
806
806
807 class TestBytes(TraitTestBase):
807 class TestBytes(TraitTestBase):
808
808
809 obj = BytesTrait()
809 obj = BytesTrait()
810
810
811 _default_value = b'string'
811 _default_value = b'string'
812 _good_values = [b'10', b'-10', b'10L',
812 _good_values = [b'10', b'-10', b'10L',
813 b'-10L', b'10.1', b'-10.1', b'string']
813 b'-10L', b'10.1', b'-10.1', b'string']
814 _bad_values = [10, -10, 10.1, -10.1, 1j, [10],
814 _bad_values = [10, -10, 10.1, -10.1, 1j, [10],
815 ['ten'],{'ten': 10},(10,), None, u'string']
815 ['ten'],{'ten': 10},(10,), None, u'string']
816 if not py3compat.PY3:
816 if not py3compat.PY3:
817 _bad_values.extend([long(10), long(-10)])
817 _bad_values.extend([long(10), long(-10)])
818
818
819
819
820 class UnicodeTrait(HasTraits):
820 class UnicodeTrait(HasTraits):
821
821
822 value = Unicode(u'unicode')
822 value = Unicode(u'unicode')
823
823
824 class TestUnicode(TraitTestBase):
824 class TestUnicode(TraitTestBase):
825
825
826 obj = UnicodeTrait()
826 obj = UnicodeTrait()
827
827
828 _default_value = u'unicode'
828 _default_value = u'unicode'
829 _good_values = ['10', '-10', '10L', '-10L', '10.1',
829 _good_values = ['10', '-10', '10L', '-10L', '10.1',
830 '-10.1', '', u'', 'string', u'string', u"€"]
830 '-10.1', '', u'', 'string', u'string', u"€"]
831 _bad_values = [10, -10, 10.1, -10.1, 1j,
831 _bad_values = [10, -10, 10.1, -10.1, 1j,
832 [10], ['ten'], [u'ten'], {'ten': 10},(10,), None]
832 [10], ['ten'], [u'ten'], {'ten': 10},(10,), None]
833 if not py3compat.PY3:
833 if not py3compat.PY3:
834 _bad_values.extend([long(10), long(-10)])
834 _bad_values.extend([long(10), long(-10)])
835
835
836
836
837 class ObjectNameTrait(HasTraits):
837 class ObjectNameTrait(HasTraits):
838 value = ObjectName("abc")
838 value = ObjectName("abc")
839
839
840 class TestObjectName(TraitTestBase):
840 class TestObjectName(TraitTestBase):
841 obj = ObjectNameTrait()
841 obj = ObjectNameTrait()
842
842
843 _default_value = "abc"
843 _default_value = "abc"
844 _good_values = ["a", "gh", "g9", "g_", "_G", u"a345_"]
844 _good_values = ["a", "gh", "g9", "g_", "_G", u"a345_"]
845 _bad_values = [1, "", u"€", "9g", "!", "#abc", "aj@", "a.b", "a()", "a[0]",
845 _bad_values = [1, "", u"€", "9g", "!", "#abc", "aj@", "a.b", "a()", "a[0]",
846 object(), object]
846 object(), object]
847 if sys.version_info[0] < 3:
847 if sys.version_info[0] < 3:
848 _bad_values.append(u"ΓΎ")
848 _bad_values.append(u"ΓΎ")
849 else:
849 else:
850 _good_values.append(u"ΓΎ") # ΓΎ=1 is valid in Python 3 (PEP 3131).
850 _good_values.append(u"ΓΎ") # ΓΎ=1 is valid in Python 3 (PEP 3131).
851
851
852
852
853 class DottedObjectNameTrait(HasTraits):
853 class DottedObjectNameTrait(HasTraits):
854 value = DottedObjectName("a.b")
854 value = DottedObjectName("a.b")
855
855
856 class TestDottedObjectName(TraitTestBase):
856 class TestDottedObjectName(TraitTestBase):
857 obj = DottedObjectNameTrait()
857 obj = DottedObjectNameTrait()
858
858
859 _default_value = "a.b"
859 _default_value = "a.b"
860 _good_values = ["A", "y.t", "y765.__repr__", "os.path.join", u"os.path.join"]
860 _good_values = ["A", "y.t", "y765.__repr__", "os.path.join", u"os.path.join"]
861 _bad_values = [1, u"abc.€", "_.@", ".", ".abc", "abc.", ".abc."]
861 _bad_values = [1, u"abc.€", "_.@", ".", ".abc", "abc.", ".abc."]
862 if sys.version_info[0] < 3:
862 if sys.version_info[0] < 3:
863 _bad_values.append(u"t.ΓΎ")
863 _bad_values.append(u"t.ΓΎ")
864 else:
864 else:
865 _good_values.append(u"t.ΓΎ")
865 _good_values.append(u"t.ΓΎ")
866
866
867
867
868 class TCPAddressTrait(HasTraits):
868 class TCPAddressTrait(HasTraits):
869
869
870 value = TCPAddress()
870 value = TCPAddress()
871
871
872 class TestTCPAddress(TraitTestBase):
872 class TestTCPAddress(TraitTestBase):
873
873
874 obj = TCPAddressTrait()
874 obj = TCPAddressTrait()
875
875
876 _default_value = ('127.0.0.1',0)
876 _default_value = ('127.0.0.1',0)
877 _good_values = [('localhost',0),('192.168.0.1',1000),('www.google.com',80)]
877 _good_values = [('localhost',0),('192.168.0.1',1000),('www.google.com',80)]
878 _bad_values = [(0,0),('localhost',10.0),('localhost',-1)]
878 _bad_values = [(0,0),('localhost',10.0),('localhost',-1)]
879
879
880 class ListTrait(HasTraits):
880 class ListTrait(HasTraits):
881
881
882 value = List(Int)
882 value = List(Int)
883
883
884 class TestList(TraitTestBase):
884 class TestList(TraitTestBase):
885
885
886 obj = ListTrait()
886 obj = ListTrait()
887
887
888 _default_value = []
888 _default_value = []
889 _good_values = [[], [1], range(10)]
889 _good_values = [[], [1], list(range(10))]
890 _bad_values = [10, [1,'a'], 'a', (1,2)]
890 _bad_values = [10, [1,'a'], 'a', (1,2)]
891
891
892 class LenListTrait(HasTraits):
892 class LenListTrait(HasTraits):
893
893
894 value = List(Int, [0], minlen=1, maxlen=2)
894 value = List(Int, [0], minlen=1, maxlen=2)
895
895
896 class TestLenList(TraitTestBase):
896 class TestLenList(TraitTestBase):
897
897
898 obj = LenListTrait()
898 obj = LenListTrait()
899
899
900 _default_value = [0]
900 _default_value = [0]
901 _good_values = [[1], range(2)]
901 _good_values = [[1], list(range(2))]
902 _bad_values = [10, [1,'a'], 'a', (1,2), [], range(3)]
902 _bad_values = [10, [1,'a'], 'a', (1,2), [], list(range(3))]
903
903
904 class TupleTrait(HasTraits):
904 class TupleTrait(HasTraits):
905
905
906 value = Tuple(Int)
906 value = Tuple(Int)
907
907
908 class TestTupleTrait(TraitTestBase):
908 class TestTupleTrait(TraitTestBase):
909
909
910 obj = TupleTrait()
910 obj = TupleTrait()
911
911
912 _default_value = None
912 _default_value = None
913 _good_values = [(1,), None,(0,)]
913 _good_values = [(1,), None,(0,)]
914 _bad_values = [10, (1,2), [1],('a'), ()]
914 _bad_values = [10, (1,2), [1],('a'), ()]
915
915
916 def test_invalid_args(self):
916 def test_invalid_args(self):
917 self.assertRaises(TypeError, Tuple, 5)
917 self.assertRaises(TypeError, Tuple, 5)
918 self.assertRaises(TypeError, Tuple, default_value='hello')
918 self.assertRaises(TypeError, Tuple, default_value='hello')
919 t = Tuple(Int, CBytes, default_value=(1,5))
919 t = Tuple(Int, CBytes, default_value=(1,5))
920
920
921 class LooseTupleTrait(HasTraits):
921 class LooseTupleTrait(HasTraits):
922
922
923 value = Tuple((1,2,3))
923 value = Tuple((1,2,3))
924
924
925 class TestLooseTupleTrait(TraitTestBase):
925 class TestLooseTupleTrait(TraitTestBase):
926
926
927 obj = LooseTupleTrait()
927 obj = LooseTupleTrait()
928
928
929 _default_value = (1,2,3)
929 _default_value = (1,2,3)
930 _good_values = [(1,), None, (0,), tuple(range(5)), tuple('hello'), ('a',5), ()]
930 _good_values = [(1,), None, (0,), tuple(range(5)), tuple('hello'), ('a',5), ()]
931 _bad_values = [10, 'hello', [1], []]
931 _bad_values = [10, 'hello', [1], []]
932
932
933 def test_invalid_args(self):
933 def test_invalid_args(self):
934 self.assertRaises(TypeError, Tuple, 5)
934 self.assertRaises(TypeError, Tuple, 5)
935 self.assertRaises(TypeError, Tuple, default_value='hello')
935 self.assertRaises(TypeError, Tuple, default_value='hello')
936 t = Tuple(Int, CBytes, default_value=(1,5))
936 t = Tuple(Int, CBytes, default_value=(1,5))
937
937
938
938
939 class MultiTupleTrait(HasTraits):
939 class MultiTupleTrait(HasTraits):
940
940
941 value = Tuple(Int, Bytes, default_value=[99,b'bottles'])
941 value = Tuple(Int, Bytes, default_value=[99,b'bottles'])
942
942
943 class TestMultiTuple(TraitTestBase):
943 class TestMultiTuple(TraitTestBase):
944
944
945 obj = MultiTupleTrait()
945 obj = MultiTupleTrait()
946
946
947 _default_value = (99,b'bottles')
947 _default_value = (99,b'bottles')
948 _good_values = [(1,b'a'), (2,b'b')]
948 _good_values = [(1,b'a'), (2,b'b')]
949 _bad_values = ((),10, b'a', (1,b'a',3), (b'a',1), (1, u'a'))
949 _bad_values = ((),10, b'a', (1,b'a',3), (b'a',1), (1, u'a'))
950
950
951 class CRegExpTrait(HasTraits):
951 class CRegExpTrait(HasTraits):
952
952
953 value = CRegExp(r'')
953 value = CRegExp(r'')
954
954
955 class TestCRegExp(TraitTestBase):
955 class TestCRegExp(TraitTestBase):
956
956
957 def coerce(self, value):
957 def coerce(self, value):
958 return re.compile(value)
958 return re.compile(value)
959
959
960 obj = CRegExpTrait()
960 obj = CRegExpTrait()
961
961
962 _default_value = re.compile(r'')
962 _default_value = re.compile(r'')
963 _good_values = [r'\d+', re.compile(r'\d+')]
963 _good_values = [r'\d+', re.compile(r'\d+')]
964 _bad_values = [r'(', None, ()]
964 _bad_values = [r'(', None, ()]
965
965
966 class DictTrait(HasTraits):
966 class DictTrait(HasTraits):
967 value = Dict()
967 value = Dict()
968
968
969 def test_dict_assignment():
969 def test_dict_assignment():
970 d = dict()
970 d = dict()
971 c = DictTrait()
971 c = DictTrait()
972 c.value = d
972 c.value = d
973 d['a'] = 5
973 d['a'] = 5
974 nt.assert_equal(d, c.value)
974 nt.assert_equal(d, c.value)
975 nt.assert_true(c.value is d)
975 nt.assert_true(c.value is d)
@@ -1,758 +1,758 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """
2 """
3 Utilities for working with strings and text.
3 Utilities for working with strings and text.
4
4
5 Inheritance diagram:
5 Inheritance diagram:
6
6
7 .. inheritance-diagram:: IPython.utils.text
7 .. inheritance-diagram:: IPython.utils.text
8 :parts: 3
8 :parts: 3
9 """
9 """
10
10
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12 # Copyright (C) 2008-2011 The IPython Development Team
12 # Copyright (C) 2008-2011 The IPython Development Team
13 #
13 #
14 # Distributed under the terms of the BSD License. The full license is in
14 # Distributed under the terms of the BSD License. The full license is in
15 # the file COPYING, distributed as part of this software.
15 # the file COPYING, distributed as part of this software.
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17
17
18 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
19 # Imports
19 # Imports
20 #-----------------------------------------------------------------------------
20 #-----------------------------------------------------------------------------
21
21
22 import os
22 import os
23 import re
23 import re
24 import sys
24 import sys
25 import textwrap
25 import textwrap
26 from string import Formatter
26 from string import Formatter
27
27
28 from IPython.external.path import path
28 from IPython.external.path import path
29 from IPython.testing.skipdoctest import skip_doctest_py3, skip_doctest
29 from IPython.testing.skipdoctest import skip_doctest_py3, skip_doctest
30 from IPython.utils import py3compat
30 from IPython.utils import py3compat
31
31
32 #-----------------------------------------------------------------------------
32 #-----------------------------------------------------------------------------
33 # Declarations
33 # Declarations
34 #-----------------------------------------------------------------------------
34 #-----------------------------------------------------------------------------
35
35
36 # datetime.strftime date format for ipython
36 # datetime.strftime date format for ipython
37 if sys.platform == 'win32':
37 if sys.platform == 'win32':
38 date_format = "%B %d, %Y"
38 date_format = "%B %d, %Y"
39 else:
39 else:
40 date_format = "%B %-d, %Y"
40 date_format = "%B %-d, %Y"
41
41
42
42
43 #-----------------------------------------------------------------------------
43 #-----------------------------------------------------------------------------
44 # Code
44 # Code
45 #-----------------------------------------------------------------------------
45 #-----------------------------------------------------------------------------
46
46
47 class LSString(str):
47 class LSString(str):
48 """String derivative with a special access attributes.
48 """String derivative with a special access attributes.
49
49
50 These are normal strings, but with the special attributes:
50 These are normal strings, but with the special attributes:
51
51
52 .l (or .list) : value as list (split on newlines).
52 .l (or .list) : value as list (split on newlines).
53 .n (or .nlstr): original value (the string itself).
53 .n (or .nlstr): original value (the string itself).
54 .s (or .spstr): value as whitespace-separated string.
54 .s (or .spstr): value as whitespace-separated string.
55 .p (or .paths): list of path objects
55 .p (or .paths): list of path objects
56
56
57 Any values which require transformations are computed only once and
57 Any values which require transformations are computed only once and
58 cached.
58 cached.
59
59
60 Such strings are very useful to efficiently interact with the shell, which
60 Such strings are very useful to efficiently interact with the shell, which
61 typically only understands whitespace-separated options for commands."""
61 typically only understands whitespace-separated options for commands."""
62
62
63 def get_list(self):
63 def get_list(self):
64 try:
64 try:
65 return self.__list
65 return self.__list
66 except AttributeError:
66 except AttributeError:
67 self.__list = self.split('\n')
67 self.__list = self.split('\n')
68 return self.__list
68 return self.__list
69
69
70 l = list = property(get_list)
70 l = list = property(get_list)
71
71
72 def get_spstr(self):
72 def get_spstr(self):
73 try:
73 try:
74 return self.__spstr
74 return self.__spstr
75 except AttributeError:
75 except AttributeError:
76 self.__spstr = self.replace('\n',' ')
76 self.__spstr = self.replace('\n',' ')
77 return self.__spstr
77 return self.__spstr
78
78
79 s = spstr = property(get_spstr)
79 s = spstr = property(get_spstr)
80
80
81 def get_nlstr(self):
81 def get_nlstr(self):
82 return self
82 return self
83
83
84 n = nlstr = property(get_nlstr)
84 n = nlstr = property(get_nlstr)
85
85
86 def get_paths(self):
86 def get_paths(self):
87 try:
87 try:
88 return self.__paths
88 return self.__paths
89 except AttributeError:
89 except AttributeError:
90 self.__paths = [path(p) for p in self.split('\n') if os.path.exists(p)]
90 self.__paths = [path(p) for p in self.split('\n') if os.path.exists(p)]
91 return self.__paths
91 return self.__paths
92
92
93 p = paths = property(get_paths)
93 p = paths = property(get_paths)
94
94
95 # FIXME: We need to reimplement type specific displayhook and then add this
95 # FIXME: We need to reimplement type specific displayhook and then add this
96 # back as a custom printer. This should also be moved outside utils into the
96 # back as a custom printer. This should also be moved outside utils into the
97 # core.
97 # core.
98
98
99 # def print_lsstring(arg):
99 # def print_lsstring(arg):
100 # """ Prettier (non-repr-like) and more informative printer for LSString """
100 # """ Prettier (non-repr-like) and more informative printer for LSString """
101 # print "LSString (.p, .n, .l, .s available). Value:"
101 # print "LSString (.p, .n, .l, .s available). Value:"
102 # print arg
102 # print arg
103 #
103 #
104 #
104 #
105 # print_lsstring = result_display.when_type(LSString)(print_lsstring)
105 # print_lsstring = result_display.when_type(LSString)(print_lsstring)
106
106
107
107
108 class SList(list):
108 class SList(list):
109 """List derivative with a special access attributes.
109 """List derivative with a special access attributes.
110
110
111 These are normal lists, but with the special attributes:
111 These are normal lists, but with the special attributes:
112
112
113 .l (or .list) : value as list (the list itself).
113 .l (or .list) : value as list (the list itself).
114 .n (or .nlstr): value as a string, joined on newlines.
114 .n (or .nlstr): value as a string, joined on newlines.
115 .s (or .spstr): value as a string, joined on spaces.
115 .s (or .spstr): value as a string, joined on spaces.
116 .p (or .paths): list of path objects
116 .p (or .paths): list of path objects
117
117
118 Any values which require transformations are computed only once and
118 Any values which require transformations are computed only once and
119 cached."""
119 cached."""
120
120
121 def get_list(self):
121 def get_list(self):
122 return self
122 return self
123
123
124 l = list = property(get_list)
124 l = list = property(get_list)
125
125
126 def get_spstr(self):
126 def get_spstr(self):
127 try:
127 try:
128 return self.__spstr
128 return self.__spstr
129 except AttributeError:
129 except AttributeError:
130 self.__spstr = ' '.join(self)
130 self.__spstr = ' '.join(self)
131 return self.__spstr
131 return self.__spstr
132
132
133 s = spstr = property(get_spstr)
133 s = spstr = property(get_spstr)
134
134
135 def get_nlstr(self):
135 def get_nlstr(self):
136 try:
136 try:
137 return self.__nlstr
137 return self.__nlstr
138 except AttributeError:
138 except AttributeError:
139 self.__nlstr = '\n'.join(self)
139 self.__nlstr = '\n'.join(self)
140 return self.__nlstr
140 return self.__nlstr
141
141
142 n = nlstr = property(get_nlstr)
142 n = nlstr = property(get_nlstr)
143
143
144 def get_paths(self):
144 def get_paths(self):
145 try:
145 try:
146 return self.__paths
146 return self.__paths
147 except AttributeError:
147 except AttributeError:
148 self.__paths = [path(p) for p in self if os.path.exists(p)]
148 self.__paths = [path(p) for p in self if os.path.exists(p)]
149 return self.__paths
149 return self.__paths
150
150
151 p = paths = property(get_paths)
151 p = paths = property(get_paths)
152
152
153 def grep(self, pattern, prune = False, field = None):
153 def grep(self, pattern, prune = False, field = None):
154 """ Return all strings matching 'pattern' (a regex or callable)
154 """ Return all strings matching 'pattern' (a regex or callable)
155
155
156 This is case-insensitive. If prune is true, return all items
156 This is case-insensitive. If prune is true, return all items
157 NOT matching the pattern.
157 NOT matching the pattern.
158
158
159 If field is specified, the match must occur in the specified
159 If field is specified, the match must occur in the specified
160 whitespace-separated field.
160 whitespace-separated field.
161
161
162 Examples::
162 Examples::
163
163
164 a.grep( lambda x: x.startswith('C') )
164 a.grep( lambda x: x.startswith('C') )
165 a.grep('Cha.*log', prune=1)
165 a.grep('Cha.*log', prune=1)
166 a.grep('chm', field=-1)
166 a.grep('chm', field=-1)
167 """
167 """
168
168
169 def match_target(s):
169 def match_target(s):
170 if field is None:
170 if field is None:
171 return s
171 return s
172 parts = s.split()
172 parts = s.split()
173 try:
173 try:
174 tgt = parts[field]
174 tgt = parts[field]
175 return tgt
175 return tgt
176 except IndexError:
176 except IndexError:
177 return ""
177 return ""
178
178
179 if isinstance(pattern, py3compat.string_types):
179 if isinstance(pattern, py3compat.string_types):
180 pred = lambda x : re.search(pattern, x, re.IGNORECASE)
180 pred = lambda x : re.search(pattern, x, re.IGNORECASE)
181 else:
181 else:
182 pred = pattern
182 pred = pattern
183 if not prune:
183 if not prune:
184 return SList([el for el in self if pred(match_target(el))])
184 return SList([el for el in self if pred(match_target(el))])
185 else:
185 else:
186 return SList([el for el in self if not pred(match_target(el))])
186 return SList([el for el in self if not pred(match_target(el))])
187
187
188 def fields(self, *fields):
188 def fields(self, *fields):
189 """ Collect whitespace-separated fields from string list
189 """ Collect whitespace-separated fields from string list
190
190
191 Allows quick awk-like usage of string lists.
191 Allows quick awk-like usage of string lists.
192
192
193 Example data (in var a, created by 'a = !ls -l')::
193 Example data (in var a, created by 'a = !ls -l')::
194 -rwxrwxrwx 1 ville None 18 Dec 14 2006 ChangeLog
194 -rwxrwxrwx 1 ville None 18 Dec 14 2006 ChangeLog
195 drwxrwxrwx+ 6 ville None 0 Oct 24 18:05 IPython
195 drwxrwxrwx+ 6 ville None 0 Oct 24 18:05 IPython
196
196
197 a.fields(0) is ['-rwxrwxrwx', 'drwxrwxrwx+']
197 a.fields(0) is ['-rwxrwxrwx', 'drwxrwxrwx+']
198 a.fields(1,0) is ['1 -rwxrwxrwx', '6 drwxrwxrwx+']
198 a.fields(1,0) is ['1 -rwxrwxrwx', '6 drwxrwxrwx+']
199 (note the joining by space).
199 (note the joining by space).
200 a.fields(-1) is ['ChangeLog', 'IPython']
200 a.fields(-1) is ['ChangeLog', 'IPython']
201
201
202 IndexErrors are ignored.
202 IndexErrors are ignored.
203
203
204 Without args, fields() just split()'s the strings.
204 Without args, fields() just split()'s the strings.
205 """
205 """
206 if len(fields) == 0:
206 if len(fields) == 0:
207 return [el.split() for el in self]
207 return [el.split() for el in self]
208
208
209 res = SList()
209 res = SList()
210 for el in [f.split() for f in self]:
210 for el in [f.split() for f in self]:
211 lineparts = []
211 lineparts = []
212
212
213 for fd in fields:
213 for fd in fields:
214 try:
214 try:
215 lineparts.append(el[fd])
215 lineparts.append(el[fd])
216 except IndexError:
216 except IndexError:
217 pass
217 pass
218 if lineparts:
218 if lineparts:
219 res.append(" ".join(lineparts))
219 res.append(" ".join(lineparts))
220
220
221 return res
221 return res
222
222
223 def sort(self,field= None, nums = False):
223 def sort(self,field= None, nums = False):
224 """ sort by specified fields (see fields())
224 """ sort by specified fields (see fields())
225
225
226 Example::
226 Example::
227 a.sort(1, nums = True)
227 a.sort(1, nums = True)
228
228
229 Sorts a by second field, in numerical order (so that 21 > 3)
229 Sorts a by second field, in numerical order (so that 21 > 3)
230
230
231 """
231 """
232
232
233 #decorate, sort, undecorate
233 #decorate, sort, undecorate
234 if field is not None:
234 if field is not None:
235 dsu = [[SList([line]).fields(field), line] for line in self]
235 dsu = [[SList([line]).fields(field), line] for line in self]
236 else:
236 else:
237 dsu = [[line, line] for line in self]
237 dsu = [[line, line] for line in self]
238 if nums:
238 if nums:
239 for i in range(len(dsu)):
239 for i in range(len(dsu)):
240 numstr = "".join([ch for ch in dsu[i][0] if ch.isdigit()])
240 numstr = "".join([ch for ch in dsu[i][0] if ch.isdigit()])
241 try:
241 try:
242 n = int(numstr)
242 n = int(numstr)
243 except ValueError:
243 except ValueError:
244 n = 0;
244 n = 0;
245 dsu[i][0] = n
245 dsu[i][0] = n
246
246
247
247
248 dsu.sort()
248 dsu.sort()
249 return SList([t[1] for t in dsu])
249 return SList([t[1] for t in dsu])
250
250
251
251
252 # FIXME: We need to reimplement type specific displayhook and then add this
252 # FIXME: We need to reimplement type specific displayhook and then add this
253 # back as a custom printer. This should also be moved outside utils into the
253 # back as a custom printer. This should also be moved outside utils into the
254 # core.
254 # core.
255
255
256 # def print_slist(arg):
256 # def print_slist(arg):
257 # """ Prettier (non-repr-like) and more informative printer for SList """
257 # """ Prettier (non-repr-like) and more informative printer for SList """
258 # print "SList (.p, .n, .l, .s, .grep(), .fields(), sort() available):"
258 # print "SList (.p, .n, .l, .s, .grep(), .fields(), sort() available):"
259 # if hasattr(arg, 'hideonce') and arg.hideonce:
259 # if hasattr(arg, 'hideonce') and arg.hideonce:
260 # arg.hideonce = False
260 # arg.hideonce = False
261 # return
261 # return
262 #
262 #
263 # nlprint(arg) # This was a nested list printer, now removed.
263 # nlprint(arg) # This was a nested list printer, now removed.
264 #
264 #
265 # print_slist = result_display.when_type(SList)(print_slist)
265 # print_slist = result_display.when_type(SList)(print_slist)
266
266
267
267
268 def indent(instr,nspaces=4, ntabs=0, flatten=False):
268 def indent(instr,nspaces=4, ntabs=0, flatten=False):
269 """Indent a string a given number of spaces or tabstops.
269 """Indent a string a given number of spaces or tabstops.
270
270
271 indent(str,nspaces=4,ntabs=0) -> indent str by ntabs+nspaces.
271 indent(str,nspaces=4,ntabs=0) -> indent str by ntabs+nspaces.
272
272
273 Parameters
273 Parameters
274 ----------
274 ----------
275
275
276 instr : basestring
276 instr : basestring
277 The string to be indented.
277 The string to be indented.
278 nspaces : int (default: 4)
278 nspaces : int (default: 4)
279 The number of spaces to be indented.
279 The number of spaces to be indented.
280 ntabs : int (default: 0)
280 ntabs : int (default: 0)
281 The number of tabs to be indented.
281 The number of tabs to be indented.
282 flatten : bool (default: False)
282 flatten : bool (default: False)
283 Whether to scrub existing indentation. If True, all lines will be
283 Whether to scrub existing indentation. If True, all lines will be
284 aligned to the same indentation. If False, existing indentation will
284 aligned to the same indentation. If False, existing indentation will
285 be strictly increased.
285 be strictly increased.
286
286
287 Returns
287 Returns
288 -------
288 -------
289
289
290 str|unicode : string indented by ntabs and nspaces.
290 str|unicode : string indented by ntabs and nspaces.
291
291
292 """
292 """
293 if instr is None:
293 if instr is None:
294 return
294 return
295 ind = '\t'*ntabs+' '*nspaces
295 ind = '\t'*ntabs+' '*nspaces
296 if flatten:
296 if flatten:
297 pat = re.compile(r'^\s*', re.MULTILINE)
297 pat = re.compile(r'^\s*', re.MULTILINE)
298 else:
298 else:
299 pat = re.compile(r'^', re.MULTILINE)
299 pat = re.compile(r'^', re.MULTILINE)
300 outstr = re.sub(pat, ind, instr)
300 outstr = re.sub(pat, ind, instr)
301 if outstr.endswith(os.linesep+ind):
301 if outstr.endswith(os.linesep+ind):
302 return outstr[:-len(ind)]
302 return outstr[:-len(ind)]
303 else:
303 else:
304 return outstr
304 return outstr
305
305
306
306
307 def list_strings(arg):
307 def list_strings(arg):
308 """Always return a list of strings, given a string or list of strings
308 """Always return a list of strings, given a string or list of strings
309 as input.
309 as input.
310
310
311 :Examples:
311 :Examples:
312
312
313 In [7]: list_strings('A single string')
313 In [7]: list_strings('A single string')
314 Out[7]: ['A single string']
314 Out[7]: ['A single string']
315
315
316 In [8]: list_strings(['A single string in a list'])
316 In [8]: list_strings(['A single string in a list'])
317 Out[8]: ['A single string in a list']
317 Out[8]: ['A single string in a list']
318
318
319 In [9]: list_strings(['A','list','of','strings'])
319 In [9]: list_strings(['A','list','of','strings'])
320 Out[9]: ['A', 'list', 'of', 'strings']
320 Out[9]: ['A', 'list', 'of', 'strings']
321 """
321 """
322
322
323 if isinstance(arg, py3compat.string_types): return [arg]
323 if isinstance(arg, py3compat.string_types): return [arg]
324 else: return arg
324 else: return arg
325
325
326
326
327 def marquee(txt='',width=78,mark='*'):
327 def marquee(txt='',width=78,mark='*'):
328 """Return the input string centered in a 'marquee'.
328 """Return the input string centered in a 'marquee'.
329
329
330 :Examples:
330 :Examples:
331
331
332 In [16]: marquee('A test',40)
332 In [16]: marquee('A test',40)
333 Out[16]: '**************** A test ****************'
333 Out[16]: '**************** A test ****************'
334
334
335 In [17]: marquee('A test',40,'-')
335 In [17]: marquee('A test',40,'-')
336 Out[17]: '---------------- A test ----------------'
336 Out[17]: '---------------- A test ----------------'
337
337
338 In [18]: marquee('A test',40,' ')
338 In [18]: marquee('A test',40,' ')
339 Out[18]: ' A test '
339 Out[18]: ' A test '
340
340
341 """
341 """
342 if not txt:
342 if not txt:
343 return (mark*width)[:width]
343 return (mark*width)[:width]
344 nmark = (width-len(txt)-2)//len(mark)//2
344 nmark = (width-len(txt)-2)//len(mark)//2
345 if nmark < 0: nmark =0
345 if nmark < 0: nmark =0
346 marks = mark*nmark
346 marks = mark*nmark
347 return '%s %s %s' % (marks,txt,marks)
347 return '%s %s %s' % (marks,txt,marks)
348
348
349
349
350 ini_spaces_re = re.compile(r'^(\s+)')
350 ini_spaces_re = re.compile(r'^(\s+)')
351
351
352 def num_ini_spaces(strng):
352 def num_ini_spaces(strng):
353 """Return the number of initial spaces in a string"""
353 """Return the number of initial spaces in a string"""
354
354
355 ini_spaces = ini_spaces_re.match(strng)
355 ini_spaces = ini_spaces_re.match(strng)
356 if ini_spaces:
356 if ini_spaces:
357 return ini_spaces.end()
357 return ini_spaces.end()
358 else:
358 else:
359 return 0
359 return 0
360
360
361
361
362 def format_screen(strng):
362 def format_screen(strng):
363 """Format a string for screen printing.
363 """Format a string for screen printing.
364
364
365 This removes some latex-type format codes."""
365 This removes some latex-type format codes."""
366 # Paragraph continue
366 # Paragraph continue
367 par_re = re.compile(r'\\$',re.MULTILINE)
367 par_re = re.compile(r'\\$',re.MULTILINE)
368 strng = par_re.sub('',strng)
368 strng = par_re.sub('',strng)
369 return strng
369 return strng
370
370
371
371
372 def dedent(text):
372 def dedent(text):
373 """Equivalent of textwrap.dedent that ignores unindented first line.
373 """Equivalent of textwrap.dedent that ignores unindented first line.
374
374
375 This means it will still dedent strings like:
375 This means it will still dedent strings like:
376 '''foo
376 '''foo
377 is a bar
377 is a bar
378 '''
378 '''
379
379
380 For use in wrap_paragraphs.
380 For use in wrap_paragraphs.
381 """
381 """
382
382
383 if text.startswith('\n'):
383 if text.startswith('\n'):
384 # text starts with blank line, don't ignore the first line
384 # text starts with blank line, don't ignore the first line
385 return textwrap.dedent(text)
385 return textwrap.dedent(text)
386
386
387 # split first line
387 # split first line
388 splits = text.split('\n',1)
388 splits = text.split('\n',1)
389 if len(splits) == 1:
389 if len(splits) == 1:
390 # only one line
390 # only one line
391 return textwrap.dedent(text)
391 return textwrap.dedent(text)
392
392
393 first, rest = splits
393 first, rest = splits
394 # dedent everything but the first line
394 # dedent everything but the first line
395 rest = textwrap.dedent(rest)
395 rest = textwrap.dedent(rest)
396 return '\n'.join([first, rest])
396 return '\n'.join([first, rest])
397
397
398
398
399 def wrap_paragraphs(text, ncols=80):
399 def wrap_paragraphs(text, ncols=80):
400 """Wrap multiple paragraphs to fit a specified width.
400 """Wrap multiple paragraphs to fit a specified width.
401
401
402 This is equivalent to textwrap.wrap, but with support for multiple
402 This is equivalent to textwrap.wrap, but with support for multiple
403 paragraphs, as separated by empty lines.
403 paragraphs, as separated by empty lines.
404
404
405 Returns
405 Returns
406 -------
406 -------
407
407
408 list of complete paragraphs, wrapped to fill `ncols` columns.
408 list of complete paragraphs, wrapped to fill `ncols` columns.
409 """
409 """
410 paragraph_re = re.compile(r'\n(\s*\n)+', re.MULTILINE)
410 paragraph_re = re.compile(r'\n(\s*\n)+', re.MULTILINE)
411 text = dedent(text).strip()
411 text = dedent(text).strip()
412 paragraphs = paragraph_re.split(text)[::2] # every other entry is space
412 paragraphs = paragraph_re.split(text)[::2] # every other entry is space
413 out_ps = []
413 out_ps = []
414 indent_re = re.compile(r'\n\s+', re.MULTILINE)
414 indent_re = re.compile(r'\n\s+', re.MULTILINE)
415 for p in paragraphs:
415 for p in paragraphs:
416 # presume indentation that survives dedent is meaningful formatting,
416 # presume indentation that survives dedent is meaningful formatting,
417 # so don't fill unless text is flush.
417 # so don't fill unless text is flush.
418 if indent_re.search(p) is None:
418 if indent_re.search(p) is None:
419 # wrap paragraph
419 # wrap paragraph
420 p = textwrap.fill(p, ncols)
420 p = textwrap.fill(p, ncols)
421 out_ps.append(p)
421 out_ps.append(p)
422 return out_ps
422 return out_ps
423
423
424
424
425 def long_substr(data):
425 def long_substr(data):
426 """Return the longest common substring in a list of strings.
426 """Return the longest common substring in a list of strings.
427
427
428 Credit: http://stackoverflow.com/questions/2892931/longest-common-substring-from-more-than-two-strings-python
428 Credit: http://stackoverflow.com/questions/2892931/longest-common-substring-from-more-than-two-strings-python
429 """
429 """
430 substr = ''
430 substr = ''
431 if len(data) > 1 and len(data[0]) > 0:
431 if len(data) > 1 and len(data[0]) > 0:
432 for i in range(len(data[0])):
432 for i in range(len(data[0])):
433 for j in range(len(data[0])-i+1):
433 for j in range(len(data[0])-i+1):
434 if j > len(substr) and all(data[0][i:i+j] in x for x in data):
434 if j > len(substr) and all(data[0][i:i+j] in x for x in data):
435 substr = data[0][i:i+j]
435 substr = data[0][i:i+j]
436 elif len(data) == 1:
436 elif len(data) == 1:
437 substr = data[0]
437 substr = data[0]
438 return substr
438 return substr
439
439
440
440
441 def strip_email_quotes(text):
441 def strip_email_quotes(text):
442 """Strip leading email quotation characters ('>').
442 """Strip leading email quotation characters ('>').
443
443
444 Removes any combination of leading '>' interspersed with whitespace that
444 Removes any combination of leading '>' interspersed with whitespace that
445 appears *identically* in all lines of the input text.
445 appears *identically* in all lines of the input text.
446
446
447 Parameters
447 Parameters
448 ----------
448 ----------
449 text : str
449 text : str
450
450
451 Examples
451 Examples
452 --------
452 --------
453
453
454 Simple uses::
454 Simple uses::
455
455
456 In [2]: strip_email_quotes('> > text')
456 In [2]: strip_email_quotes('> > text')
457 Out[2]: 'text'
457 Out[2]: 'text'
458
458
459 In [3]: strip_email_quotes('> > text\\n> > more')
459 In [3]: strip_email_quotes('> > text\\n> > more')
460 Out[3]: 'text\\nmore'
460 Out[3]: 'text\\nmore'
461
461
462 Note how only the common prefix that appears in all lines is stripped::
462 Note how only the common prefix that appears in all lines is stripped::
463
463
464 In [4]: strip_email_quotes('> > text\\n> > more\\n> more...')
464 In [4]: strip_email_quotes('> > text\\n> > more\\n> more...')
465 Out[4]: '> text\\n> more\\nmore...'
465 Out[4]: '> text\\n> more\\nmore...'
466
466
467 So if any line has no quote marks ('>') , then none are stripped from any
467 So if any line has no quote marks ('>') , then none are stripped from any
468 of them ::
468 of them ::
469
469
470 In [5]: strip_email_quotes('> > text\\n> > more\\nlast different')
470 In [5]: strip_email_quotes('> > text\\n> > more\\nlast different')
471 Out[5]: '> > text\\n> > more\\nlast different'
471 Out[5]: '> > text\\n> > more\\nlast different'
472 """
472 """
473 lines = text.splitlines()
473 lines = text.splitlines()
474 matches = set()
474 matches = set()
475 for line in lines:
475 for line in lines:
476 prefix = re.match(r'^(\s*>[ >]*)', line)
476 prefix = re.match(r'^(\s*>[ >]*)', line)
477 if prefix:
477 if prefix:
478 matches.add(prefix.group(1))
478 matches.add(prefix.group(1))
479 else:
479 else:
480 break
480 break
481 else:
481 else:
482 prefix = long_substr(list(matches))
482 prefix = long_substr(list(matches))
483 if prefix:
483 if prefix:
484 strip = len(prefix)
484 strip = len(prefix)
485 text = '\n'.join([ ln[strip:] for ln in lines])
485 text = '\n'.join([ ln[strip:] for ln in lines])
486 return text
486 return text
487
487
488
488
489 class EvalFormatter(Formatter):
489 class EvalFormatter(Formatter):
490 """A String Formatter that allows evaluation of simple expressions.
490 """A String Formatter that allows evaluation of simple expressions.
491
491
492 Note that this version interprets a : as specifying a format string (as per
492 Note that this version interprets a : as specifying a format string (as per
493 standard string formatting), so if slicing is required, you must explicitly
493 standard string formatting), so if slicing is required, you must explicitly
494 create a slice.
494 create a slice.
495
495
496 This is to be used in templating cases, such as the parallel batch
496 This is to be used in templating cases, such as the parallel batch
497 script templates, where simple arithmetic on arguments is useful.
497 script templates, where simple arithmetic on arguments is useful.
498
498
499 Examples
499 Examples
500 --------
500 --------
501
501
502 In [1]: f = EvalFormatter()
502 In [1]: f = EvalFormatter()
503 In [2]: f.format('{n//4}', n=8)
503 In [2]: f.format('{n//4}', n=8)
504 Out [2]: '2'
504 Out [2]: '2'
505
505
506 In [3]: f.format("{greeting[slice(2,4)]}", greeting="Hello")
506 In [3]: f.format("{greeting[slice(2,4)]}", greeting="Hello")
507 Out [3]: 'll'
507 Out [3]: 'll'
508 """
508 """
509 def get_field(self, name, args, kwargs):
509 def get_field(self, name, args, kwargs):
510 v = eval(name, kwargs)
510 v = eval(name, kwargs)
511 return v, name
511 return v, name
512
512
513
513
514 @skip_doctest_py3
514 @skip_doctest_py3
515 class FullEvalFormatter(Formatter):
515 class FullEvalFormatter(Formatter):
516 """A String Formatter that allows evaluation of simple expressions.
516 """A String Formatter that allows evaluation of simple expressions.
517
517
518 Any time a format key is not found in the kwargs,
518 Any time a format key is not found in the kwargs,
519 it will be tried as an expression in the kwargs namespace.
519 it will be tried as an expression in the kwargs namespace.
520
520
521 Note that this version allows slicing using [1:2], so you cannot specify
521 Note that this version allows slicing using [1:2], so you cannot specify
522 a format string. Use :class:`EvalFormatter` to permit format strings.
522 a format string. Use :class:`EvalFormatter` to permit format strings.
523
523
524 Examples
524 Examples
525 --------
525 --------
526
526
527 In [1]: f = FullEvalFormatter()
527 In [1]: f = FullEvalFormatter()
528 In [2]: f.format('{n//4}', n=8)
528 In [2]: f.format('{n//4}', n=8)
529 Out[2]: u'2'
529 Out[2]: u'2'
530
530
531 In [3]: f.format('{list(range(5))[2:4]}')
531 In [3]: f.format('{list(range(5))[2:4]}')
532 Out[3]: u'[2, 3]'
532 Out[3]: u'[2, 3]'
533
533
534 In [4]: f.format('{3*2}')
534 In [4]: f.format('{3*2}')
535 Out[4]: u'6'
535 Out[4]: u'6'
536 """
536 """
537 # copied from Formatter._vformat with minor changes to allow eval
537 # copied from Formatter._vformat with minor changes to allow eval
538 # and replace the format_spec code with slicing
538 # and replace the format_spec code with slicing
539 def _vformat(self, format_string, args, kwargs, used_args, recursion_depth):
539 def _vformat(self, format_string, args, kwargs, used_args, recursion_depth):
540 if recursion_depth < 0:
540 if recursion_depth < 0:
541 raise ValueError('Max string recursion exceeded')
541 raise ValueError('Max string recursion exceeded')
542 result = []
542 result = []
543 for literal_text, field_name, format_spec, conversion in \
543 for literal_text, field_name, format_spec, conversion in \
544 self.parse(format_string):
544 self.parse(format_string):
545
545
546 # output the literal text
546 # output the literal text
547 if literal_text:
547 if literal_text:
548 result.append(literal_text)
548 result.append(literal_text)
549
549
550 # if there's a field, output it
550 # if there's a field, output it
551 if field_name is not None:
551 if field_name is not None:
552 # this is some markup, find the object and do
552 # this is some markup, find the object and do
553 # the formatting
553 # the formatting
554
554
555 if format_spec:
555 if format_spec:
556 # override format spec, to allow slicing:
556 # override format spec, to allow slicing:
557 field_name = ':'.join([field_name, format_spec])
557 field_name = ':'.join([field_name, format_spec])
558
558
559 # eval the contents of the field for the object
559 # eval the contents of the field for the object
560 # to be formatted
560 # to be formatted
561 obj = eval(field_name, kwargs)
561 obj = eval(field_name, kwargs)
562
562
563 # do any conversion on the resulting object
563 # do any conversion on the resulting object
564 obj = self.convert_field(obj, conversion)
564 obj = self.convert_field(obj, conversion)
565
565
566 # format the object and append to the result
566 # format the object and append to the result
567 result.append(self.format_field(obj, ''))
567 result.append(self.format_field(obj, ''))
568
568
569 return u''.join(py3compat.cast_unicode(s) for s in result)
569 return u''.join(py3compat.cast_unicode(s) for s in result)
570
570
571
571
572 @skip_doctest_py3
572 @skip_doctest_py3
573 class DollarFormatter(FullEvalFormatter):
573 class DollarFormatter(FullEvalFormatter):
574 """Formatter allowing Itpl style $foo replacement, for names and attribute
574 """Formatter allowing Itpl style $foo replacement, for names and attribute
575 access only. Standard {foo} replacement also works, and allows full
575 access only. Standard {foo} replacement also works, and allows full
576 evaluation of its arguments.
576 evaluation of its arguments.
577
577
578 Examples
578 Examples
579 --------
579 --------
580 In [1]: f = DollarFormatter()
580 In [1]: f = DollarFormatter()
581 In [2]: f.format('{n//4}', n=8)
581 In [2]: f.format('{n//4}', n=8)
582 Out[2]: u'2'
582 Out[2]: u'2'
583
583
584 In [3]: f.format('23 * 76 is $result', result=23*76)
584 In [3]: f.format('23 * 76 is $result', result=23*76)
585 Out[3]: u'23 * 76 is 1748'
585 Out[3]: u'23 * 76 is 1748'
586
586
587 In [4]: f.format('$a or {b}', a=1, b=2)
587 In [4]: f.format('$a or {b}', a=1, b=2)
588 Out[4]: u'1 or 2'
588 Out[4]: u'1 or 2'
589 """
589 """
590 _dollar_pattern = re.compile("(.*?)\$(\$?[\w\.]+)")
590 _dollar_pattern = re.compile("(.*?)\$(\$?[\w\.]+)")
591 def parse(self, fmt_string):
591 def parse(self, fmt_string):
592 for literal_txt, field_name, format_spec, conversion \
592 for literal_txt, field_name, format_spec, conversion \
593 in Formatter.parse(self, fmt_string):
593 in Formatter.parse(self, fmt_string):
594
594
595 # Find $foo patterns in the literal text.
595 # Find $foo patterns in the literal text.
596 continue_from = 0
596 continue_from = 0
597 txt = ""
597 txt = ""
598 for m in self._dollar_pattern.finditer(literal_txt):
598 for m in self._dollar_pattern.finditer(literal_txt):
599 new_txt, new_field = m.group(1,2)
599 new_txt, new_field = m.group(1,2)
600 # $$foo --> $foo
600 # $$foo --> $foo
601 if new_field.startswith("$"):
601 if new_field.startswith("$"):
602 txt += new_txt + new_field
602 txt += new_txt + new_field
603 else:
603 else:
604 yield (txt + new_txt, new_field, "", None)
604 yield (txt + new_txt, new_field, "", None)
605 txt = ""
605 txt = ""
606 continue_from = m.end()
606 continue_from = m.end()
607
607
608 # Re-yield the {foo} style pattern
608 # Re-yield the {foo} style pattern
609 yield (txt + literal_txt[continue_from:], field_name, format_spec, conversion)
609 yield (txt + literal_txt[continue_from:], field_name, format_spec, conversion)
610
610
611 #-----------------------------------------------------------------------------
611 #-----------------------------------------------------------------------------
612 # Utils to columnize a list of string
612 # Utils to columnize a list of string
613 #-----------------------------------------------------------------------------
613 #-----------------------------------------------------------------------------
614
614
615 def _chunks(l, n):
615 def _chunks(l, n):
616 """Yield successive n-sized chunks from l."""
616 """Yield successive n-sized chunks from l."""
617 for i in py3compat.xrange(0, len(l), n):
617 for i in py3compat.xrange(0, len(l), n):
618 yield l[i:i+n]
618 yield l[i:i+n]
619
619
620
620
621 def _find_optimal(rlist , separator_size=2 , displaywidth=80):
621 def _find_optimal(rlist , separator_size=2 , displaywidth=80):
622 """Calculate optimal info to columnize a list of string"""
622 """Calculate optimal info to columnize a list of string"""
623 for nrow in range(1, len(rlist)+1) :
623 for nrow in range(1, len(rlist)+1) :
624 chk = map(max,_chunks(rlist, nrow))
624 chk = list(map(max,_chunks(rlist, nrow)))
625 sumlength = sum(chk)
625 sumlength = sum(chk)
626 ncols = len(chk)
626 ncols = len(chk)
627 if sumlength+separator_size*(ncols-1) <= displaywidth :
627 if sumlength+separator_size*(ncols-1) <= displaywidth :
628 break;
628 break;
629 return {'columns_numbers' : ncols,
629 return {'columns_numbers' : ncols,
630 'optimal_separator_width':(displaywidth - sumlength)/(ncols-1) if (ncols -1) else 0,
630 'optimal_separator_width':(displaywidth - sumlength)/(ncols-1) if (ncols -1) else 0,
631 'rows_numbers' : nrow,
631 'rows_numbers' : nrow,
632 'columns_width' : chk
632 'columns_width' : chk
633 }
633 }
634
634
635
635
636 def _get_or_default(mylist, i, default=None):
636 def _get_or_default(mylist, i, default=None):
637 """return list item number, or default if don't exist"""
637 """return list item number, or default if don't exist"""
638 if i >= len(mylist):
638 if i >= len(mylist):
639 return default
639 return default
640 else :
640 else :
641 return mylist[i]
641 return mylist[i]
642
642
643
643
644 @skip_doctest
644 @skip_doctest
645 def compute_item_matrix(items, empty=None, *args, **kwargs) :
645 def compute_item_matrix(items, empty=None, *args, **kwargs) :
646 """Returns a nested list, and info to columnize items
646 """Returns a nested list, and info to columnize items
647
647
648 Parameters
648 Parameters
649 ----------
649 ----------
650
650
651 items :
651 items :
652 list of strings to columize
652 list of strings to columize
653 empty : (default None)
653 empty : (default None)
654 default value to fill list if needed
654 default value to fill list if needed
655 separator_size : int (default=2)
655 separator_size : int (default=2)
656 How much caracters will be used as a separation between each columns.
656 How much caracters will be used as a separation between each columns.
657 displaywidth : int (default=80)
657 displaywidth : int (default=80)
658 The width of the area onto wich the columns should enter
658 The width of the area onto wich the columns should enter
659
659
660 Returns
660 Returns
661 -------
661 -------
662
662
663 Returns a tuple of (strings_matrix, dict_info)
663 Returns a tuple of (strings_matrix, dict_info)
664
664
665 strings_matrix :
665 strings_matrix :
666
666
667 nested list of string, the outer most list contains as many list as
667 nested list of string, the outer most list contains as many list as
668 rows, the innermost lists have each as many element as colums. If the
668 rows, the innermost lists have each as many element as colums. If the
669 total number of elements in `items` does not equal the product of
669 total number of elements in `items` does not equal the product of
670 rows*columns, the last element of some lists are filled with `None`.
670 rows*columns, the last element of some lists are filled with `None`.
671
671
672 dict_info :
672 dict_info :
673 some info to make columnize easier:
673 some info to make columnize easier:
674
674
675 columns_numbers : number of columns
675 columns_numbers : number of columns
676 rows_numbers : number of rows
676 rows_numbers : number of rows
677 columns_width : list of with of each columns
677 columns_width : list of with of each columns
678 optimal_separator_width : best separator width between columns
678 optimal_separator_width : best separator width between columns
679
679
680 Examples
680 Examples
681 --------
681 --------
682
682
683 In [1]: l = ['aaa','b','cc','d','eeeee','f','g','h','i','j','k','l']
683 In [1]: l = ['aaa','b','cc','d','eeeee','f','g','h','i','j','k','l']
684 ...: compute_item_matrix(l,displaywidth=12)
684 ...: compute_item_matrix(l,displaywidth=12)
685 Out[1]:
685 Out[1]:
686 ([['aaa', 'f', 'k'],
686 ([['aaa', 'f', 'k'],
687 ['b', 'g', 'l'],
687 ['b', 'g', 'l'],
688 ['cc', 'h', None],
688 ['cc', 'h', None],
689 ['d', 'i', None],
689 ['d', 'i', None],
690 ['eeeee', 'j', None]],
690 ['eeeee', 'j', None]],
691 {'columns_numbers': 3,
691 {'columns_numbers': 3,
692 'columns_width': [5, 1, 1],
692 'columns_width': [5, 1, 1],
693 'optimal_separator_width': 2,
693 'optimal_separator_width': 2,
694 'rows_numbers': 5})
694 'rows_numbers': 5})
695
695
696 """
696 """
697 info = _find_optimal(map(len, items), *args, **kwargs)
697 info = _find_optimal(list(map(len, items)), *args, **kwargs)
698 nrow, ncol = info['rows_numbers'], info['columns_numbers']
698 nrow, ncol = info['rows_numbers'], info['columns_numbers']
699 return ([[ _get_or_default(items, c*nrow+i, default=empty) for c in range(ncol) ] for i in range(nrow) ], info)
699 return ([[ _get_or_default(items, c*nrow+i, default=empty) for c in range(ncol) ] for i in range(nrow) ], info)
700
700
701
701
702 def columnize(items, separator=' ', displaywidth=80):
702 def columnize(items, separator=' ', displaywidth=80):
703 """ Transform a list of strings into a single string with columns.
703 """ Transform a list of strings into a single string with columns.
704
704
705 Parameters
705 Parameters
706 ----------
706 ----------
707 items : sequence of strings
707 items : sequence of strings
708 The strings to process.
708 The strings to process.
709
709
710 separator : str, optional [default is two spaces]
710 separator : str, optional [default is two spaces]
711 The string that separates columns.
711 The string that separates columns.
712
712
713 displaywidth : int, optional [default is 80]
713 displaywidth : int, optional [default is 80]
714 Width of the display in number of characters.
714 Width of the display in number of characters.
715
715
716 Returns
716 Returns
717 -------
717 -------
718 The formatted string.
718 The formatted string.
719 """
719 """
720 if not items :
720 if not items :
721 return '\n'
721 return '\n'
722 matrix, info = compute_item_matrix(items, separator_size=len(separator), displaywidth=displaywidth)
722 matrix, info = compute_item_matrix(items, separator_size=len(separator), displaywidth=displaywidth)
723 fmatrix = [filter(None, x) for x in matrix]
723 fmatrix = [filter(None, x) for x in matrix]
724 sjoin = lambda x : separator.join([ y.ljust(w, ' ') for y, w in zip(x, info['columns_width'])])
724 sjoin = lambda x : separator.join([ y.ljust(w, ' ') for y, w in zip(x, info['columns_width'])])
725 return '\n'.join(map(sjoin, fmatrix))+'\n'
725 return '\n'.join(map(sjoin, fmatrix))+'\n'
726
726
727
727
728 def get_text_list(list_, last_sep=' and ', sep=", ", wrap_item_with=""):
728 def get_text_list(list_, last_sep=' and ', sep=", ", wrap_item_with=""):
729 """
729 """
730 Return a string with a natural enumeration of items
730 Return a string with a natural enumeration of items
731
731
732 >>> get_text_list(['a', 'b', 'c', 'd'])
732 >>> get_text_list(['a', 'b', 'c', 'd'])
733 'a, b, c and d'
733 'a, b, c and d'
734 >>> get_text_list(['a', 'b', 'c'], ' or ')
734 >>> get_text_list(['a', 'b', 'c'], ' or ')
735 'a, b or c'
735 'a, b or c'
736 >>> get_text_list(['a', 'b', 'c'], ', ')
736 >>> get_text_list(['a', 'b', 'c'], ', ')
737 'a, b, c'
737 'a, b, c'
738 >>> get_text_list(['a', 'b'], ' or ')
738 >>> get_text_list(['a', 'b'], ' or ')
739 'a or b'
739 'a or b'
740 >>> get_text_list(['a'])
740 >>> get_text_list(['a'])
741 'a'
741 'a'
742 >>> get_text_list([])
742 >>> get_text_list([])
743 ''
743 ''
744 >>> get_text_list(['a', 'b'], wrap_item_with="`")
744 >>> get_text_list(['a', 'b'], wrap_item_with="`")
745 '`a` and `b`'
745 '`a` and `b`'
746 >>> get_text_list(['a', 'b', 'c', 'd'], " = ", sep=" + ")
746 >>> get_text_list(['a', 'b', 'c', 'd'], " = ", sep=" + ")
747 'a + b + c = d'
747 'a + b + c = d'
748 """
748 """
749 if len(list_) == 0:
749 if len(list_) == 0:
750 return ''
750 return ''
751 if wrap_item_with:
751 if wrap_item_with:
752 list_ = ['%s%s%s' % (wrap_item_with, item, wrap_item_with) for
752 list_ = ['%s%s%s' % (wrap_item_with, item, wrap_item_with) for
753 item in list_]
753 item in list_]
754 if len(list_) == 1:
754 if len(list_) == 1:
755 return list_[0]
755 return list_[0]
756 return '%s%s%s' % (
756 return '%s%s%s' % (
757 sep.join(i for i in list_[:-1]),
757 sep.join(i for i in list_[:-1]),
758 last_sep, list_[-1]) No newline at end of file
758 last_sep, list_[-1])
General Comments 0
You need to be logged in to leave comments. Login now