##// END OF EJS Templates
ignore the single quotes string when parsing
ICanWaitAndFishAllDay -
Show More
@@ -1,899 +1,901 b''
1 1 # -*- coding: utf-8 -*-
2 2 """Tests for the key interactiveshell module.
3 3
4 4 Historically the main classes in interactiveshell have been under-tested. This
5 5 module should grow as many single-method tests as possible to trap many of the
6 6 recurring bugs we seem to encounter with high-level interaction.
7 7 """
8 8
9 9 # Copyright (c) IPython Development Team.
10 10 # Distributed under the terms of the Modified BSD License.
11 11
12 12 import ast
13 13 import os
14 14 import signal
15 15 import shutil
16 16 import sys
17 17 import tempfile
18 18 import unittest
19 19 from unittest import mock
20 20 from io import StringIO
21 21
22 22 from os.path import join
23 23
24 24 import nose.tools as nt
25 25
26 26 from IPython.core.error import InputRejected
27 27 from IPython.core.inputtransformer import InputTransformer
28 28 from IPython.testing.decorators import (
29 29 skipif, skip_win32, onlyif_unicode_paths, onlyif_cmds_exist,
30 30 )
31 31 from IPython.testing import tools as tt
32 32 from IPython.utils.process import find_cmd
33 33 from IPython.utils import py3compat
34 34
35 35 #-----------------------------------------------------------------------------
36 36 # Globals
37 37 #-----------------------------------------------------------------------------
38 38 # This is used by every single test, no point repeating it ad nauseam
39 39 ip = get_ipython()
40 40
41 41 #-----------------------------------------------------------------------------
42 42 # Tests
43 43 #-----------------------------------------------------------------------------
44 44
45 45 class DerivedInterrupt(KeyboardInterrupt):
46 46 pass
47 47
48 48 class InteractiveShellTestCase(unittest.TestCase):
49 49 def test_naked_string_cells(self):
50 50 """Test that cells with only naked strings are fully executed"""
51 51 # First, single-line inputs
52 52 ip.run_cell('"a"\n')
53 53 self.assertEqual(ip.user_ns['_'], 'a')
54 54 # And also multi-line cells
55 55 ip.run_cell('"""a\nb"""\n')
56 56 self.assertEqual(ip.user_ns['_'], 'a\nb')
57 57
58 58 def test_run_empty_cell(self):
59 59 """Just make sure we don't get a horrible error with a blank
60 60 cell of input. Yes, I did overlook that."""
61 61 old_xc = ip.execution_count
62 62 res = ip.run_cell('')
63 63 self.assertEqual(ip.execution_count, old_xc)
64 64 self.assertEqual(res.execution_count, None)
65 65
66 66 def test_run_cell_multiline(self):
67 67 """Multi-block, multi-line cells must execute correctly.
68 68 """
69 69 src = '\n'.join(["x=1",
70 70 "y=2",
71 71 "if 1:",
72 72 " x += 1",
73 73 " y += 1",])
74 74 res = ip.run_cell(src)
75 75 self.assertEqual(ip.user_ns['x'], 2)
76 76 self.assertEqual(ip.user_ns['y'], 3)
77 77 self.assertEqual(res.success, True)
78 78 self.assertEqual(res.result, None)
79 79
80 80 def test_multiline_string_cells(self):
81 81 "Code sprinkled with multiline strings should execute (GH-306)"
82 82 ip.run_cell('tmp=0')
83 83 self.assertEqual(ip.user_ns['tmp'], 0)
84 84 res = ip.run_cell('tmp=1;"""a\nb"""\n')
85 85 self.assertEqual(ip.user_ns['tmp'], 1)
86 86 self.assertEqual(res.success, True)
87 87 self.assertEqual(res.result, "a\nb")
88 88
89 89 def test_dont_cache_with_semicolon(self):
90 90 "Ending a line with semicolon should not cache the returned object (GH-307)"
91 91 oldlen = len(ip.user_ns['Out'])
92 92 for cell in ['1;', '1;1;']:
93 93 res = ip.run_cell(cell, store_history=True)
94 94 newlen = len(ip.user_ns['Out'])
95 95 self.assertEqual(oldlen, newlen)
96 96 self.assertIsNone(res.result)
97 97 i = 0
98 98 #also test the default caching behavior
99 99 for cell in ['1', '1;1']:
100 100 ip.run_cell(cell, store_history=True)
101 101 newlen = len(ip.user_ns['Out'])
102 102 i += 1
103 103 self.assertEqual(oldlen+i, newlen)
104 104
105 105 def test_syntax_error(self):
106 106 res = ip.run_cell("raise = 3")
107 107 self.assertIsInstance(res.error_before_exec, SyntaxError)
108 108
109 109 def test_In_variable(self):
110 110 "Verify that In variable grows with user input (GH-284)"
111 111 oldlen = len(ip.user_ns['In'])
112 112 ip.run_cell('1;', store_history=True)
113 113 newlen = len(ip.user_ns['In'])
114 114 self.assertEqual(oldlen+1, newlen)
115 115 self.assertEqual(ip.user_ns['In'][-1],'1;')
116 116
117 117 def test_magic_names_in_string(self):
118 118 ip.run_cell('a = """\n%exit\n"""')
119 119 self.assertEqual(ip.user_ns['a'], '\n%exit\n')
120 120
121 121 def test_trailing_newline(self):
122 122 """test that running !(command) does not raise a SyntaxError"""
123 123 ip.run_cell('!(true)\n', False)
124 124 ip.run_cell('!(true)\n\n\n', False)
125 125
126 126 def test_gh_597(self):
127 127 """Pretty-printing lists of objects with non-ascii reprs may cause
128 128 problems."""
129 129 class Spam(object):
130 130 def __repr__(self):
131 131 return "\xe9"*50
132 132 import IPython.core.formatters
133 133 f = IPython.core.formatters.PlainTextFormatter()
134 134 f([Spam(),Spam()])
135 135
136 136
137 137 def test_future_flags(self):
138 138 """Check that future flags are used for parsing code (gh-777)"""
139 139 ip.run_cell('from __future__ import barry_as_FLUFL')
140 140 try:
141 141 ip.run_cell('prfunc_return_val = 1 <> 2')
142 142 assert 'prfunc_return_val' in ip.user_ns
143 143 finally:
144 144 # Reset compiler flags so we don't mess up other tests.
145 145 ip.compile.reset_compiler_flags()
146 146
147 147 def test_can_pickle(self):
148 148 "Can we pickle objects defined interactively (GH-29)"
149 149 ip = get_ipython()
150 150 ip.reset()
151 151 ip.run_cell(("class Mylist(list):\n"
152 152 " def __init__(self,x=[]):\n"
153 153 " list.__init__(self,x)"))
154 154 ip.run_cell("w=Mylist([1,2,3])")
155 155
156 156 from pickle import dumps
157 157
158 158 # We need to swap in our main module - this is only necessary
159 159 # inside the test framework, because IPython puts the interactive module
160 160 # in place (but the test framework undoes this).
161 161 _main = sys.modules['__main__']
162 162 sys.modules['__main__'] = ip.user_module
163 163 try:
164 164 res = dumps(ip.user_ns["w"])
165 165 finally:
166 166 sys.modules['__main__'] = _main
167 167 self.assertTrue(isinstance(res, bytes))
168 168
169 169 def test_global_ns(self):
170 170 "Code in functions must be able to access variables outside them."
171 171 ip = get_ipython()
172 172 ip.run_cell("a = 10")
173 173 ip.run_cell(("def f(x):\n"
174 174 " return x + a"))
175 175 ip.run_cell("b = f(12)")
176 176 self.assertEqual(ip.user_ns["b"], 22)
177 177
178 178 def test_bad_custom_tb(self):
179 179 """Check that InteractiveShell is protected from bad custom exception handlers"""
180 180 ip.set_custom_exc((IOError,), lambda etype,value,tb: 1/0)
181 181 self.assertEqual(ip.custom_exceptions, (IOError,))
182 182 with tt.AssertPrints("Custom TB Handler failed", channel='stderr'):
183 183 ip.run_cell(u'raise IOError("foo")')
184 184 self.assertEqual(ip.custom_exceptions, ())
185 185
186 186 def test_bad_custom_tb_return(self):
187 187 """Check that InteractiveShell is protected from bad return types in custom exception handlers"""
188 188 ip.set_custom_exc((NameError,),lambda etype,value,tb, tb_offset=None: 1)
189 189 self.assertEqual(ip.custom_exceptions, (NameError,))
190 190 with tt.AssertPrints("Custom TB Handler failed", channel='stderr'):
191 191 ip.run_cell(u'a=abracadabra')
192 192 self.assertEqual(ip.custom_exceptions, ())
193 193
194 194 def test_drop_by_id(self):
195 195 myvars = {"a":object(), "b":object(), "c": object()}
196 196 ip.push(myvars, interactive=False)
197 197 for name in myvars:
198 198 assert name in ip.user_ns, name
199 199 assert name in ip.user_ns_hidden, name
200 200 ip.user_ns['b'] = 12
201 201 ip.drop_by_id(myvars)
202 202 for name in ["a", "c"]:
203 203 assert name not in ip.user_ns, name
204 204 assert name not in ip.user_ns_hidden, name
205 205 assert ip.user_ns['b'] == 12
206 206 ip.reset()
207 207
208 208 def test_var_expand(self):
209 209 ip.user_ns['f'] = u'Ca\xf1o'
210 210 self.assertEqual(ip.var_expand(u'echo $f'), u'echo Ca\xf1o')
211 211 self.assertEqual(ip.var_expand(u'echo {f}'), u'echo Ca\xf1o')
212 212 self.assertEqual(ip.var_expand(u'echo {f[:-1]}'), u'echo Ca\xf1')
213 213 self.assertEqual(ip.var_expand(u'echo {1*2}'), u'echo 2')
214
215 self.assertEqual(ip.var_expand(u"grep x | awk '{print $1}'"), u"grep x | awk '{print $1}'")
214 216
215 217 ip.user_ns['f'] = b'Ca\xc3\xb1o'
216 218 # This should not raise any exception:
217 219 ip.var_expand(u'echo $f')
218
220
219 221 def test_var_expand_local(self):
220 222 """Test local variable expansion in !system and %magic calls"""
221 223 # !system
222 224 ip.run_cell('def test():\n'
223 225 ' lvar = "ttt"\n'
224 226 ' ret = !echo {lvar}\n'
225 227 ' return ret[0]\n')
226 228 res = ip.user_ns['test']()
227 229 nt.assert_in('ttt', res)
228 230
229 231 # %magic
230 232 ip.run_cell('def makemacro():\n'
231 233 ' macroname = "macro_var_expand_locals"\n'
232 234 ' %macro {macroname} codestr\n')
233 235 ip.user_ns['codestr'] = "str(12)"
234 236 ip.run_cell('makemacro()')
235 237 nt.assert_in('macro_var_expand_locals', ip.user_ns)
236 238
237 239 def test_var_expand_self(self):
238 240 """Test variable expansion with the name 'self', which was failing.
239 241
240 242 See https://github.com/ipython/ipython/issues/1878#issuecomment-7698218
241 243 """
242 244 ip.run_cell('class cTest:\n'
243 245 ' classvar="see me"\n'
244 246 ' def test(self):\n'
245 247 ' res = !echo Variable: {self.classvar}\n'
246 248 ' return res[0]\n')
247 249 nt.assert_in('see me', ip.user_ns['cTest']().test())
248 250
249 251 def test_bad_var_expand(self):
250 252 """var_expand on invalid formats shouldn't raise"""
251 253 # SyntaxError
252 254 self.assertEqual(ip.var_expand(u"{'a':5}"), u"{'a':5}")
253 255 # NameError
254 256 self.assertEqual(ip.var_expand(u"{asdf}"), u"{asdf}")
255 257 # ZeroDivisionError
256 258 self.assertEqual(ip.var_expand(u"{1/0}"), u"{1/0}")
257 259
258 260 def test_silent_postexec(self):
259 261 """run_cell(silent=True) doesn't invoke pre/post_run_cell callbacks"""
260 262 pre_explicit = mock.Mock()
261 263 pre_always = mock.Mock()
262 264 post_explicit = mock.Mock()
263 265 post_always = mock.Mock()
264 266
265 267 ip.events.register('pre_run_cell', pre_explicit)
266 268 ip.events.register('pre_execute', pre_always)
267 269 ip.events.register('post_run_cell', post_explicit)
268 270 ip.events.register('post_execute', post_always)
269 271
270 272 try:
271 273 ip.run_cell("1", silent=True)
272 274 assert pre_always.called
273 275 assert not pre_explicit.called
274 276 assert post_always.called
275 277 assert not post_explicit.called
276 278 # double-check that non-silent exec did what we expected
277 279 # silent to avoid
278 280 ip.run_cell("1")
279 281 assert pre_explicit.called
280 282 assert post_explicit.called
281 283 finally:
282 284 # remove post-exec
283 285 ip.events.unregister('pre_run_cell', pre_explicit)
284 286 ip.events.unregister('pre_execute', pre_always)
285 287 ip.events.unregister('post_run_cell', post_explicit)
286 288 ip.events.unregister('post_execute', post_always)
287 289
288 290 def test_silent_noadvance(self):
289 291 """run_cell(silent=True) doesn't advance execution_count"""
290 292 ec = ip.execution_count
291 293 # silent should force store_history=False
292 294 ip.run_cell("1", store_history=True, silent=True)
293 295
294 296 self.assertEqual(ec, ip.execution_count)
295 297 # double-check that non-silent exec did what we expected
296 298 # silent to avoid
297 299 ip.run_cell("1", store_history=True)
298 300 self.assertEqual(ec+1, ip.execution_count)
299 301
300 302 def test_silent_nodisplayhook(self):
301 303 """run_cell(silent=True) doesn't trigger displayhook"""
302 304 d = dict(called=False)
303 305
304 306 trap = ip.display_trap
305 307 save_hook = trap.hook
306 308
307 309 def failing_hook(*args, **kwargs):
308 310 d['called'] = True
309 311
310 312 try:
311 313 trap.hook = failing_hook
312 314 res = ip.run_cell("1", silent=True)
313 315 self.assertFalse(d['called'])
314 316 self.assertIsNone(res.result)
315 317 # double-check that non-silent exec did what we expected
316 318 # silent to avoid
317 319 ip.run_cell("1")
318 320 self.assertTrue(d['called'])
319 321 finally:
320 322 trap.hook = save_hook
321 323
322 324 def test_ofind_line_magic(self):
323 325 from IPython.core.magic import register_line_magic
324 326
325 327 @register_line_magic
326 328 def lmagic(line):
327 329 "A line magic"
328 330
329 331 # Get info on line magic
330 332 lfind = ip._ofind('lmagic')
331 333 info = dict(found=True, isalias=False, ismagic=True,
332 334 namespace = 'IPython internal', obj= lmagic.__wrapped__,
333 335 parent = None)
334 336 nt.assert_equal(lfind, info)
335 337
336 338 def test_ofind_cell_magic(self):
337 339 from IPython.core.magic import register_cell_magic
338 340
339 341 @register_cell_magic
340 342 def cmagic(line, cell):
341 343 "A cell magic"
342 344
343 345 # Get info on cell magic
344 346 find = ip._ofind('cmagic')
345 347 info = dict(found=True, isalias=False, ismagic=True,
346 348 namespace = 'IPython internal', obj= cmagic.__wrapped__,
347 349 parent = None)
348 350 nt.assert_equal(find, info)
349 351
350 352 def test_ofind_property_with_error(self):
351 353 class A(object):
352 354 @property
353 355 def foo(self):
354 356 raise NotImplementedError()
355 357 a = A()
356 358
357 359 found = ip._ofind('a.foo', [('locals', locals())])
358 360 info = dict(found=True, isalias=False, ismagic=False,
359 361 namespace='locals', obj=A.foo, parent=a)
360 362 nt.assert_equal(found, info)
361 363
362 364 def test_ofind_multiple_attribute_lookups(self):
363 365 class A(object):
364 366 @property
365 367 def foo(self):
366 368 raise NotImplementedError()
367 369
368 370 a = A()
369 371 a.a = A()
370 372 a.a.a = A()
371 373
372 374 found = ip._ofind('a.a.a.foo', [('locals', locals())])
373 375 info = dict(found=True, isalias=False, ismagic=False,
374 376 namespace='locals', obj=A.foo, parent=a.a.a)
375 377 nt.assert_equal(found, info)
376 378
377 379 def test_ofind_slotted_attributes(self):
378 380 class A(object):
379 381 __slots__ = ['foo']
380 382 def __init__(self):
381 383 self.foo = 'bar'
382 384
383 385 a = A()
384 386 found = ip._ofind('a.foo', [('locals', locals())])
385 387 info = dict(found=True, isalias=False, ismagic=False,
386 388 namespace='locals', obj=a.foo, parent=a)
387 389 nt.assert_equal(found, info)
388 390
389 391 found = ip._ofind('a.bar', [('locals', locals())])
390 392 info = dict(found=False, isalias=False, ismagic=False,
391 393 namespace=None, obj=None, parent=a)
392 394 nt.assert_equal(found, info)
393 395
394 396 def test_ofind_prefers_property_to_instance_level_attribute(self):
395 397 class A(object):
396 398 @property
397 399 def foo(self):
398 400 return 'bar'
399 401 a = A()
400 402 a.__dict__['foo'] = 'baz'
401 403 nt.assert_equal(a.foo, 'bar')
402 404 found = ip._ofind('a.foo', [('locals', locals())])
403 405 nt.assert_is(found['obj'], A.foo)
404 406
405 407 def test_custom_syntaxerror_exception(self):
406 408 called = []
407 409 def my_handler(shell, etype, value, tb, tb_offset=None):
408 410 called.append(etype)
409 411 shell.showtraceback((etype, value, tb), tb_offset=tb_offset)
410 412
411 413 ip.set_custom_exc((SyntaxError,), my_handler)
412 414 try:
413 415 ip.run_cell("1f")
414 416 # Check that this was called, and only once.
415 417 self.assertEqual(called, [SyntaxError])
416 418 finally:
417 419 # Reset the custom exception hook
418 420 ip.set_custom_exc((), None)
419 421
420 422 def test_custom_exception(self):
421 423 called = []
422 424 def my_handler(shell, etype, value, tb, tb_offset=None):
423 425 called.append(etype)
424 426 shell.showtraceback((etype, value, tb), tb_offset=tb_offset)
425 427
426 428 ip.set_custom_exc((ValueError,), my_handler)
427 429 try:
428 430 res = ip.run_cell("raise ValueError('test')")
429 431 # Check that this was called, and only once.
430 432 self.assertEqual(called, [ValueError])
431 433 # Check that the error is on the result object
432 434 self.assertIsInstance(res.error_in_exec, ValueError)
433 435 finally:
434 436 # Reset the custom exception hook
435 437 ip.set_custom_exc((), None)
436 438
437 439 def test_mktempfile(self):
438 440 filename = ip.mktempfile()
439 441 # Check that we can open the file again on Windows
440 442 with open(filename, 'w') as f:
441 443 f.write('abc')
442 444
443 445 filename = ip.mktempfile(data='blah')
444 446 with open(filename, 'r') as f:
445 447 self.assertEqual(f.read(), 'blah')
446 448
447 449 def test_new_main_mod(self):
448 450 # Smoketest to check that this accepts a unicode module name
449 451 name = u'jiefmw'
450 452 mod = ip.new_main_mod(u'%s.py' % name, name)
451 453 self.assertEqual(mod.__name__, name)
452 454
453 455 def test_get_exception_only(self):
454 456 try:
455 457 raise KeyboardInterrupt
456 458 except KeyboardInterrupt:
457 459 msg = ip.get_exception_only()
458 460 self.assertEqual(msg, 'KeyboardInterrupt\n')
459 461
460 462 try:
461 463 raise DerivedInterrupt("foo")
462 464 except KeyboardInterrupt:
463 465 msg = ip.get_exception_only()
464 466 self.assertEqual(msg, 'IPython.core.tests.test_interactiveshell.DerivedInterrupt: foo\n')
465 467
466 468 def test_inspect_text(self):
467 469 ip.run_cell('a = 5')
468 470 text = ip.object_inspect_text('a')
469 471 self.assertIsInstance(text, str)
470 472
471 473
472 474 class TestSafeExecfileNonAsciiPath(unittest.TestCase):
473 475
474 476 @onlyif_unicode_paths
475 477 def setUp(self):
476 478 self.BASETESTDIR = tempfile.mkdtemp()
477 479 self.TESTDIR = join(self.BASETESTDIR, u"Γ₯Àâ")
478 480 os.mkdir(self.TESTDIR)
479 481 with open(join(self.TESTDIR, u"Γ₯Àâtestscript.py"), "w") as sfile:
480 482 sfile.write("pass\n")
481 483 self.oldpath = os.getcwd()
482 484 os.chdir(self.TESTDIR)
483 485 self.fname = u"Γ₯Àâtestscript.py"
484 486
485 487 def tearDown(self):
486 488 os.chdir(self.oldpath)
487 489 shutil.rmtree(self.BASETESTDIR)
488 490
489 491 @onlyif_unicode_paths
490 492 def test_1(self):
491 493 """Test safe_execfile with non-ascii path
492 494 """
493 495 ip.safe_execfile(self.fname, {}, raise_exceptions=True)
494 496
495 497 class ExitCodeChecks(tt.TempFileMixin):
496 498 def test_exit_code_ok(self):
497 499 self.system('exit 0')
498 500 self.assertEqual(ip.user_ns['_exit_code'], 0)
499 501
500 502 def test_exit_code_error(self):
501 503 self.system('exit 1')
502 504 self.assertEqual(ip.user_ns['_exit_code'], 1)
503 505
504 506 @skipif(not hasattr(signal, 'SIGALRM'))
505 507 def test_exit_code_signal(self):
506 508 self.mktmp("import signal, time\n"
507 509 "signal.setitimer(signal.ITIMER_REAL, 0.1)\n"
508 510 "time.sleep(1)\n")
509 511 self.system("%s %s" % (sys.executable, self.fname))
510 512 self.assertEqual(ip.user_ns['_exit_code'], -signal.SIGALRM)
511 513
512 514 @onlyif_cmds_exist("csh")
513 515 def test_exit_code_signal_csh(self):
514 516 SHELL = os.environ.get('SHELL', None)
515 517 os.environ['SHELL'] = find_cmd("csh")
516 518 try:
517 519 self.test_exit_code_signal()
518 520 finally:
519 521 if SHELL is not None:
520 522 os.environ['SHELL'] = SHELL
521 523 else:
522 524 del os.environ['SHELL']
523 525
524 526 class TestSystemRaw(unittest.TestCase, ExitCodeChecks):
525 527 system = ip.system_raw
526 528
527 529 @onlyif_unicode_paths
528 530 def test_1(self):
529 531 """Test system_raw with non-ascii cmd
530 532 """
531 533 cmd = u'''python -c "'Γ₯Àâ'" '''
532 534 ip.system_raw(cmd)
533 535
534 536 @mock.patch('subprocess.call', side_effect=KeyboardInterrupt)
535 537 @mock.patch('os.system', side_effect=KeyboardInterrupt)
536 538 def test_control_c(self, *mocks):
537 539 try:
538 540 self.system("sleep 1 # wont happen")
539 541 except KeyboardInterrupt:
540 542 self.fail("system call should intercept "
541 543 "keyboard interrupt from subprocess.call")
542 544 self.assertEqual(ip.user_ns['_exit_code'], -signal.SIGINT)
543 545
544 546 # TODO: Exit codes are currently ignored on Windows.
545 547 class TestSystemPipedExitCode(unittest.TestCase, ExitCodeChecks):
546 548 system = ip.system_piped
547 549
548 550 @skip_win32
549 551 def test_exit_code_ok(self):
550 552 ExitCodeChecks.test_exit_code_ok(self)
551 553
552 554 @skip_win32
553 555 def test_exit_code_error(self):
554 556 ExitCodeChecks.test_exit_code_error(self)
555 557
556 558 @skip_win32
557 559 def test_exit_code_signal(self):
558 560 ExitCodeChecks.test_exit_code_signal(self)
559 561
560 562 class TestModules(unittest.TestCase, tt.TempFileMixin):
561 563 def test_extraneous_loads(self):
562 564 """Test we're not loading modules on startup that we shouldn't.
563 565 """
564 566 self.mktmp("import sys\n"
565 567 "print('numpy' in sys.modules)\n"
566 568 "print('ipyparallel' in sys.modules)\n"
567 569 "print('ipykernel' in sys.modules)\n"
568 570 )
569 571 out = "False\nFalse\nFalse\n"
570 572 tt.ipexec_validate(self.fname, out)
571 573
572 574 class Negator(ast.NodeTransformer):
573 575 """Negates all number literals in an AST."""
574 576 def visit_Num(self, node):
575 577 node.n = -node.n
576 578 return node
577 579
578 580 class TestAstTransform(unittest.TestCase):
579 581 def setUp(self):
580 582 self.negator = Negator()
581 583 ip.ast_transformers.append(self.negator)
582 584
583 585 def tearDown(self):
584 586 ip.ast_transformers.remove(self.negator)
585 587
586 588 def test_run_cell(self):
587 589 with tt.AssertPrints('-34'):
588 590 ip.run_cell('print (12 + 22)')
589 591
590 592 # A named reference to a number shouldn't be transformed.
591 593 ip.user_ns['n'] = 55
592 594 with tt.AssertNotPrints('-55'):
593 595 ip.run_cell('print (n)')
594 596
595 597 def test_timeit(self):
596 598 called = set()
597 599 def f(x):
598 600 called.add(x)
599 601 ip.push({'f':f})
600 602
601 603 with tt.AssertPrints("std. dev. of"):
602 604 ip.run_line_magic("timeit", "-n1 f(1)")
603 605 self.assertEqual(called, {-1})
604 606 called.clear()
605 607
606 608 with tt.AssertPrints("std. dev. of"):
607 609 ip.run_cell_magic("timeit", "-n1 f(2)", "f(3)")
608 610 self.assertEqual(called, {-2, -3})
609 611
610 612 def test_time(self):
611 613 called = []
612 614 def f(x):
613 615 called.append(x)
614 616 ip.push({'f':f})
615 617
616 618 # Test with an expression
617 619 with tt.AssertPrints("Wall time: "):
618 620 ip.run_line_magic("time", "f(5+9)")
619 621 self.assertEqual(called, [-14])
620 622 called[:] = []
621 623
622 624 # Test with a statement (different code path)
623 625 with tt.AssertPrints("Wall time: "):
624 626 ip.run_line_magic("time", "a = f(-3 + -2)")
625 627 self.assertEqual(called, [5])
626 628
627 629 def test_macro(self):
628 630 ip.push({'a':10})
629 631 # The AST transformation makes this do a+=-1
630 632 ip.define_macro("amacro", "a+=1\nprint(a)")
631 633
632 634 with tt.AssertPrints("9"):
633 635 ip.run_cell("amacro")
634 636 with tt.AssertPrints("8"):
635 637 ip.run_cell("amacro")
636 638
637 639 class IntegerWrapper(ast.NodeTransformer):
638 640 """Wraps all integers in a call to Integer()"""
639 641 def visit_Num(self, node):
640 642 if isinstance(node.n, int):
641 643 return ast.Call(func=ast.Name(id='Integer', ctx=ast.Load()),
642 644 args=[node], keywords=[])
643 645 return node
644 646
645 647 class TestAstTransform2(unittest.TestCase):
646 648 def setUp(self):
647 649 self.intwrapper = IntegerWrapper()
648 650 ip.ast_transformers.append(self.intwrapper)
649 651
650 652 self.calls = []
651 653 def Integer(*args):
652 654 self.calls.append(args)
653 655 return args
654 656 ip.push({"Integer": Integer})
655 657
656 658 def tearDown(self):
657 659 ip.ast_transformers.remove(self.intwrapper)
658 660 del ip.user_ns['Integer']
659 661
660 662 def test_run_cell(self):
661 663 ip.run_cell("n = 2")
662 664 self.assertEqual(self.calls, [(2,)])
663 665
664 666 # This shouldn't throw an error
665 667 ip.run_cell("o = 2.0")
666 668 self.assertEqual(ip.user_ns['o'], 2.0)
667 669
668 670 def test_timeit(self):
669 671 called = set()
670 672 def f(x):
671 673 called.add(x)
672 674 ip.push({'f':f})
673 675
674 676 with tt.AssertPrints("std. dev. of"):
675 677 ip.run_line_magic("timeit", "-n1 f(1)")
676 678 self.assertEqual(called, {(1,)})
677 679 called.clear()
678 680
679 681 with tt.AssertPrints("std. dev. of"):
680 682 ip.run_cell_magic("timeit", "-n1 f(2)", "f(3)")
681 683 self.assertEqual(called, {(2,), (3,)})
682 684
683 685 class ErrorTransformer(ast.NodeTransformer):
684 686 """Throws an error when it sees a number."""
685 687 def visit_Num(self, node):
686 688 raise ValueError("test")
687 689
688 690 class TestAstTransformError(unittest.TestCase):
689 691 def test_unregistering(self):
690 692 err_transformer = ErrorTransformer()
691 693 ip.ast_transformers.append(err_transformer)
692 694
693 695 with tt.AssertPrints("unregister", channel='stderr'):
694 696 ip.run_cell("1 + 2")
695 697
696 698 # This should have been removed.
697 699 nt.assert_not_in(err_transformer, ip.ast_transformers)
698 700
699 701
700 702 class StringRejector(ast.NodeTransformer):
701 703 """Throws an InputRejected when it sees a string literal.
702 704
703 705 Used to verify that NodeTransformers can signal that a piece of code should
704 706 not be executed by throwing an InputRejected.
705 707 """
706 708
707 709 def visit_Str(self, node):
708 710 raise InputRejected("test")
709 711
710 712
711 713 class TestAstTransformInputRejection(unittest.TestCase):
712 714
713 715 def setUp(self):
714 716 self.transformer = StringRejector()
715 717 ip.ast_transformers.append(self.transformer)
716 718
717 719 def tearDown(self):
718 720 ip.ast_transformers.remove(self.transformer)
719 721
720 722 def test_input_rejection(self):
721 723 """Check that NodeTransformers can reject input."""
722 724
723 725 expect_exception_tb = tt.AssertPrints("InputRejected: test")
724 726 expect_no_cell_output = tt.AssertNotPrints("'unsafe'", suppress=False)
725 727
726 728 # Run the same check twice to verify that the transformer is not
727 729 # disabled after raising.
728 730 with expect_exception_tb, expect_no_cell_output:
729 731 ip.run_cell("'unsafe'")
730 732
731 733 with expect_exception_tb, expect_no_cell_output:
732 734 res = ip.run_cell("'unsafe'")
733 735
734 736 self.assertIsInstance(res.error_before_exec, InputRejected)
735 737
736 738 def test__IPYTHON__():
737 739 # This shouldn't raise a NameError, that's all
738 740 __IPYTHON__
739 741
740 742
741 743 class DummyRepr(object):
742 744 def __repr__(self):
743 745 return "DummyRepr"
744 746
745 747 def _repr_html_(self):
746 748 return "<b>dummy</b>"
747 749
748 750 def _repr_javascript_(self):
749 751 return "console.log('hi');", {'key': 'value'}
750 752
751 753
752 754 def test_user_variables():
753 755 # enable all formatters
754 756 ip.display_formatter.active_types = ip.display_formatter.format_types
755 757
756 758 ip.user_ns['dummy'] = d = DummyRepr()
757 759 keys = {'dummy', 'doesnotexist'}
758 760 r = ip.user_expressions({ key:key for key in keys})
759 761
760 762 nt.assert_equal(keys, set(r.keys()))
761 763 dummy = r['dummy']
762 764 nt.assert_equal({'status', 'data', 'metadata'}, set(dummy.keys()))
763 765 nt.assert_equal(dummy['status'], 'ok')
764 766 data = dummy['data']
765 767 metadata = dummy['metadata']
766 768 nt.assert_equal(data.get('text/html'), d._repr_html_())
767 769 js, jsmd = d._repr_javascript_()
768 770 nt.assert_equal(data.get('application/javascript'), js)
769 771 nt.assert_equal(metadata.get('application/javascript'), jsmd)
770 772
771 773 dne = r['doesnotexist']
772 774 nt.assert_equal(dne['status'], 'error')
773 775 nt.assert_equal(dne['ename'], 'NameError')
774 776
775 777 # back to text only
776 778 ip.display_formatter.active_types = ['text/plain']
777 779
778 780 def test_user_expression():
779 781 # enable all formatters
780 782 ip.display_formatter.active_types = ip.display_formatter.format_types
781 783 query = {
782 784 'a' : '1 + 2',
783 785 'b' : '1/0',
784 786 }
785 787 r = ip.user_expressions(query)
786 788 import pprint
787 789 pprint.pprint(r)
788 790 nt.assert_equal(set(r.keys()), set(query.keys()))
789 791 a = r['a']
790 792 nt.assert_equal({'status', 'data', 'metadata'}, set(a.keys()))
791 793 nt.assert_equal(a['status'], 'ok')
792 794 data = a['data']
793 795 metadata = a['metadata']
794 796 nt.assert_equal(data.get('text/plain'), '3')
795 797
796 798 b = r['b']
797 799 nt.assert_equal(b['status'], 'error')
798 800 nt.assert_equal(b['ename'], 'ZeroDivisionError')
799 801
800 802 # back to text only
801 803 ip.display_formatter.active_types = ['text/plain']
802 804
803 805
804 806
805 807
806 808
807 809 class TestSyntaxErrorTransformer(unittest.TestCase):
808 810 """Check that SyntaxError raised by an input transformer is handled by run_cell()"""
809 811
810 812 class SyntaxErrorTransformer(InputTransformer):
811 813
812 814 def push(self, line):
813 815 pos = line.find('syntaxerror')
814 816 if pos >= 0:
815 817 e = SyntaxError('input contains "syntaxerror"')
816 818 e.text = line
817 819 e.offset = pos + 1
818 820 raise e
819 821 return line
820 822
821 823 def reset(self):
822 824 pass
823 825
824 826 def setUp(self):
825 827 self.transformer = TestSyntaxErrorTransformer.SyntaxErrorTransformer()
826 828 ip.input_splitter.python_line_transforms.append(self.transformer)
827 829 ip.input_transformer_manager.python_line_transforms.append(self.transformer)
828 830
829 831 def tearDown(self):
830 832 ip.input_splitter.python_line_transforms.remove(self.transformer)
831 833 ip.input_transformer_manager.python_line_transforms.remove(self.transformer)
832 834
833 835 def test_syntaxerror_input_transformer(self):
834 836 with tt.AssertPrints('1234'):
835 837 ip.run_cell('1234')
836 838 with tt.AssertPrints('SyntaxError: invalid syntax'):
837 839 ip.run_cell('1 2 3') # plain python syntax error
838 840 with tt.AssertPrints('SyntaxError: input contains "syntaxerror"'):
839 841 ip.run_cell('2345 # syntaxerror') # input transformer syntax error
840 842 with tt.AssertPrints('3456'):
841 843 ip.run_cell('3456')
842 844
843 845
844 846
845 847 def test_warning_suppression():
846 848 ip.run_cell("import warnings")
847 849 try:
848 850 with tt.AssertPrints("UserWarning: asdf", channel="stderr"):
849 851 ip.run_cell("warnings.warn('asdf')")
850 852 # Here's the real test -- if we run that again, we should get the
851 853 # warning again. Traditionally, each warning was only issued once per
852 854 # IPython session (approximately), even if the user typed in new and
853 855 # different code that should have also triggered the warning, leading
854 856 # to much confusion.
855 857 with tt.AssertPrints("UserWarning: asdf", channel="stderr"):
856 858 ip.run_cell("warnings.warn('asdf')")
857 859 finally:
858 860 ip.run_cell("del warnings")
859 861
860 862
861 863 def test_deprecation_warning():
862 864 ip.run_cell("""
863 865 import warnings
864 866 def wrn():
865 867 warnings.warn(
866 868 "I AM A WARNING",
867 869 DeprecationWarning
868 870 )
869 871 """)
870 872 try:
871 873 with tt.AssertPrints("I AM A WARNING", channel="stderr"):
872 874 ip.run_cell("wrn()")
873 875 finally:
874 876 ip.run_cell("del warnings")
875 877 ip.run_cell("del wrn")
876 878
877 879
878 880 class TestImportNoDeprecate(tt.TempFileMixin):
879 881
880 882 def setup(self):
881 883 """Make a valid python temp file."""
882 884 self.mktmp("""
883 885 import warnings
884 886 def wrn():
885 887 warnings.warn(
886 888 "I AM A WARNING",
887 889 DeprecationWarning
888 890 )
889 891 """)
890 892
891 893 def test_no_dep(self):
892 894 """
893 895 No deprecation warning should be raised from imported functions
894 896 """
895 897 ip.run_cell("from {} import wrn".format(self.fname))
896 898
897 899 with tt.AssertNotPrints("I AM A WARNING"):
898 900 ip.run_cell("wrn()")
899 901 ip.run_cell("del wrn")
@@ -1,776 +1,776 b''
1 1 # encoding: utf-8
2 2 """
3 3 Utilities for working with strings and text.
4 4
5 5 Inheritance diagram:
6 6
7 7 .. inheritance-diagram:: IPython.utils.text
8 8 :parts: 3
9 9 """
10 10
11 11 import os
12 12 import re
13 13 import sys
14 14 import textwrap
15 15 from string import Formatter
16 16 try:
17 17 from pathlib import Path
18 18 except ImportError:
19 19 # for Python 3.3
20 20 from pathlib2 import Path
21 21
22 22 from IPython.utils import py3compat
23 23
24 24 # datetime.strftime date format for ipython
25 25 if sys.platform == 'win32':
26 26 date_format = "%B %d, %Y"
27 27 else:
28 28 date_format = "%B %-d, %Y"
29 29
30 30 class LSString(str):
31 31 """String derivative with a special access attributes.
32 32
33 33 These are normal strings, but with the special attributes:
34 34
35 35 .l (or .list) : value as list (split on newlines).
36 36 .n (or .nlstr): original value (the string itself).
37 37 .s (or .spstr): value as whitespace-separated string.
38 38 .p (or .paths): list of path objects (requires path.py package)
39 39
40 40 Any values which require transformations are computed only once and
41 41 cached.
42 42
43 43 Such strings are very useful to efficiently interact with the shell, which
44 44 typically only understands whitespace-separated options for commands."""
45 45
46 46 def get_list(self):
47 47 try:
48 48 return self.__list
49 49 except AttributeError:
50 50 self.__list = self.split('\n')
51 51 return self.__list
52 52
53 53 l = list = property(get_list)
54 54
55 55 def get_spstr(self):
56 56 try:
57 57 return self.__spstr
58 58 except AttributeError:
59 59 self.__spstr = self.replace('\n',' ')
60 60 return self.__spstr
61 61
62 62 s = spstr = property(get_spstr)
63 63
64 64 def get_nlstr(self):
65 65 return self
66 66
67 67 n = nlstr = property(get_nlstr)
68 68
69 69 def get_paths(self):
70 70 try:
71 71 return self.__paths
72 72 except AttributeError:
73 73 self.__paths = [Path(p) for p in self.split('\n') if os.path.exists(p)]
74 74 return self.__paths
75 75
76 76 p = paths = property(get_paths)
77 77
78 78 # FIXME: We need to reimplement type specific displayhook and then add this
79 79 # back as a custom printer. This should also be moved outside utils into the
80 80 # core.
81 81
82 82 # def print_lsstring(arg):
83 83 # """ Prettier (non-repr-like) and more informative printer for LSString """
84 84 # print "LSString (.p, .n, .l, .s available). Value:"
85 85 # print arg
86 86 #
87 87 #
88 88 # print_lsstring = result_display.when_type(LSString)(print_lsstring)
89 89
90 90
91 91 class SList(list):
92 92 """List derivative with a special access attributes.
93 93
94 94 These are normal lists, but with the special attributes:
95 95
96 96 * .l (or .list) : value as list (the list itself).
97 97 * .n (or .nlstr): value as a string, joined on newlines.
98 98 * .s (or .spstr): value as a string, joined on spaces.
99 99 * .p (or .paths): list of path objects (requires path.py package)
100 100
101 101 Any values which require transformations are computed only once and
102 102 cached."""
103 103
104 104 def get_list(self):
105 105 return self
106 106
107 107 l = list = property(get_list)
108 108
109 109 def get_spstr(self):
110 110 try:
111 111 return self.__spstr
112 112 except AttributeError:
113 113 self.__spstr = ' '.join(self)
114 114 return self.__spstr
115 115
116 116 s = spstr = property(get_spstr)
117 117
118 118 def get_nlstr(self):
119 119 try:
120 120 return self.__nlstr
121 121 except AttributeError:
122 122 self.__nlstr = '\n'.join(self)
123 123 return self.__nlstr
124 124
125 125 n = nlstr = property(get_nlstr)
126 126
127 127 def get_paths(self):
128 128 try:
129 129 return self.__paths
130 130 except AttributeError:
131 131 self.__paths = [Path(p) for p in self if os.path.exists(p)]
132 132 return self.__paths
133 133
134 134 p = paths = property(get_paths)
135 135
136 136 def grep(self, pattern, prune = False, field = None):
137 137 """ Return all strings matching 'pattern' (a regex or callable)
138 138
139 139 This is case-insensitive. If prune is true, return all items
140 140 NOT matching the pattern.
141 141
142 142 If field is specified, the match must occur in the specified
143 143 whitespace-separated field.
144 144
145 145 Examples::
146 146
147 147 a.grep( lambda x: x.startswith('C') )
148 148 a.grep('Cha.*log', prune=1)
149 149 a.grep('chm', field=-1)
150 150 """
151 151
152 152 def match_target(s):
153 153 if field is None:
154 154 return s
155 155 parts = s.split()
156 156 try:
157 157 tgt = parts[field]
158 158 return tgt
159 159 except IndexError:
160 160 return ""
161 161
162 162 if isinstance(pattern, str):
163 163 pred = lambda x : re.search(pattern, x, re.IGNORECASE)
164 164 else:
165 165 pred = pattern
166 166 if not prune:
167 167 return SList([el for el in self if pred(match_target(el))])
168 168 else:
169 169 return SList([el for el in self if not pred(match_target(el))])
170 170
171 171 def fields(self, *fields):
172 172 """ Collect whitespace-separated fields from string list
173 173
174 174 Allows quick awk-like usage of string lists.
175 175
176 176 Example data (in var a, created by 'a = !ls -l')::
177 177
178 178 -rwxrwxrwx 1 ville None 18 Dec 14 2006 ChangeLog
179 179 drwxrwxrwx+ 6 ville None 0 Oct 24 18:05 IPython
180 180
181 181 * ``a.fields(0)`` is ``['-rwxrwxrwx', 'drwxrwxrwx+']``
182 182 * ``a.fields(1,0)`` is ``['1 -rwxrwxrwx', '6 drwxrwxrwx+']``
183 183 (note the joining by space).
184 184 * ``a.fields(-1)`` is ``['ChangeLog', 'IPython']``
185 185
186 186 IndexErrors are ignored.
187 187
188 188 Without args, fields() just split()'s the strings.
189 189 """
190 190 if len(fields) == 0:
191 191 return [el.split() for el in self]
192 192
193 193 res = SList()
194 194 for el in [f.split() for f in self]:
195 195 lineparts = []
196 196
197 197 for fd in fields:
198 198 try:
199 199 lineparts.append(el[fd])
200 200 except IndexError:
201 201 pass
202 202 if lineparts:
203 203 res.append(" ".join(lineparts))
204 204
205 205 return res
206 206
207 207 def sort(self,field= None, nums = False):
208 208 """ sort by specified fields (see fields())
209 209
210 210 Example::
211 211
212 212 a.sort(1, nums = True)
213 213
214 214 Sorts a by second field, in numerical order (so that 21 > 3)
215 215
216 216 """
217 217
218 218 #decorate, sort, undecorate
219 219 if field is not None:
220 220 dsu = [[SList([line]).fields(field), line] for line in self]
221 221 else:
222 222 dsu = [[line, line] for line in self]
223 223 if nums:
224 224 for i in range(len(dsu)):
225 225 numstr = "".join([ch for ch in dsu[i][0] if ch.isdigit()])
226 226 try:
227 227 n = int(numstr)
228 228 except ValueError:
229 229 n = 0
230 230 dsu[i][0] = n
231 231
232 232
233 233 dsu.sort()
234 234 return SList([t[1] for t in dsu])
235 235
236 236
237 237 # FIXME: We need to reimplement type specific displayhook and then add this
238 238 # back as a custom printer. This should also be moved outside utils into the
239 239 # core.
240 240
241 241 # def print_slist(arg):
242 242 # """ Prettier (non-repr-like) and more informative printer for SList """
243 243 # print "SList (.p, .n, .l, .s, .grep(), .fields(), sort() available):"
244 244 # if hasattr(arg, 'hideonce') and arg.hideonce:
245 245 # arg.hideonce = False
246 246 # return
247 247 #
248 248 # nlprint(arg) # This was a nested list printer, now removed.
249 249 #
250 250 # print_slist = result_display.when_type(SList)(print_slist)
251 251
252 252
253 253 def indent(instr,nspaces=4, ntabs=0, flatten=False):
254 254 """Indent a string a given number of spaces or tabstops.
255 255
256 256 indent(str,nspaces=4,ntabs=0) -> indent str by ntabs+nspaces.
257 257
258 258 Parameters
259 259 ----------
260 260
261 261 instr : basestring
262 262 The string to be indented.
263 263 nspaces : int (default: 4)
264 264 The number of spaces to be indented.
265 265 ntabs : int (default: 0)
266 266 The number of tabs to be indented.
267 267 flatten : bool (default: False)
268 268 Whether to scrub existing indentation. If True, all lines will be
269 269 aligned to the same indentation. If False, existing indentation will
270 270 be strictly increased.
271 271
272 272 Returns
273 273 -------
274 274
275 275 str|unicode : string indented by ntabs and nspaces.
276 276
277 277 """
278 278 if instr is None:
279 279 return
280 280 ind = '\t'*ntabs+' '*nspaces
281 281 if flatten:
282 282 pat = re.compile(r'^\s*', re.MULTILINE)
283 283 else:
284 284 pat = re.compile(r'^', re.MULTILINE)
285 285 outstr = re.sub(pat, ind, instr)
286 286 if outstr.endswith(os.linesep+ind):
287 287 return outstr[:-len(ind)]
288 288 else:
289 289 return outstr
290 290
291 291
292 292 def list_strings(arg):
293 293 """Always return a list of strings, given a string or list of strings
294 294 as input.
295 295
296 296 Examples
297 297 --------
298 298 ::
299 299
300 300 In [7]: list_strings('A single string')
301 301 Out[7]: ['A single string']
302 302
303 303 In [8]: list_strings(['A single string in a list'])
304 304 Out[8]: ['A single string in a list']
305 305
306 306 In [9]: list_strings(['A','list','of','strings'])
307 307 Out[9]: ['A', 'list', 'of', 'strings']
308 308 """
309 309
310 310 if isinstance(arg, str):
311 311 return [arg]
312 312 else:
313 313 return arg
314 314
315 315
316 316 def marquee(txt='',width=78,mark='*'):
317 317 """Return the input string centered in a 'marquee'.
318 318
319 319 Examples
320 320 --------
321 321 ::
322 322
323 323 In [16]: marquee('A test',40)
324 324 Out[16]: '**************** A test ****************'
325 325
326 326 In [17]: marquee('A test',40,'-')
327 327 Out[17]: '---------------- A test ----------------'
328 328
329 329 In [18]: marquee('A test',40,' ')
330 330 Out[18]: ' A test '
331 331
332 332 """
333 333 if not txt:
334 334 return (mark*width)[:width]
335 335 nmark = (width-len(txt)-2)//len(mark)//2
336 336 if nmark < 0: nmark =0
337 337 marks = mark*nmark
338 338 return '%s %s %s' % (marks,txt,marks)
339 339
340 340
341 341 ini_spaces_re = re.compile(r'^(\s+)')
342 342
343 343 def num_ini_spaces(strng):
344 344 """Return the number of initial spaces in a string"""
345 345
346 346 ini_spaces = ini_spaces_re.match(strng)
347 347 if ini_spaces:
348 348 return ini_spaces.end()
349 349 else:
350 350 return 0
351 351
352 352
353 353 def format_screen(strng):
354 354 """Format a string for screen printing.
355 355
356 356 This removes some latex-type format codes."""
357 357 # Paragraph continue
358 358 par_re = re.compile(r'\\$',re.MULTILINE)
359 359 strng = par_re.sub('',strng)
360 360 return strng
361 361
362 362
363 363 def dedent(text):
364 364 """Equivalent of textwrap.dedent that ignores unindented first line.
365 365
366 366 This means it will still dedent strings like:
367 367 '''foo
368 368 is a bar
369 369 '''
370 370
371 371 For use in wrap_paragraphs.
372 372 """
373 373
374 374 if text.startswith('\n'):
375 375 # text starts with blank line, don't ignore the first line
376 376 return textwrap.dedent(text)
377 377
378 378 # split first line
379 379 splits = text.split('\n',1)
380 380 if len(splits) == 1:
381 381 # only one line
382 382 return textwrap.dedent(text)
383 383
384 384 first, rest = splits
385 385 # dedent everything but the first line
386 386 rest = textwrap.dedent(rest)
387 387 return '\n'.join([first, rest])
388 388
389 389
390 390 def wrap_paragraphs(text, ncols=80):
391 391 """Wrap multiple paragraphs to fit a specified width.
392 392
393 393 This is equivalent to textwrap.wrap, but with support for multiple
394 394 paragraphs, as separated by empty lines.
395 395
396 396 Returns
397 397 -------
398 398
399 399 list of complete paragraphs, wrapped to fill `ncols` columns.
400 400 """
401 401 paragraph_re = re.compile(r'\n(\s*\n)+', re.MULTILINE)
402 402 text = dedent(text).strip()
403 403 paragraphs = paragraph_re.split(text)[::2] # every other entry is space
404 404 out_ps = []
405 405 indent_re = re.compile(r'\n\s+', re.MULTILINE)
406 406 for p in paragraphs:
407 407 # presume indentation that survives dedent is meaningful formatting,
408 408 # so don't fill unless text is flush.
409 409 if indent_re.search(p) is None:
410 410 # wrap paragraph
411 411 p = textwrap.fill(p, ncols)
412 412 out_ps.append(p)
413 413 return out_ps
414 414
415 415
416 416 def long_substr(data):
417 417 """Return the longest common substring in a list of strings.
418 418
419 419 Credit: http://stackoverflow.com/questions/2892931/longest-common-substring-from-more-than-two-strings-python
420 420 """
421 421 substr = ''
422 422 if len(data) > 1 and len(data[0]) > 0:
423 423 for i in range(len(data[0])):
424 424 for j in range(len(data[0])-i+1):
425 425 if j > len(substr) and all(data[0][i:i+j] in x for x in data):
426 426 substr = data[0][i:i+j]
427 427 elif len(data) == 1:
428 428 substr = data[0]
429 429 return substr
430 430
431 431
432 432 def strip_email_quotes(text):
433 433 """Strip leading email quotation characters ('>').
434 434
435 435 Removes any combination of leading '>' interspersed with whitespace that
436 436 appears *identically* in all lines of the input text.
437 437
438 438 Parameters
439 439 ----------
440 440 text : str
441 441
442 442 Examples
443 443 --------
444 444
445 445 Simple uses::
446 446
447 447 In [2]: strip_email_quotes('> > text')
448 448 Out[2]: 'text'
449 449
450 450 In [3]: strip_email_quotes('> > text\\n> > more')
451 451 Out[3]: 'text\\nmore'
452 452
453 453 Note how only the common prefix that appears in all lines is stripped::
454 454
455 455 In [4]: strip_email_quotes('> > text\\n> > more\\n> more...')
456 456 Out[4]: '> text\\n> more\\nmore...'
457 457
458 458 So if any line has no quote marks ('>') , then none are stripped from any
459 459 of them ::
460 460
461 461 In [5]: strip_email_quotes('> > text\\n> > more\\nlast different')
462 462 Out[5]: '> > text\\n> > more\\nlast different'
463 463 """
464 464 lines = text.splitlines()
465 465 matches = set()
466 466 for line in lines:
467 467 prefix = re.match(r'^(\s*>[ >]*)', line)
468 468 if prefix:
469 469 matches.add(prefix.group(1))
470 470 else:
471 471 break
472 472 else:
473 473 prefix = long_substr(list(matches))
474 474 if prefix:
475 475 strip = len(prefix)
476 476 text = '\n'.join([ ln[strip:] for ln in lines])
477 477 return text
478 478
479 479 def strip_ansi(source):
480 480 """
481 481 Remove ansi escape codes from text.
482 482
483 483 Parameters
484 484 ----------
485 485 source : str
486 486 Source to remove the ansi from
487 487 """
488 488 return re.sub(r'\033\[(\d|;)+?m', '', source)
489 489
490 490
491 491 class EvalFormatter(Formatter):
492 492 """A String Formatter that allows evaluation of simple expressions.
493 493
494 494 Note that this version interprets a : as specifying a format string (as per
495 495 standard string formatting), so if slicing is required, you must explicitly
496 496 create a slice.
497 497
498 498 This is to be used in templating cases, such as the parallel batch
499 499 script templates, where simple arithmetic on arguments is useful.
500 500
501 501 Examples
502 502 --------
503 503 ::
504 504
505 505 In [1]: f = EvalFormatter()
506 506 In [2]: f.format('{n//4}', n=8)
507 507 Out[2]: '2'
508 508
509 509 In [3]: f.format("{greeting[slice(2,4)]}", greeting="Hello")
510 510 Out[3]: 'll'
511 511 """
512 512 def get_field(self, name, args, kwargs):
513 513 v = eval(name, kwargs)
514 514 return v, name
515 515
516 516 #XXX: As of Python 3.4, the format string parsing no longer splits on a colon
517 517 # inside [], so EvalFormatter can handle slicing. Once we only support 3.4 and
518 518 # above, it should be possible to remove FullEvalFormatter.
519 519
520 520 class FullEvalFormatter(Formatter):
521 521 """A String Formatter that allows evaluation of simple expressions.
522 522
523 523 Any time a format key is not found in the kwargs,
524 524 it will be tried as an expression in the kwargs namespace.
525 525
526 526 Note that this version allows slicing using [1:2], so you cannot specify
527 527 a format string. Use :class:`EvalFormatter` to permit format strings.
528 528
529 529 Examples
530 530 --------
531 531 ::
532 532
533 533 In [1]: f = FullEvalFormatter()
534 534 In [2]: f.format('{n//4}', n=8)
535 535 Out[2]: '2'
536 536
537 537 In [3]: f.format('{list(range(5))[2:4]}')
538 538 Out[3]: '[2, 3]'
539 539
540 540 In [4]: f.format('{3*2}')
541 541 Out[4]: '6'
542 542 """
543 543 # copied from Formatter._vformat with minor changes to allow eval
544 544 # and replace the format_spec code with slicing
545 545 def vformat(self, format_string, args, kwargs):
546 546 result = []
547 547 for literal_text, field_name, format_spec, conversion in \
548 548 self.parse(format_string):
549 549
550 550 # output the literal text
551 551 if literal_text:
552 552 result.append(literal_text)
553 553
554 554 # if there's a field, output it
555 555 if field_name is not None:
556 556 # this is some markup, find the object and do
557 557 # the formatting
558 558
559 559 if format_spec:
560 560 # override format spec, to allow slicing:
561 561 field_name = ':'.join([field_name, format_spec])
562 562
563 563 # eval the contents of the field for the object
564 564 # to be formatted
565 565 obj = eval(field_name, kwargs)
566 566
567 567 # do any conversion on the resulting object
568 568 obj = self.convert_field(obj, conversion)
569 569
570 570 # format the object and append to the result
571 571 result.append(self.format_field(obj, ''))
572 572
573 573 return ''.join(py3compat.cast_unicode(s) for s in result)
574 574
575 575
576 576 class DollarFormatter(FullEvalFormatter):
577 577 """Formatter allowing Itpl style $foo replacement, for names and attribute
578 578 access only. Standard {foo} replacement also works, and allows full
579 579 evaluation of its arguments.
580 580
581 581 Examples
582 582 --------
583 583 ::
584 584
585 585 In [1]: f = DollarFormatter()
586 586 In [2]: f.format('{n//4}', n=8)
587 587 Out[2]: '2'
588 588
589 589 In [3]: f.format('23 * 76 is $result', result=23*76)
590 590 Out[3]: '23 * 76 is 1748'
591 591
592 592 In [4]: f.format('$a or {b}', a=1, b=2)
593 593 Out[4]: '1 or 2'
594 594 """
595 _dollar_pattern = re.compile("(.*?)\$(\$?[\w\.]+)")
595 _dollar_pattern_ignore_single_quote = re.compile("(.*?)\$(\$?[\w\.]+)(?=([^']*'[^']*')*[^']*$)")
596 596 def parse(self, fmt_string):
597 597 for literal_txt, field_name, format_spec, conversion \
598 598 in Formatter.parse(self, fmt_string):
599 599
600 600 # Find $foo patterns in the literal text.
601 601 continue_from = 0
602 602 txt = ""
603 for m in self._dollar_pattern.finditer(literal_txt):
603 for m in self._dollar_pattern_ignore_single_quote.finditer(literal_txt):
604 604 new_txt, new_field = m.group(1,2)
605 605 # $$foo --> $foo
606 606 if new_field.startswith("$"):
607 607 txt += new_txt + new_field
608 608 else:
609 609 yield (txt + new_txt, new_field, "", None)
610 610 txt = ""
611 611 continue_from = m.end()
612 612
613 613 # Re-yield the {foo} style pattern
614 614 yield (txt + literal_txt[continue_from:], field_name, format_spec, conversion)
615 615
616 616 #-----------------------------------------------------------------------------
617 617 # Utils to columnize a list of string
618 618 #-----------------------------------------------------------------------------
619 619
620 620 def _col_chunks(l, max_rows, row_first=False):
621 621 """Yield successive max_rows-sized column chunks from l."""
622 622 if row_first:
623 623 ncols = (len(l) // max_rows) + (len(l) % max_rows > 0)
624 624 for i in range(ncols):
625 625 yield [l[j] for j in range(i, len(l), ncols)]
626 626 else:
627 627 for i in range(0, len(l), max_rows):
628 628 yield l[i:(i + max_rows)]
629 629
630 630
631 631 def _find_optimal(rlist, row_first=False, separator_size=2, displaywidth=80):
632 632 """Calculate optimal info to columnize a list of string"""
633 633 for max_rows in range(1, len(rlist) + 1):
634 634 col_widths = list(map(max, _col_chunks(rlist, max_rows, row_first)))
635 635 sumlength = sum(col_widths)
636 636 ncols = len(col_widths)
637 637 if sumlength + separator_size * (ncols - 1) <= displaywidth:
638 638 break
639 639 return {'num_columns': ncols,
640 640 'optimal_separator_width': (displaywidth - sumlength) // (ncols - 1) if (ncols - 1) else 0,
641 641 'max_rows': max_rows,
642 642 'column_widths': col_widths
643 643 }
644 644
645 645
646 646 def _get_or_default(mylist, i, default=None):
647 647 """return list item number, or default if don't exist"""
648 648 if i >= len(mylist):
649 649 return default
650 650 else :
651 651 return mylist[i]
652 652
653 653
654 654 def compute_item_matrix(items, row_first=False, empty=None, *args, **kwargs) :
655 655 """Returns a nested list, and info to columnize items
656 656
657 657 Parameters
658 658 ----------
659 659
660 660 items
661 661 list of strings to columize
662 662 row_first : (default False)
663 663 Whether to compute columns for a row-first matrix instead of
664 664 column-first (default).
665 665 empty : (default None)
666 666 default value to fill list if needed
667 667 separator_size : int (default=2)
668 668 How much caracters will be used as a separation between each columns.
669 669 displaywidth : int (default=80)
670 670 The width of the area onto wich the columns should enter
671 671
672 672 Returns
673 673 -------
674 674
675 675 strings_matrix
676 676
677 677 nested list of string, the outer most list contains as many list as
678 678 rows, the innermost lists have each as many element as colums. If the
679 679 total number of elements in `items` does not equal the product of
680 680 rows*columns, the last element of some lists are filled with `None`.
681 681
682 682 dict_info
683 683 some info to make columnize easier:
684 684
685 685 num_columns
686 686 number of columns
687 687 max_rows
688 688 maximum number of rows (final number may be less)
689 689 column_widths
690 690 list of with of each columns
691 691 optimal_separator_width
692 692 best separator width between columns
693 693
694 694 Examples
695 695 --------
696 696 ::
697 697
698 698 In [1]: l = ['aaa','b','cc','d','eeeee','f','g','h','i','j','k','l']
699 699 In [2]: list, info = compute_item_matrix(l, displaywidth=12)
700 700 In [3]: list
701 701 Out[3]: [['aaa', 'f', 'k'], ['b', 'g', 'l'], ['cc', 'h', None], ['d', 'i', None], ['eeeee', 'j', None]]
702 702 In [4]: ideal = {'num_columns': 3, 'column_widths': [5, 1, 1], 'optimal_separator_width': 2, 'max_rows': 5}
703 703 In [5]: all((info[k] == ideal[k] for k in ideal.keys()))
704 704 Out[5]: True
705 705 """
706 706 info = _find_optimal(list(map(len, items)), row_first, *args, **kwargs)
707 707 nrow, ncol = info['max_rows'], info['num_columns']
708 708 if row_first:
709 709 return ([[_get_or_default(items, r * ncol + c, default=empty) for c in range(ncol)] for r in range(nrow)], info)
710 710 else:
711 711 return ([[_get_or_default(items, c * nrow + r, default=empty) for c in range(ncol)] for r in range(nrow)], info)
712 712
713 713
714 714 def columnize(items, row_first=False, separator=' ', displaywidth=80, spread=False):
715 715 """ Transform a list of strings into a single string with columns.
716 716
717 717 Parameters
718 718 ----------
719 719 items : sequence of strings
720 720 The strings to process.
721 721
722 722 row_first : (default False)
723 723 Whether to compute columns for a row-first matrix instead of
724 724 column-first (default).
725 725
726 726 separator : str, optional [default is two spaces]
727 727 The string that separates columns.
728 728
729 729 displaywidth : int, optional [default is 80]
730 730 Width of the display in number of characters.
731 731
732 732 Returns
733 733 -------
734 734 The formatted string.
735 735 """
736 736 if not items:
737 737 return '\n'
738 738 matrix, info = compute_item_matrix(items, row_first=row_first, separator_size=len(separator), displaywidth=displaywidth)
739 739 if spread:
740 740 separator = separator.ljust(int(info['optimal_separator_width']))
741 741 fmatrix = [filter(None, x) for x in matrix]
742 742 sjoin = lambda x : separator.join([ y.ljust(w, ' ') for y, w in zip(x, info['column_widths'])])
743 743 return '\n'.join(map(sjoin, fmatrix))+'\n'
744 744
745 745
746 746 def get_text_list(list_, last_sep=' and ', sep=", ", wrap_item_with=""):
747 747 """
748 748 Return a string with a natural enumeration of items
749 749
750 750 >>> get_text_list(['a', 'b', 'c', 'd'])
751 751 'a, b, c and d'
752 752 >>> get_text_list(['a', 'b', 'c'], ' or ')
753 753 'a, b or c'
754 754 >>> get_text_list(['a', 'b', 'c'], ', ')
755 755 'a, b, c'
756 756 >>> get_text_list(['a', 'b'], ' or ')
757 757 'a or b'
758 758 >>> get_text_list(['a'])
759 759 'a'
760 760 >>> get_text_list([])
761 761 ''
762 762 >>> get_text_list(['a', 'b'], wrap_item_with="`")
763 763 '`a` and `b`'
764 764 >>> get_text_list(['a', 'b', 'c', 'd'], " = ", sep=" + ")
765 765 'a + b + c = d'
766 766 """
767 767 if len(list_) == 0:
768 768 return ''
769 769 if wrap_item_with:
770 770 list_ = ['%s%s%s' % (wrap_item_with, item, wrap_item_with) for
771 771 item in list_]
772 772 if len(list_) == 1:
773 773 return list_[0]
774 774 return '%s%s%s' % (
775 775 sep.join(i for i in list_[:-1]),
776 776 last_sep, list_[-1])
General Comments 0
You need to be logged in to leave comments. Login now