##// END OF EJS Templates
Fix tests in utils
Thomas Kluyver -
Show More
@@ -1,764 +1,764 b''
1 1 """Nose Plugin that supports IPython doctests.
2 2
3 3 Limitations:
4 4
5 5 - When generating examples for use as doctests, make sure that you have
6 6 pretty-printing OFF. This can be done either by setting the
7 7 ``PlainTextFormatter.pprint`` option in your configuration file to False, or
8 8 by interactively disabling it with %Pprint. This is required so that IPython
9 9 output matches that of normal Python, which is used by doctest for internal
10 10 execution.
11 11
12 12 - Do not rely on specific prompt numbers for results (such as using
13 13 '_34==True', for example). For IPython tests run via an external process the
14 14 prompt numbers may be different, and IPython tests run as normal python code
15 15 won't even have these special _NN variables set at all.
16 16 """
17 17
18 18 #-----------------------------------------------------------------------------
19 19 # Module imports
20 20
21 21 # From the standard library
22 22 import doctest
23 23 import inspect
24 24 import logging
25 25 import os
26 26 import re
27 27 import sys
28 28 import traceback
29 29 import unittest
30 30
31 31 from inspect import getmodule
32 32
33 33 # We are overriding the default doctest runner, so we need to import a few
34 34 # things from doctest directly
35 35 from doctest import (REPORTING_FLAGS, REPORT_ONLY_FIRST_FAILURE,
36 36 _unittest_reportflags, DocTestRunner,
37 37 _extract_future_flags, pdb, _OutputRedirectingPdb,
38 38 _exception_traceback,
39 39 linecache)
40 40
41 41 # Third-party modules
42 42 import nose.core
43 43
44 44 from nose.plugins import doctests, Plugin
45 45 from nose.util import anyp, getpackage, test_address, resolve_name, tolist
46 46
47 47 # Our own imports
48 48 from IPython.utils.py3compat import builtin_mod, PY3
49 49
50 50 if PY3:
51 51 from io import StringIO
52 52 else:
53 53 from StringIO import StringIO
54 54
55 55 #-----------------------------------------------------------------------------
56 56 # Module globals and other constants
57 57 #-----------------------------------------------------------------------------
58 58
59 59 log = logging.getLogger(__name__)
60 60
61 61
62 62 #-----------------------------------------------------------------------------
63 63 # Classes and functions
64 64 #-----------------------------------------------------------------------------
65 65
66 66 def is_extension_module(filename):
67 67 """Return whether the given filename is an extension module.
68 68
69 69 This simply checks that the extension is either .so or .pyd.
70 70 """
71 71 return os.path.splitext(filename)[1].lower() in ('.so','.pyd')
72 72
73 73
74 74 class DocTestSkip(object):
75 75 """Object wrapper for doctests to be skipped."""
76 76
77 77 ds_skip = """Doctest to skip.
78 78 >>> 1 #doctest: +SKIP
79 79 """
80 80
81 81 def __init__(self,obj):
82 82 self.obj = obj
83 83
84 84 def __getattribute__(self,key):
85 85 if key == '__doc__':
86 86 return DocTestSkip.ds_skip
87 87 else:
88 88 return getattr(object.__getattribute__(self,'obj'),key)
89 89
90 90 # Modified version of the one in the stdlib, that fixes a python bug (doctests
91 91 # not found in extension modules, http://bugs.python.org/issue3158)
92 92 class DocTestFinder(doctest.DocTestFinder):
93 93
94 94 def _from_module(self, module, object):
95 95 """
96 96 Return true if the given object is defined in the given
97 97 module.
98 98 """
99 99 if module is None:
100 100 return True
101 101 elif inspect.isfunction(object):
102 102 return module.__dict__ is object.__globals__
103 103 elif inspect.isbuiltin(object):
104 104 return module.__name__ == object.__module__
105 105 elif inspect.isclass(object):
106 106 return module.__name__ == object.__module__
107 107 elif inspect.ismethod(object):
108 108 # This one may be a bug in cython that fails to correctly set the
109 109 # __module__ attribute of methods, but since the same error is easy
110 110 # to make by extension code writers, having this safety in place
111 111 # isn't such a bad idea
112 112 return module.__name__ == object.__self__.__class__.__module__
113 113 elif inspect.getmodule(object) is not None:
114 114 return module is inspect.getmodule(object)
115 115 elif hasattr(object, '__module__'):
116 116 return module.__name__ == object.__module__
117 117 elif isinstance(object, property):
118 118 return True # [XX] no way not be sure.
119 119 else:
120 raise ValueError("object must be a class or function")
120 raise ValueError("object must be a class or function, got %r" % object)
121 121
122 122 def _find(self, tests, obj, name, module, source_lines, globs, seen):
123 123 """
124 124 Find tests for the given object and any contained objects, and
125 125 add them to `tests`.
126 126 """
127 127 #print '_find for:', obj, name, module # dbg
128 128 if hasattr(obj,"skip_doctest"):
129 129 #print 'SKIPPING DOCTEST FOR:',obj # dbg
130 130 obj = DocTestSkip(obj)
131 131
132 132 doctest.DocTestFinder._find(self,tests, obj, name, module,
133 133 source_lines, globs, seen)
134 134
135 135 # Below we re-run pieces of the above method with manual modifications,
136 136 # because the original code is buggy and fails to correctly identify
137 137 # doctests in extension modules.
138 138
139 139 # Local shorthands
140 140 from inspect import isroutine, isclass, ismodule
141 141
142 142 # Look for tests in a module's contained objects.
143 143 if inspect.ismodule(obj) and self._recurse:
144 144 for valname, val in obj.__dict__.items():
145 145 valname1 = '%s.%s' % (name, valname)
146 146 if ( (isroutine(val) or isclass(val))
147 147 and self._from_module(module, val) ):
148 148
149 149 self._find(tests, val, valname1, module, source_lines,
150 150 globs, seen)
151 151
152 152 # Look for tests in a class's contained objects.
153 153 if inspect.isclass(obj) and self._recurse:
154 154 #print 'RECURSE into class:',obj # dbg
155 155 for valname, val in obj.__dict__.items():
156 156 # Special handling for staticmethod/classmethod.
157 157 if isinstance(val, staticmethod):
158 158 val = getattr(obj, valname)
159 159 if isinstance(val, classmethod):
160 160 val = getattr(obj, valname).__func__
161 161
162 162 # Recurse to methods, properties, and nested classes.
163 163 if ((inspect.isfunction(val) or inspect.isclass(val) or
164 164 inspect.ismethod(val) or
165 165 isinstance(val, property)) and
166 166 self._from_module(module, val)):
167 167 valname = '%s.%s' % (name, valname)
168 168 self._find(tests, val, valname, module, source_lines,
169 169 globs, seen)
170 170
171 171
172 172 class IPDoctestOutputChecker(doctest.OutputChecker):
173 173 """Second-chance checker with support for random tests.
174 174
175 175 If the default comparison doesn't pass, this checker looks in the expected
176 176 output string for flags that tell us to ignore the output.
177 177 """
178 178
179 179 random_re = re.compile(r'#\s*random\s+')
180 180
181 181 def check_output(self, want, got, optionflags):
182 182 """Check output, accepting special markers embedded in the output.
183 183
184 184 If the output didn't pass the default validation but the special string
185 185 '#random' is included, we accept it."""
186 186
187 187 # Let the original tester verify first, in case people have valid tests
188 188 # that happen to have a comment saying '#random' embedded in.
189 189 ret = doctest.OutputChecker.check_output(self, want, got,
190 190 optionflags)
191 191 if not ret and self.random_re.search(want):
192 192 #print >> sys.stderr, 'RANDOM OK:',want # dbg
193 193 return True
194 194
195 195 return ret
196 196
197 197
198 198 class DocTestCase(doctests.DocTestCase):
199 199 """Proxy for DocTestCase: provides an address() method that
200 200 returns the correct address for the doctest case. Otherwise
201 201 acts as a proxy to the test case. To provide hints for address(),
202 202 an obj may also be passed -- this will be used as the test object
203 203 for purposes of determining the test address, if it is provided.
204 204 """
205 205
206 206 # Note: this method was taken from numpy's nosetester module.
207 207
208 208 # Subclass nose.plugins.doctests.DocTestCase to work around a bug in
209 209 # its constructor that blocks non-default arguments from being passed
210 210 # down into doctest.DocTestCase
211 211
212 212 def __init__(self, test, optionflags=0, setUp=None, tearDown=None,
213 213 checker=None, obj=None, result_var='_'):
214 214 self._result_var = result_var
215 215 doctests.DocTestCase.__init__(self, test,
216 216 optionflags=optionflags,
217 217 setUp=setUp, tearDown=tearDown,
218 218 checker=checker)
219 219 # Now we must actually copy the original constructor from the stdlib
220 220 # doctest class, because we can't call it directly and a bug in nose
221 221 # means it never gets passed the right arguments.
222 222
223 223 self._dt_optionflags = optionflags
224 224 self._dt_checker = checker
225 225 self._dt_test = test
226 226 self._dt_test_globs_ori = test.globs
227 227 self._dt_setUp = setUp
228 228 self._dt_tearDown = tearDown
229 229
230 230 # XXX - store this runner once in the object!
231 231 runner = IPDocTestRunner(optionflags=optionflags,
232 232 checker=checker, verbose=False)
233 233 self._dt_runner = runner
234 234
235 235
236 236 # Each doctest should remember the directory it was loaded from, so
237 237 # things like %run work without too many contortions
238 238 self._ori_dir = os.path.dirname(test.filename)
239 239
240 240 # Modified runTest from the default stdlib
241 241 def runTest(self):
242 242 test = self._dt_test
243 243 runner = self._dt_runner
244 244
245 245 old = sys.stdout
246 246 new = StringIO()
247 247 optionflags = self._dt_optionflags
248 248
249 249 if not (optionflags & REPORTING_FLAGS):
250 250 # The option flags don't include any reporting flags,
251 251 # so add the default reporting flags
252 252 optionflags |= _unittest_reportflags
253 253
254 254 try:
255 255 # Save our current directory and switch out to the one where the
256 256 # test was originally created, in case another doctest did a
257 257 # directory change. We'll restore this in the finally clause.
258 258 curdir = os.getcwdu()
259 259 #print 'runTest in dir:', self._ori_dir # dbg
260 260 os.chdir(self._ori_dir)
261 261
262 262 runner.DIVIDER = "-"*70
263 263 failures, tries = runner.run(test,out=new.write,
264 264 clear_globs=False)
265 265 finally:
266 266 sys.stdout = old
267 267 os.chdir(curdir)
268 268
269 269 if failures:
270 270 raise self.failureException(self.format_failure(new.getvalue()))
271 271
272 272 def setUp(self):
273 273 """Modified test setup that syncs with ipython namespace"""
274 274 #print "setUp test", self._dt_test.examples # dbg
275 275 if isinstance(self._dt_test.examples[0], IPExample):
276 276 # for IPython examples *only*, we swap the globals with the ipython
277 277 # namespace, after updating it with the globals (which doctest
278 278 # fills with the necessary info from the module being tested).
279 279 self.user_ns_orig = {}
280 280 self.user_ns_orig.update(_ip.user_ns)
281 281 _ip.user_ns.update(self._dt_test.globs)
282 282 # We must remove the _ key in the namespace, so that Python's
283 283 # doctest code sets it naturally
284 284 _ip.user_ns.pop('_', None)
285 285 _ip.user_ns['__builtins__'] = builtin_mod
286 286 self._dt_test.globs = _ip.user_ns
287 287
288 288 super(DocTestCase, self).setUp()
289 289
290 290 def tearDown(self):
291 291
292 292 # Undo the test.globs reassignment we made, so that the parent class
293 293 # teardown doesn't destroy the ipython namespace
294 294 if isinstance(self._dt_test.examples[0], IPExample):
295 295 self._dt_test.globs = self._dt_test_globs_ori
296 296 _ip.user_ns.clear()
297 297 _ip.user_ns.update(self.user_ns_orig)
298 298
299 299 # XXX - fperez: I am not sure if this is truly a bug in nose 0.11, but
300 300 # it does look like one to me: its tearDown method tries to run
301 301 #
302 302 # delattr(builtin_mod, self._result_var)
303 303 #
304 304 # without checking that the attribute really is there; it implicitly
305 305 # assumes it should have been set via displayhook. But if the
306 306 # displayhook was never called, this doesn't necessarily happen. I
307 307 # haven't been able to find a little self-contained example outside of
308 308 # ipython that would show the problem so I can report it to the nose
309 309 # team, but it does happen a lot in our code.
310 310 #
311 311 # So here, we just protect as narrowly as possible by trapping an
312 312 # attribute error whose message would be the name of self._result_var,
313 313 # and letting any other error propagate.
314 314 try:
315 315 super(DocTestCase, self).tearDown()
316 316 except AttributeError as exc:
317 317 if exc.args[0] != self._result_var:
318 318 raise
319 319
320 320
321 321 # A simple subclassing of the original with a different class name, so we can
322 322 # distinguish and treat differently IPython examples from pure python ones.
323 323 class IPExample(doctest.Example): pass
324 324
325 325
326 326 class IPExternalExample(doctest.Example):
327 327 """Doctest examples to be run in an external process."""
328 328
329 329 def __init__(self, source, want, exc_msg=None, lineno=0, indent=0,
330 330 options=None):
331 331 # Parent constructor
332 332 doctest.Example.__init__(self,source,want,exc_msg,lineno,indent,options)
333 333
334 334 # An EXTRA newline is needed to prevent pexpect hangs
335 335 self.source += '\n'
336 336
337 337
338 338 class IPDocTestParser(doctest.DocTestParser):
339 339 """
340 340 A class used to parse strings containing doctest examples.
341 341
342 342 Note: This is a version modified to properly recognize IPython input and
343 343 convert any IPython examples into valid Python ones.
344 344 """
345 345 # This regular expression is used to find doctest examples in a
346 346 # string. It defines three groups: `source` is the source code
347 347 # (including leading indentation and prompts); `indent` is the
348 348 # indentation of the first (PS1) line of the source code; and
349 349 # `want` is the expected output (including leading indentation).
350 350
351 351 # Classic Python prompts or default IPython ones
352 352 _PS1_PY = r'>>>'
353 353 _PS2_PY = r'\.\.\.'
354 354
355 355 _PS1_IP = r'In\ \[\d+\]:'
356 356 _PS2_IP = r'\ \ \ \.\.\.+:'
357 357
358 358 _RE_TPL = r'''
359 359 # Source consists of a PS1 line followed by zero or more PS2 lines.
360 360 (?P<source>
361 361 (?:^(?P<indent> [ ]*) (?P<ps1> %s) .*) # PS1 line
362 362 (?:\n [ ]* (?P<ps2> %s) .*)*) # PS2 lines
363 363 \n? # a newline
364 364 # Want consists of any non-blank lines that do not start with PS1.
365 365 (?P<want> (?:(?![ ]*$) # Not a blank line
366 366 (?![ ]*%s) # Not a line starting with PS1
367 367 (?![ ]*%s) # Not a line starting with PS2
368 368 .*$\n? # But any other line
369 369 )*)
370 370 '''
371 371
372 372 _EXAMPLE_RE_PY = re.compile( _RE_TPL % (_PS1_PY,_PS2_PY,_PS1_PY,_PS2_PY),
373 373 re.MULTILINE | re.VERBOSE)
374 374
375 375 _EXAMPLE_RE_IP = re.compile( _RE_TPL % (_PS1_IP,_PS2_IP,_PS1_IP,_PS2_IP),
376 376 re.MULTILINE | re.VERBOSE)
377 377
378 378 # Mark a test as being fully random. In this case, we simply append the
379 379 # random marker ('#random') to each individual example's output. This way
380 380 # we don't need to modify any other code.
381 381 _RANDOM_TEST = re.compile(r'#\s*all-random\s+')
382 382
383 383 # Mark tests to be executed in an external process - currently unsupported.
384 384 _EXTERNAL_IP = re.compile(r'#\s*ipdoctest:\s*EXTERNAL')
385 385
386 386 def ip2py(self,source):
387 387 """Convert input IPython source into valid Python."""
388 388 block = _ip.input_transformer_manager.transform_cell(source)
389 389 if len(block.splitlines()) == 1:
390 390 return _ip.prefilter(block)
391 391 else:
392 392 return block
393 393
394 394 def parse(self, string, name='<string>'):
395 395 """
396 396 Divide the given string into examples and intervening text,
397 397 and return them as a list of alternating Examples and strings.
398 398 Line numbers for the Examples are 0-based. The optional
399 399 argument `name` is a name identifying this string, and is only
400 400 used for error messages.
401 401 """
402 402
403 403 #print 'Parse string:\n',string # dbg
404 404
405 405 string = string.expandtabs()
406 406 # If all lines begin with the same indentation, then strip it.
407 407 min_indent = self._min_indent(string)
408 408 if min_indent > 0:
409 409 string = '\n'.join([l[min_indent:] for l in string.split('\n')])
410 410
411 411 output = []
412 412 charno, lineno = 0, 0
413 413
414 414 # We make 'all random' tests by adding the '# random' mark to every
415 415 # block of output in the test.
416 416 if self._RANDOM_TEST.search(string):
417 417 random_marker = '\n# random'
418 418 else:
419 419 random_marker = ''
420 420
421 421 # Whether to convert the input from ipython to python syntax
422 422 ip2py = False
423 423 # Find all doctest examples in the string. First, try them as Python
424 424 # examples, then as IPython ones
425 425 terms = list(self._EXAMPLE_RE_PY.finditer(string))
426 426 if terms:
427 427 # Normal Python example
428 428 #print '-'*70 # dbg
429 429 #print 'PyExample, Source:\n',string # dbg
430 430 #print '-'*70 # dbg
431 431 Example = doctest.Example
432 432 else:
433 433 # It's an ipython example. Note that IPExamples are run
434 434 # in-process, so their syntax must be turned into valid python.
435 435 # IPExternalExamples are run out-of-process (via pexpect) so they
436 436 # don't need any filtering (a real ipython will be executing them).
437 437 terms = list(self._EXAMPLE_RE_IP.finditer(string))
438 438 if self._EXTERNAL_IP.search(string):
439 439 #print '-'*70 # dbg
440 440 #print 'IPExternalExample, Source:\n',string # dbg
441 441 #print '-'*70 # dbg
442 442 Example = IPExternalExample
443 443 else:
444 444 #print '-'*70 # dbg
445 445 #print 'IPExample, Source:\n',string # dbg
446 446 #print '-'*70 # dbg
447 447 Example = IPExample
448 448 ip2py = True
449 449
450 450 for m in terms:
451 451 # Add the pre-example text to `output`.
452 452 output.append(string[charno:m.start()])
453 453 # Update lineno (lines before this example)
454 454 lineno += string.count('\n', charno, m.start())
455 455 # Extract info from the regexp match.
456 456 (source, options, want, exc_msg) = \
457 457 self._parse_example(m, name, lineno,ip2py)
458 458
459 459 # Append the random-output marker (it defaults to empty in most
460 460 # cases, it's only non-empty for 'all-random' tests):
461 461 want += random_marker
462 462
463 463 if Example is IPExternalExample:
464 464 options[doctest.NORMALIZE_WHITESPACE] = True
465 465 want += '\n'
466 466
467 467 # Create an Example, and add it to the list.
468 468 if not self._IS_BLANK_OR_COMMENT(source):
469 469 output.append(Example(source, want, exc_msg,
470 470 lineno=lineno,
471 471 indent=min_indent+len(m.group('indent')),
472 472 options=options))
473 473 # Update lineno (lines inside this example)
474 474 lineno += string.count('\n', m.start(), m.end())
475 475 # Update charno.
476 476 charno = m.end()
477 477 # Add any remaining post-example text to `output`.
478 478 output.append(string[charno:])
479 479 return output
480 480
481 481 def _parse_example(self, m, name, lineno,ip2py=False):
482 482 """
483 483 Given a regular expression match from `_EXAMPLE_RE` (`m`),
484 484 return a pair `(source, want)`, where `source` is the matched
485 485 example's source code (with prompts and indentation stripped);
486 486 and `want` is the example's expected output (with indentation
487 487 stripped).
488 488
489 489 `name` is the string's name, and `lineno` is the line number
490 490 where the example starts; both are used for error messages.
491 491
492 492 Optional:
493 493 `ip2py`: if true, filter the input via IPython to convert the syntax
494 494 into valid python.
495 495 """
496 496
497 497 # Get the example's indentation level.
498 498 indent = len(m.group('indent'))
499 499
500 500 # Divide source into lines; check that they're properly
501 501 # indented; and then strip their indentation & prompts.
502 502 source_lines = m.group('source').split('\n')
503 503
504 504 # We're using variable-length input prompts
505 505 ps1 = m.group('ps1')
506 506 ps2 = m.group('ps2')
507 507 ps1_len = len(ps1)
508 508
509 509 self._check_prompt_blank(source_lines, indent, name, lineno,ps1_len)
510 510 if ps2:
511 511 self._check_prefix(source_lines[1:], ' '*indent + ps2, name, lineno)
512 512
513 513 source = '\n'.join([sl[indent+ps1_len+1:] for sl in source_lines])
514 514
515 515 if ip2py:
516 516 # Convert source input from IPython into valid Python syntax
517 517 source = self.ip2py(source)
518 518
519 519 # Divide want into lines; check that it's properly indented; and
520 520 # then strip the indentation. Spaces before the last newline should
521 521 # be preserved, so plain rstrip() isn't good enough.
522 522 want = m.group('want')
523 523 want_lines = want.split('\n')
524 524 if len(want_lines) > 1 and re.match(r' *$', want_lines[-1]):
525 525 del want_lines[-1] # forget final newline & spaces after it
526 526 self._check_prefix(want_lines, ' '*indent, name,
527 527 lineno + len(source_lines))
528 528
529 529 # Remove ipython output prompt that might be present in the first line
530 530 want_lines[0] = re.sub(r'Out\[\d+\]: \s*?\n?','',want_lines[0])
531 531
532 532 want = '\n'.join([wl[indent:] for wl in want_lines])
533 533
534 534 # If `want` contains a traceback message, then extract it.
535 535 m = self._EXCEPTION_RE.match(want)
536 536 if m:
537 537 exc_msg = m.group('msg')
538 538 else:
539 539 exc_msg = None
540 540
541 541 # Extract options from the source.
542 542 options = self._find_options(source, name, lineno)
543 543
544 544 return source, options, want, exc_msg
545 545
546 546 def _check_prompt_blank(self, lines, indent, name, lineno, ps1_len):
547 547 """
548 548 Given the lines of a source string (including prompts and
549 549 leading indentation), check to make sure that every prompt is
550 550 followed by a space character. If any line is not followed by
551 551 a space character, then raise ValueError.
552 552
553 553 Note: IPython-modified version which takes the input prompt length as a
554 554 parameter, so that prompts of variable length can be dealt with.
555 555 """
556 556 space_idx = indent+ps1_len
557 557 min_len = space_idx+1
558 558 for i, line in enumerate(lines):
559 559 if len(line) >= min_len and line[space_idx] != ' ':
560 560 raise ValueError('line %r of the docstring for %s '
561 561 'lacks blank after %s: %r' %
562 562 (lineno+i+1, name,
563 563 line[indent:space_idx], line))
564 564
565 565
566 566 SKIP = doctest.register_optionflag('SKIP')
567 567
568 568
569 569 class IPDocTestRunner(doctest.DocTestRunner,object):
570 570 """Test runner that synchronizes the IPython namespace with test globals.
571 571 """
572 572
573 573 def run(self, test, compileflags=None, out=None, clear_globs=True):
574 574
575 575 # Hack: ipython needs access to the execution context of the example,
576 576 # so that it can propagate user variables loaded by %run into
577 577 # test.globs. We put them here into our modified %run as a function
578 578 # attribute. Our new %run will then only make the namespace update
579 579 # when called (rather than unconconditionally updating test.globs here
580 580 # for all examples, most of which won't be calling %run anyway).
581 581 #_ip._ipdoctest_test_globs = test.globs
582 582 #_ip._ipdoctest_test_filename = test.filename
583 583
584 584 test.globs.update(_ip.user_ns)
585 585
586 586 return super(IPDocTestRunner,self).run(test,
587 587 compileflags,out,clear_globs)
588 588
589 589
590 590 class DocFileCase(doctest.DocFileCase):
591 591 """Overrides to provide filename
592 592 """
593 593 def address(self):
594 594 return (self._dt_test.filename, None, None)
595 595
596 596
597 597 class ExtensionDoctest(doctests.Doctest):
598 598 """Nose Plugin that supports doctests in extension modules.
599 599 """
600 600 name = 'extdoctest' # call nosetests with --with-extdoctest
601 601 enabled = True
602 602
603 603 def options(self, parser, env=os.environ):
604 604 Plugin.options(self, parser, env)
605 605 parser.add_option('--doctest-tests', action='store_true',
606 606 dest='doctest_tests',
607 607 default=env.get('NOSE_DOCTEST_TESTS',True),
608 608 help="Also look for doctests in test modules. "
609 609 "Note that classes, methods and functions should "
610 610 "have either doctests or non-doctest tests, "
611 611 "not both. [NOSE_DOCTEST_TESTS]")
612 612 parser.add_option('--doctest-extension', action="append",
613 613 dest="doctestExtension",
614 614 help="Also look for doctests in files with "
615 615 "this extension [NOSE_DOCTEST_EXTENSION]")
616 616 # Set the default as a list, if given in env; otherwise
617 617 # an additional value set on the command line will cause
618 618 # an error.
619 619 env_setting = env.get('NOSE_DOCTEST_EXTENSION')
620 620 if env_setting is not None:
621 621 parser.set_defaults(doctestExtension=tolist(env_setting))
622 622
623 623
624 624 def configure(self, options, config):
625 625 Plugin.configure(self, options, config)
626 626 # Pull standard doctest plugin out of config; we will do doctesting
627 627 config.plugins.plugins = [p for p in config.plugins.plugins
628 628 if p.name != 'doctest']
629 629 self.doctest_tests = options.doctest_tests
630 630 self.extension = tolist(options.doctestExtension)
631 631
632 632 self.parser = doctest.DocTestParser()
633 633 self.finder = DocTestFinder()
634 634 self.checker = IPDoctestOutputChecker()
635 635 self.globs = None
636 636 self.extraglobs = None
637 637
638 638
639 639 def loadTestsFromExtensionModule(self,filename):
640 640 bpath,mod = os.path.split(filename)
641 641 modname = os.path.splitext(mod)[0]
642 642 try:
643 643 sys.path.append(bpath)
644 644 module = __import__(modname)
645 645 tests = list(self.loadTestsFromModule(module))
646 646 finally:
647 647 sys.path.pop()
648 648 return tests
649 649
650 650 # NOTE: the method below is almost a copy of the original one in nose, with
651 651 # a few modifications to control output checking.
652 652
653 653 def loadTestsFromModule(self, module):
654 654 #print '*** ipdoctest - lTM',module # dbg
655 655
656 656 if not self.matches(module.__name__):
657 657 log.debug("Doctest doesn't want module %s", module)
658 658 return
659 659
660 660 tests = self.finder.find(module,globs=self.globs,
661 661 extraglobs=self.extraglobs)
662 662 if not tests:
663 663 return
664 664
665 665 # always use whitespace and ellipsis options
666 666 optionflags = doctest.NORMALIZE_WHITESPACE | doctest.ELLIPSIS
667 667
668 668 tests.sort()
669 669 module_file = module.__file__
670 670 if module_file[-4:] in ('.pyc', '.pyo'):
671 671 module_file = module_file[:-1]
672 672 for test in tests:
673 673 if not test.examples:
674 674 continue
675 675 if not test.filename:
676 676 test.filename = module_file
677 677
678 678 yield DocTestCase(test,
679 679 optionflags=optionflags,
680 680 checker=self.checker)
681 681
682 682
683 683 def loadTestsFromFile(self, filename):
684 684 #print "ipdoctest - from file", filename # dbg
685 685 if is_extension_module(filename):
686 686 for t in self.loadTestsFromExtensionModule(filename):
687 687 yield t
688 688 else:
689 689 if self.extension and anyp(filename.endswith, self.extension):
690 690 name = os.path.basename(filename)
691 691 dh = open(filename)
692 692 try:
693 693 doc = dh.read()
694 694 finally:
695 695 dh.close()
696 696 test = self.parser.get_doctest(
697 697 doc, globs={'__file__': filename}, name=name,
698 698 filename=filename, lineno=0)
699 699 if test.examples:
700 700 #print 'FileCase:',test.examples # dbg
701 701 yield DocFileCase(test)
702 702 else:
703 703 yield False # no tests to load
704 704
705 705
706 706 class IPythonDoctest(ExtensionDoctest):
707 707 """Nose Plugin that supports doctests in extension modules.
708 708 """
709 709 name = 'ipdoctest' # call nosetests with --with-ipdoctest
710 710 enabled = True
711 711
712 712 def makeTest(self, obj, parent):
713 713 """Look for doctests in the given object, which will be a
714 714 function, method or class.
715 715 """
716 716 #print 'Plugin analyzing:', obj, parent # dbg
717 717 # always use whitespace and ellipsis options
718 718 optionflags = doctest.NORMALIZE_WHITESPACE | doctest.ELLIPSIS
719 719
720 720 doctests = self.finder.find(obj, module=getmodule(parent))
721 721 if doctests:
722 722 for test in doctests:
723 723 if len(test.examples) == 0:
724 724 continue
725 725
726 726 yield DocTestCase(test, obj=obj,
727 727 optionflags=optionflags,
728 728 checker=self.checker)
729 729
730 730 def options(self, parser, env=os.environ):
731 731 #print "Options for nose plugin:", self.name # dbg
732 732 Plugin.options(self, parser, env)
733 733 parser.add_option('--ipdoctest-tests', action='store_true',
734 734 dest='ipdoctest_tests',
735 735 default=env.get('NOSE_IPDOCTEST_TESTS',True),
736 736 help="Also look for doctests in test modules. "
737 737 "Note that classes, methods and functions should "
738 738 "have either doctests or non-doctest tests, "
739 739 "not both. [NOSE_IPDOCTEST_TESTS]")
740 740 parser.add_option('--ipdoctest-extension', action="append",
741 741 dest="ipdoctest_extension",
742 742 help="Also look for doctests in files with "
743 743 "this extension [NOSE_IPDOCTEST_EXTENSION]")
744 744 # Set the default as a list, if given in env; otherwise
745 745 # an additional value set on the command line will cause
746 746 # an error.
747 747 env_setting = env.get('NOSE_IPDOCTEST_EXTENSION')
748 748 if env_setting is not None:
749 749 parser.set_defaults(ipdoctest_extension=tolist(env_setting))
750 750
751 751 def configure(self, options, config):
752 752 #print "Configuring nose plugin:", self.name # dbg
753 753 Plugin.configure(self, options, config)
754 754 # Pull standard doctest plugin out of config; we will do doctesting
755 755 config.plugins.plugins = [p for p in config.plugins.plugins
756 756 if p.name != 'doctest']
757 757 self.doctest_tests = options.ipdoctest_tests
758 758 self.extension = tolist(options.ipdoctest_extension)
759 759
760 760 self.parser = IPDocTestParser()
761 761 self.finder = DocTestFinder(parser=self.parser)
762 762 self.checker = IPDoctestOutputChecker()
763 763 self.globs = None
764 764 self.extraglobs = None
@@ -1,95 +1,95 b''
1 1 # encoding: utf-8
2 2 """
3 3 Utilities for working with stack frames.
4 4 """
5 5 from __future__ import print_function
6 6
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2008-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 import sys
19 19 from IPython.utils import py3compat
20 20
21 21 #-----------------------------------------------------------------------------
22 22 # Code
23 23 #-----------------------------------------------------------------------------
24 24
25 25 @py3compat.doctest_refactor_print
26 26 def extract_vars(*names,**kw):
27 27 """Extract a set of variables by name from another frame.
28 28
29 29 :Parameters:
30 30 - `*names`: strings
31 31 One or more variable names which will be extracted from the caller's
32 32 frame.
33 33
34 34 :Keywords:
35 35 - `depth`: integer (0)
36 36 How many frames in the stack to walk when looking for your variables.
37 37
38 38
39 39 Examples:
40 40
41 41 In [2]: def func(x):
42 42 ...: y = 1
43 ...: print sorted(extract_vars('x','y').items())
43 ...: print(sorted(extract_vars('x','y').items()))
44 44 ...:
45 45
46 46 In [3]: func('hello')
47 47 [('x', 'hello'), ('y', 1)]
48 48 """
49 49
50 50 depth = kw.get('depth',0)
51 51
52 52 callerNS = sys._getframe(depth+1).f_locals
53 53 return dict((k,callerNS[k]) for k in names)
54 54
55 55
56 56 def extract_vars_above(*names):
57 57 """Extract a set of variables by name from another frame.
58 58
59 59 Similar to extractVars(), but with a specified depth of 1, so that names
60 60 are exctracted exactly from above the caller.
61 61
62 62 This is simply a convenience function so that the very common case (for us)
63 63 of skipping exactly 1 frame doesn't have to construct a special dict for
64 64 keyword passing."""
65 65
66 66 callerNS = sys._getframe(2).f_locals
67 67 return dict((k,callerNS[k]) for k in names)
68 68
69 69
70 70 def debugx(expr,pre_msg=''):
71 71 """Print the value of an expression from the caller's frame.
72 72
73 73 Takes an expression, evaluates it in the caller's frame and prints both
74 74 the given expression and the resulting value (as well as a debug mark
75 75 indicating the name of the calling function. The input must be of a form
76 76 suitable for eval().
77 77
78 78 An optional message can be passed, which will be prepended to the printed
79 79 expr->value pair."""
80 80
81 81 cf = sys._getframe(1)
82 82 print('[DBG:%s] %s%s -> %r' % (cf.f_code.co_name,pre_msg,expr,
83 83 eval(expr,cf.f_globals,cf.f_locals)))
84 84
85 85
86 86 # deactivate it by uncommenting the following line, which makes it a no-op
87 87 #def debugx(expr,pre_msg=''): pass
88 88
89 89 def extract_module_locals(depth=0):
90 90 """Returns (module, locals) of the funciton `depth` frames away from the caller"""
91 91 f = sys._getframe(depth + 1)
92 92 global_ns = f.f_globals
93 93 module = sys.modules[global_ns['__name__']]
94 94 return (module, f.f_locals)
95 95
@@ -1,229 +1,229 b''
1 1 """Utilities to manipulate JSON objects.
2 2 """
3 3 #-----------------------------------------------------------------------------
4 4 # Copyright (C) 2010-2011 The IPython Development Team
5 5 #
6 6 # Distributed under the terms of the BSD License. The full license is in
7 7 # the file COPYING.txt, distributed as part of this software.
8 8 #-----------------------------------------------------------------------------
9 9
10 10 #-----------------------------------------------------------------------------
11 11 # Imports
12 12 #-----------------------------------------------------------------------------
13 13 # stdlib
14 14 import math
15 15 import re
16 16 import types
17 17 from datetime import datetime
18 18
19 19 try:
20 20 # base64.encodestring is deprecated in Python 3.x
21 21 from base64 import encodebytes
22 22 except ImportError:
23 23 # Python 2.x
24 24 from base64 import encodestring as encodebytes
25 25
26 26 from IPython.utils import py3compat
27 27 from IPython.utils.py3compat import string_types, unicode_type, iteritems
28 28 from IPython.utils.encoding import DEFAULT_ENCODING
29 29 next_attr_name = '__next__' if py3compat.PY3 else 'next'
30 30
31 31 #-----------------------------------------------------------------------------
32 32 # Globals and constants
33 33 #-----------------------------------------------------------------------------
34 34
35 35 # timestamp formats
36 36 ISO8601="%Y-%m-%dT%H:%M:%S.%f"
37 37 ISO8601_PAT=re.compile(r"^(\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d+)Z?([\+\-]\d{2}:?\d{2})?$")
38 38
39 39 #-----------------------------------------------------------------------------
40 40 # Classes and functions
41 41 #-----------------------------------------------------------------------------
42 42
43 43 def rekey(dikt):
44 44 """Rekey a dict that has been forced to use str keys where there should be
45 45 ints by json."""
46 46 for k in dikt:
47 47 if isinstance(k, string_types):
48 48 ik=fk=None
49 49 try:
50 50 ik = int(k)
51 51 except ValueError:
52 52 try:
53 53 fk = float(k)
54 54 except ValueError:
55 55 continue
56 56 if ik is not None:
57 57 nk = ik
58 58 else:
59 59 nk = fk
60 60 if nk in dikt:
61 61 raise KeyError("already have key %r"%nk)
62 62 dikt[nk] = dikt.pop(k)
63 63 return dikt
64 64
65 65
66 66 def extract_dates(obj):
67 67 """extract ISO8601 dates from unpacked JSON"""
68 68 if isinstance(obj, dict):
69 69 obj = dict(obj) # don't clobber
70 70 for k,v in iteritems(obj):
71 71 obj[k] = extract_dates(v)
72 72 elif isinstance(obj, (list, tuple)):
73 73 obj = [ extract_dates(o) for o in obj ]
74 74 elif isinstance(obj, string_types):
75 75 m = ISO8601_PAT.match(obj)
76 76 if m:
77 77 # FIXME: add actual timezone support
78 78 # this just drops the timezone info
79 79 notz = m.groups()[0]
80 80 obj = datetime.strptime(notz, ISO8601)
81 81 return obj
82 82
83 83 def squash_dates(obj):
84 84 """squash datetime objects into ISO8601 strings"""
85 85 if isinstance(obj, dict):
86 86 obj = dict(obj) # don't clobber
87 87 for k,v in iteritems(obj):
88 88 obj[k] = squash_dates(v)
89 89 elif isinstance(obj, (list, tuple)):
90 90 obj = [ squash_dates(o) for o in obj ]
91 91 elif isinstance(obj, datetime):
92 92 obj = obj.isoformat()
93 93 return obj
94 94
95 95 def date_default(obj):
96 96 """default function for packing datetime objects in JSON."""
97 97 if isinstance(obj, datetime):
98 98 return obj.isoformat()
99 99 else:
100 100 raise TypeError("%r is not JSON serializable"%obj)
101 101
102 102
103 103 # constants for identifying png/jpeg data
104 104 PNG = b'\x89PNG\r\n\x1a\n'
105 105 # front of PNG base64-encoded
106 106 PNG64 = b'iVBORw0KG'
107 107 JPEG = b'\xff\xd8'
108 108 # front of JPEG base64-encoded
109 109 JPEG64 = b'/9'
110 110
111 111 def encode_images(format_dict):
112 112 """b64-encodes images in a displaypub format dict
113 113
114 114 Perhaps this should be handled in json_clean itself?
115 115
116 116 Parameters
117 117 ----------
118 118
119 119 format_dict : dict
120 120 A dictionary of display data keyed by mime-type
121 121
122 122 Returns
123 123 -------
124 124
125 125 format_dict : dict
126 126 A copy of the same dictionary,
127 127 but binary image data ('image/png' or 'image/jpeg')
128 128 is base64-encoded.
129 129
130 130 """
131 131 encoded = format_dict.copy()
132 132
133 133 pngdata = format_dict.get('image/png')
134 134 if isinstance(pngdata, bytes):
135 135 # make sure we don't double-encode
136 136 if not pngdata.startswith(PNG64):
137 137 pngdata = encodebytes(pngdata)
138 138 encoded['image/png'] = pngdata.decode('ascii')
139 139
140 140 jpegdata = format_dict.get('image/jpeg')
141 141 if isinstance(jpegdata, bytes):
142 142 # make sure we don't double-encode
143 143 if not jpegdata.startswith(JPEG64):
144 144 jpegdata = encodebytes(jpegdata)
145 145 encoded['image/jpeg'] = jpegdata.decode('ascii')
146 146
147 147 return encoded
148 148
149 149
150 150 def json_clean(obj):
151 151 """Clean an object to ensure it's safe to encode in JSON.
152 152
153 153 Atomic, immutable objects are returned unmodified. Sets and tuples are
154 154 converted to lists, lists are copied and dicts are also copied.
155 155
156 156 Note: dicts whose keys could cause collisions upon encoding (such as a dict
157 157 with both the number 1 and the string '1' as keys) will cause a ValueError
158 158 to be raised.
159 159
160 160 Parameters
161 161 ----------
162 162 obj : any python object
163 163
164 164 Returns
165 165 -------
166 166 out : object
167 167
168 168 A version of the input which will not cause an encoding error when
169 169 encoded as JSON. Note that this function does not *encode* its inputs,
170 170 it simply sanitizes it so that there will be no encoding errors later.
171 171
172 172 Examples
173 173 --------
174 174 >>> json_clean(4)
175 175 4
176 >>> json_clean(range(10))
176 >>> json_clean(list(range(10)))
177 177 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
178 178 >>> sorted(json_clean(dict(x=1, y=2)).items())
179 179 [('x', 1), ('y', 2)]
180 180 >>> sorted(json_clean(dict(x=1, y=2, z=[1,2,3])).items())
181 181 [('x', 1), ('y', 2), ('z', [1, 2, 3])]
182 182 >>> json_clean(True)
183 183 True
184 184 """
185 185 # types that are 'atomic' and ok in json as-is. bool doesn't need to be
186 186 # listed explicitly because bools pass as int instances
187 187 atomic_ok = (unicode_type, int, type(None))
188 188
189 189 # containers that we need to convert into lists
190 190 container_to_list = (tuple, set, types.GeneratorType)
191 191
192 192 if isinstance(obj, float):
193 193 # cast out-of-range floats to their reprs
194 194 if math.isnan(obj) or math.isinf(obj):
195 195 return repr(obj)
196 196 return obj
197 197
198 198 if isinstance(obj, atomic_ok):
199 199 return obj
200 200
201 201 if isinstance(obj, bytes):
202 202 return obj.decode(DEFAULT_ENCODING, 'replace')
203 203
204 204 if isinstance(obj, container_to_list) or (
205 205 hasattr(obj, '__iter__') and hasattr(obj, next_attr_name)):
206 206 obj = list(obj)
207 207
208 208 if isinstance(obj, list):
209 209 return [json_clean(x) for x in obj]
210 210
211 211 if isinstance(obj, dict):
212 212 # First, validate that the dict won't lose data in conversion due to
213 213 # key collisions after stringification. This can happen with keys like
214 214 # True and 'true' or 1 and '1', which collide in JSON.
215 215 nkeys = len(obj)
216 216 nkeys_collapsed = len(set(map(str, obj)))
217 217 if nkeys != nkeys_collapsed:
218 218 raise ValueError('dict can not be safely converted to JSON: '
219 219 'key collision would lead to dropped values')
220 220 # If all OK, proceed by making the new dict that will be json-safe
221 221 out = {}
222 222 for k,v in iteritems(obj):
223 223 out[str(k)] = json_clean(v)
224 224 return out
225 225
226 226 # If we get here, we don't know how to handle the object, so we just get
227 227 # its repr and return that. This will catch lambdas, open sockets, class
228 228 # objects, and any other complicated contraption that json can't encode
229 229 return repr(obj)
@@ -1,239 +1,239 b''
1 1 # coding: utf-8
2 2 """Compatibility tricks for Python 3. Mainly to do with unicode."""
3 3 import functools
4 4 import sys
5 5 import re
6 6 import types
7 7
8 8 from .encoding import DEFAULT_ENCODING
9 9
10 10 orig_open = open
11 11
12 12 def no_code(x, encoding=None):
13 13 return x
14 14
15 15 def decode(s, encoding=None):
16 16 encoding = encoding or DEFAULT_ENCODING
17 17 return s.decode(encoding, "replace")
18 18
19 19 def encode(u, encoding=None):
20 20 encoding = encoding or DEFAULT_ENCODING
21 21 return u.encode(encoding, "replace")
22 22
23 23
24 24 def cast_unicode(s, encoding=None):
25 25 if isinstance(s, bytes):
26 26 return decode(s, encoding)
27 27 return s
28 28
29 29 def cast_bytes(s, encoding=None):
30 30 if not isinstance(s, bytes):
31 31 return encode(s, encoding)
32 32 return s
33 33
34 34 def _modify_str_or_docstring(str_change_func):
35 35 @functools.wraps(str_change_func)
36 36 def wrapper(func_or_str):
37 37 if isinstance(func_or_str, string_types):
38 38 func = None
39 39 doc = func_or_str
40 40 else:
41 41 func = func_or_str
42 42 doc = func.__doc__
43 43
44 44 doc = str_change_func(doc)
45 45
46 46 if func:
47 47 func.__doc__ = doc
48 48 return func
49 49 return doc
50 50 return wrapper
51 51
52 52 def safe_unicode(e):
53 53 """unicode(e) with various fallbacks. Used for exceptions, which may not be
54 54 safe to call unicode() on.
55 55 """
56 56 try:
57 57 return unicode_type(e)
58 58 except UnicodeError:
59 59 pass
60 60
61 61 try:
62 62 return str_to_unicode(str(e))
63 63 except UnicodeError:
64 64 pass
65 65
66 66 try:
67 67 return str_to_unicode(repr(e))
68 68 except UnicodeError:
69 69 pass
70 70
71 71 return u'Unrecoverably corrupt evalue'
72 72
73 73 if sys.version_info[0] >= 3:
74 74 PY3 = True
75 75
76 76 input = input
77 77 builtin_mod_name = "builtins"
78 78 import builtins as builtin_mod
79 79
80 80 str_to_unicode = no_code
81 81 unicode_to_str = no_code
82 82 str_to_bytes = encode
83 83 bytes_to_str = decode
84 84 cast_bytes_py2 = no_code
85 85
86 86 string_types = (str,)
87 87 unicode_type = str
88 88
89 89 def isidentifier(s, dotted=False):
90 90 if dotted:
91 91 return all(isidentifier(a) for a in s.split("."))
92 92 return s.isidentifier()
93 93
94 94 open = orig_open
95 95 xrange = range
96 iteritems = dict.items
97 itervalues = dict.values
96 def iteritems(d): return iter(d.items())
97 def itervalues(d): return iter(d.values())
98 98
99 99 MethodType = types.MethodType
100 100
101 101 def execfile(fname, glob, loc=None):
102 102 loc = loc if (loc is not None) else glob
103 103 with open(fname, 'rb') as f:
104 104 exec(compile(f.read(), fname, 'exec'), glob, loc)
105 105
106 106 # Refactor print statements in doctests.
107 107 _print_statement_re = re.compile(r"\bprint (?P<expr>.*)$", re.MULTILINE)
108 108 def _print_statement_sub(match):
109 109 expr = match.groups('expr')
110 110 return "print(%s)" % expr
111 111
112 112 @_modify_str_or_docstring
113 113 def doctest_refactor_print(doc):
114 114 """Refactor 'print x' statements in a doctest to print(x) style. 2to3
115 115 unfortunately doesn't pick up on our doctests.
116 116
117 117 Can accept a string or a function, so it can be used as a decorator."""
118 118 return _print_statement_re.sub(_print_statement_sub, doc)
119 119
120 120 # Abstract u'abc' syntax:
121 121 @_modify_str_or_docstring
122 122 def u_format(s):
123 123 """"{u}'abc'" --> "'abc'" (Python 3)
124 124
125 125 Accepts a string or a function, so it can be used as a decorator."""
126 126 return s.format(u='')
127 127
128 128 else:
129 129 PY3 = False
130 130
131 131 input = raw_input
132 132 builtin_mod_name = "__builtin__"
133 133 import __builtin__ as builtin_mod
134 134
135 135 str_to_unicode = decode
136 136 unicode_to_str = encode
137 137 str_to_bytes = no_code
138 138 bytes_to_str = no_code
139 139 cast_bytes_py2 = cast_bytes
140 140
141 141 string_types = (str, unicode)
142 142 unicode_type = unicode
143 143
144 144 import re
145 145 _name_re = re.compile(r"[a-zA-Z_][a-zA-Z0-9_]*$")
146 146 def isidentifier(s, dotted=False):
147 147 if dotted:
148 148 return all(isidentifier(a) for a in s.split("."))
149 149 return bool(_name_re.match(s))
150 150
151 151 class open(object):
152 152 """Wrapper providing key part of Python 3 open() interface."""
153 153 def __init__(self, fname, mode="r", encoding="utf-8"):
154 154 self.f = orig_open(fname, mode)
155 155 self.enc = encoding
156 156
157 157 def write(self, s):
158 158 return self.f.write(s.encode(self.enc))
159 159
160 160 def read(self, size=-1):
161 161 return self.f.read(size).decode(self.enc)
162 162
163 163 def close(self):
164 164 return self.f.close()
165 165
166 166 def __enter__(self):
167 167 return self
168 168
169 169 def __exit__(self, etype, value, traceback):
170 170 self.f.close()
171 171
172 172 xrange = xrange
173 iteritems = dict.iteritems
174 itervalues = dict.itervalues
173 def iteritems(d): return d.iteritems()
174 def itervalues(d): return d.itervalues()
175 175
176 176 def MethodType(func, instance):
177 177 return types.MethodType(func, instance, type(instance))
178 178
179 179 # don't override system execfile on 2.x:
180 180 execfile = execfile
181 181
182 182 def doctest_refactor_print(func_or_str):
183 183 return func_or_str
184 184
185 185
186 186 # Abstract u'abc' syntax:
187 187 @_modify_str_or_docstring
188 188 def u_format(s):
189 189 """"{u}'abc'" --> "u'abc'" (Python 2)
190 190
191 191 Accepts a string or a function, so it can be used as a decorator."""
192 192 return s.format(u='u')
193 193
194 194 if sys.platform == 'win32':
195 195 def execfile(fname, glob=None, loc=None):
196 196 loc = loc if (loc is not None) else glob
197 197 # The rstrip() is necessary b/c trailing whitespace in files will
198 198 # cause an IndentationError in Python 2.6 (this was fixed in 2.7,
199 199 # but we still support 2.6). See issue 1027.
200 200 scripttext = builtin_mod.open(fname).read().rstrip() + '\n'
201 201 # compile converts unicode filename to str assuming
202 202 # ascii. Let's do the conversion before calling compile
203 203 if isinstance(fname, unicode):
204 204 filename = unicode_to_str(fname)
205 205 else:
206 206 filename = fname
207 207 exec(compile(scripttext, filename, 'exec'), glob, loc)
208 208 else:
209 209 def execfile(fname, *where):
210 210 if isinstance(fname, unicode):
211 211 filename = fname.encode(sys.getfilesystemencoding())
212 212 else:
213 213 filename = fname
214 214 builtin_mod.execfile(filename, *where)
215 215
216 216 # Parts below taken from six:
217 217 # Copyright (c) 2010-2013 Benjamin Peterson
218 218 #
219 219 # Permission is hereby granted, free of charge, to any person obtaining a copy
220 220 # of this software and associated documentation files (the "Software"), to deal
221 221 # in the Software without restriction, including without limitation the rights
222 222 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
223 223 # copies of the Software, and to permit persons to whom the Software is
224 224 # furnished to do so, subject to the following conditions:
225 225 #
226 226 # The above copyright notice and this permission notice shall be included in all
227 227 # copies or substantial portions of the Software.
228 228 #
229 229 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
230 230 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
231 231 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
232 232 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
233 233 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
234 234 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
235 235 # SOFTWARE.
236 236
237 237 def with_metaclass(meta, *bases):
238 238 """Create a base class with a metaclass."""
239 239 return meta("NewBase", bases, {})
@@ -1,643 +1,643 b''
1 1 # encoding: utf-8
2 2 """Tests for IPython.utils.path.py"""
3 3
4 4 #-----------------------------------------------------------------------------
5 5 # Copyright (C) 2008-2011 The IPython Development Team
6 6 #
7 7 # Distributed under the terms of the BSD License. The full license is in
8 8 # the file COPYING, distributed as part of this software.
9 9 #-----------------------------------------------------------------------------
10 10
11 11 #-----------------------------------------------------------------------------
12 12 # Imports
13 13 #-----------------------------------------------------------------------------
14 14
15 15 from __future__ import with_statement
16 16
17 17 import os
18 18 import shutil
19 19 import sys
20 20 import tempfile
21 21 from contextlib import contextmanager
22 22
23 23 from os.path import join, abspath, split
24 24
25 25 import nose.tools as nt
26 26
27 27 from nose import with_setup
28 28
29 29 import IPython
30 30 from IPython.testing import decorators as dec
31 31 from IPython.testing.decorators import (skip_if_not_win32, skip_win32,
32 32 onlyif_unicode_paths,)
33 33 from IPython.testing.tools import make_tempfile, AssertPrints
34 34 from IPython.utils import path
35 35 from IPython.utils import py3compat
36 36 from IPython.utils.tempdir import TemporaryDirectory
37 37
38 38 # Platform-dependent imports
39 39 try:
40 40 import winreg as wreg # Py 3
41 41 except ImportError:
42 42 try:
43 43 import _winreg as wreg # Py 2
44 44 except ImportError:
45 45 #Fake _winreg module on none windows platforms
46 46 import types
47 47 wr_name = "winreg" if py3compat.PY3 else "_winreg"
48 48 sys.modules[wr_name] = types.ModuleType(wr_name)
49 49 try:
50 50 import winreg as wreg
51 51 except ImportError:
52 52 import _winreg as wreg
53 53 #Add entries that needs to be stubbed by the testing code
54 54 (wreg.OpenKey, wreg.QueryValueEx,) = (None, None)
55 55
56 56 try:
57 57 reload
58 58 except NameError: # Python 3
59 59 from imp import reload
60 60
61 61 #-----------------------------------------------------------------------------
62 62 # Globals
63 63 #-----------------------------------------------------------------------------
64 64 env = os.environ
65 65 TEST_FILE_PATH = split(abspath(__file__))[0]
66 66 TMP_TEST_DIR = tempfile.mkdtemp()
67 67 HOME_TEST_DIR = join(TMP_TEST_DIR, "home_test_dir")
68 68 XDG_TEST_DIR = join(HOME_TEST_DIR, "xdg_test_dir")
69 69 XDG_CACHE_DIR = join(HOME_TEST_DIR, "xdg_cache_dir")
70 70 IP_TEST_DIR = join(HOME_TEST_DIR,'.ipython')
71 71 #
72 72 # Setup/teardown functions/decorators
73 73 #
74 74
75 75 def setup():
76 76 """Setup testenvironment for the module:
77 77
78 78 - Adds dummy home dir tree
79 79 """
80 80 # Do not mask exceptions here. In particular, catching WindowsError is a
81 81 # problem because that exception is only defined on Windows...
82 82 os.makedirs(IP_TEST_DIR)
83 83 os.makedirs(os.path.join(XDG_TEST_DIR, 'ipython'))
84 84 os.makedirs(os.path.join(XDG_CACHE_DIR, 'ipython'))
85 85
86 86
87 87 def teardown():
88 88 """Teardown testenvironment for the module:
89 89
90 90 - Remove dummy home dir tree
91 91 """
92 92 # Note: we remove the parent test dir, which is the root of all test
93 93 # subdirs we may have created. Use shutil instead of os.removedirs, so
94 94 # that non-empty directories are all recursively removed.
95 95 shutil.rmtree(TMP_TEST_DIR)
96 96
97 97
98 98 def setup_environment():
99 99 """Setup testenvironment for some functions that are tested
100 100 in this module. In particular this functions stores attributes
101 101 and other things that we need to stub in some test functions.
102 102 This needs to be done on a function level and not module level because
103 103 each testfunction needs a pristine environment.
104 104 """
105 105 global oldstuff, platformstuff
106 106 oldstuff = (env.copy(), os.name, sys.platform, path.get_home_dir, IPython.__file__, os.getcwd())
107 107
108 108 if os.name == 'nt':
109 109 platformstuff = (wreg.OpenKey, wreg.QueryValueEx,)
110 110
111 111
112 112 def teardown_environment():
113 113 """Restore things that were remembered by the setup_environment function
114 114 """
115 115 (oldenv, os.name, sys.platform, path.get_home_dir, IPython.__file__, old_wd) = oldstuff
116 116 os.chdir(old_wd)
117 117 reload(path)
118 118
119 for key in env.keys():
119 for key in list(env):
120 120 if key not in oldenv:
121 121 del env[key]
122 122 env.update(oldenv)
123 123 if hasattr(sys, 'frozen'):
124 124 del sys.frozen
125 125 if os.name == 'nt':
126 126 (wreg.OpenKey, wreg.QueryValueEx,) = platformstuff
127 127
128 128 # Build decorator that uses the setup_environment/setup_environment
129 129 with_environment = with_setup(setup_environment, teardown_environment)
130 130
131 131 @skip_if_not_win32
132 132 @with_environment
133 133 def test_get_home_dir_1():
134 134 """Testcase for py2exe logic, un-compressed lib
135 135 """
136 136 unfrozen = path.get_home_dir()
137 137 sys.frozen = True
138 138
139 139 #fake filename for IPython.__init__
140 140 IPython.__file__ = abspath(join(HOME_TEST_DIR, "Lib/IPython/__init__.py"))
141 141
142 142 home_dir = path.get_home_dir()
143 143 nt.assert_equal(home_dir, unfrozen)
144 144
145 145
146 146 @skip_if_not_win32
147 147 @with_environment
148 148 def test_get_home_dir_2():
149 149 """Testcase for py2exe logic, compressed lib
150 150 """
151 151 unfrozen = path.get_home_dir()
152 152 sys.frozen = True
153 153 #fake filename for IPython.__init__
154 154 IPython.__file__ = abspath(join(HOME_TEST_DIR, "Library.zip/IPython/__init__.py")).lower()
155 155
156 156 home_dir = path.get_home_dir(True)
157 157 nt.assert_equal(home_dir, unfrozen)
158 158
159 159
160 160 @with_environment
161 161 def test_get_home_dir_3():
162 162 """get_home_dir() uses $HOME if set"""
163 163 env["HOME"] = HOME_TEST_DIR
164 164 home_dir = path.get_home_dir(True)
165 165 # get_home_dir expands symlinks
166 166 nt.assert_equal(home_dir, os.path.realpath(env["HOME"]))
167 167
168 168
169 169 @with_environment
170 170 def test_get_home_dir_4():
171 171 """get_home_dir() still works if $HOME is not set"""
172 172
173 173 if 'HOME' in env: del env['HOME']
174 174 # this should still succeed, but we don't care what the answer is
175 175 home = path.get_home_dir(False)
176 176
177 177 @with_environment
178 178 def test_get_home_dir_5():
179 179 """raise HomeDirError if $HOME is specified, but not a writable dir"""
180 180 env['HOME'] = abspath(HOME_TEST_DIR+'garbage')
181 181 # set os.name = posix, to prevent My Documents fallback on Windows
182 182 os.name = 'posix'
183 183 nt.assert_raises(path.HomeDirError, path.get_home_dir, True)
184 184
185 185
186 186 # Should we stub wreg fully so we can run the test on all platforms?
187 187 @skip_if_not_win32
188 188 @with_environment
189 189 def test_get_home_dir_8():
190 190 """Using registry hack for 'My Documents', os=='nt'
191 191
192 192 HOMESHARE, HOMEDRIVE, HOMEPATH, USERPROFILE and others are missing.
193 193 """
194 194 os.name = 'nt'
195 195 # Remove from stub environment all keys that may be set
196 196 for key in ['HOME', 'HOMESHARE', 'HOMEDRIVE', 'HOMEPATH', 'USERPROFILE']:
197 197 env.pop(key, None)
198 198
199 199 #Stub windows registry functions
200 200 def OpenKey(x, y):
201 201 class key:
202 202 def Close(self):
203 203 pass
204 204 return key()
205 205 def QueryValueEx(x, y):
206 206 return [abspath(HOME_TEST_DIR)]
207 207
208 208 wreg.OpenKey = OpenKey
209 209 wreg.QueryValueEx = QueryValueEx
210 210
211 211 home_dir = path.get_home_dir()
212 212 nt.assert_equal(home_dir, abspath(HOME_TEST_DIR))
213 213
214 214
215 215 @with_environment
216 216 def test_get_ipython_dir_1():
217 217 """test_get_ipython_dir_1, Testcase to see if we can call get_ipython_dir without Exceptions."""
218 218 env_ipdir = os.path.join("someplace", ".ipython")
219 219 path._writable_dir = lambda path: True
220 220 env['IPYTHONDIR'] = env_ipdir
221 221 ipdir = path.get_ipython_dir()
222 222 nt.assert_equal(ipdir, env_ipdir)
223 223
224 224
225 225 @with_environment
226 226 def test_get_ipython_dir_2():
227 227 """test_get_ipython_dir_2, Testcase to see if we can call get_ipython_dir without Exceptions."""
228 228 path.get_home_dir = lambda : "someplace"
229 229 path.get_xdg_dir = lambda : None
230 230 path._writable_dir = lambda path: True
231 231 os.name = "posix"
232 232 env.pop('IPYTHON_DIR', None)
233 233 env.pop('IPYTHONDIR', None)
234 234 env.pop('XDG_CONFIG_HOME', None)
235 235 ipdir = path.get_ipython_dir()
236 236 nt.assert_equal(ipdir, os.path.join("someplace", ".ipython"))
237 237
238 238 @with_environment
239 239 def test_get_ipython_dir_3():
240 240 """test_get_ipython_dir_3, use XDG if defined, and .ipython doesn't exist."""
241 241 path.get_home_dir = lambda : "someplace"
242 242 path._writable_dir = lambda path: True
243 243 os.name = "posix"
244 244 env.pop('IPYTHON_DIR', None)
245 245 env.pop('IPYTHONDIR', None)
246 246 env['XDG_CONFIG_HOME'] = XDG_TEST_DIR
247 247 ipdir = path.get_ipython_dir()
248 248 if sys.platform == "darwin":
249 249 expected = os.path.join("someplace", ".ipython")
250 250 else:
251 251 expected = os.path.join(XDG_TEST_DIR, "ipython")
252 252 nt.assert_equal(ipdir, expected)
253 253
254 254 @with_environment
255 255 def test_get_ipython_dir_4():
256 256 """test_get_ipython_dir_4, use XDG if both exist."""
257 257 path.get_home_dir = lambda : HOME_TEST_DIR
258 258 os.name = "posix"
259 259 env.pop('IPYTHON_DIR', None)
260 260 env.pop('IPYTHONDIR', None)
261 261 env['XDG_CONFIG_HOME'] = XDG_TEST_DIR
262 262 ipdir = path.get_ipython_dir()
263 263 if sys.platform == "darwin":
264 264 expected = os.path.join(HOME_TEST_DIR, ".ipython")
265 265 else:
266 266 expected = os.path.join(XDG_TEST_DIR, "ipython")
267 267 nt.assert_equal(ipdir, expected)
268 268
269 269 @with_environment
270 270 def test_get_ipython_dir_5():
271 271 """test_get_ipython_dir_5, use .ipython if exists and XDG defined, but doesn't exist."""
272 272 path.get_home_dir = lambda : HOME_TEST_DIR
273 273 os.name = "posix"
274 274 env.pop('IPYTHON_DIR', None)
275 275 env.pop('IPYTHONDIR', None)
276 276 env['XDG_CONFIG_HOME'] = XDG_TEST_DIR
277 277 os.rmdir(os.path.join(XDG_TEST_DIR, 'ipython'))
278 278 ipdir = path.get_ipython_dir()
279 279 nt.assert_equal(ipdir, IP_TEST_DIR)
280 280
281 281 @with_environment
282 282 def test_get_ipython_dir_6():
283 283 """test_get_ipython_dir_6, use XDG if defined and neither exist."""
284 284 xdg = os.path.join(HOME_TEST_DIR, 'somexdg')
285 285 os.mkdir(xdg)
286 286 shutil.rmtree(os.path.join(HOME_TEST_DIR, '.ipython'))
287 287 path.get_home_dir = lambda : HOME_TEST_DIR
288 288 path.get_xdg_dir = lambda : xdg
289 289 os.name = "posix"
290 290 env.pop('IPYTHON_DIR', None)
291 291 env.pop('IPYTHONDIR', None)
292 292 env.pop('XDG_CONFIG_HOME', None)
293 293 xdg_ipdir = os.path.join(xdg, "ipython")
294 294 ipdir = path.get_ipython_dir()
295 295 nt.assert_equal(ipdir, xdg_ipdir)
296 296
297 297 @with_environment
298 298 def test_get_ipython_dir_7():
299 299 """test_get_ipython_dir_7, test home directory expansion on IPYTHONDIR"""
300 300 path._writable_dir = lambda path: True
301 301 home_dir = os.path.normpath(os.path.expanduser('~'))
302 302 env['IPYTHONDIR'] = os.path.join('~', 'somewhere')
303 303 ipdir = path.get_ipython_dir()
304 304 nt.assert_equal(ipdir, os.path.join(home_dir, 'somewhere'))
305 305
306 306 @skip_win32
307 307 @with_environment
308 308 def test_get_ipython_dir_8():
309 309 """test_get_ipython_dir_8, test / home directory"""
310 310 old = path._writable_dir, path.get_xdg_dir
311 311 try:
312 312 path._writable_dir = lambda path: bool(path)
313 313 path.get_xdg_dir = lambda: None
314 314 env.pop('IPYTHON_DIR', None)
315 315 env.pop('IPYTHONDIR', None)
316 316 env['HOME'] = '/'
317 317 nt.assert_equal(path.get_ipython_dir(), '/.ipython')
318 318 finally:
319 319 path._writable_dir, path.get_xdg_dir = old
320 320
321 321 @with_environment
322 322 def test_get_xdg_dir_0():
323 323 """test_get_xdg_dir_0, check xdg_dir"""
324 324 reload(path)
325 325 path._writable_dir = lambda path: True
326 326 path.get_home_dir = lambda : 'somewhere'
327 327 os.name = "posix"
328 328 sys.platform = "linux2"
329 329 env.pop('IPYTHON_DIR', None)
330 330 env.pop('IPYTHONDIR', None)
331 331 env.pop('XDG_CONFIG_HOME', None)
332 332
333 333 nt.assert_equal(path.get_xdg_dir(), os.path.join('somewhere', '.config'))
334 334
335 335
336 336 @with_environment
337 337 def test_get_xdg_dir_1():
338 338 """test_get_xdg_dir_1, check nonexistant xdg_dir"""
339 339 reload(path)
340 340 path.get_home_dir = lambda : HOME_TEST_DIR
341 341 os.name = "posix"
342 342 sys.platform = "linux2"
343 343 env.pop('IPYTHON_DIR', None)
344 344 env.pop('IPYTHONDIR', None)
345 345 env.pop('XDG_CONFIG_HOME', None)
346 346 nt.assert_equal(path.get_xdg_dir(), None)
347 347
348 348 @with_environment
349 349 def test_get_xdg_dir_2():
350 350 """test_get_xdg_dir_2, check xdg_dir default to ~/.config"""
351 351 reload(path)
352 352 path.get_home_dir = lambda : HOME_TEST_DIR
353 353 os.name = "posix"
354 354 sys.platform = "linux2"
355 355 env.pop('IPYTHON_DIR', None)
356 356 env.pop('IPYTHONDIR', None)
357 357 env.pop('XDG_CONFIG_HOME', None)
358 358 cfgdir=os.path.join(path.get_home_dir(), '.config')
359 359 if not os.path.exists(cfgdir):
360 360 os.makedirs(cfgdir)
361 361
362 362 nt.assert_equal(path.get_xdg_dir(), cfgdir)
363 363
364 364 @with_environment
365 365 def test_get_xdg_dir_3():
366 366 """test_get_xdg_dir_3, check xdg_dir not used on OS X"""
367 367 reload(path)
368 368 path.get_home_dir = lambda : HOME_TEST_DIR
369 369 os.name = "posix"
370 370 sys.platform = "darwin"
371 371 env.pop('IPYTHON_DIR', None)
372 372 env.pop('IPYTHONDIR', None)
373 373 env.pop('XDG_CONFIG_HOME', None)
374 374 cfgdir=os.path.join(path.get_home_dir(), '.config')
375 375 if not os.path.exists(cfgdir):
376 376 os.makedirs(cfgdir)
377 377
378 378 nt.assert_equal(path.get_xdg_dir(), None)
379 379
380 380 def test_filefind():
381 381 """Various tests for filefind"""
382 382 f = tempfile.NamedTemporaryFile()
383 383 # print 'fname:',f.name
384 384 alt_dirs = path.get_ipython_dir()
385 385 t = path.filefind(f.name, alt_dirs)
386 386 # print 'found:',t
387 387
388 388 @with_environment
389 389 def test_get_ipython_cache_dir():
390 390 os.environ["HOME"] = HOME_TEST_DIR
391 391 if os.name == 'posix' and sys.platform != 'darwin':
392 392 # test default
393 393 os.makedirs(os.path.join(HOME_TEST_DIR, ".cache"))
394 394 os.environ.pop("XDG_CACHE_HOME", None)
395 395 ipdir = path.get_ipython_cache_dir()
396 396 nt.assert_equal(os.path.join(HOME_TEST_DIR, ".cache", "ipython"),
397 397 ipdir)
398 398 nt.assert_true(os.path.isdir(ipdir))
399 399
400 400 # test env override
401 401 os.environ["XDG_CACHE_HOME"] = XDG_CACHE_DIR
402 402 ipdir = path.get_ipython_cache_dir()
403 403 nt.assert_true(os.path.isdir(ipdir))
404 404 nt.assert_equal(ipdir, os.path.join(XDG_CACHE_DIR, "ipython"))
405 405 else:
406 406 nt.assert_equal(path.get_ipython_cache_dir(),
407 407 path.get_ipython_dir())
408 408
409 409 def test_get_ipython_package_dir():
410 410 ipdir = path.get_ipython_package_dir()
411 411 nt.assert_true(os.path.isdir(ipdir))
412 412
413 413
414 414 def test_get_ipython_module_path():
415 415 ipapp_path = path.get_ipython_module_path('IPython.terminal.ipapp')
416 416 nt.assert_true(os.path.isfile(ipapp_path))
417 417
418 418
419 419 @dec.skip_if_not_win32
420 420 def test_get_long_path_name_win32():
421 421 with TemporaryDirectory() as tmpdir:
422 422
423 423 # Make a long path.
424 424 long_path = os.path.join(tmpdir, u'this is my long path name')
425 425 os.makedirs(long_path)
426 426
427 427 # Test to see if the short path evaluates correctly.
428 428 short_path = os.path.join(tmpdir, u'THISIS~1')
429 429 evaluated_path = path.get_long_path_name(short_path)
430 430 nt.assert_equal(evaluated_path.lower(), long_path.lower())
431 431
432 432
433 433 @dec.skip_win32
434 434 def test_get_long_path_name():
435 435 p = path.get_long_path_name('/usr/local')
436 436 nt.assert_equal(p,'/usr/local')
437 437
438 438 @dec.skip_win32 # can't create not-user-writable dir on win
439 439 @with_environment
440 440 def test_not_writable_ipdir():
441 441 tmpdir = tempfile.mkdtemp()
442 442 os.name = "posix"
443 443 env.pop('IPYTHON_DIR', None)
444 444 env.pop('IPYTHONDIR', None)
445 445 env.pop('XDG_CONFIG_HOME', None)
446 446 env['HOME'] = tmpdir
447 447 ipdir = os.path.join(tmpdir, '.ipython')
448 448 os.mkdir(ipdir)
449 449 os.chmod(ipdir, 600)
450 450 with AssertPrints('is not a writable location', channel='stderr'):
451 451 ipdir = path.get_ipython_dir()
452 452 env.pop('IPYTHON_DIR', None)
453 453
454 454 def test_unquote_filename():
455 455 for win32 in (True, False):
456 456 nt.assert_equal(path.unquote_filename('foo.py', win32=win32), 'foo.py')
457 457 nt.assert_equal(path.unquote_filename('foo bar.py', win32=win32), 'foo bar.py')
458 458 nt.assert_equal(path.unquote_filename('"foo.py"', win32=True), 'foo.py')
459 459 nt.assert_equal(path.unquote_filename('"foo bar.py"', win32=True), 'foo bar.py')
460 460 nt.assert_equal(path.unquote_filename("'foo.py'", win32=True), 'foo.py')
461 461 nt.assert_equal(path.unquote_filename("'foo bar.py'", win32=True), 'foo bar.py')
462 462 nt.assert_equal(path.unquote_filename('"foo.py"', win32=False), '"foo.py"')
463 463 nt.assert_equal(path.unquote_filename('"foo bar.py"', win32=False), '"foo bar.py"')
464 464 nt.assert_equal(path.unquote_filename("'foo.py'", win32=False), "'foo.py'")
465 465 nt.assert_equal(path.unquote_filename("'foo bar.py'", win32=False), "'foo bar.py'")
466 466
467 467 @with_environment
468 468 def test_get_py_filename():
469 469 os.chdir(TMP_TEST_DIR)
470 470 for win32 in (True, False):
471 471 with make_tempfile('foo.py'):
472 472 nt.assert_equal(path.get_py_filename('foo.py', force_win32=win32), 'foo.py')
473 473 nt.assert_equal(path.get_py_filename('foo', force_win32=win32), 'foo.py')
474 474 with make_tempfile('foo'):
475 475 nt.assert_equal(path.get_py_filename('foo', force_win32=win32), 'foo')
476 476 nt.assert_raises(IOError, path.get_py_filename, 'foo.py', force_win32=win32)
477 477 nt.assert_raises(IOError, path.get_py_filename, 'foo', force_win32=win32)
478 478 nt.assert_raises(IOError, path.get_py_filename, 'foo.py', force_win32=win32)
479 479 true_fn = 'foo with spaces.py'
480 480 with make_tempfile(true_fn):
481 481 nt.assert_equal(path.get_py_filename('foo with spaces', force_win32=win32), true_fn)
482 482 nt.assert_equal(path.get_py_filename('foo with spaces.py', force_win32=win32), true_fn)
483 483 if win32:
484 484 nt.assert_equal(path.get_py_filename('"foo with spaces.py"', force_win32=True), true_fn)
485 485 nt.assert_equal(path.get_py_filename("'foo with spaces.py'", force_win32=True), true_fn)
486 486 else:
487 487 nt.assert_raises(IOError, path.get_py_filename, '"foo with spaces.py"', force_win32=False)
488 488 nt.assert_raises(IOError, path.get_py_filename, "'foo with spaces.py'", force_win32=False)
489 489
490 490 @onlyif_unicode_paths
491 491 def test_unicode_in_filename():
492 492 """When a file doesn't exist, the exception raised should be safe to call
493 493 str() on - i.e. in Python 2 it must only have ASCII characters.
494 494
495 495 https://github.com/ipython/ipython/issues/875
496 496 """
497 497 try:
498 498 # these calls should not throw unicode encode exceptions
499 499 path.get_py_filename(u'fooéè.py', force_win32=False)
500 500 except IOError as ex:
501 501 str(ex)
502 502
503 503
504 504 class TestShellGlob(object):
505 505
506 506 @classmethod
507 507 def setUpClass(cls):
508 cls.filenames_start_with_a = map('a{0}'.format, range(3))
509 cls.filenames_end_with_b = map('{0}b'.format, range(3))
508 cls.filenames_start_with_a = ['a0', 'a1', 'a2']
509 cls.filenames_end_with_b = ['0b', '1b', '2b']
510 510 cls.filenames = cls.filenames_start_with_a + cls.filenames_end_with_b
511 511 cls.tempdir = TemporaryDirectory()
512 512 td = cls.tempdir.name
513 513
514 514 with cls.in_tempdir():
515 515 # Create empty files
516 516 for fname in cls.filenames:
517 517 open(os.path.join(td, fname), 'w').close()
518 518
519 519 @classmethod
520 520 def tearDownClass(cls):
521 521 cls.tempdir.cleanup()
522 522
523 523 @classmethod
524 524 @contextmanager
525 525 def in_tempdir(cls):
526 526 save = os.getcwdu()
527 527 try:
528 528 os.chdir(cls.tempdir.name)
529 529 yield
530 530 finally:
531 531 os.chdir(save)
532 532
533 533 def check_match(self, patterns, matches):
534 534 with self.in_tempdir():
535 535 # glob returns unordered list. that's why sorted is required.
536 536 nt.assert_equals(sorted(path.shellglob(patterns)),
537 537 sorted(matches))
538 538
539 539 def common_cases(self):
540 540 return [
541 541 (['*'], self.filenames),
542 542 (['a*'], self.filenames_start_with_a),
543 543 (['*c'], ['*c']),
544 544 (['*', 'a*', '*b', '*c'], self.filenames
545 545 + self.filenames_start_with_a
546 546 + self.filenames_end_with_b
547 547 + ['*c']),
548 548 (['a[012]'], self.filenames_start_with_a),
549 549 ]
550 550
551 551 @skip_win32
552 552 def test_match_posix(self):
553 553 for (patterns, matches) in self.common_cases() + [
554 554 ([r'\*'], ['*']),
555 555 ([r'a\*', 'a*'], ['a*'] + self.filenames_start_with_a),
556 556 ([r'a\[012]'], ['a[012]']),
557 557 ]:
558 558 yield (self.check_match, patterns, matches)
559 559
560 560 @skip_if_not_win32
561 561 def test_match_windows(self):
562 562 for (patterns, matches) in self.common_cases() + [
563 563 # In windows, backslash is interpreted as path
564 564 # separator. Therefore, you can't escape glob
565 565 # using it.
566 566 ([r'a\*', 'a*'], [r'a\*'] + self.filenames_start_with_a),
567 567 ([r'a\[012]'], [r'a\[012]']),
568 568 ]:
569 569 yield (self.check_match, patterns, matches)
570 570
571 571
572 572 def test_unescape_glob():
573 573 nt.assert_equals(path.unescape_glob(r'\*\[\!\]\?'), '*[!]?')
574 574 nt.assert_equals(path.unescape_glob(r'\\*'), r'\*')
575 575 nt.assert_equals(path.unescape_glob(r'\\\*'), r'\*')
576 576 nt.assert_equals(path.unescape_glob(r'\\a'), r'\a')
577 577 nt.assert_equals(path.unescape_glob(r'\a'), r'\a')
578 578
579 579
580 580 class TestLinkOrCopy(object):
581 581 def setUp(self):
582 582 self.tempdir = TemporaryDirectory()
583 583 self.src = self.dst("src")
584 584 with open(self.src, "w") as f:
585 585 f.write("Hello, world!")
586 586
587 587 def tearDown(self):
588 588 self.tempdir.cleanup()
589 589
590 590 def dst(self, *args):
591 591 return os.path.join(self.tempdir.name, *args)
592 592
593 593 def assert_inode_not_equal(self, a, b):
594 594 nt.assert_not_equals(os.stat(a).st_ino, os.stat(b).st_ino,
595 595 "%r and %r do reference the same indoes" %(a, b))
596 596
597 597 def assert_inode_equal(self, a, b):
598 598 nt.assert_equals(os.stat(a).st_ino, os.stat(b).st_ino,
599 599 "%r and %r do not reference the same indoes" %(a, b))
600 600
601 601 def assert_content_equal(self, a, b):
602 602 with open(a) as a_f:
603 603 with open(b) as b_f:
604 604 nt.assert_equals(a_f.read(), b_f.read())
605 605
606 606 @skip_win32
607 607 def test_link_successful(self):
608 608 dst = self.dst("target")
609 609 path.link_or_copy(self.src, dst)
610 610 self.assert_inode_equal(self.src, dst)
611 611
612 612 @skip_win32
613 613 def test_link_into_dir(self):
614 614 dst = self.dst("some_dir")
615 615 os.mkdir(dst)
616 616 path.link_or_copy(self.src, dst)
617 617 expected_dst = self.dst("some_dir", os.path.basename(self.src))
618 618 self.assert_inode_equal(self.src, expected_dst)
619 619
620 620 @skip_win32
621 621 def test_target_exists(self):
622 622 dst = self.dst("target")
623 623 open(dst, "w").close()
624 624 path.link_or_copy(self.src, dst)
625 625 self.assert_inode_equal(self.src, dst)
626 626
627 627 @skip_win32
628 628 def test_no_link(self):
629 629 real_link = os.link
630 630 try:
631 631 del os.link
632 632 dst = self.dst("target")
633 633 path.link_or_copy(self.src, dst)
634 634 self.assert_content_equal(self.src, dst)
635 635 self.assert_inode_not_equal(self.src, dst)
636 636 finally:
637 637 os.link = real_link
638 638
639 639 @skip_if_not_win32
640 640 def test_windows(self):
641 641 dst = self.dst("target")
642 642 path.link_or_copy(self.src, dst)
643 643 self.assert_content_equal(self.src, dst)
@@ -1,975 +1,975 b''
1 1 # encoding: utf-8
2 2 """
3 3 Tests for IPython.utils.traitlets.
4 4
5 5 Authors:
6 6
7 7 * Brian Granger
8 8 * Enthought, Inc. Some of the code in this file comes from enthought.traits
9 9 and is licensed under the BSD license. Also, many of the ideas also come
10 10 from enthought.traits even though our implementation is very different.
11 11 """
12 12
13 13 #-----------------------------------------------------------------------------
14 14 # Copyright (C) 2008-2011 The IPython Development Team
15 15 #
16 16 # Distributed under the terms of the BSD License. The full license is in
17 17 # the file COPYING, distributed as part of this software.
18 18 #-----------------------------------------------------------------------------
19 19
20 20 #-----------------------------------------------------------------------------
21 21 # Imports
22 22 #-----------------------------------------------------------------------------
23 23
24 24 import re
25 25 import sys
26 26 from unittest import TestCase
27 27
28 28 import nose.tools as nt
29 29 from nose import SkipTest
30 30
31 31 from IPython.utils.traitlets import (
32 32 HasTraits, MetaHasTraits, TraitType, Any, CBytes, Dict,
33 33 Int, Long, Integer, Float, Complex, Bytes, Unicode, TraitError,
34 34 Undefined, Type, This, Instance, TCPAddress, List, Tuple,
35 35 ObjectName, DottedObjectName, CRegExp
36 36 )
37 37 from IPython.utils import py3compat
38 38 from IPython.testing.decorators import skipif
39 39
40 40 #-----------------------------------------------------------------------------
41 41 # Helper classes for testing
42 42 #-----------------------------------------------------------------------------
43 43
44 44
45 45 class HasTraitsStub(HasTraits):
46 46
47 47 def _notify_trait(self, name, old, new):
48 48 self._notify_name = name
49 49 self._notify_old = old
50 50 self._notify_new = new
51 51
52 52
53 53 #-----------------------------------------------------------------------------
54 54 # Test classes
55 55 #-----------------------------------------------------------------------------
56 56
57 57
58 58 class TestTraitType(TestCase):
59 59
60 60 def test_get_undefined(self):
61 61 class A(HasTraits):
62 62 a = TraitType
63 63 a = A()
64 64 self.assertEqual(a.a, Undefined)
65 65
66 66 def test_set(self):
67 67 class A(HasTraitsStub):
68 68 a = TraitType
69 69
70 70 a = A()
71 71 a.a = 10
72 72 self.assertEqual(a.a, 10)
73 73 self.assertEqual(a._notify_name, 'a')
74 74 self.assertEqual(a._notify_old, Undefined)
75 75 self.assertEqual(a._notify_new, 10)
76 76
77 77 def test_validate(self):
78 78 class MyTT(TraitType):
79 79 def validate(self, inst, value):
80 80 return -1
81 81 class A(HasTraitsStub):
82 82 tt = MyTT
83 83
84 84 a = A()
85 85 a.tt = 10
86 86 self.assertEqual(a.tt, -1)
87 87
88 88 def test_default_validate(self):
89 89 class MyIntTT(TraitType):
90 90 def validate(self, obj, value):
91 91 if isinstance(value, int):
92 92 return value
93 93 self.error(obj, value)
94 94 class A(HasTraits):
95 95 tt = MyIntTT(10)
96 96 a = A()
97 97 self.assertEqual(a.tt, 10)
98 98
99 99 # Defaults are validated when the HasTraits is instantiated
100 100 class B(HasTraits):
101 101 tt = MyIntTT('bad default')
102 102 self.assertRaises(TraitError, B)
103 103
104 104 def test_is_valid_for(self):
105 105 class MyTT(TraitType):
106 106 def is_valid_for(self, value):
107 107 return True
108 108 class A(HasTraits):
109 109 tt = MyTT
110 110
111 111 a = A()
112 112 a.tt = 10
113 113 self.assertEqual(a.tt, 10)
114 114
115 115 def test_value_for(self):
116 116 class MyTT(TraitType):
117 117 def value_for(self, value):
118 118 return 20
119 119 class A(HasTraits):
120 120 tt = MyTT
121 121
122 122 a = A()
123 123 a.tt = 10
124 124 self.assertEqual(a.tt, 20)
125 125
126 126 def test_info(self):
127 127 class A(HasTraits):
128 128 tt = TraitType
129 129 a = A()
130 130 self.assertEqual(A.tt.info(), 'any value')
131 131
132 132 def test_error(self):
133 133 class A(HasTraits):
134 134 tt = TraitType
135 135 a = A()
136 136 self.assertRaises(TraitError, A.tt.error, a, 10)
137 137
138 138 def test_dynamic_initializer(self):
139 139 class A(HasTraits):
140 140 x = Int(10)
141 141 def _x_default(self):
142 142 return 11
143 143 class B(A):
144 144 x = Int(20)
145 145 class C(A):
146 146 def _x_default(self):
147 147 return 21
148 148
149 149 a = A()
150 150 self.assertEqual(a._trait_values, {})
151 self.assertEqual(a._trait_dyn_inits.keys(), ['x'])
151 self.assertEqual(list(a._trait_dyn_inits.keys()), ['x'])
152 152 self.assertEqual(a.x, 11)
153 153 self.assertEqual(a._trait_values, {'x': 11})
154 154 b = B()
155 155 self.assertEqual(b._trait_values, {'x': 20})
156 self.assertEqual(a._trait_dyn_inits.keys(), ['x'])
156 self.assertEqual(list(a._trait_dyn_inits.keys()), ['x'])
157 157 self.assertEqual(b.x, 20)
158 158 c = C()
159 159 self.assertEqual(c._trait_values, {})
160 self.assertEqual(a._trait_dyn_inits.keys(), ['x'])
160 self.assertEqual(list(a._trait_dyn_inits.keys()), ['x'])
161 161 self.assertEqual(c.x, 21)
162 162 self.assertEqual(c._trait_values, {'x': 21})
163 163 # Ensure that the base class remains unmolested when the _default
164 164 # initializer gets overridden in a subclass.
165 165 a = A()
166 166 c = C()
167 167 self.assertEqual(a._trait_values, {})
168 self.assertEqual(a._trait_dyn_inits.keys(), ['x'])
168 self.assertEqual(list(a._trait_dyn_inits.keys()), ['x'])
169 169 self.assertEqual(a.x, 11)
170 170 self.assertEqual(a._trait_values, {'x': 11})
171 171
172 172
173 173
174 174 class TestHasTraitsMeta(TestCase):
175 175
176 176 def test_metaclass(self):
177 177 self.assertEqual(type(HasTraits), MetaHasTraits)
178 178
179 179 class A(HasTraits):
180 180 a = Int
181 181
182 182 a = A()
183 183 self.assertEqual(type(a.__class__), MetaHasTraits)
184 184 self.assertEqual(a.a,0)
185 185 a.a = 10
186 186 self.assertEqual(a.a,10)
187 187
188 188 class B(HasTraits):
189 189 b = Int()
190 190
191 191 b = B()
192 192 self.assertEqual(b.b,0)
193 193 b.b = 10
194 194 self.assertEqual(b.b,10)
195 195
196 196 class C(HasTraits):
197 197 c = Int(30)
198 198
199 199 c = C()
200 200 self.assertEqual(c.c,30)
201 201 c.c = 10
202 202 self.assertEqual(c.c,10)
203 203
204 204 def test_this_class(self):
205 205 class A(HasTraits):
206 206 t = This()
207 207 tt = This()
208 208 class B(A):
209 209 tt = This()
210 210 ttt = This()
211 211 self.assertEqual(A.t.this_class, A)
212 212 self.assertEqual(B.t.this_class, A)
213 213 self.assertEqual(B.tt.this_class, B)
214 214 self.assertEqual(B.ttt.this_class, B)
215 215
216 216 class TestHasTraitsNotify(TestCase):
217 217
218 218 def setUp(self):
219 219 self._notify1 = []
220 220 self._notify2 = []
221 221
222 222 def notify1(self, name, old, new):
223 223 self._notify1.append((name, old, new))
224 224
225 225 def notify2(self, name, old, new):
226 226 self._notify2.append((name, old, new))
227 227
228 228 def test_notify_all(self):
229 229
230 230 class A(HasTraits):
231 231 a = Int
232 232 b = Float
233 233
234 234 a = A()
235 235 a.on_trait_change(self.notify1)
236 236 a.a = 0
237 237 self.assertEqual(len(self._notify1),0)
238 238 a.b = 0.0
239 239 self.assertEqual(len(self._notify1),0)
240 240 a.a = 10
241 241 self.assertTrue(('a',0,10) in self._notify1)
242 242 a.b = 10.0
243 243 self.assertTrue(('b',0.0,10.0) in self._notify1)
244 244 self.assertRaises(TraitError,setattr,a,'a','bad string')
245 245 self.assertRaises(TraitError,setattr,a,'b','bad string')
246 246 self._notify1 = []
247 247 a.on_trait_change(self.notify1,remove=True)
248 248 a.a = 20
249 249 a.b = 20.0
250 250 self.assertEqual(len(self._notify1),0)
251 251
252 252 def test_notify_one(self):
253 253
254 254 class A(HasTraits):
255 255 a = Int
256 256 b = Float
257 257
258 258 a = A()
259 259 a.on_trait_change(self.notify1, 'a')
260 260 a.a = 0
261 261 self.assertEqual(len(self._notify1),0)
262 262 a.a = 10
263 263 self.assertTrue(('a',0,10) in self._notify1)
264 264 self.assertRaises(TraitError,setattr,a,'a','bad string')
265 265
266 266 def test_subclass(self):
267 267
268 268 class A(HasTraits):
269 269 a = Int
270 270
271 271 class B(A):
272 272 b = Float
273 273
274 274 b = B()
275 275 self.assertEqual(b.a,0)
276 276 self.assertEqual(b.b,0.0)
277 277 b.a = 100
278 278 b.b = 100.0
279 279 self.assertEqual(b.a,100)
280 280 self.assertEqual(b.b,100.0)
281 281
282 282 def test_notify_subclass(self):
283 283
284 284 class A(HasTraits):
285 285 a = Int
286 286
287 287 class B(A):
288 288 b = Float
289 289
290 290 b = B()
291 291 b.on_trait_change(self.notify1, 'a')
292 292 b.on_trait_change(self.notify2, 'b')
293 293 b.a = 0
294 294 b.b = 0.0
295 295 self.assertEqual(len(self._notify1),0)
296 296 self.assertEqual(len(self._notify2),0)
297 297 b.a = 10
298 298 b.b = 10.0
299 299 self.assertTrue(('a',0,10) in self._notify1)
300 300 self.assertTrue(('b',0.0,10.0) in self._notify2)
301 301
302 302 def test_static_notify(self):
303 303
304 304 class A(HasTraits):
305 305 a = Int
306 306 _notify1 = []
307 307 def _a_changed(self, name, old, new):
308 308 self._notify1.append((name, old, new))
309 309
310 310 a = A()
311 311 a.a = 0
312 312 # This is broken!!!
313 313 self.assertEqual(len(a._notify1),0)
314 314 a.a = 10
315 315 self.assertTrue(('a',0,10) in a._notify1)
316 316
317 317 class B(A):
318 318 b = Float
319 319 _notify2 = []
320 320 def _b_changed(self, name, old, new):
321 321 self._notify2.append((name, old, new))
322 322
323 323 b = B()
324 324 b.a = 10
325 325 b.b = 10.0
326 326 self.assertTrue(('a',0,10) in b._notify1)
327 327 self.assertTrue(('b',0.0,10.0) in b._notify2)
328 328
329 329 def test_notify_args(self):
330 330
331 331 def callback0():
332 332 self.cb = ()
333 333 def callback1(name):
334 334 self.cb = (name,)
335 335 def callback2(name, new):
336 336 self.cb = (name, new)
337 337 def callback3(name, old, new):
338 338 self.cb = (name, old, new)
339 339
340 340 class A(HasTraits):
341 341 a = Int
342 342
343 343 a = A()
344 344 a.on_trait_change(callback0, 'a')
345 345 a.a = 10
346 346 self.assertEqual(self.cb,())
347 347 a.on_trait_change(callback0, 'a', remove=True)
348 348
349 349 a.on_trait_change(callback1, 'a')
350 350 a.a = 100
351 351 self.assertEqual(self.cb,('a',))
352 352 a.on_trait_change(callback1, 'a', remove=True)
353 353
354 354 a.on_trait_change(callback2, 'a')
355 355 a.a = 1000
356 356 self.assertEqual(self.cb,('a',1000))
357 357 a.on_trait_change(callback2, 'a', remove=True)
358 358
359 359 a.on_trait_change(callback3, 'a')
360 360 a.a = 10000
361 361 self.assertEqual(self.cb,('a',1000,10000))
362 362 a.on_trait_change(callback3, 'a', remove=True)
363 363
364 364 self.assertEqual(len(a._trait_notifiers['a']),0)
365 365
366 366 def test_notify_only_once(self):
367 367
368 368 class A(HasTraits):
369 369 listen_to = ['a']
370 370
371 371 a = Int(0)
372 372 b = 0
373 373
374 374 def __init__(self, **kwargs):
375 375 super(A, self).__init__(**kwargs)
376 376 self.on_trait_change(self.listener1, ['a'])
377 377
378 378 def listener1(self, name, old, new):
379 379 self.b += 1
380 380
381 381 class B(A):
382 382
383 383 c = 0
384 384 d = 0
385 385
386 386 def __init__(self, **kwargs):
387 387 super(B, self).__init__(**kwargs)
388 388 self.on_trait_change(self.listener2)
389 389
390 390 def listener2(self, name, old, new):
391 391 self.c += 1
392 392
393 393 def _a_changed(self, name, old, new):
394 394 self.d += 1
395 395
396 396 b = B()
397 397 b.a += 1
398 398 self.assertEqual(b.b, b.c)
399 399 self.assertEqual(b.b, b.d)
400 400 b.a += 1
401 401 self.assertEqual(b.b, b.c)
402 402 self.assertEqual(b.b, b.d)
403 403
404 404
405 405 class TestHasTraits(TestCase):
406 406
407 407 def test_trait_names(self):
408 408 class A(HasTraits):
409 409 i = Int
410 410 f = Float
411 411 a = A()
412 412 self.assertEqual(sorted(a.trait_names()),['f','i'])
413 413 self.assertEqual(sorted(A.class_trait_names()),['f','i'])
414 414
415 415 def test_trait_metadata(self):
416 416 class A(HasTraits):
417 417 i = Int(config_key='MY_VALUE')
418 418 a = A()
419 419 self.assertEqual(a.trait_metadata('i','config_key'), 'MY_VALUE')
420 420
421 421 def test_traits(self):
422 422 class A(HasTraits):
423 423 i = Int
424 424 f = Float
425 425 a = A()
426 426 self.assertEqual(a.traits(), dict(i=A.i, f=A.f))
427 427 self.assertEqual(A.class_traits(), dict(i=A.i, f=A.f))
428 428
429 429 def test_traits_metadata(self):
430 430 class A(HasTraits):
431 431 i = Int(config_key='VALUE1', other_thing='VALUE2')
432 432 f = Float(config_key='VALUE3', other_thing='VALUE2')
433 433 j = Int(0)
434 434 a = A()
435 435 self.assertEqual(a.traits(), dict(i=A.i, f=A.f, j=A.j))
436 436 traits = a.traits(config_key='VALUE1', other_thing='VALUE2')
437 437 self.assertEqual(traits, dict(i=A.i))
438 438
439 439 # This passes, but it shouldn't because I am replicating a bug in
440 440 # traits.
441 441 traits = a.traits(config_key=lambda v: True)
442 442 self.assertEqual(traits, dict(i=A.i, f=A.f, j=A.j))
443 443
444 444 def test_init(self):
445 445 class A(HasTraits):
446 446 i = Int()
447 447 x = Float()
448 448 a = A(i=1, x=10.0)
449 449 self.assertEqual(a.i, 1)
450 450 self.assertEqual(a.x, 10.0)
451 451
452 452 def test_positional_args(self):
453 453 class A(HasTraits):
454 454 i = Int(0)
455 455 def __init__(self, i):
456 456 super(A, self).__init__()
457 457 self.i = i
458 458
459 459 a = A(5)
460 460 self.assertEqual(a.i, 5)
461 461 # should raise TypeError if no positional arg given
462 462 self.assertRaises(TypeError, A)
463 463
464 464 #-----------------------------------------------------------------------------
465 465 # Tests for specific trait types
466 466 #-----------------------------------------------------------------------------
467 467
468 468
469 469 class TestType(TestCase):
470 470
471 471 def test_default(self):
472 472
473 473 class B(object): pass
474 474 class A(HasTraits):
475 475 klass = Type
476 476
477 477 a = A()
478 478 self.assertEqual(a.klass, None)
479 479
480 480 a.klass = B
481 481 self.assertEqual(a.klass, B)
482 482 self.assertRaises(TraitError, setattr, a, 'klass', 10)
483 483
484 484 def test_value(self):
485 485
486 486 class B(object): pass
487 487 class C(object): pass
488 488 class A(HasTraits):
489 489 klass = Type(B)
490 490
491 491 a = A()
492 492 self.assertEqual(a.klass, B)
493 493 self.assertRaises(TraitError, setattr, a, 'klass', C)
494 494 self.assertRaises(TraitError, setattr, a, 'klass', object)
495 495 a.klass = B
496 496
497 497 def test_allow_none(self):
498 498
499 499 class B(object): pass
500 500 class C(B): pass
501 501 class A(HasTraits):
502 502 klass = Type(B, allow_none=False)
503 503
504 504 a = A()
505 505 self.assertEqual(a.klass, B)
506 506 self.assertRaises(TraitError, setattr, a, 'klass', None)
507 507 a.klass = C
508 508 self.assertEqual(a.klass, C)
509 509
510 510 def test_validate_klass(self):
511 511
512 512 class A(HasTraits):
513 513 klass = Type('no strings allowed')
514 514
515 515 self.assertRaises(ImportError, A)
516 516
517 517 class A(HasTraits):
518 518 klass = Type('rub.adub.Duck')
519 519
520 520 self.assertRaises(ImportError, A)
521 521
522 522 def test_validate_default(self):
523 523
524 524 class B(object): pass
525 525 class A(HasTraits):
526 526 klass = Type('bad default', B)
527 527
528 528 self.assertRaises(ImportError, A)
529 529
530 530 class C(HasTraits):
531 531 klass = Type(None, B, allow_none=False)
532 532
533 533 self.assertRaises(TraitError, C)
534 534
535 535 def test_str_klass(self):
536 536
537 537 class A(HasTraits):
538 538 klass = Type('IPython.utils.ipstruct.Struct')
539 539
540 540 from IPython.utils.ipstruct import Struct
541 541 a = A()
542 542 a.klass = Struct
543 543 self.assertEqual(a.klass, Struct)
544 544
545 545 self.assertRaises(TraitError, setattr, a, 'klass', 10)
546 546
547 547 class TestInstance(TestCase):
548 548
549 549 def test_basic(self):
550 550 class Foo(object): pass
551 551 class Bar(Foo): pass
552 552 class Bah(object): pass
553 553
554 554 class A(HasTraits):
555 555 inst = Instance(Foo)
556 556
557 557 a = A()
558 558 self.assertTrue(a.inst is None)
559 559 a.inst = Foo()
560 560 self.assertTrue(isinstance(a.inst, Foo))
561 561 a.inst = Bar()
562 562 self.assertTrue(isinstance(a.inst, Foo))
563 563 self.assertRaises(TraitError, setattr, a, 'inst', Foo)
564 564 self.assertRaises(TraitError, setattr, a, 'inst', Bar)
565 565 self.assertRaises(TraitError, setattr, a, 'inst', Bah())
566 566
567 567 def test_unique_default_value(self):
568 568 class Foo(object): pass
569 569 class A(HasTraits):
570 570 inst = Instance(Foo,(),{})
571 571
572 572 a = A()
573 573 b = A()
574 574 self.assertTrue(a.inst is not b.inst)
575 575
576 576 def test_args_kw(self):
577 577 class Foo(object):
578 578 def __init__(self, c): self.c = c
579 579 class Bar(object): pass
580 580 class Bah(object):
581 581 def __init__(self, c, d):
582 582 self.c = c; self.d = d
583 583
584 584 class A(HasTraits):
585 585 inst = Instance(Foo, (10,))
586 586 a = A()
587 587 self.assertEqual(a.inst.c, 10)
588 588
589 589 class B(HasTraits):
590 590 inst = Instance(Bah, args=(10,), kw=dict(d=20))
591 591 b = B()
592 592 self.assertEqual(b.inst.c, 10)
593 593 self.assertEqual(b.inst.d, 20)
594 594
595 595 class C(HasTraits):
596 596 inst = Instance(Foo)
597 597 c = C()
598 598 self.assertTrue(c.inst is None)
599 599
600 600 def test_bad_default(self):
601 601 class Foo(object): pass
602 602
603 603 class A(HasTraits):
604 604 inst = Instance(Foo, allow_none=False)
605 605
606 606 self.assertRaises(TraitError, A)
607 607
608 608 def test_instance(self):
609 609 class Foo(object): pass
610 610
611 611 def inner():
612 612 class A(HasTraits):
613 613 inst = Instance(Foo())
614 614
615 615 self.assertRaises(TraitError, inner)
616 616
617 617
618 618 class TestThis(TestCase):
619 619
620 620 def test_this_class(self):
621 621 class Foo(HasTraits):
622 622 this = This
623 623
624 624 f = Foo()
625 625 self.assertEqual(f.this, None)
626 626 g = Foo()
627 627 f.this = g
628 628 self.assertEqual(f.this, g)
629 629 self.assertRaises(TraitError, setattr, f, 'this', 10)
630 630
631 631 def test_this_inst(self):
632 632 class Foo(HasTraits):
633 633 this = This()
634 634
635 635 f = Foo()
636 636 f.this = Foo()
637 637 self.assertTrue(isinstance(f.this, Foo))
638 638
639 639 def test_subclass(self):
640 640 class Foo(HasTraits):
641 641 t = This()
642 642 class Bar(Foo):
643 643 pass
644 644 f = Foo()
645 645 b = Bar()
646 646 f.t = b
647 647 b.t = f
648 648 self.assertEqual(f.t, b)
649 649 self.assertEqual(b.t, f)
650 650
651 651 def test_subclass_override(self):
652 652 class Foo(HasTraits):
653 653 t = This()
654 654 class Bar(Foo):
655 655 t = This()
656 656 f = Foo()
657 657 b = Bar()
658 658 f.t = b
659 659 self.assertEqual(f.t, b)
660 660 self.assertRaises(TraitError, setattr, b, 't', f)
661 661
662 662 class TraitTestBase(TestCase):
663 663 """A best testing class for basic trait types."""
664 664
665 665 def assign(self, value):
666 666 self.obj.value = value
667 667
668 668 def coerce(self, value):
669 669 return value
670 670
671 671 def test_good_values(self):
672 672 if hasattr(self, '_good_values'):
673 673 for value in self._good_values:
674 674 self.assign(value)
675 675 self.assertEqual(self.obj.value, self.coerce(value))
676 676
677 677 def test_bad_values(self):
678 678 if hasattr(self, '_bad_values'):
679 679 for value in self._bad_values:
680 680 try:
681 681 self.assertRaises(TraitError, self.assign, value)
682 682 except AssertionError:
683 683 assert False, value
684 684
685 685 def test_default_value(self):
686 686 if hasattr(self, '_default_value'):
687 687 self.assertEqual(self._default_value, self.obj.value)
688 688
689 689 def tearDown(self):
690 690 # restore default value after tests, if set
691 691 if hasattr(self, '_default_value'):
692 692 self.obj.value = self._default_value
693 693
694 694
695 695 class AnyTrait(HasTraits):
696 696
697 697 value = Any
698 698
699 699 class AnyTraitTest(TraitTestBase):
700 700
701 701 obj = AnyTrait()
702 702
703 703 _default_value = None
704 704 _good_values = [10.0, 'ten', u'ten', [10], {'ten': 10},(10,), None, 1j]
705 705 _bad_values = []
706 706
707 707
708 708 class IntTrait(HasTraits):
709 709
710 710 value = Int(99)
711 711
712 712 class TestInt(TraitTestBase):
713 713
714 714 obj = IntTrait()
715 715 _default_value = 99
716 716 _good_values = [10, -10]
717 717 _bad_values = ['ten', u'ten', [10], {'ten': 10},(10,), None, 1j,
718 718 10.1, -10.1, '10L', '-10L', '10.1', '-10.1', u'10L',
719 719 u'-10L', u'10.1', u'-10.1', '10', '-10', u'10', u'-10']
720 720 if not py3compat.PY3:
721 721 _bad_values.extend([long(10), long(-10), 10*sys.maxint, -10*sys.maxint])
722 722
723 723
724 724 class LongTrait(HasTraits):
725 725
726 726 value = Long(99 if py3compat.PY3 else long(99))
727 727
728 728 class TestLong(TraitTestBase):
729 729
730 730 obj = LongTrait()
731 731
732 732 _default_value = 99 if py3compat.PY3 else long(99)
733 733 _good_values = [10, -10]
734 734 _bad_values = ['ten', u'ten', [10], {'ten': 10},(10,),
735 735 None, 1j, 10.1, -10.1, '10', '-10', '10L', '-10L', '10.1',
736 736 '-10.1', u'10', u'-10', u'10L', u'-10L', u'10.1',
737 737 u'-10.1']
738 738 if not py3compat.PY3:
739 739 # maxint undefined on py3, because int == long
740 740 _good_values.extend([long(10), long(-10), 10*sys.maxint, -10*sys.maxint])
741 741 _bad_values.extend([[long(10)], (long(10),)])
742 742
743 743 @skipif(py3compat.PY3, "not relevant on py3")
744 744 def test_cast_small(self):
745 745 """Long casts ints to long"""
746 746 self.obj.value = 10
747 747 self.assertEqual(type(self.obj.value), long)
748 748
749 749
750 750 class IntegerTrait(HasTraits):
751 751 value = Integer(1)
752 752
753 753 class TestInteger(TestLong):
754 754 obj = IntegerTrait()
755 755 _default_value = 1
756 756
757 757 def coerce(self, n):
758 758 return int(n)
759 759
760 760 @skipif(py3compat.PY3, "not relevant on py3")
761 761 def test_cast_small(self):
762 762 """Integer casts small longs to int"""
763 763 if py3compat.PY3:
764 764 raise SkipTest("not relevant on py3")
765 765
766 766 self.obj.value = long(100)
767 767 self.assertEqual(type(self.obj.value), int)
768 768
769 769
770 770 class FloatTrait(HasTraits):
771 771
772 772 value = Float(99.0)
773 773
774 774 class TestFloat(TraitTestBase):
775 775
776 776 obj = FloatTrait()
777 777
778 778 _default_value = 99.0
779 779 _good_values = [10, -10, 10.1, -10.1]
780 780 _bad_values = ['ten', u'ten', [10], {'ten': 10},(10,), None,
781 781 1j, '10', '-10', '10L', '-10L', '10.1', '-10.1', u'10',
782 782 u'-10', u'10L', u'-10L', u'10.1', u'-10.1']
783 783 if not py3compat.PY3:
784 784 _bad_values.extend([long(10), long(-10)])
785 785
786 786
787 787 class ComplexTrait(HasTraits):
788 788
789 789 value = Complex(99.0-99.0j)
790 790
791 791 class TestComplex(TraitTestBase):
792 792
793 793 obj = ComplexTrait()
794 794
795 795 _default_value = 99.0-99.0j
796 796 _good_values = [10, -10, 10.1, -10.1, 10j, 10+10j, 10-10j,
797 797 10.1j, 10.1+10.1j, 10.1-10.1j]
798 798 _bad_values = [u'10L', u'-10L', 'ten', [10], {'ten': 10},(10,), None]
799 799 if not py3compat.PY3:
800 800 _bad_values.extend([long(10), long(-10)])
801 801
802 802
803 803 class BytesTrait(HasTraits):
804 804
805 805 value = Bytes(b'string')
806 806
807 807 class TestBytes(TraitTestBase):
808 808
809 809 obj = BytesTrait()
810 810
811 811 _default_value = b'string'
812 812 _good_values = [b'10', b'-10', b'10L',
813 813 b'-10L', b'10.1', b'-10.1', b'string']
814 814 _bad_values = [10, -10, 10.1, -10.1, 1j, [10],
815 815 ['ten'],{'ten': 10},(10,), None, u'string']
816 816 if not py3compat.PY3:
817 817 _bad_values.extend([long(10), long(-10)])
818 818
819 819
820 820 class UnicodeTrait(HasTraits):
821 821
822 822 value = Unicode(u'unicode')
823 823
824 824 class TestUnicode(TraitTestBase):
825 825
826 826 obj = UnicodeTrait()
827 827
828 828 _default_value = u'unicode'
829 829 _good_values = ['10', '-10', '10L', '-10L', '10.1',
830 830 '-10.1', '', u'', 'string', u'string', u"€"]
831 831 _bad_values = [10, -10, 10.1, -10.1, 1j,
832 832 [10], ['ten'], [u'ten'], {'ten': 10},(10,), None]
833 833 if not py3compat.PY3:
834 834 _bad_values.extend([long(10), long(-10)])
835 835
836 836
837 837 class ObjectNameTrait(HasTraits):
838 838 value = ObjectName("abc")
839 839
840 840 class TestObjectName(TraitTestBase):
841 841 obj = ObjectNameTrait()
842 842
843 843 _default_value = "abc"
844 844 _good_values = ["a", "gh", "g9", "g_", "_G", u"a345_"]
845 845 _bad_values = [1, "", u"€", "9g", "!", "#abc", "aj@", "a.b", "a()", "a[0]",
846 846 object(), object]
847 847 if sys.version_info[0] < 3:
848 848 _bad_values.append(u"ΓΎ")
849 849 else:
850 850 _good_values.append(u"ΓΎ") # ΓΎ=1 is valid in Python 3 (PEP 3131).
851 851
852 852
853 853 class DottedObjectNameTrait(HasTraits):
854 854 value = DottedObjectName("a.b")
855 855
856 856 class TestDottedObjectName(TraitTestBase):
857 857 obj = DottedObjectNameTrait()
858 858
859 859 _default_value = "a.b"
860 860 _good_values = ["A", "y.t", "y765.__repr__", "os.path.join", u"os.path.join"]
861 861 _bad_values = [1, u"abc.€", "_.@", ".", ".abc", "abc.", ".abc."]
862 862 if sys.version_info[0] < 3:
863 863 _bad_values.append(u"t.ΓΎ")
864 864 else:
865 865 _good_values.append(u"t.ΓΎ")
866 866
867 867
868 868 class TCPAddressTrait(HasTraits):
869 869
870 870 value = TCPAddress()
871 871
872 872 class TestTCPAddress(TraitTestBase):
873 873
874 874 obj = TCPAddressTrait()
875 875
876 876 _default_value = ('127.0.0.1',0)
877 877 _good_values = [('localhost',0),('192.168.0.1',1000),('www.google.com',80)]
878 878 _bad_values = [(0,0),('localhost',10.0),('localhost',-1)]
879 879
880 880 class ListTrait(HasTraits):
881 881
882 882 value = List(Int)
883 883
884 884 class TestList(TraitTestBase):
885 885
886 886 obj = ListTrait()
887 887
888 888 _default_value = []
889 _good_values = [[], [1], range(10)]
889 _good_values = [[], [1], list(range(10))]
890 890 _bad_values = [10, [1,'a'], 'a', (1,2)]
891 891
892 892 class LenListTrait(HasTraits):
893 893
894 894 value = List(Int, [0], minlen=1, maxlen=2)
895 895
896 896 class TestLenList(TraitTestBase):
897 897
898 898 obj = LenListTrait()
899 899
900 900 _default_value = [0]
901 _good_values = [[1], range(2)]
902 _bad_values = [10, [1,'a'], 'a', (1,2), [], range(3)]
901 _good_values = [[1], list(range(2))]
902 _bad_values = [10, [1,'a'], 'a', (1,2), [], list(range(3))]
903 903
904 904 class TupleTrait(HasTraits):
905 905
906 906 value = Tuple(Int)
907 907
908 908 class TestTupleTrait(TraitTestBase):
909 909
910 910 obj = TupleTrait()
911 911
912 912 _default_value = None
913 913 _good_values = [(1,), None,(0,)]
914 914 _bad_values = [10, (1,2), [1],('a'), ()]
915 915
916 916 def test_invalid_args(self):
917 917 self.assertRaises(TypeError, Tuple, 5)
918 918 self.assertRaises(TypeError, Tuple, default_value='hello')
919 919 t = Tuple(Int, CBytes, default_value=(1,5))
920 920
921 921 class LooseTupleTrait(HasTraits):
922 922
923 923 value = Tuple((1,2,3))
924 924
925 925 class TestLooseTupleTrait(TraitTestBase):
926 926
927 927 obj = LooseTupleTrait()
928 928
929 929 _default_value = (1,2,3)
930 930 _good_values = [(1,), None, (0,), tuple(range(5)), tuple('hello'), ('a',5), ()]
931 931 _bad_values = [10, 'hello', [1], []]
932 932
933 933 def test_invalid_args(self):
934 934 self.assertRaises(TypeError, Tuple, 5)
935 935 self.assertRaises(TypeError, Tuple, default_value='hello')
936 936 t = Tuple(Int, CBytes, default_value=(1,5))
937 937
938 938
939 939 class MultiTupleTrait(HasTraits):
940 940
941 941 value = Tuple(Int, Bytes, default_value=[99,b'bottles'])
942 942
943 943 class TestMultiTuple(TraitTestBase):
944 944
945 945 obj = MultiTupleTrait()
946 946
947 947 _default_value = (99,b'bottles')
948 948 _good_values = [(1,b'a'), (2,b'b')]
949 949 _bad_values = ((),10, b'a', (1,b'a',3), (b'a',1), (1, u'a'))
950 950
951 951 class CRegExpTrait(HasTraits):
952 952
953 953 value = CRegExp(r'')
954 954
955 955 class TestCRegExp(TraitTestBase):
956 956
957 957 def coerce(self, value):
958 958 return re.compile(value)
959 959
960 960 obj = CRegExpTrait()
961 961
962 962 _default_value = re.compile(r'')
963 963 _good_values = [r'\d+', re.compile(r'\d+')]
964 964 _bad_values = [r'(', None, ()]
965 965
966 966 class DictTrait(HasTraits):
967 967 value = Dict()
968 968
969 969 def test_dict_assignment():
970 970 d = dict()
971 971 c = DictTrait()
972 972 c.value = d
973 973 d['a'] = 5
974 974 nt.assert_equal(d, c.value)
975 975 nt.assert_true(c.value is d)
@@ -1,758 +1,758 b''
1 1 # encoding: utf-8
2 2 """
3 3 Utilities for working with strings and text.
4 4
5 5 Inheritance diagram:
6 6
7 7 .. inheritance-diagram:: IPython.utils.text
8 8 :parts: 3
9 9 """
10 10
11 11 #-----------------------------------------------------------------------------
12 12 # Copyright (C) 2008-2011 The IPython Development Team
13 13 #
14 14 # Distributed under the terms of the BSD License. The full license is in
15 15 # the file COPYING, distributed as part of this software.
16 16 #-----------------------------------------------------------------------------
17 17
18 18 #-----------------------------------------------------------------------------
19 19 # Imports
20 20 #-----------------------------------------------------------------------------
21 21
22 22 import os
23 23 import re
24 24 import sys
25 25 import textwrap
26 26 from string import Formatter
27 27
28 28 from IPython.external.path import path
29 29 from IPython.testing.skipdoctest import skip_doctest_py3, skip_doctest
30 30 from IPython.utils import py3compat
31 31
32 32 #-----------------------------------------------------------------------------
33 33 # Declarations
34 34 #-----------------------------------------------------------------------------
35 35
36 36 # datetime.strftime date format for ipython
37 37 if sys.platform == 'win32':
38 38 date_format = "%B %d, %Y"
39 39 else:
40 40 date_format = "%B %-d, %Y"
41 41
42 42
43 43 #-----------------------------------------------------------------------------
44 44 # Code
45 45 #-----------------------------------------------------------------------------
46 46
47 47 class LSString(str):
48 48 """String derivative with a special access attributes.
49 49
50 50 These are normal strings, but with the special attributes:
51 51
52 52 .l (or .list) : value as list (split on newlines).
53 53 .n (or .nlstr): original value (the string itself).
54 54 .s (or .spstr): value as whitespace-separated string.
55 55 .p (or .paths): list of path objects
56 56
57 57 Any values which require transformations are computed only once and
58 58 cached.
59 59
60 60 Such strings are very useful to efficiently interact with the shell, which
61 61 typically only understands whitespace-separated options for commands."""
62 62
63 63 def get_list(self):
64 64 try:
65 65 return self.__list
66 66 except AttributeError:
67 67 self.__list = self.split('\n')
68 68 return self.__list
69 69
70 70 l = list = property(get_list)
71 71
72 72 def get_spstr(self):
73 73 try:
74 74 return self.__spstr
75 75 except AttributeError:
76 76 self.__spstr = self.replace('\n',' ')
77 77 return self.__spstr
78 78
79 79 s = spstr = property(get_spstr)
80 80
81 81 def get_nlstr(self):
82 82 return self
83 83
84 84 n = nlstr = property(get_nlstr)
85 85
86 86 def get_paths(self):
87 87 try:
88 88 return self.__paths
89 89 except AttributeError:
90 90 self.__paths = [path(p) for p in self.split('\n') if os.path.exists(p)]
91 91 return self.__paths
92 92
93 93 p = paths = property(get_paths)
94 94
95 95 # FIXME: We need to reimplement type specific displayhook and then add this
96 96 # back as a custom printer. This should also be moved outside utils into the
97 97 # core.
98 98
99 99 # def print_lsstring(arg):
100 100 # """ Prettier (non-repr-like) and more informative printer for LSString """
101 101 # print "LSString (.p, .n, .l, .s available). Value:"
102 102 # print arg
103 103 #
104 104 #
105 105 # print_lsstring = result_display.when_type(LSString)(print_lsstring)
106 106
107 107
108 108 class SList(list):
109 109 """List derivative with a special access attributes.
110 110
111 111 These are normal lists, but with the special attributes:
112 112
113 113 .l (or .list) : value as list (the list itself).
114 114 .n (or .nlstr): value as a string, joined on newlines.
115 115 .s (or .spstr): value as a string, joined on spaces.
116 116 .p (or .paths): list of path objects
117 117
118 118 Any values which require transformations are computed only once and
119 119 cached."""
120 120
121 121 def get_list(self):
122 122 return self
123 123
124 124 l = list = property(get_list)
125 125
126 126 def get_spstr(self):
127 127 try:
128 128 return self.__spstr
129 129 except AttributeError:
130 130 self.__spstr = ' '.join(self)
131 131 return self.__spstr
132 132
133 133 s = spstr = property(get_spstr)
134 134
135 135 def get_nlstr(self):
136 136 try:
137 137 return self.__nlstr
138 138 except AttributeError:
139 139 self.__nlstr = '\n'.join(self)
140 140 return self.__nlstr
141 141
142 142 n = nlstr = property(get_nlstr)
143 143
144 144 def get_paths(self):
145 145 try:
146 146 return self.__paths
147 147 except AttributeError:
148 148 self.__paths = [path(p) for p in self if os.path.exists(p)]
149 149 return self.__paths
150 150
151 151 p = paths = property(get_paths)
152 152
153 153 def grep(self, pattern, prune = False, field = None):
154 154 """ Return all strings matching 'pattern' (a regex or callable)
155 155
156 156 This is case-insensitive. If prune is true, return all items
157 157 NOT matching the pattern.
158 158
159 159 If field is specified, the match must occur in the specified
160 160 whitespace-separated field.
161 161
162 162 Examples::
163 163
164 164 a.grep( lambda x: x.startswith('C') )
165 165 a.grep('Cha.*log', prune=1)
166 166 a.grep('chm', field=-1)
167 167 """
168 168
169 169 def match_target(s):
170 170 if field is None:
171 171 return s
172 172 parts = s.split()
173 173 try:
174 174 tgt = parts[field]
175 175 return tgt
176 176 except IndexError:
177 177 return ""
178 178
179 179 if isinstance(pattern, py3compat.string_types):
180 180 pred = lambda x : re.search(pattern, x, re.IGNORECASE)
181 181 else:
182 182 pred = pattern
183 183 if not prune:
184 184 return SList([el for el in self if pred(match_target(el))])
185 185 else:
186 186 return SList([el for el in self if not pred(match_target(el))])
187 187
188 188 def fields(self, *fields):
189 189 """ Collect whitespace-separated fields from string list
190 190
191 191 Allows quick awk-like usage of string lists.
192 192
193 193 Example data (in var a, created by 'a = !ls -l')::
194 194 -rwxrwxrwx 1 ville None 18 Dec 14 2006 ChangeLog
195 195 drwxrwxrwx+ 6 ville None 0 Oct 24 18:05 IPython
196 196
197 197 a.fields(0) is ['-rwxrwxrwx', 'drwxrwxrwx+']
198 198 a.fields(1,0) is ['1 -rwxrwxrwx', '6 drwxrwxrwx+']
199 199 (note the joining by space).
200 200 a.fields(-1) is ['ChangeLog', 'IPython']
201 201
202 202 IndexErrors are ignored.
203 203
204 204 Without args, fields() just split()'s the strings.
205 205 """
206 206 if len(fields) == 0:
207 207 return [el.split() for el in self]
208 208
209 209 res = SList()
210 210 for el in [f.split() for f in self]:
211 211 lineparts = []
212 212
213 213 for fd in fields:
214 214 try:
215 215 lineparts.append(el[fd])
216 216 except IndexError:
217 217 pass
218 218 if lineparts:
219 219 res.append(" ".join(lineparts))
220 220
221 221 return res
222 222
223 223 def sort(self,field= None, nums = False):
224 224 """ sort by specified fields (see fields())
225 225
226 226 Example::
227 227 a.sort(1, nums = True)
228 228
229 229 Sorts a by second field, in numerical order (so that 21 > 3)
230 230
231 231 """
232 232
233 233 #decorate, sort, undecorate
234 234 if field is not None:
235 235 dsu = [[SList([line]).fields(field), line] for line in self]
236 236 else:
237 237 dsu = [[line, line] for line in self]
238 238 if nums:
239 239 for i in range(len(dsu)):
240 240 numstr = "".join([ch for ch in dsu[i][0] if ch.isdigit()])
241 241 try:
242 242 n = int(numstr)
243 243 except ValueError:
244 244 n = 0;
245 245 dsu[i][0] = n
246 246
247 247
248 248 dsu.sort()
249 249 return SList([t[1] for t in dsu])
250 250
251 251
252 252 # FIXME: We need to reimplement type specific displayhook and then add this
253 253 # back as a custom printer. This should also be moved outside utils into the
254 254 # core.
255 255
256 256 # def print_slist(arg):
257 257 # """ Prettier (non-repr-like) and more informative printer for SList """
258 258 # print "SList (.p, .n, .l, .s, .grep(), .fields(), sort() available):"
259 259 # if hasattr(arg, 'hideonce') and arg.hideonce:
260 260 # arg.hideonce = False
261 261 # return
262 262 #
263 263 # nlprint(arg) # This was a nested list printer, now removed.
264 264 #
265 265 # print_slist = result_display.when_type(SList)(print_slist)
266 266
267 267
268 268 def indent(instr,nspaces=4, ntabs=0, flatten=False):
269 269 """Indent a string a given number of spaces or tabstops.
270 270
271 271 indent(str,nspaces=4,ntabs=0) -> indent str by ntabs+nspaces.
272 272
273 273 Parameters
274 274 ----------
275 275
276 276 instr : basestring
277 277 The string to be indented.
278 278 nspaces : int (default: 4)
279 279 The number of spaces to be indented.
280 280 ntabs : int (default: 0)
281 281 The number of tabs to be indented.
282 282 flatten : bool (default: False)
283 283 Whether to scrub existing indentation. If True, all lines will be
284 284 aligned to the same indentation. If False, existing indentation will
285 285 be strictly increased.
286 286
287 287 Returns
288 288 -------
289 289
290 290 str|unicode : string indented by ntabs and nspaces.
291 291
292 292 """
293 293 if instr is None:
294 294 return
295 295 ind = '\t'*ntabs+' '*nspaces
296 296 if flatten:
297 297 pat = re.compile(r'^\s*', re.MULTILINE)
298 298 else:
299 299 pat = re.compile(r'^', re.MULTILINE)
300 300 outstr = re.sub(pat, ind, instr)
301 301 if outstr.endswith(os.linesep+ind):
302 302 return outstr[:-len(ind)]
303 303 else:
304 304 return outstr
305 305
306 306
307 307 def list_strings(arg):
308 308 """Always return a list of strings, given a string or list of strings
309 309 as input.
310 310
311 311 :Examples:
312 312
313 313 In [7]: list_strings('A single string')
314 314 Out[7]: ['A single string']
315 315
316 316 In [8]: list_strings(['A single string in a list'])
317 317 Out[8]: ['A single string in a list']
318 318
319 319 In [9]: list_strings(['A','list','of','strings'])
320 320 Out[9]: ['A', 'list', 'of', 'strings']
321 321 """
322 322
323 323 if isinstance(arg, py3compat.string_types): return [arg]
324 324 else: return arg
325 325
326 326
327 327 def marquee(txt='',width=78,mark='*'):
328 328 """Return the input string centered in a 'marquee'.
329 329
330 330 :Examples:
331 331
332 332 In [16]: marquee('A test',40)
333 333 Out[16]: '**************** A test ****************'
334 334
335 335 In [17]: marquee('A test',40,'-')
336 336 Out[17]: '---------------- A test ----------------'
337 337
338 338 In [18]: marquee('A test',40,' ')
339 339 Out[18]: ' A test '
340 340
341 341 """
342 342 if not txt:
343 343 return (mark*width)[:width]
344 344 nmark = (width-len(txt)-2)//len(mark)//2
345 345 if nmark < 0: nmark =0
346 346 marks = mark*nmark
347 347 return '%s %s %s' % (marks,txt,marks)
348 348
349 349
350 350 ini_spaces_re = re.compile(r'^(\s+)')
351 351
352 352 def num_ini_spaces(strng):
353 353 """Return the number of initial spaces in a string"""
354 354
355 355 ini_spaces = ini_spaces_re.match(strng)
356 356 if ini_spaces:
357 357 return ini_spaces.end()
358 358 else:
359 359 return 0
360 360
361 361
362 362 def format_screen(strng):
363 363 """Format a string for screen printing.
364 364
365 365 This removes some latex-type format codes."""
366 366 # Paragraph continue
367 367 par_re = re.compile(r'\\$',re.MULTILINE)
368 368 strng = par_re.sub('',strng)
369 369 return strng
370 370
371 371
372 372 def dedent(text):
373 373 """Equivalent of textwrap.dedent that ignores unindented first line.
374 374
375 375 This means it will still dedent strings like:
376 376 '''foo
377 377 is a bar
378 378 '''
379 379
380 380 For use in wrap_paragraphs.
381 381 """
382 382
383 383 if text.startswith('\n'):
384 384 # text starts with blank line, don't ignore the first line
385 385 return textwrap.dedent(text)
386 386
387 387 # split first line
388 388 splits = text.split('\n',1)
389 389 if len(splits) == 1:
390 390 # only one line
391 391 return textwrap.dedent(text)
392 392
393 393 first, rest = splits
394 394 # dedent everything but the first line
395 395 rest = textwrap.dedent(rest)
396 396 return '\n'.join([first, rest])
397 397
398 398
399 399 def wrap_paragraphs(text, ncols=80):
400 400 """Wrap multiple paragraphs to fit a specified width.
401 401
402 402 This is equivalent to textwrap.wrap, but with support for multiple
403 403 paragraphs, as separated by empty lines.
404 404
405 405 Returns
406 406 -------
407 407
408 408 list of complete paragraphs, wrapped to fill `ncols` columns.
409 409 """
410 410 paragraph_re = re.compile(r'\n(\s*\n)+', re.MULTILINE)
411 411 text = dedent(text).strip()
412 412 paragraphs = paragraph_re.split(text)[::2] # every other entry is space
413 413 out_ps = []
414 414 indent_re = re.compile(r'\n\s+', re.MULTILINE)
415 415 for p in paragraphs:
416 416 # presume indentation that survives dedent is meaningful formatting,
417 417 # so don't fill unless text is flush.
418 418 if indent_re.search(p) is None:
419 419 # wrap paragraph
420 420 p = textwrap.fill(p, ncols)
421 421 out_ps.append(p)
422 422 return out_ps
423 423
424 424
425 425 def long_substr(data):
426 426 """Return the longest common substring in a list of strings.
427 427
428 428 Credit: http://stackoverflow.com/questions/2892931/longest-common-substring-from-more-than-two-strings-python
429 429 """
430 430 substr = ''
431 431 if len(data) > 1 and len(data[0]) > 0:
432 432 for i in range(len(data[0])):
433 433 for j in range(len(data[0])-i+1):
434 434 if j > len(substr) and all(data[0][i:i+j] in x for x in data):
435 435 substr = data[0][i:i+j]
436 436 elif len(data) == 1:
437 437 substr = data[0]
438 438 return substr
439 439
440 440
441 441 def strip_email_quotes(text):
442 442 """Strip leading email quotation characters ('>').
443 443
444 444 Removes any combination of leading '>' interspersed with whitespace that
445 445 appears *identically* in all lines of the input text.
446 446
447 447 Parameters
448 448 ----------
449 449 text : str
450 450
451 451 Examples
452 452 --------
453 453
454 454 Simple uses::
455 455
456 456 In [2]: strip_email_quotes('> > text')
457 457 Out[2]: 'text'
458 458
459 459 In [3]: strip_email_quotes('> > text\\n> > more')
460 460 Out[3]: 'text\\nmore'
461 461
462 462 Note how only the common prefix that appears in all lines is stripped::
463 463
464 464 In [4]: strip_email_quotes('> > text\\n> > more\\n> more...')
465 465 Out[4]: '> text\\n> more\\nmore...'
466 466
467 467 So if any line has no quote marks ('>') , then none are stripped from any
468 468 of them ::
469 469
470 470 In [5]: strip_email_quotes('> > text\\n> > more\\nlast different')
471 471 Out[5]: '> > text\\n> > more\\nlast different'
472 472 """
473 473 lines = text.splitlines()
474 474 matches = set()
475 475 for line in lines:
476 476 prefix = re.match(r'^(\s*>[ >]*)', line)
477 477 if prefix:
478 478 matches.add(prefix.group(1))
479 479 else:
480 480 break
481 481 else:
482 482 prefix = long_substr(list(matches))
483 483 if prefix:
484 484 strip = len(prefix)
485 485 text = '\n'.join([ ln[strip:] for ln in lines])
486 486 return text
487 487
488 488
489 489 class EvalFormatter(Formatter):
490 490 """A String Formatter that allows evaluation of simple expressions.
491 491
492 492 Note that this version interprets a : as specifying a format string (as per
493 493 standard string formatting), so if slicing is required, you must explicitly
494 494 create a slice.
495 495
496 496 This is to be used in templating cases, such as the parallel batch
497 497 script templates, where simple arithmetic on arguments is useful.
498 498
499 499 Examples
500 500 --------
501 501
502 502 In [1]: f = EvalFormatter()
503 503 In [2]: f.format('{n//4}', n=8)
504 504 Out [2]: '2'
505 505
506 506 In [3]: f.format("{greeting[slice(2,4)]}", greeting="Hello")
507 507 Out [3]: 'll'
508 508 """
509 509 def get_field(self, name, args, kwargs):
510 510 v = eval(name, kwargs)
511 511 return v, name
512 512
513 513
514 514 @skip_doctest_py3
515 515 class FullEvalFormatter(Formatter):
516 516 """A String Formatter that allows evaluation of simple expressions.
517 517
518 518 Any time a format key is not found in the kwargs,
519 519 it will be tried as an expression in the kwargs namespace.
520 520
521 521 Note that this version allows slicing using [1:2], so you cannot specify
522 522 a format string. Use :class:`EvalFormatter` to permit format strings.
523 523
524 524 Examples
525 525 --------
526 526
527 527 In [1]: f = FullEvalFormatter()
528 528 In [2]: f.format('{n//4}', n=8)
529 529 Out[2]: u'2'
530 530
531 531 In [3]: f.format('{list(range(5))[2:4]}')
532 532 Out[3]: u'[2, 3]'
533 533
534 534 In [4]: f.format('{3*2}')
535 535 Out[4]: u'6'
536 536 """
537 537 # copied from Formatter._vformat with minor changes to allow eval
538 538 # and replace the format_spec code with slicing
539 539 def _vformat(self, format_string, args, kwargs, used_args, recursion_depth):
540 540 if recursion_depth < 0:
541 541 raise ValueError('Max string recursion exceeded')
542 542 result = []
543 543 for literal_text, field_name, format_spec, conversion in \
544 544 self.parse(format_string):
545 545
546 546 # output the literal text
547 547 if literal_text:
548 548 result.append(literal_text)
549 549
550 550 # if there's a field, output it
551 551 if field_name is not None:
552 552 # this is some markup, find the object and do
553 553 # the formatting
554 554
555 555 if format_spec:
556 556 # override format spec, to allow slicing:
557 557 field_name = ':'.join([field_name, format_spec])
558 558
559 559 # eval the contents of the field for the object
560 560 # to be formatted
561 561 obj = eval(field_name, kwargs)
562 562
563 563 # do any conversion on the resulting object
564 564 obj = self.convert_field(obj, conversion)
565 565
566 566 # format the object and append to the result
567 567 result.append(self.format_field(obj, ''))
568 568
569 569 return u''.join(py3compat.cast_unicode(s) for s in result)
570 570
571 571
572 572 @skip_doctest_py3
573 573 class DollarFormatter(FullEvalFormatter):
574 574 """Formatter allowing Itpl style $foo replacement, for names and attribute
575 575 access only. Standard {foo} replacement also works, and allows full
576 576 evaluation of its arguments.
577 577
578 578 Examples
579 579 --------
580 580 In [1]: f = DollarFormatter()
581 581 In [2]: f.format('{n//4}', n=8)
582 582 Out[2]: u'2'
583 583
584 584 In [3]: f.format('23 * 76 is $result', result=23*76)
585 585 Out[3]: u'23 * 76 is 1748'
586 586
587 587 In [4]: f.format('$a or {b}', a=1, b=2)
588 588 Out[4]: u'1 or 2'
589 589 """
590 590 _dollar_pattern = re.compile("(.*?)\$(\$?[\w\.]+)")
591 591 def parse(self, fmt_string):
592 592 for literal_txt, field_name, format_spec, conversion \
593 593 in Formatter.parse(self, fmt_string):
594 594
595 595 # Find $foo patterns in the literal text.
596 596 continue_from = 0
597 597 txt = ""
598 598 for m in self._dollar_pattern.finditer(literal_txt):
599 599 new_txt, new_field = m.group(1,2)
600 600 # $$foo --> $foo
601 601 if new_field.startswith("$"):
602 602 txt += new_txt + new_field
603 603 else:
604 604 yield (txt + new_txt, new_field, "", None)
605 605 txt = ""
606 606 continue_from = m.end()
607 607
608 608 # Re-yield the {foo} style pattern
609 609 yield (txt + literal_txt[continue_from:], field_name, format_spec, conversion)
610 610
611 611 #-----------------------------------------------------------------------------
612 612 # Utils to columnize a list of string
613 613 #-----------------------------------------------------------------------------
614 614
615 615 def _chunks(l, n):
616 616 """Yield successive n-sized chunks from l."""
617 617 for i in py3compat.xrange(0, len(l), n):
618 618 yield l[i:i+n]
619 619
620 620
621 621 def _find_optimal(rlist , separator_size=2 , displaywidth=80):
622 622 """Calculate optimal info to columnize a list of string"""
623 623 for nrow in range(1, len(rlist)+1) :
624 chk = map(max,_chunks(rlist, nrow))
624 chk = list(map(max,_chunks(rlist, nrow)))
625 625 sumlength = sum(chk)
626 626 ncols = len(chk)
627 627 if sumlength+separator_size*(ncols-1) <= displaywidth :
628 628 break;
629 629 return {'columns_numbers' : ncols,
630 630 'optimal_separator_width':(displaywidth - sumlength)/(ncols-1) if (ncols -1) else 0,
631 631 'rows_numbers' : nrow,
632 632 'columns_width' : chk
633 633 }
634 634
635 635
636 636 def _get_or_default(mylist, i, default=None):
637 637 """return list item number, or default if don't exist"""
638 638 if i >= len(mylist):
639 639 return default
640 640 else :
641 641 return mylist[i]
642 642
643 643
644 644 @skip_doctest
645 645 def compute_item_matrix(items, empty=None, *args, **kwargs) :
646 646 """Returns a nested list, and info to columnize items
647 647
648 648 Parameters
649 649 ----------
650 650
651 651 items :
652 652 list of strings to columize
653 653 empty : (default None)
654 654 default value to fill list if needed
655 655 separator_size : int (default=2)
656 656 How much caracters will be used as a separation between each columns.
657 657 displaywidth : int (default=80)
658 658 The width of the area onto wich the columns should enter
659 659
660 660 Returns
661 661 -------
662 662
663 663 Returns a tuple of (strings_matrix, dict_info)
664 664
665 665 strings_matrix :
666 666
667 667 nested list of string, the outer most list contains as many list as
668 668 rows, the innermost lists have each as many element as colums. If the
669 669 total number of elements in `items` does not equal the product of
670 670 rows*columns, the last element of some lists are filled with `None`.
671 671
672 672 dict_info :
673 673 some info to make columnize easier:
674 674
675 675 columns_numbers : number of columns
676 676 rows_numbers : number of rows
677 677 columns_width : list of with of each columns
678 678 optimal_separator_width : best separator width between columns
679 679
680 680 Examples
681 681 --------
682 682
683 683 In [1]: l = ['aaa','b','cc','d','eeeee','f','g','h','i','j','k','l']
684 684 ...: compute_item_matrix(l,displaywidth=12)
685 685 Out[1]:
686 686 ([['aaa', 'f', 'k'],
687 687 ['b', 'g', 'l'],
688 688 ['cc', 'h', None],
689 689 ['d', 'i', None],
690 690 ['eeeee', 'j', None]],
691 691 {'columns_numbers': 3,
692 692 'columns_width': [5, 1, 1],
693 693 'optimal_separator_width': 2,
694 694 'rows_numbers': 5})
695 695
696 696 """
697 info = _find_optimal(map(len, items), *args, **kwargs)
697 info = _find_optimal(list(map(len, items)), *args, **kwargs)
698 698 nrow, ncol = info['rows_numbers'], info['columns_numbers']
699 699 return ([[ _get_or_default(items, c*nrow+i, default=empty) for c in range(ncol) ] for i in range(nrow) ], info)
700 700
701 701
702 702 def columnize(items, separator=' ', displaywidth=80):
703 703 """ Transform a list of strings into a single string with columns.
704 704
705 705 Parameters
706 706 ----------
707 707 items : sequence of strings
708 708 The strings to process.
709 709
710 710 separator : str, optional [default is two spaces]
711 711 The string that separates columns.
712 712
713 713 displaywidth : int, optional [default is 80]
714 714 Width of the display in number of characters.
715 715
716 716 Returns
717 717 -------
718 718 The formatted string.
719 719 """
720 720 if not items :
721 721 return '\n'
722 722 matrix, info = compute_item_matrix(items, separator_size=len(separator), displaywidth=displaywidth)
723 723 fmatrix = [filter(None, x) for x in matrix]
724 724 sjoin = lambda x : separator.join([ y.ljust(w, ' ') for y, w in zip(x, info['columns_width'])])
725 725 return '\n'.join(map(sjoin, fmatrix))+'\n'
726 726
727 727
728 728 def get_text_list(list_, last_sep=' and ', sep=", ", wrap_item_with=""):
729 729 """
730 730 Return a string with a natural enumeration of items
731 731
732 732 >>> get_text_list(['a', 'b', 'c', 'd'])
733 733 'a, b, c and d'
734 734 >>> get_text_list(['a', 'b', 'c'], ' or ')
735 735 'a, b or c'
736 736 >>> get_text_list(['a', 'b', 'c'], ', ')
737 737 'a, b, c'
738 738 >>> get_text_list(['a', 'b'], ' or ')
739 739 'a or b'
740 740 >>> get_text_list(['a'])
741 741 'a'
742 742 >>> get_text_list([])
743 743 ''
744 744 >>> get_text_list(['a', 'b'], wrap_item_with="`")
745 745 '`a` and `b`'
746 746 >>> get_text_list(['a', 'b', 'c', 'd'], " = ", sep=" + ")
747 747 'a + b + c = d'
748 748 """
749 749 if len(list_) == 0:
750 750 return ''
751 751 if wrap_item_with:
752 752 list_ = ['%s%s%s' % (wrap_item_with, item, wrap_item_with) for
753 753 item in list_]
754 754 if len(list_) == 1:
755 755 return list_[0]
756 756 return '%s%s%s' % (
757 757 sep.join(i for i in list_[:-1]),
758 758 last_sep, list_[-1]) No newline at end of file
General Comments 0
You need to be logged in to leave comments. Login now