##// END OF EJS Templates
Backport PR #13349: get_running_loop is only valid in coroutines...
Min RK -
Show More
@@ -1,173 +1,183 b''
1 1 """
2 2 Async helper function that are invalid syntax on Python 3.5 and below.
3 3
4 4 This code is best effort, and may have edge cases not behaving as expected. In
5 5 particular it contain a number of heuristics to detect whether code is
6 6 effectively async and need to run in an event loop or not.
7 7
8 8 Some constructs (like top-level `return`, or `yield`) are taken care of
9 9 explicitly to actually raise a SyntaxError and stay as close as possible to
10 10 Python semantics.
11 11 """
12 12
13 13
14 14 import ast
15 15 import sys
16 import asyncio
16 17 import inspect
17 18 from textwrap import dedent, indent
18 19
19 20
20 21 class _AsyncIORunner:
22 def __init__(self):
23 self._loop = None
24
25 @property
26 def loop(self):
27 """Always returns a non-closed event loop"""
28 if self._loop is None or self._loop.is_closed():
29 policy = asyncio.get_event_loop_policy()
30 self._loop = policy.new_event_loop()
31 policy.set_event_loop(self._loop)
32 return self._loop
21 33
22 34 def __call__(self, coro):
23 35 """
24 36 Handler for asyncio autoawait
25 37 """
26 import asyncio
27
28 return asyncio.get_event_loop().run_until_complete(coro)
38 return self.loop.run_until_complete(coro)
29 39
30 40 def __str__(self):
31 41 return 'asyncio'
32 42
33 43 _asyncio_runner = _AsyncIORunner()
34 44
35 45
36 46 def _curio_runner(coroutine):
37 47 """
38 48 handler for curio autoawait
39 49 """
40 50 import curio
41 51
42 52 return curio.run(coroutine)
43 53
44 54
45 55 def _trio_runner(async_fn):
46 56 import trio
47 57
48 58 async def loc(coro):
49 59 """
50 60 We need the dummy no-op async def to protect from
51 61 trio's internal. See https://github.com/python-trio/trio/issues/89
52 62 """
53 63 return await coro
54 64
55 65 return trio.run(loc, async_fn)
56 66
57 67
58 68 def _pseudo_sync_runner(coro):
59 69 """
60 70 A runner that does not really allow async execution, and just advance the coroutine.
61 71
62 72 See discussion in https://github.com/python-trio/trio/issues/608,
63 73
64 74 Credit to Nathaniel Smith
65 75
66 76 """
67 77 try:
68 78 coro.send(None)
69 79 except StopIteration as exc:
70 80 return exc.value
71 81 else:
72 82 # TODO: do not raise but return an execution result with the right info.
73 83 raise RuntimeError(
74 84 "{coro_name!r} needs a real async loop".format(coro_name=coro.__name__)
75 85 )
76 86
77 87
78 88 def _asyncify(code: str) -> str:
79 89 """wrap code in async def definition.
80 90
81 91 And setup a bit of context to run it later.
82 92 """
83 93 res = dedent(
84 94 """
85 95 async def __wrapper__():
86 96 try:
87 97 {usercode}
88 98 finally:
89 99 locals()
90 100 """
91 101 ).format(usercode=indent(code, " " * 8))
92 102 return res
93 103
94 104
95 105 class _AsyncSyntaxErrorVisitor(ast.NodeVisitor):
96 106 """
97 107 Find syntax errors that would be an error in an async repl, but because
98 108 the implementation involves wrapping the repl in an async function, it
99 109 is erroneously allowed (e.g. yield or return at the top level)
100 110 """
101 111 def __init__(self):
102 112 if sys.version_info >= (3,8):
103 113 raise ValueError('DEPRECATED in Python 3.8+')
104 114 self.depth = 0
105 115 super().__init__()
106 116
107 117 def generic_visit(self, node):
108 118 func_types = (ast.FunctionDef, ast.AsyncFunctionDef)
109 119 invalid_types_by_depth = {
110 120 0: (ast.Return, ast.Yield, ast.YieldFrom),
111 121 1: (ast.Nonlocal,)
112 122 }
113 123
114 124 should_traverse = self.depth < max(invalid_types_by_depth.keys())
115 125 if isinstance(node, func_types) and should_traverse:
116 126 self.depth += 1
117 127 super().generic_visit(node)
118 128 self.depth -= 1
119 129 elif isinstance(node, invalid_types_by_depth[self.depth]):
120 130 raise SyntaxError()
121 131 else:
122 132 super().generic_visit(node)
123 133
124 134
125 135 def _async_parse_cell(cell: str) -> ast.AST:
126 136 """
127 137 This is a compatibility shim for pre-3.7 when async outside of a function
128 138 is a syntax error at the parse stage.
129 139
130 140 It will return an abstract syntax tree parsed as if async and await outside
131 141 of a function were not a syntax error.
132 142 """
133 143 if sys.version_info < (3, 7):
134 144 # Prior to 3.7 you need to asyncify before parse
135 145 wrapped_parse_tree = ast.parse(_asyncify(cell))
136 146 return wrapped_parse_tree.body[0].body[0]
137 147 else:
138 148 return ast.parse(cell)
139 149
140 150
141 151 def _should_be_async(cell: str) -> bool:
142 152 """Detect if a block of code need to be wrapped in an `async def`
143 153
144 154 Attempt to parse the block of code, it it compile we're fine.
145 155 Otherwise we wrap if and try to compile.
146 156
147 157 If it works, assume it should be async. Otherwise Return False.
148 158
149 159 Not handled yet: If the block of code has a return statement as the top
150 160 level, it will be seen as async. This is a know limitation.
151 161 """
152 162 if sys.version_info > (3, 8):
153 163 try:
154 164 code = compile(cell, "<>", "exec", flags=getattr(ast,'PyCF_ALLOW_TOP_LEVEL_AWAIT', 0x0))
155 165 return inspect.CO_COROUTINE & code.co_flags == inspect.CO_COROUTINE
156 166 except (SyntaxError, MemoryError):
157 167 return False
158 168 try:
159 169 # we can't limit ourself to ast.parse, as it __accepts__ to parse on
160 170 # 3.7+, but just does not _compile_
161 171 code = compile(cell, "<>", "exec")
162 172 except (SyntaxError, MemoryError):
163 173 try:
164 174 parse_tree = _async_parse_cell(cell)
165 175
166 176 # Raise a SyntaxError if there are top-level return or yields
167 177 v = _AsyncSyntaxErrorVisitor()
168 178 v.visit(parse_tree)
169 179
170 180 except (SyntaxError, MemoryError):
171 181 return False
172 182 return True
173 183 return False
@@ -1,1061 +1,1077 b''
1 1 # -*- coding: utf-8 -*-
2 2 """Tests for the key interactiveshell module.
3 3
4 4 Historically the main classes in interactiveshell have been under-tested. This
5 5 module should grow as many single-method tests as possible to trap many of the
6 6 recurring bugs we seem to encounter with high-level interaction.
7 7 """
8 8
9 9 # Copyright (c) IPython Development Team.
10 10 # Distributed under the terms of the Modified BSD License.
11 11
12 12 import asyncio
13 13 import ast
14 14 import os
15 15 import signal
16 16 import shutil
17 17 import sys
18 18 import tempfile
19 19 import unittest
20 20 from unittest import mock
21 21
22 22 from os.path import join
23 23
24 24 import nose.tools as nt
25 25
26 26 from IPython.core.error import InputRejected
27 27 from IPython.core.inputtransformer import InputTransformer
28 28 from IPython.core import interactiveshell
29 29 from IPython.testing.decorators import (
30 30 skipif, skip_win32, onlyif_unicode_paths, onlyif_cmds_exist,
31 31 )
32 32 from IPython.testing import tools as tt
33 33 from IPython.utils.process import find_cmd
34 34
35 35 #-----------------------------------------------------------------------------
36 36 # Globals
37 37 #-----------------------------------------------------------------------------
38 38 # This is used by every single test, no point repeating it ad nauseam
39 39
40 40 #-----------------------------------------------------------------------------
41 41 # Tests
42 42 #-----------------------------------------------------------------------------
43 43
44 44 class DerivedInterrupt(KeyboardInterrupt):
45 45 pass
46 46
47 47 class InteractiveShellTestCase(unittest.TestCase):
48 48 def test_naked_string_cells(self):
49 49 """Test that cells with only naked strings are fully executed"""
50 50 # First, single-line inputs
51 51 ip.run_cell('"a"\n')
52 52 self.assertEqual(ip.user_ns['_'], 'a')
53 53 # And also multi-line cells
54 54 ip.run_cell('"""a\nb"""\n')
55 55 self.assertEqual(ip.user_ns['_'], 'a\nb')
56 56
57 57 def test_run_empty_cell(self):
58 58 """Just make sure we don't get a horrible error with a blank
59 59 cell of input. Yes, I did overlook that."""
60 60 old_xc = ip.execution_count
61 61 res = ip.run_cell('')
62 62 self.assertEqual(ip.execution_count, old_xc)
63 63 self.assertEqual(res.execution_count, None)
64 64
65 65 def test_run_cell_multiline(self):
66 66 """Multi-block, multi-line cells must execute correctly.
67 67 """
68 68 src = '\n'.join(["x=1",
69 69 "y=2",
70 70 "if 1:",
71 71 " x += 1",
72 72 " y += 1",])
73 73 res = ip.run_cell(src)
74 74 self.assertEqual(ip.user_ns['x'], 2)
75 75 self.assertEqual(ip.user_ns['y'], 3)
76 76 self.assertEqual(res.success, True)
77 77 self.assertEqual(res.result, None)
78 78
79 79 def test_multiline_string_cells(self):
80 80 "Code sprinkled with multiline strings should execute (GH-306)"
81 81 ip.run_cell('tmp=0')
82 82 self.assertEqual(ip.user_ns['tmp'], 0)
83 83 res = ip.run_cell('tmp=1;"""a\nb"""\n')
84 84 self.assertEqual(ip.user_ns['tmp'], 1)
85 85 self.assertEqual(res.success, True)
86 86 self.assertEqual(res.result, "a\nb")
87 87
88 88 def test_dont_cache_with_semicolon(self):
89 89 "Ending a line with semicolon should not cache the returned object (GH-307)"
90 90 oldlen = len(ip.user_ns['Out'])
91 91 for cell in ['1;', '1;1;']:
92 92 res = ip.run_cell(cell, store_history=True)
93 93 newlen = len(ip.user_ns['Out'])
94 94 self.assertEqual(oldlen, newlen)
95 95 self.assertIsNone(res.result)
96 96 i = 0
97 97 #also test the default caching behavior
98 98 for cell in ['1', '1;1']:
99 99 ip.run_cell(cell, store_history=True)
100 100 newlen = len(ip.user_ns['Out'])
101 101 i += 1
102 102 self.assertEqual(oldlen+i, newlen)
103 103
104 104 def test_syntax_error(self):
105 105 res = ip.run_cell("raise = 3")
106 106 self.assertIsInstance(res.error_before_exec, SyntaxError)
107 107
108 108 def test_In_variable(self):
109 109 "Verify that In variable grows with user input (GH-284)"
110 110 oldlen = len(ip.user_ns['In'])
111 111 ip.run_cell('1;', store_history=True)
112 112 newlen = len(ip.user_ns['In'])
113 113 self.assertEqual(oldlen+1, newlen)
114 114 self.assertEqual(ip.user_ns['In'][-1],'1;')
115 115
116 116 def test_magic_names_in_string(self):
117 117 ip.run_cell('a = """\n%exit\n"""')
118 118 self.assertEqual(ip.user_ns['a'], '\n%exit\n')
119 119
120 120 def test_trailing_newline(self):
121 121 """test that running !(command) does not raise a SyntaxError"""
122 122 ip.run_cell('!(true)\n', False)
123 123 ip.run_cell('!(true)\n\n\n', False)
124 124
125 125 def test_gh_597(self):
126 126 """Pretty-printing lists of objects with non-ascii reprs may cause
127 127 problems."""
128 128 class Spam(object):
129 129 def __repr__(self):
130 130 return "\xe9"*50
131 131 import IPython.core.formatters
132 132 f = IPython.core.formatters.PlainTextFormatter()
133 133 f([Spam(),Spam()])
134 134
135 135
136 136 def test_future_flags(self):
137 137 """Check that future flags are used for parsing code (gh-777)"""
138 138 ip.run_cell('from __future__ import barry_as_FLUFL')
139 139 try:
140 140 ip.run_cell('prfunc_return_val = 1 <> 2')
141 141 assert 'prfunc_return_val' in ip.user_ns
142 142 finally:
143 143 # Reset compiler flags so we don't mess up other tests.
144 144 ip.compile.reset_compiler_flags()
145 145
146 146 def test_can_pickle(self):
147 147 "Can we pickle objects defined interactively (GH-29)"
148 148 ip = get_ipython()
149 149 ip.reset()
150 150 ip.run_cell(("class Mylist(list):\n"
151 151 " def __init__(self,x=[]):\n"
152 152 " list.__init__(self,x)"))
153 153 ip.run_cell("w=Mylist([1,2,3])")
154 154
155 155 from pickle import dumps
156 156
157 157 # We need to swap in our main module - this is only necessary
158 158 # inside the test framework, because IPython puts the interactive module
159 159 # in place (but the test framework undoes this).
160 160 _main = sys.modules['__main__']
161 161 sys.modules['__main__'] = ip.user_module
162 162 try:
163 163 res = dumps(ip.user_ns["w"])
164 164 finally:
165 165 sys.modules['__main__'] = _main
166 166 self.assertTrue(isinstance(res, bytes))
167 167
168 168 def test_global_ns(self):
169 169 "Code in functions must be able to access variables outside them."
170 170 ip = get_ipython()
171 171 ip.run_cell("a = 10")
172 172 ip.run_cell(("def f(x):\n"
173 173 " return x + a"))
174 174 ip.run_cell("b = f(12)")
175 175 self.assertEqual(ip.user_ns["b"], 22)
176 176
177 177 def test_bad_custom_tb(self):
178 178 """Check that InteractiveShell is protected from bad custom exception handlers"""
179 179 ip.set_custom_exc((IOError,), lambda etype,value,tb: 1/0)
180 180 self.assertEqual(ip.custom_exceptions, (IOError,))
181 181 with tt.AssertPrints("Custom TB Handler failed", channel='stderr'):
182 182 ip.run_cell(u'raise IOError("foo")')
183 183 self.assertEqual(ip.custom_exceptions, ())
184 184
185 185 def test_bad_custom_tb_return(self):
186 186 """Check that InteractiveShell is protected from bad return types in custom exception handlers"""
187 187 ip.set_custom_exc((NameError,),lambda etype,value,tb, tb_offset=None: 1)
188 188 self.assertEqual(ip.custom_exceptions, (NameError,))
189 189 with tt.AssertPrints("Custom TB Handler failed", channel='stderr'):
190 190 ip.run_cell(u'a=abracadabra')
191 191 self.assertEqual(ip.custom_exceptions, ())
192 192
193 193 def test_drop_by_id(self):
194 194 myvars = {"a":object(), "b":object(), "c": object()}
195 195 ip.push(myvars, interactive=False)
196 196 for name in myvars:
197 197 assert name in ip.user_ns, name
198 198 assert name in ip.user_ns_hidden, name
199 199 ip.user_ns['b'] = 12
200 200 ip.drop_by_id(myvars)
201 201 for name in ["a", "c"]:
202 202 assert name not in ip.user_ns, name
203 203 assert name not in ip.user_ns_hidden, name
204 204 assert ip.user_ns['b'] == 12
205 205 ip.reset()
206 206
207 207 def test_var_expand(self):
208 208 ip.user_ns['f'] = u'Ca\xf1o'
209 209 self.assertEqual(ip.var_expand(u'echo $f'), u'echo Ca\xf1o')
210 210 self.assertEqual(ip.var_expand(u'echo {f}'), u'echo Ca\xf1o')
211 211 self.assertEqual(ip.var_expand(u'echo {f[:-1]}'), u'echo Ca\xf1')
212 212 self.assertEqual(ip.var_expand(u'echo {1*2}'), u'echo 2')
213 213
214 214 self.assertEqual(ip.var_expand(u"grep x | awk '{print $1}'"), u"grep x | awk '{print $1}'")
215 215
216 216 ip.user_ns['f'] = b'Ca\xc3\xb1o'
217 217 # This should not raise any exception:
218 218 ip.var_expand(u'echo $f')
219 219
220 220 def test_var_expand_local(self):
221 221 """Test local variable expansion in !system and %magic calls"""
222 222 # !system
223 223 ip.run_cell('def test():\n'
224 224 ' lvar = "ttt"\n'
225 225 ' ret = !echo {lvar}\n'
226 226 ' return ret[0]\n')
227 227 res = ip.user_ns['test']()
228 228 nt.assert_in('ttt', res)
229 229
230 230 # %magic
231 231 ip.run_cell('def makemacro():\n'
232 232 ' macroname = "macro_var_expand_locals"\n'
233 233 ' %macro {macroname} codestr\n')
234 234 ip.user_ns['codestr'] = "str(12)"
235 235 ip.run_cell('makemacro()')
236 236 nt.assert_in('macro_var_expand_locals', ip.user_ns)
237 237
238 238 def test_var_expand_self(self):
239 239 """Test variable expansion with the name 'self', which was failing.
240 240
241 241 See https://github.com/ipython/ipython/issues/1878#issuecomment-7698218
242 242 """
243 243 ip.run_cell('class cTest:\n'
244 244 ' classvar="see me"\n'
245 245 ' def test(self):\n'
246 246 ' res = !echo Variable: {self.classvar}\n'
247 247 ' return res[0]\n')
248 248 nt.assert_in('see me', ip.user_ns['cTest']().test())
249 249
250 250 def test_bad_var_expand(self):
251 251 """var_expand on invalid formats shouldn't raise"""
252 252 # SyntaxError
253 253 self.assertEqual(ip.var_expand(u"{'a':5}"), u"{'a':5}")
254 254 # NameError
255 255 self.assertEqual(ip.var_expand(u"{asdf}"), u"{asdf}")
256 256 # ZeroDivisionError
257 257 self.assertEqual(ip.var_expand(u"{1/0}"), u"{1/0}")
258 258
259 259 def test_silent_postexec(self):
260 260 """run_cell(silent=True) doesn't invoke pre/post_run_cell callbacks"""
261 261 pre_explicit = mock.Mock()
262 262 pre_always = mock.Mock()
263 263 post_explicit = mock.Mock()
264 264 post_always = mock.Mock()
265 265 all_mocks = [pre_explicit, pre_always, post_explicit, post_always]
266 266
267 267 ip.events.register('pre_run_cell', pre_explicit)
268 268 ip.events.register('pre_execute', pre_always)
269 269 ip.events.register('post_run_cell', post_explicit)
270 270 ip.events.register('post_execute', post_always)
271 271
272 272 try:
273 273 ip.run_cell("1", silent=True)
274 274 assert pre_always.called
275 275 assert not pre_explicit.called
276 276 assert post_always.called
277 277 assert not post_explicit.called
278 278 # double-check that non-silent exec did what we expected
279 279 # silent to avoid
280 280 ip.run_cell("1")
281 281 assert pre_explicit.called
282 282 assert post_explicit.called
283 283 info, = pre_explicit.call_args[0]
284 284 result, = post_explicit.call_args[0]
285 285 self.assertEqual(info, result.info)
286 286 # check that post hooks are always called
287 287 [m.reset_mock() for m in all_mocks]
288 288 ip.run_cell("syntax error")
289 289 assert pre_always.called
290 290 assert pre_explicit.called
291 291 assert post_always.called
292 292 assert post_explicit.called
293 293 info, = pre_explicit.call_args[0]
294 294 result, = post_explicit.call_args[0]
295 295 self.assertEqual(info, result.info)
296 296 finally:
297 297 # remove post-exec
298 298 ip.events.unregister('pre_run_cell', pre_explicit)
299 299 ip.events.unregister('pre_execute', pre_always)
300 300 ip.events.unregister('post_run_cell', post_explicit)
301 301 ip.events.unregister('post_execute', post_always)
302 302
303 303 def test_silent_noadvance(self):
304 304 """run_cell(silent=True) doesn't advance execution_count"""
305 305 ec = ip.execution_count
306 306 # silent should force store_history=False
307 307 ip.run_cell("1", store_history=True, silent=True)
308 308
309 309 self.assertEqual(ec, ip.execution_count)
310 310 # double-check that non-silent exec did what we expected
311 311 # silent to avoid
312 312 ip.run_cell("1", store_history=True)
313 313 self.assertEqual(ec+1, ip.execution_count)
314 314
315 315 def test_silent_nodisplayhook(self):
316 316 """run_cell(silent=True) doesn't trigger displayhook"""
317 317 d = dict(called=False)
318 318
319 319 trap = ip.display_trap
320 320 save_hook = trap.hook
321 321
322 322 def failing_hook(*args, **kwargs):
323 323 d['called'] = True
324 324
325 325 try:
326 326 trap.hook = failing_hook
327 327 res = ip.run_cell("1", silent=True)
328 328 self.assertFalse(d['called'])
329 329 self.assertIsNone(res.result)
330 330 # double-check that non-silent exec did what we expected
331 331 # silent to avoid
332 332 ip.run_cell("1")
333 333 self.assertTrue(d['called'])
334 334 finally:
335 335 trap.hook = save_hook
336 336
337 337 def test_ofind_line_magic(self):
338 338 from IPython.core.magic import register_line_magic
339 339
340 340 @register_line_magic
341 341 def lmagic(line):
342 342 "A line magic"
343 343
344 344 # Get info on line magic
345 345 lfind = ip._ofind('lmagic')
346 346 info = dict(found=True, isalias=False, ismagic=True,
347 347 namespace = 'IPython internal', obj= lmagic.__wrapped__,
348 348 parent = None)
349 349 nt.assert_equal(lfind, info)
350 350
351 351 def test_ofind_cell_magic(self):
352 352 from IPython.core.magic import register_cell_magic
353 353
354 354 @register_cell_magic
355 355 def cmagic(line, cell):
356 356 "A cell magic"
357 357
358 358 # Get info on cell magic
359 359 find = ip._ofind('cmagic')
360 360 info = dict(found=True, isalias=False, ismagic=True,
361 361 namespace = 'IPython internal', obj= cmagic.__wrapped__,
362 362 parent = None)
363 363 nt.assert_equal(find, info)
364 364
365 365 def test_ofind_property_with_error(self):
366 366 class A(object):
367 367 @property
368 368 def foo(self):
369 369 raise NotImplementedError()
370 370 a = A()
371 371
372 372 found = ip._ofind('a.foo', [('locals', locals())])
373 373 info = dict(found=True, isalias=False, ismagic=False,
374 374 namespace='locals', obj=A.foo, parent=a)
375 375 nt.assert_equal(found, info)
376 376
377 377 def test_ofind_multiple_attribute_lookups(self):
378 378 class A(object):
379 379 @property
380 380 def foo(self):
381 381 raise NotImplementedError()
382 382
383 383 a = A()
384 384 a.a = A()
385 385 a.a.a = A()
386 386
387 387 found = ip._ofind('a.a.a.foo', [('locals', locals())])
388 388 info = dict(found=True, isalias=False, ismagic=False,
389 389 namespace='locals', obj=A.foo, parent=a.a.a)
390 390 nt.assert_equal(found, info)
391 391
392 392 def test_ofind_slotted_attributes(self):
393 393 class A(object):
394 394 __slots__ = ['foo']
395 395 def __init__(self):
396 396 self.foo = 'bar'
397 397
398 398 a = A()
399 399 found = ip._ofind('a.foo', [('locals', locals())])
400 400 info = dict(found=True, isalias=False, ismagic=False,
401 401 namespace='locals', obj=a.foo, parent=a)
402 402 nt.assert_equal(found, info)
403 403
404 404 found = ip._ofind('a.bar', [('locals', locals())])
405 405 info = dict(found=False, isalias=False, ismagic=False,
406 406 namespace=None, obj=None, parent=a)
407 407 nt.assert_equal(found, info)
408 408
409 409 def test_ofind_prefers_property_to_instance_level_attribute(self):
410 410 class A(object):
411 411 @property
412 412 def foo(self):
413 413 return 'bar'
414 414 a = A()
415 415 a.__dict__['foo'] = 'baz'
416 416 nt.assert_equal(a.foo, 'bar')
417 417 found = ip._ofind('a.foo', [('locals', locals())])
418 418 nt.assert_is(found['obj'], A.foo)
419 419
420 420 def test_custom_syntaxerror_exception(self):
421 421 called = []
422 422 def my_handler(shell, etype, value, tb, tb_offset=None):
423 423 called.append(etype)
424 424 shell.showtraceback((etype, value, tb), tb_offset=tb_offset)
425 425
426 426 ip.set_custom_exc((SyntaxError,), my_handler)
427 427 try:
428 428 ip.run_cell("1f")
429 429 # Check that this was called, and only once.
430 430 self.assertEqual(called, [SyntaxError])
431 431 finally:
432 432 # Reset the custom exception hook
433 433 ip.set_custom_exc((), None)
434 434
435 435 def test_custom_exception(self):
436 436 called = []
437 437 def my_handler(shell, etype, value, tb, tb_offset=None):
438 438 called.append(etype)
439 439 shell.showtraceback((etype, value, tb), tb_offset=tb_offset)
440 440
441 441 ip.set_custom_exc((ValueError,), my_handler)
442 442 try:
443 443 res = ip.run_cell("raise ValueError('test')")
444 444 # Check that this was called, and only once.
445 445 self.assertEqual(called, [ValueError])
446 446 # Check that the error is on the result object
447 447 self.assertIsInstance(res.error_in_exec, ValueError)
448 448 finally:
449 449 # Reset the custom exception hook
450 450 ip.set_custom_exc((), None)
451 451
452 452 @mock.patch("builtins.print")
453 453 def test_showtraceback_with_surrogates(self, mocked_print):
454 454 values = []
455 455
456 456 def mock_print_func(value, sep=" ", end="\n", file=sys.stdout, flush=False):
457 457 values.append(value)
458 458 if value == chr(0xD8FF):
459 459 raise UnicodeEncodeError("utf-8", chr(0xD8FF), 0, 1, "")
460 460
461 461 # mock builtins.print
462 462 mocked_print.side_effect = mock_print_func
463 463
464 464 # ip._showtraceback() is replaced in globalipapp.py.
465 465 # Call original method to test.
466 466 interactiveshell.InteractiveShell._showtraceback(ip, None, None, chr(0xD8FF))
467 467
468 468 self.assertEqual(mocked_print.call_count, 2)
469 469 self.assertEqual(values, [chr(0xD8FF), "\\ud8ff"])
470 470
471 471 def test_mktempfile(self):
472 472 filename = ip.mktempfile()
473 473 # Check that we can open the file again on Windows
474 474 with open(filename, 'w') as f:
475 475 f.write('abc')
476 476
477 477 filename = ip.mktempfile(data='blah')
478 478 with open(filename, 'r') as f:
479 479 self.assertEqual(f.read(), 'blah')
480 480
481 481 def test_new_main_mod(self):
482 482 # Smoketest to check that this accepts a unicode module name
483 483 name = u'jiefmw'
484 484 mod = ip.new_main_mod(u'%s.py' % name, name)
485 485 self.assertEqual(mod.__name__, name)
486 486
487 487 def test_get_exception_only(self):
488 488 try:
489 489 raise KeyboardInterrupt
490 490 except KeyboardInterrupt:
491 491 msg = ip.get_exception_only()
492 492 self.assertEqual(msg, 'KeyboardInterrupt\n')
493 493
494 494 try:
495 495 raise DerivedInterrupt("foo")
496 496 except KeyboardInterrupt:
497 497 msg = ip.get_exception_only()
498 498 self.assertEqual(msg, 'IPython.core.tests.test_interactiveshell.DerivedInterrupt: foo\n')
499 499
500 500 def test_inspect_text(self):
501 501 ip.run_cell('a = 5')
502 502 text = ip.object_inspect_text('a')
503 503 self.assertIsInstance(text, str)
504 504
505 505 def test_last_execution_result(self):
506 506 """ Check that last execution result gets set correctly (GH-10702) """
507 507 result = ip.run_cell('a = 5; a')
508 508 self.assertTrue(ip.last_execution_succeeded)
509 509 self.assertEqual(ip.last_execution_result.result, 5)
510 510
511 511 result = ip.run_cell('a = x_invalid_id_x')
512 512 self.assertFalse(ip.last_execution_succeeded)
513 513 self.assertFalse(ip.last_execution_result.success)
514 514 self.assertIsInstance(ip.last_execution_result.error_in_exec, NameError)
515 515
516 516 def test_reset_aliasing(self):
517 517 """ Check that standard posix aliases work after %reset. """
518 518 if os.name != 'posix':
519 519 return
520 520
521 521 ip.reset()
522 522 for cmd in ('clear', 'more', 'less', 'man'):
523 523 res = ip.run_cell('%' + cmd)
524 524 self.assertEqual(res.success, True)
525 525
526 526
527 527 class TestSafeExecfileNonAsciiPath(unittest.TestCase):
528 528
529 529 @onlyif_unicode_paths
530 530 def setUp(self):
531 531 self.BASETESTDIR = tempfile.mkdtemp()
532 532 self.TESTDIR = join(self.BASETESTDIR, u"Γ₯Àâ")
533 533 os.mkdir(self.TESTDIR)
534 534 with open(join(self.TESTDIR, u"Γ₯Àâtestscript.py"), "w") as sfile:
535 535 sfile.write("pass\n")
536 536 self.oldpath = os.getcwd()
537 537 os.chdir(self.TESTDIR)
538 538 self.fname = u"Γ₯Àâtestscript.py"
539 539
540 540 def tearDown(self):
541 541 os.chdir(self.oldpath)
542 542 shutil.rmtree(self.BASETESTDIR)
543 543
544 544 @onlyif_unicode_paths
545 545 def test_1(self):
546 546 """Test safe_execfile with non-ascii path
547 547 """
548 548 ip.safe_execfile(self.fname, {}, raise_exceptions=True)
549 549
550 550 class ExitCodeChecks(tt.TempFileMixin):
551 551
552 552 def setUp(self):
553 553 self.system = ip.system_raw
554 554
555 555 def test_exit_code_ok(self):
556 556 self.system('exit 0')
557 557 self.assertEqual(ip.user_ns['_exit_code'], 0)
558 558
559 559 def test_exit_code_error(self):
560 560 self.system('exit 1')
561 561 self.assertEqual(ip.user_ns['_exit_code'], 1)
562 562
563 563 @skipif(not hasattr(signal, 'SIGALRM'))
564 564 def test_exit_code_signal(self):
565 565 self.mktmp("import signal, time\n"
566 566 "signal.setitimer(signal.ITIMER_REAL, 0.1)\n"
567 567 "time.sleep(1)\n")
568 568 self.system("%s %s" % (sys.executable, self.fname))
569 569 self.assertEqual(ip.user_ns['_exit_code'], -signal.SIGALRM)
570 570
571 571 @onlyif_cmds_exist("csh")
572 572 def test_exit_code_signal_csh(self):
573 573 SHELL = os.environ.get('SHELL', None)
574 574 os.environ['SHELL'] = find_cmd("csh")
575 575 try:
576 576 self.test_exit_code_signal()
577 577 finally:
578 578 if SHELL is not None:
579 579 os.environ['SHELL'] = SHELL
580 580 else:
581 581 del os.environ['SHELL']
582 582
583 583
584 584 class TestSystemRaw(ExitCodeChecks):
585 585
586 586 def setUp(self):
587 587 super().setUp()
588 588 self.system = ip.system_raw
589 589
590 590 @onlyif_unicode_paths
591 591 def test_1(self):
592 592 """Test system_raw with non-ascii cmd
593 593 """
594 594 cmd = u'''python -c "'Γ₯Àâ'" '''
595 595 ip.system_raw(cmd)
596 596
597 597 @mock.patch('subprocess.call', side_effect=KeyboardInterrupt)
598 598 @mock.patch('os.system', side_effect=KeyboardInterrupt)
599 599 def test_control_c(self, *mocks):
600 600 try:
601 601 self.system("sleep 1 # wont happen")
602 602 except KeyboardInterrupt:
603 603 self.fail("system call should intercept "
604 604 "keyboard interrupt from subprocess.call")
605 605 self.assertEqual(ip.user_ns['_exit_code'], -signal.SIGINT)
606 606
607 607 # TODO: Exit codes are currently ignored on Windows.
608 608 class TestSystemPipedExitCode(ExitCodeChecks):
609 609
610 610 def setUp(self):
611 611 super().setUp()
612 612 self.system = ip.system_piped
613 613
614 614 @skip_win32
615 615 def test_exit_code_ok(self):
616 616 ExitCodeChecks.test_exit_code_ok(self)
617 617
618 618 @skip_win32
619 619 def test_exit_code_error(self):
620 620 ExitCodeChecks.test_exit_code_error(self)
621 621
622 622 @skip_win32
623 623 def test_exit_code_signal(self):
624 624 ExitCodeChecks.test_exit_code_signal(self)
625 625
626 626 class TestModules(tt.TempFileMixin):
627 627 def test_extraneous_loads(self):
628 628 """Test we're not loading modules on startup that we shouldn't.
629 629 """
630 630 self.mktmp("import sys\n"
631 631 "print('numpy' in sys.modules)\n"
632 632 "print('ipyparallel' in sys.modules)\n"
633 633 "print('ipykernel' in sys.modules)\n"
634 634 )
635 635 out = "False\nFalse\nFalse\n"
636 636 tt.ipexec_validate(self.fname, out)
637 637
638 638 class Negator(ast.NodeTransformer):
639 639 """Negates all number literals in an AST."""
640 640
641 641 # for python 3.7 and earlier
642 642 def visit_Num(self, node):
643 643 node.n = -node.n
644 644 return node
645 645
646 646 # for python 3.8+
647 647 def visit_Constant(self, node):
648 648 if isinstance(node.value, int):
649 649 return self.visit_Num(node)
650 650 return node
651 651
652 652 class TestAstTransform(unittest.TestCase):
653 653 def setUp(self):
654 654 self.negator = Negator()
655 655 ip.ast_transformers.append(self.negator)
656 656
657 657 def tearDown(self):
658 658 ip.ast_transformers.remove(self.negator)
659 659
660 660 def test_run_cell(self):
661 661 with tt.AssertPrints('-34'):
662 662 ip.run_cell('print (12 + 22)')
663 663
664 664 # A named reference to a number shouldn't be transformed.
665 665 ip.user_ns['n'] = 55
666 666 with tt.AssertNotPrints('-55'):
667 667 ip.run_cell('print (n)')
668 668
669 669 def test_timeit(self):
670 670 called = set()
671 671 def f(x):
672 672 called.add(x)
673 673 ip.push({'f':f})
674 674
675 675 with tt.AssertPrints("std. dev. of"):
676 676 ip.run_line_magic("timeit", "-n1 f(1)")
677 677 self.assertEqual(called, {-1})
678 678 called.clear()
679 679
680 680 with tt.AssertPrints("std. dev. of"):
681 681 ip.run_cell_magic("timeit", "-n1 f(2)", "f(3)")
682 682 self.assertEqual(called, {-2, -3})
683 683
684 684 def test_time(self):
685 685 called = []
686 686 def f(x):
687 687 called.append(x)
688 688 ip.push({'f':f})
689 689
690 690 # Test with an expression
691 691 with tt.AssertPrints("Wall time: "):
692 692 ip.run_line_magic("time", "f(5+9)")
693 693 self.assertEqual(called, [-14])
694 694 called[:] = []
695 695
696 696 # Test with a statement (different code path)
697 697 with tt.AssertPrints("Wall time: "):
698 698 ip.run_line_magic("time", "a = f(-3 + -2)")
699 699 self.assertEqual(called, [5])
700 700
701 701 def test_macro(self):
702 702 ip.push({'a':10})
703 703 # The AST transformation makes this do a+=-1
704 704 ip.define_macro("amacro", "a+=1\nprint(a)")
705 705
706 706 with tt.AssertPrints("9"):
707 707 ip.run_cell("amacro")
708 708 with tt.AssertPrints("8"):
709 709 ip.run_cell("amacro")
710 710
711 711 class TestMiscTransform(unittest.TestCase):
712 712
713 713
714 714 def test_transform_only_once(self):
715 715 cleanup = 0
716 716 line_t = 0
717 717 def count_cleanup(lines):
718 718 nonlocal cleanup
719 719 cleanup += 1
720 720 return lines
721 721
722 722 def count_line_t(lines):
723 723 nonlocal line_t
724 724 line_t += 1
725 725 return lines
726 726
727 727 ip.input_transformer_manager.cleanup_transforms.append(count_cleanup)
728 728 ip.input_transformer_manager.line_transforms.append(count_line_t)
729 729
730 730 ip.run_cell('1')
731 731
732 732 assert cleanup == 1
733 733 assert line_t == 1
734 734
735 735 class IntegerWrapper(ast.NodeTransformer):
736 736 """Wraps all integers in a call to Integer()"""
737 737
738 738 # for Python 3.7 and earlier
739 739
740 740 # for Python 3.7 and earlier
741 741 def visit_Num(self, node):
742 742 if isinstance(node.n, int):
743 743 return ast.Call(func=ast.Name(id='Integer', ctx=ast.Load()),
744 744 args=[node], keywords=[])
745 745 return node
746 746
747 747 # For Python 3.8+
748 748 def visit_Constant(self, node):
749 749 if isinstance(node.value, int):
750 750 return self.visit_Num(node)
751 751 return node
752 752
753 753
754 754 class TestAstTransform2(unittest.TestCase):
755 755 def setUp(self):
756 756 self.intwrapper = IntegerWrapper()
757 757 ip.ast_transformers.append(self.intwrapper)
758 758
759 759 self.calls = []
760 760 def Integer(*args):
761 761 self.calls.append(args)
762 762 return args
763 763 ip.push({"Integer": Integer})
764 764
765 765 def tearDown(self):
766 766 ip.ast_transformers.remove(self.intwrapper)
767 767 del ip.user_ns['Integer']
768 768
769 769 def test_run_cell(self):
770 770 ip.run_cell("n = 2")
771 771 self.assertEqual(self.calls, [(2,)])
772 772
773 773 # This shouldn't throw an error
774 774 ip.run_cell("o = 2.0")
775 775 self.assertEqual(ip.user_ns['o'], 2.0)
776 776
777 777 def test_timeit(self):
778 778 called = set()
779 779 def f(x):
780 780 called.add(x)
781 781 ip.push({'f':f})
782 782
783 783 with tt.AssertPrints("std. dev. of"):
784 784 ip.run_line_magic("timeit", "-n1 f(1)")
785 785 self.assertEqual(called, {(1,)})
786 786 called.clear()
787 787
788 788 with tt.AssertPrints("std. dev. of"):
789 789 ip.run_cell_magic("timeit", "-n1 f(2)", "f(3)")
790 790 self.assertEqual(called, {(2,), (3,)})
791 791
792 792 class ErrorTransformer(ast.NodeTransformer):
793 793 """Throws an error when it sees a number."""
794 794
795 795 # for Python 3.7 and earlier
796 796 def visit_Num(self, node):
797 797 raise ValueError("test")
798 798
799 799 # for Python 3.8+
800 800 def visit_Constant(self, node):
801 801 if isinstance(node.value, int):
802 802 return self.visit_Num(node)
803 803 return node
804 804
805 805
806 806 class TestAstTransformError(unittest.TestCase):
807 807 def test_unregistering(self):
808 808 err_transformer = ErrorTransformer()
809 809 ip.ast_transformers.append(err_transformer)
810 810
811 811 with self.assertWarnsRegex(UserWarning, "It will be unregistered"):
812 812 ip.run_cell("1 + 2")
813 813
814 814 # This should have been removed.
815 815 nt.assert_not_in(err_transformer, ip.ast_transformers)
816 816
817 817
818 818 class StringRejector(ast.NodeTransformer):
819 819 """Throws an InputRejected when it sees a string literal.
820 820
821 821 Used to verify that NodeTransformers can signal that a piece of code should
822 822 not be executed by throwing an InputRejected.
823 823 """
824 824
825 825 #for python 3.7 and earlier
826 826 def visit_Str(self, node):
827 827 raise InputRejected("test")
828 828
829 829 # 3.8 only
830 830 def visit_Constant(self, node):
831 831 if isinstance(node.value, str):
832 832 raise InputRejected("test")
833 833 return node
834 834
835 835
836 836 class TestAstTransformInputRejection(unittest.TestCase):
837 837
838 838 def setUp(self):
839 839 self.transformer = StringRejector()
840 840 ip.ast_transformers.append(self.transformer)
841 841
842 842 def tearDown(self):
843 843 ip.ast_transformers.remove(self.transformer)
844 844
845 845 def test_input_rejection(self):
846 846 """Check that NodeTransformers can reject input."""
847 847
848 848 expect_exception_tb = tt.AssertPrints("InputRejected: test")
849 849 expect_no_cell_output = tt.AssertNotPrints("'unsafe'", suppress=False)
850 850
851 851 # Run the same check twice to verify that the transformer is not
852 852 # disabled after raising.
853 853 with expect_exception_tb, expect_no_cell_output:
854 854 ip.run_cell("'unsafe'")
855 855
856 856 with expect_exception_tb, expect_no_cell_output:
857 857 res = ip.run_cell("'unsafe'")
858 858
859 859 self.assertIsInstance(res.error_before_exec, InputRejected)
860 860
861 861 def test__IPYTHON__():
862 862 # This shouldn't raise a NameError, that's all
863 863 __IPYTHON__
864 864
865 865
866 866 class DummyRepr(object):
867 867 def __repr__(self):
868 868 return "DummyRepr"
869 869
870 870 def _repr_html_(self):
871 871 return "<b>dummy</b>"
872 872
873 873 def _repr_javascript_(self):
874 874 return "console.log('hi');", {'key': 'value'}
875 875
876 876
877 877 def test_user_variables():
878 878 # enable all formatters
879 879 ip.display_formatter.active_types = ip.display_formatter.format_types
880 880
881 881 ip.user_ns['dummy'] = d = DummyRepr()
882 882 keys = {'dummy', 'doesnotexist'}
883 883 r = ip.user_expressions({ key:key for key in keys})
884 884
885 885 nt.assert_equal(keys, set(r.keys()))
886 886 dummy = r['dummy']
887 887 nt.assert_equal({'status', 'data', 'metadata'}, set(dummy.keys()))
888 888 nt.assert_equal(dummy['status'], 'ok')
889 889 data = dummy['data']
890 890 metadata = dummy['metadata']
891 891 nt.assert_equal(data.get('text/html'), d._repr_html_())
892 892 js, jsmd = d._repr_javascript_()
893 893 nt.assert_equal(data.get('application/javascript'), js)
894 894 nt.assert_equal(metadata.get('application/javascript'), jsmd)
895 895
896 896 dne = r['doesnotexist']
897 897 nt.assert_equal(dne['status'], 'error')
898 898 nt.assert_equal(dne['ename'], 'NameError')
899 899
900 900 # back to text only
901 901 ip.display_formatter.active_types = ['text/plain']
902 902
903 903 def test_user_expression():
904 904 # enable all formatters
905 905 ip.display_formatter.active_types = ip.display_formatter.format_types
906 906 query = {
907 907 'a' : '1 + 2',
908 908 'b' : '1/0',
909 909 }
910 910 r = ip.user_expressions(query)
911 911 import pprint
912 912 pprint.pprint(r)
913 913 nt.assert_equal(set(r.keys()), set(query.keys()))
914 914 a = r['a']
915 915 nt.assert_equal({'status', 'data', 'metadata'}, set(a.keys()))
916 916 nt.assert_equal(a['status'], 'ok')
917 917 data = a['data']
918 918 metadata = a['metadata']
919 919 nt.assert_equal(data.get('text/plain'), '3')
920 920
921 921 b = r['b']
922 922 nt.assert_equal(b['status'], 'error')
923 923 nt.assert_equal(b['ename'], 'ZeroDivisionError')
924 924
925 925 # back to text only
926 926 ip.display_formatter.active_types = ['text/plain']
927 927
928 928
929 929 class TestSyntaxErrorTransformer(unittest.TestCase):
930 930 """Check that SyntaxError raised by an input transformer is handled by run_cell()"""
931 931
932 932 @staticmethod
933 933 def transformer(lines):
934 934 for line in lines:
935 935 pos = line.find('syntaxerror')
936 936 if pos >= 0:
937 937 e = SyntaxError('input contains "syntaxerror"')
938 938 e.text = line
939 939 e.offset = pos + 1
940 940 raise e
941 941 return lines
942 942
943 943 def setUp(self):
944 944 ip.input_transformers_post.append(self.transformer)
945 945
946 946 def tearDown(self):
947 947 ip.input_transformers_post.remove(self.transformer)
948 948
949 949 def test_syntaxerror_input_transformer(self):
950 950 with tt.AssertPrints('1234'):
951 951 ip.run_cell('1234')
952 952 with tt.AssertPrints('SyntaxError: invalid syntax'):
953 953 ip.run_cell('1 2 3') # plain python syntax error
954 954 with tt.AssertPrints('SyntaxError: input contains "syntaxerror"'):
955 955 ip.run_cell('2345 # syntaxerror') # input transformer syntax error
956 956 with tt.AssertPrints('3456'):
957 957 ip.run_cell('3456')
958 958
959 959
960 960 class TestWarningSuppression(unittest.TestCase):
961 961 def test_warning_suppression(self):
962 962 ip.run_cell("import warnings")
963 963 try:
964 964 with self.assertWarnsRegex(UserWarning, "asdf"):
965 965 ip.run_cell("warnings.warn('asdf')")
966 966 # Here's the real test -- if we run that again, we should get the
967 967 # warning again. Traditionally, each warning was only issued once per
968 968 # IPython session (approximately), even if the user typed in new and
969 969 # different code that should have also triggered the warning, leading
970 970 # to much confusion.
971 971 with self.assertWarnsRegex(UserWarning, "asdf"):
972 972 ip.run_cell("warnings.warn('asdf')")
973 973 finally:
974 974 ip.run_cell("del warnings")
975 975
976 976
977 977 def test_deprecation_warning(self):
978 978 ip.run_cell("""
979 979 import warnings
980 980 def wrn():
981 981 warnings.warn(
982 982 "I AM A WARNING",
983 983 DeprecationWarning
984 984 )
985 985 """)
986 986 try:
987 987 with self.assertWarnsRegex(DeprecationWarning, "I AM A WARNING"):
988 988 ip.run_cell("wrn()")
989 989 finally:
990 990 ip.run_cell("del warnings")
991 991 ip.run_cell("del wrn")
992 992
993 993
994 994 class TestImportNoDeprecate(tt.TempFileMixin):
995 995
996 996 def setUp(self):
997 997 """Make a valid python temp file."""
998 998 self.mktmp("""
999 999 import warnings
1000 1000 def wrn():
1001 1001 warnings.warn(
1002 1002 "I AM A WARNING",
1003 1003 DeprecationWarning
1004 1004 )
1005 1005 """)
1006 1006 super().setUp()
1007 1007
1008 1008 def test_no_dep(self):
1009 1009 """
1010 1010 No deprecation warning should be raised from imported functions
1011 1011 """
1012 1012 ip.run_cell("from {} import wrn".format(self.fname))
1013 1013
1014 1014 with tt.AssertNotPrints("I AM A WARNING"):
1015 1015 ip.run_cell("wrn()")
1016 1016 ip.run_cell("del wrn")
1017 1017
1018 1018
1019 1019 def test_custom_exc_count():
1020 1020 hook = mock.Mock(return_value=None)
1021 1021 ip.set_custom_exc((SyntaxError,), hook)
1022 1022 before = ip.execution_count
1023 1023 ip.run_cell("def foo()", store_history=True)
1024 1024 # restore default excepthook
1025 1025 ip.set_custom_exc((), None)
1026 1026 nt.assert_equal(hook.call_count, 1)
1027 1027 nt.assert_equal(ip.execution_count, before + 1)
1028 1028
1029 1029
1030 1030 def test_run_cell_async():
1031 1031 loop = asyncio.get_event_loop()
1032 1032 ip.run_cell("import asyncio")
1033 1033 coro = ip.run_cell_async("await asyncio.sleep(0.01)\n5")
1034 1034 assert asyncio.iscoroutine(coro)
1035 1035 result = loop.run_until_complete(coro)
1036 1036 assert isinstance(result, interactiveshell.ExecutionResult)
1037 1037 assert result.result == 5
1038 1038
1039 1039
1040 def test_run_cell_await():
1041 ip.run_cell("import asyncio")
1042 result = ip.run_cell("await asyncio.sleep(0.01); 10")
1043 assert ip.user_ns["_"] == 10
1044
1045
1046 def test_run_cell_asyncio_run():
1047 ip.run_cell("import asyncio")
1048 result = ip.run_cell("await asyncio.sleep(0.01); 1")
1049 assert ip.user_ns["_"] == 1
1050 result = ip.run_cell("asyncio.run(asyncio.sleep(0.01)); 2")
1051 assert ip.user_ns["_"] == 2
1052 result = ip.run_cell("await asyncio.sleep(0.01); 3")
1053 assert ip.user_ns["_"] == 3
1054
1055
1040 1056 def test_should_run_async():
1041 1057 assert not ip.should_run_async("a = 5")
1042 1058 assert ip.should_run_async("await x")
1043 1059 assert ip.should_run_async("import asyncio; await asyncio.sleep(1)")
1044 1060
1045 1061
1046 1062 def test_set_custom_completer():
1047 1063 num_completers = len(ip.Completer.matchers)
1048 1064
1049 1065 def foo(*args, **kwargs):
1050 1066 return "I'm a completer!"
1051 1067
1052 1068 ip.set_custom_completer(foo, 0)
1053 1069
1054 1070 # check that we've really added a new completer
1055 1071 assert len(ip.Completer.matchers) == num_completers + 1
1056 1072
1057 1073 # check that the first completer is the function we defined
1058 1074 assert ip.Completer.matchers[0]() == "I'm a completer!"
1059 1075
1060 1076 # clean up
1061 1077 ip.Completer.custom_matchers.pop()
@@ -1,645 +1,648 b''
1 1 """IPython terminal interface using prompt_toolkit"""
2 2
3 3 import asyncio
4 4 import os
5 5 import sys
6 6 import warnings
7 7 from warnings import warn
8 8
9 9 from IPython.core.interactiveshell import InteractiveShell, InteractiveShellABC
10 10 from IPython.utils import io
11 11 from IPython.utils.py3compat import input
12 12 from IPython.utils.terminal import toggle_set_term_title, set_term_title, restore_term_title
13 13 from IPython.utils.process import abbrev_cwd
14 14 from traitlets import (
15 15 Bool, Unicode, Dict, Integer, observe, Instance, Type, default, Enum, Union,
16 16 Any, validate
17 17 )
18 18
19 19 from prompt_toolkit.enums import DEFAULT_BUFFER, EditingMode
20 20 from prompt_toolkit.filters import (HasFocus, Condition, IsDone)
21 21 from prompt_toolkit.formatted_text import PygmentsTokens
22 22 from prompt_toolkit.history import InMemoryHistory
23 23 from prompt_toolkit.layout.processors import ConditionalProcessor, HighlightMatchingBracketProcessor
24 24 from prompt_toolkit.output import ColorDepth
25 25 from prompt_toolkit.patch_stdout import patch_stdout
26 26 from prompt_toolkit.shortcuts import PromptSession, CompleteStyle, print_formatted_text
27 27 from prompt_toolkit.styles import DynamicStyle, merge_styles
28 28 from prompt_toolkit.styles.pygments import style_from_pygments_cls, style_from_pygments_dict
29 29 from prompt_toolkit import __version__ as ptk_version
30 30
31 31 from pygments.styles import get_style_by_name
32 32 from pygments.style import Style
33 33 from pygments.token import Token
34 34
35 35 from .debugger import TerminalPdb, Pdb
36 36 from .magics import TerminalMagics
37 37 from .pt_inputhooks import get_inputhook_name_and_func
38 38 from .prompts import Prompts, ClassicPrompts, RichPromptDisplayHook
39 39 from .ptutils import IPythonPTCompleter, IPythonPTLexer
40 40 from .shortcuts import create_ipython_shortcuts
41 41
42 42 DISPLAY_BANNER_DEPRECATED = object()
43 43 PTK3 = ptk_version.startswith('3.')
44 44
45 45
46 46 class _NoStyle(Style): pass
47 47
48 48
49 49
50 50 _style_overrides_light_bg = {
51 51 Token.Prompt: '#0000ff',
52 52 Token.PromptNum: '#0000ee bold',
53 53 Token.OutPrompt: '#cc0000',
54 54 Token.OutPromptNum: '#bb0000 bold',
55 55 }
56 56
57 57 _style_overrides_linux = {
58 58 Token.Prompt: '#00cc00',
59 59 Token.PromptNum: '#00bb00 bold',
60 60 Token.OutPrompt: '#cc0000',
61 61 Token.OutPromptNum: '#bb0000 bold',
62 62 }
63 63
64 64 def get_default_editor():
65 65 try:
66 66 return os.environ['EDITOR']
67 67 except KeyError:
68 68 pass
69 69 except UnicodeError:
70 70 warn("$EDITOR environment variable is not pure ASCII. Using platform "
71 71 "default editor.")
72 72
73 73 if os.name == 'posix':
74 74 return 'vi' # the only one guaranteed to be there!
75 75 else:
76 76 return 'notepad' # same in Windows!
77 77
78 78 # conservatively check for tty
79 79 # overridden streams can result in things like:
80 80 # - sys.stdin = None
81 81 # - no isatty method
82 82 for _name in ('stdin', 'stdout', 'stderr'):
83 83 _stream = getattr(sys, _name)
84 84 if not _stream or not hasattr(_stream, 'isatty') or not _stream.isatty():
85 85 _is_tty = False
86 86 break
87 87 else:
88 88 _is_tty = True
89 89
90 90
91 91 _use_simple_prompt = ('IPY_TEST_SIMPLE_PROMPT' in os.environ) or (not _is_tty)
92 92
93 93 def black_reformat_handler(text_before_cursor):
94 94 import black
95 95 formatted_text = black.format_str(text_before_cursor, mode=black.FileMode())
96 96 if not text_before_cursor.endswith('\n') and formatted_text.endswith('\n'):
97 97 formatted_text = formatted_text[:-1]
98 98 return formatted_text
99 99
100 100
101 101 class TerminalInteractiveShell(InteractiveShell):
102 102 mime_renderers = Dict().tag(config=True)
103 103
104 104 space_for_menu = Integer(6, help='Number of line at the bottom of the screen '
105 105 'to reserve for the tab completion menu, '
106 106 'search history, ...etc, the height of '
107 107 'these menus will at most this value. '
108 108 'Increase it is you prefer long and skinny '
109 109 'menus, decrease for short and wide.'
110 110 ).tag(config=True)
111 111
112 112 pt_app = None
113 113 debugger_history = None
114 114
115 115 simple_prompt = Bool(_use_simple_prompt,
116 116 help="""Use `raw_input` for the REPL, without completion and prompt colors.
117 117
118 118 Useful when controlling IPython as a subprocess, and piping STDIN/OUT/ERR. Known usage are:
119 119 IPython own testing machinery, and emacs inferior-shell integration through elpy.
120 120
121 121 This mode default to `True` if the `IPY_TEST_SIMPLE_PROMPT`
122 122 environment variable is set, or the current terminal is not a tty."""
123 123 ).tag(config=True)
124 124
125 125 @property
126 126 def debugger_cls(self):
127 127 return Pdb if self.simple_prompt else TerminalPdb
128 128
129 129 confirm_exit = Bool(True,
130 130 help="""
131 131 Set to confirm when you try to exit IPython with an EOF (Control-D
132 132 in Unix, Control-Z/Enter in Windows). By typing 'exit' or 'quit',
133 133 you can force a direct exit without any confirmation.""",
134 134 ).tag(config=True)
135 135
136 136 editing_mode = Unicode('emacs',
137 137 help="Shortcut style to use at the prompt. 'vi' or 'emacs'.",
138 138 ).tag(config=True)
139 139
140 140 autoformatter = Unicode(None,
141 141 help="Autoformatter to reformat Terminal code. Can be `'black'` or `None`",
142 142 allow_none=True
143 143 ).tag(config=True)
144 144
145 145 mouse_support = Bool(False,
146 146 help="Enable mouse support in the prompt\n(Note: prevents selecting text with the mouse)"
147 147 ).tag(config=True)
148 148
149 149 # We don't load the list of styles for the help string, because loading
150 150 # Pygments plugins takes time and can cause unexpected errors.
151 151 highlighting_style = Union([Unicode('legacy'), Type(klass=Style)],
152 152 help="""The name or class of a Pygments style to use for syntax
153 153 highlighting. To see available styles, run `pygmentize -L styles`."""
154 154 ).tag(config=True)
155 155
156 156 @validate('editing_mode')
157 157 def _validate_editing_mode(self, proposal):
158 158 if proposal['value'].lower() == 'vim':
159 159 proposal['value']= 'vi'
160 160 elif proposal['value'].lower() == 'default':
161 161 proposal['value']= 'emacs'
162 162
163 163 if hasattr(EditingMode, proposal['value'].upper()):
164 164 return proposal['value'].lower()
165 165
166 166 return self.editing_mode
167 167
168 168
169 169 @observe('editing_mode')
170 170 def _editing_mode(self, change):
171 171 if self.pt_app:
172 172 self.pt_app.editing_mode = getattr(EditingMode, change.new.upper())
173 173
174 174 @observe('autoformatter')
175 175 def _autoformatter_changed(self, change):
176 176 formatter = change.new
177 177 if formatter is None:
178 178 self.reformat_handler = lambda x:x
179 179 elif formatter == 'black':
180 180 self.reformat_handler = black_reformat_handler
181 181 else:
182 182 raise ValueError
183 183
184 184 @observe('highlighting_style')
185 185 @observe('colors')
186 186 def _highlighting_style_changed(self, change):
187 187 self.refresh_style()
188 188
189 189 def refresh_style(self):
190 190 self._style = self._make_style_from_name_or_cls(self.highlighting_style)
191 191
192 192
193 193 highlighting_style_overrides = Dict(
194 194 help="Override highlighting format for specific tokens"
195 195 ).tag(config=True)
196 196
197 197 true_color = Bool(False,
198 198 help=("Use 24bit colors instead of 256 colors in prompt highlighting. "
199 199 "If your terminal supports true color, the following command "
200 200 "should print 'TRUECOLOR' in orange: "
201 201 "printf \"\\x1b[38;2;255;100;0mTRUECOLOR\\x1b[0m\\n\"")
202 202 ).tag(config=True)
203 203
204 204 editor = Unicode(get_default_editor(),
205 205 help="Set the editor used by IPython (default to $EDITOR/vi/notepad)."
206 206 ).tag(config=True)
207 207
208 208 prompts_class = Type(Prompts, help='Class used to generate Prompt token for prompt_toolkit').tag(config=True)
209 209
210 210 prompts = Instance(Prompts)
211 211
212 212 @default('prompts')
213 213 def _prompts_default(self):
214 214 return self.prompts_class(self)
215 215
216 216 # @observe('prompts')
217 217 # def _(self, change):
218 218 # self._update_layout()
219 219
220 220 @default('displayhook_class')
221 221 def _displayhook_class_default(self):
222 222 return RichPromptDisplayHook
223 223
224 224 term_title = Bool(True,
225 225 help="Automatically set the terminal title"
226 226 ).tag(config=True)
227 227
228 228 term_title_format = Unicode("IPython: {cwd}",
229 229 help="Customize the terminal title format. This is a python format string. " +
230 230 "Available substitutions are: {cwd}."
231 231 ).tag(config=True)
232 232
233 233 display_completions = Enum(('column', 'multicolumn','readlinelike'),
234 234 help= ( "Options for displaying tab completions, 'column', 'multicolumn', and "
235 235 "'readlinelike'. These options are for `prompt_toolkit`, see "
236 236 "`prompt_toolkit` documentation for more information."
237 237 ),
238 238 default_value='multicolumn').tag(config=True)
239 239
240 240 highlight_matching_brackets = Bool(True,
241 241 help="Highlight matching brackets.",
242 242 ).tag(config=True)
243 243
244 244 extra_open_editor_shortcuts = Bool(False,
245 245 help="Enable vi (v) or Emacs (C-X C-E) shortcuts to open an external editor. "
246 246 "This is in addition to the F2 binding, which is always enabled."
247 247 ).tag(config=True)
248 248
249 249 handle_return = Any(None,
250 250 help="Provide an alternative handler to be called when the user presses "
251 251 "Return. This is an advanced option intended for debugging, which "
252 252 "may be changed or removed in later releases."
253 253 ).tag(config=True)
254 254
255 255 enable_history_search = Bool(True,
256 256 help="Allows to enable/disable the prompt toolkit history search"
257 257 ).tag(config=True)
258 258
259 259 prompt_includes_vi_mode = Bool(True,
260 260 help="Display the current vi mode (when using vi editing mode)."
261 261 ).tag(config=True)
262 262
263 263 @observe('term_title')
264 264 def init_term_title(self, change=None):
265 265 # Enable or disable the terminal title.
266 266 if self.term_title:
267 267 toggle_set_term_title(True)
268 268 set_term_title(self.term_title_format.format(cwd=abbrev_cwd()))
269 269 else:
270 270 toggle_set_term_title(False)
271 271
272 272 def restore_term_title(self):
273 273 if self.term_title:
274 274 restore_term_title()
275 275
276 276 def init_display_formatter(self):
277 277 super(TerminalInteractiveShell, self).init_display_formatter()
278 278 # terminal only supports plain text
279 279 self.display_formatter.active_types = ['text/plain']
280 280 # disable `_ipython_display_`
281 281 self.display_formatter.ipython_display_formatter.enabled = False
282 282
283 283 def init_prompt_toolkit_cli(self):
284 284 if self.simple_prompt:
285 285 # Fall back to plain non-interactive output for tests.
286 286 # This is very limited.
287 287 def prompt():
288 288 prompt_text = "".join(x[1] for x in self.prompts.in_prompt_tokens())
289 289 lines = [input(prompt_text)]
290 290 prompt_continuation = "".join(x[1] for x in self.prompts.continuation_prompt_tokens())
291 291 while self.check_complete('\n'.join(lines))[0] == 'incomplete':
292 292 lines.append( input(prompt_continuation) )
293 293 return '\n'.join(lines)
294 294 self.prompt_for_code = prompt
295 295 return
296 296
297 297 # Set up keyboard shortcuts
298 298 key_bindings = create_ipython_shortcuts(self)
299 299
300 300 # Pre-populate history from IPython's history database
301 301 history = InMemoryHistory()
302 302 last_cell = u""
303 303 for __, ___, cell in self.history_manager.get_tail(self.history_load_length,
304 304 include_latest=True):
305 305 # Ignore blank lines and consecutive duplicates
306 306 cell = cell.rstrip()
307 307 if cell and (cell != last_cell):
308 308 history.append_string(cell)
309 309 last_cell = cell
310 310
311 311 self._style = self._make_style_from_name_or_cls(self.highlighting_style)
312 312 self.style = DynamicStyle(lambda: self._style)
313 313
314 314 editing_mode = getattr(EditingMode, self.editing_mode.upper())
315 315
316 316 self.pt_loop = asyncio.new_event_loop()
317 317 self.pt_app = PromptSession(
318 318 editing_mode=editing_mode,
319 319 key_bindings=key_bindings,
320 320 history=history,
321 321 completer=IPythonPTCompleter(shell=self),
322 322 enable_history_search = self.enable_history_search,
323 323 style=self.style,
324 324 include_default_pygments_style=False,
325 325 mouse_support=self.mouse_support,
326 326 enable_open_in_editor=self.extra_open_editor_shortcuts,
327 327 color_depth=self.color_depth,
328 328 tempfile_suffix=".py",
329 329 **self._extra_prompt_options())
330 330
331 331 def _make_style_from_name_or_cls(self, name_or_cls):
332 332 """
333 333 Small wrapper that make an IPython compatible style from a style name
334 334
335 335 We need that to add style for prompt ... etc.
336 336 """
337 337 style_overrides = {}
338 338 if name_or_cls == 'legacy':
339 339 legacy = self.colors.lower()
340 340 if legacy == 'linux':
341 341 style_cls = get_style_by_name('monokai')
342 342 style_overrides = _style_overrides_linux
343 343 elif legacy == 'lightbg':
344 344 style_overrides = _style_overrides_light_bg
345 345 style_cls = get_style_by_name('pastie')
346 346 elif legacy == 'neutral':
347 347 # The default theme needs to be visible on both a dark background
348 348 # and a light background, because we can't tell what the terminal
349 349 # looks like. These tweaks to the default theme help with that.
350 350 style_cls = get_style_by_name('default')
351 351 style_overrides.update({
352 352 Token.Number: '#007700',
353 353 Token.Operator: 'noinherit',
354 354 Token.String: '#BB6622',
355 355 Token.Name.Function: '#2080D0',
356 356 Token.Name.Class: 'bold #2080D0',
357 357 Token.Name.Namespace: 'bold #2080D0',
358 358 Token.Name.Variable.Magic: '#ansiblue',
359 359 Token.Prompt: '#009900',
360 360 Token.PromptNum: '#ansibrightgreen bold',
361 361 Token.OutPrompt: '#990000',
362 362 Token.OutPromptNum: '#ansibrightred bold',
363 363 })
364 364
365 365 # Hack: Due to limited color support on the Windows console
366 366 # the prompt colors will be wrong without this
367 367 if os.name == 'nt':
368 368 style_overrides.update({
369 369 Token.Prompt: '#ansidarkgreen',
370 370 Token.PromptNum: '#ansigreen bold',
371 371 Token.OutPrompt: '#ansidarkred',
372 372 Token.OutPromptNum: '#ansired bold',
373 373 })
374 374 elif legacy =='nocolor':
375 375 style_cls=_NoStyle
376 376 style_overrides = {}
377 377 else :
378 378 raise ValueError('Got unknown colors: ', legacy)
379 379 else :
380 380 if isinstance(name_or_cls, str):
381 381 style_cls = get_style_by_name(name_or_cls)
382 382 else:
383 383 style_cls = name_or_cls
384 384 style_overrides = {
385 385 Token.Prompt: '#009900',
386 386 Token.PromptNum: '#ansibrightgreen bold',
387 387 Token.OutPrompt: '#990000',
388 388 Token.OutPromptNum: '#ansibrightred bold',
389 389 }
390 390 style_overrides.update(self.highlighting_style_overrides)
391 391 style = merge_styles([
392 392 style_from_pygments_cls(style_cls),
393 393 style_from_pygments_dict(style_overrides),
394 394 ])
395 395
396 396 return style
397 397
398 398 @property
399 399 def pt_complete_style(self):
400 400 return {
401 401 'multicolumn': CompleteStyle.MULTI_COLUMN,
402 402 'column': CompleteStyle.COLUMN,
403 403 'readlinelike': CompleteStyle.READLINE_LIKE,
404 404 }[self.display_completions]
405 405
406 406 @property
407 407 def color_depth(self):
408 408 return (ColorDepth.TRUE_COLOR if self.true_color else None)
409 409
410 410 def _extra_prompt_options(self):
411 411 """
412 412 Return the current layout option for the current Terminal InteractiveShell
413 413 """
414 414 def get_message():
415 415 return PygmentsTokens(self.prompts.in_prompt_tokens())
416 416
417 417 if self.editing_mode == 'emacs':
418 418 # with emacs mode the prompt is (usually) static, so we call only
419 419 # the function once. With VI mode it can toggle between [ins] and
420 420 # [nor] so we can't precompute.
421 421 # here I'm going to favor the default keybinding which almost
422 422 # everybody uses to decrease CPU usage.
423 423 # if we have issues with users with custom Prompts we can see how to
424 424 # work around this.
425 425 get_message = get_message()
426 426
427 427 options = {
428 428 'complete_in_thread': False,
429 429 'lexer':IPythonPTLexer(),
430 430 'reserve_space_for_menu':self.space_for_menu,
431 431 'message': get_message,
432 432 'prompt_continuation': (
433 433 lambda width, lineno, is_soft_wrap:
434 434 PygmentsTokens(self.prompts.continuation_prompt_tokens(width))),
435 435 'multiline': True,
436 436 'complete_style': self.pt_complete_style,
437 437
438 438 # Highlight matching brackets, but only when this setting is
439 439 # enabled, and only when the DEFAULT_BUFFER has the focus.
440 440 'input_processors': [ConditionalProcessor(
441 441 processor=HighlightMatchingBracketProcessor(chars='[](){}'),
442 442 filter=HasFocus(DEFAULT_BUFFER) & ~IsDone() &
443 443 Condition(lambda: self.highlight_matching_brackets))],
444 444 }
445 445 if not PTK3:
446 446 options['inputhook'] = self.inputhook
447 447
448 448 return options
449 449
450 450 def prompt_for_code(self):
451 451 if self.rl_next_input:
452 452 default = self.rl_next_input
453 453 self.rl_next_input = None
454 454 else:
455 455 default = ''
456 456
457 457 # In order to make sure that asyncio code written in the
458 458 # interactive shell doesn't interfere with the prompt, we run the
459 459 # prompt in a different event loop.
460 460 # If we don't do this, people could spawn coroutine with a
461 461 # while/true inside which will freeze the prompt.
462 462
463 policy = asyncio.get_event_loop_policy()
463 464 try:
464 old_loop = asyncio.get_running_loop()
465 old_loop = policy.get_event_loop()
465 466 except RuntimeError:
466 # This happens when the user used `asyncio.run()`.
467 # This happens when the the event loop is closed,
468 # e.g. by calling `asyncio.run()`.
467 469 old_loop = None
468 470
469 asyncio.set_event_loop(self.pt_loop)
471 policy.set_event_loop(self.pt_loop)
470 472 try:
471 473 with patch_stdout(raw=True):
472 474 text = self.pt_app.prompt(
473 475 default=default,
474 476 **self._extra_prompt_options())
475 477 finally:
476 478 # Restore the original event loop.
477 asyncio.set_event_loop(old_loop)
479 if old_loop is not None:
480 policy.set_event_loop(old_loop)
478 481
479 482 return text
480 483
481 484 def enable_win_unicode_console(self):
482 485 # Since IPython 7.10 doesn't support python < 3.6 and PEP 528, Python uses the unicode APIs for the Windows
483 486 # console by default, so WUC shouldn't be needed.
484 487 from warnings import warn
485 488 warn("`enable_win_unicode_console` is deprecated since IPython 7.10, does not do anything and will be removed in the future",
486 489 DeprecationWarning,
487 490 stacklevel=2)
488 491
489 492 def init_io(self):
490 493 if sys.platform not in {'win32', 'cli'}:
491 494 return
492 495
493 496 import colorama
494 497 colorama.init()
495 498
496 499 # For some reason we make these wrappers around stdout/stderr.
497 500 # For now, we need to reset them so all output gets coloured.
498 501 # https://github.com/ipython/ipython/issues/8669
499 502 # io.std* are deprecated, but don't show our own deprecation warnings
500 503 # during initialization of the deprecated API.
501 504 with warnings.catch_warnings():
502 505 warnings.simplefilter('ignore', DeprecationWarning)
503 506 io.stdout = io.IOStream(sys.stdout)
504 507 io.stderr = io.IOStream(sys.stderr)
505 508
506 509 def init_magics(self):
507 510 super(TerminalInteractiveShell, self).init_magics()
508 511 self.register_magics(TerminalMagics)
509 512
510 513 def init_alias(self):
511 514 # The parent class defines aliases that can be safely used with any
512 515 # frontend.
513 516 super(TerminalInteractiveShell, self).init_alias()
514 517
515 518 # Now define aliases that only make sense on the terminal, because they
516 519 # need direct access to the console in a way that we can't emulate in
517 520 # GUI or web frontend
518 521 if os.name == 'posix':
519 522 for cmd in ('clear', 'more', 'less', 'man'):
520 523 self.alias_manager.soft_define_alias(cmd, cmd)
521 524
522 525
523 526 def __init__(self, *args, **kwargs):
524 527 super(TerminalInteractiveShell, self).__init__(*args, **kwargs)
525 528 self.init_prompt_toolkit_cli()
526 529 self.init_term_title()
527 530 self.keep_running = True
528 531
529 532 self.debugger_history = InMemoryHistory()
530 533
531 534 def ask_exit(self):
532 535 self.keep_running = False
533 536
534 537 rl_next_input = None
535 538
536 539 def interact(self, display_banner=DISPLAY_BANNER_DEPRECATED):
537 540
538 541 if display_banner is not DISPLAY_BANNER_DEPRECATED:
539 542 warn('interact `display_banner` argument is deprecated since IPython 5.0. Call `show_banner()` if needed.', DeprecationWarning, stacklevel=2)
540 543
541 544 self.keep_running = True
542 545 while self.keep_running:
543 546 print(self.separate_in, end='')
544 547
545 548 try:
546 549 code = self.prompt_for_code()
547 550 except EOFError:
548 551 if (not self.confirm_exit) \
549 552 or self.ask_yes_no('Do you really want to exit ([y]/n)?','y','n'):
550 553 self.ask_exit()
551 554
552 555 else:
553 556 if code:
554 557 self.run_cell(code, store_history=True)
555 558
556 559 def mainloop(self, display_banner=DISPLAY_BANNER_DEPRECATED):
557 560 # An extra layer of protection in case someone mashing Ctrl-C breaks
558 561 # out of our internal code.
559 562 if display_banner is not DISPLAY_BANNER_DEPRECATED:
560 563 warn('mainloop `display_banner` argument is deprecated since IPython 5.0. Call `show_banner()` if needed.', DeprecationWarning, stacklevel=2)
561 564 while True:
562 565 try:
563 566 self.interact()
564 567 break
565 568 except KeyboardInterrupt as e:
566 569 print("\n%s escaped interact()\n" % type(e).__name__)
567 570 finally:
568 571 # An interrupt during the eventloop will mess up the
569 572 # internal state of the prompt_toolkit library.
570 573 # Stopping the eventloop fixes this, see
571 574 # https://github.com/ipython/ipython/pull/9867
572 575 if hasattr(self, '_eventloop'):
573 576 self._eventloop.stop()
574 577
575 578 self.restore_term_title()
576 579
577 580
578 581 _inputhook = None
579 582 def inputhook(self, context):
580 583 if self._inputhook is not None:
581 584 self._inputhook(context)
582 585
583 586 active_eventloop = None
584 587 def enable_gui(self, gui=None):
585 588 if gui and (gui != 'inline') :
586 589 self.active_eventloop, self._inputhook =\
587 590 get_inputhook_name_and_func(gui)
588 591 else:
589 592 self.active_eventloop = self._inputhook = None
590 593
591 594 # For prompt_toolkit 3.0. We have to create an asyncio event loop with
592 595 # this inputhook.
593 596 if PTK3:
594 597 import asyncio
595 598 from prompt_toolkit.eventloop import new_eventloop_with_inputhook
596 599
597 600 if gui == 'asyncio':
598 601 # When we integrate the asyncio event loop, run the UI in the
599 602 # same event loop as the rest of the code. don't use an actual
600 603 # input hook. (Asyncio is not made for nesting event loops.)
601 604 self.pt_loop = asyncio.get_event_loop()
602 605
603 606 elif self._inputhook:
604 607 # If an inputhook was set, create a new asyncio event loop with
605 608 # this inputhook for the prompt.
606 609 self.pt_loop = new_eventloop_with_inputhook(self._inputhook)
607 610 else:
608 611 # When there's no inputhook, run the prompt in a separate
609 612 # asyncio event loop.
610 613 self.pt_loop = asyncio.new_event_loop()
611 614
612 615 # Run !system commands directly, not through pipes, so terminal programs
613 616 # work correctly.
614 617 system = InteractiveShell.system_raw
615 618
616 619 def auto_rewrite_input(self, cmd):
617 620 """Overridden from the parent class to use fancy rewriting prompt"""
618 621 if not self.show_rewritten_input:
619 622 return
620 623
621 624 tokens = self.prompts.rewrite_prompt_tokens()
622 625 if self.pt_app:
623 626 print_formatted_text(PygmentsTokens(tokens), end='',
624 627 style=self.pt_app.app.style)
625 628 print(cmd)
626 629 else:
627 630 prompt = ''.join(s for t, s in tokens)
628 631 print(prompt, cmd, sep='')
629 632
630 633 _prompts_before = None
631 634 def switch_doctest_mode(self, mode):
632 635 """Switch prompts to classic for %doctest_mode"""
633 636 if mode:
634 637 self._prompts_before = self.prompts
635 638 self.prompts = ClassicPrompts(self)
636 639 elif self._prompts_before:
637 640 self.prompts = self._prompts_before
638 641 self._prompts_before = None
639 642 # self._update_layout()
640 643
641 644
642 645 InteractiveShellABC.register(TerminalInteractiveShell)
643 646
644 647 if __name__ == '__main__':
645 648 TerminalInteractiveShell.instance().interact()
General Comments 0
You need to be logged in to leave comments. Login now