##// END OF EJS Templates
Attemt to fix ast transformer now that Num and Str atr the same Constant ast node...
Matthias Bussonnier -
Show More
@@ -1,973 +1,981 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """Tests for the key interactiveshell module.
2 """Tests for the key interactiveshell module.
3
3
4 Historically the main classes in interactiveshell have been under-tested. This
4 Historically the main classes in interactiveshell have been under-tested. This
5 module should grow as many single-method tests as possible to trap many of the
5 module should grow as many single-method tests as possible to trap many of the
6 recurring bugs we seem to encounter with high-level interaction.
6 recurring bugs we seem to encounter with high-level interaction.
7 """
7 """
8
8
9 # Copyright (c) IPython Development Team.
9 # Copyright (c) IPython Development Team.
10 # Distributed under the terms of the Modified BSD License.
10 # Distributed under the terms of the Modified BSD License.
11
11
12 import asyncio
12 import asyncio
13 import ast
13 import ast
14 import os
14 import os
15 import signal
15 import signal
16 import shutil
16 import shutil
17 import sys
17 import sys
18 import tempfile
18 import tempfile
19 import unittest
19 import unittest
20 from unittest import mock
20 from unittest import mock
21
21
22 from os.path import join
22 from os.path import join
23
23
24 import nose.tools as nt
24 import nose.tools as nt
25
25
26 from IPython.core.error import InputRejected
26 from IPython.core.error import InputRejected
27 from IPython.core.inputtransformer import InputTransformer
27 from IPython.core.inputtransformer import InputTransformer
28 from IPython.core import interactiveshell
28 from IPython.core import interactiveshell
29 from IPython.testing.decorators import (
29 from IPython.testing.decorators import (
30 skipif, skip_win32, onlyif_unicode_paths, onlyif_cmds_exist,
30 skipif, skip_win32, onlyif_unicode_paths, onlyif_cmds_exist,
31 )
31 )
32 from IPython.testing import tools as tt
32 from IPython.testing import tools as tt
33 from IPython.utils.process import find_cmd
33 from IPython.utils.process import find_cmd
34
34
35 #-----------------------------------------------------------------------------
35 #-----------------------------------------------------------------------------
36 # Globals
36 # Globals
37 #-----------------------------------------------------------------------------
37 #-----------------------------------------------------------------------------
38 # This is used by every single test, no point repeating it ad nauseam
38 # This is used by every single test, no point repeating it ad nauseam
39 ip = get_ipython()
39 ip = get_ipython()
40
40
41 #-----------------------------------------------------------------------------
41 #-----------------------------------------------------------------------------
42 # Tests
42 # Tests
43 #-----------------------------------------------------------------------------
43 #-----------------------------------------------------------------------------
44
44
45 class DerivedInterrupt(KeyboardInterrupt):
45 class DerivedInterrupt(KeyboardInterrupt):
46 pass
46 pass
47
47
48 class InteractiveShellTestCase(unittest.TestCase):
48 class InteractiveShellTestCase(unittest.TestCase):
49 def test_naked_string_cells(self):
49 def test_naked_string_cells(self):
50 """Test that cells with only naked strings are fully executed"""
50 """Test that cells with only naked strings are fully executed"""
51 # First, single-line inputs
51 # First, single-line inputs
52 ip.run_cell('"a"\n')
52 ip.run_cell('"a"\n')
53 self.assertEqual(ip.user_ns['_'], 'a')
53 self.assertEqual(ip.user_ns['_'], 'a')
54 # And also multi-line cells
54 # And also multi-line cells
55 ip.run_cell('"""a\nb"""\n')
55 ip.run_cell('"""a\nb"""\n')
56 self.assertEqual(ip.user_ns['_'], 'a\nb')
56 self.assertEqual(ip.user_ns['_'], 'a\nb')
57
57
58 def test_run_empty_cell(self):
58 def test_run_empty_cell(self):
59 """Just make sure we don't get a horrible error with a blank
59 """Just make sure we don't get a horrible error with a blank
60 cell of input. Yes, I did overlook that."""
60 cell of input. Yes, I did overlook that."""
61 old_xc = ip.execution_count
61 old_xc = ip.execution_count
62 res = ip.run_cell('')
62 res = ip.run_cell('')
63 self.assertEqual(ip.execution_count, old_xc)
63 self.assertEqual(ip.execution_count, old_xc)
64 self.assertEqual(res.execution_count, None)
64 self.assertEqual(res.execution_count, None)
65
65
66 def test_run_cell_multiline(self):
66 def test_run_cell_multiline(self):
67 """Multi-block, multi-line cells must execute correctly.
67 """Multi-block, multi-line cells must execute correctly.
68 """
68 """
69 src = '\n'.join(["x=1",
69 src = '\n'.join(["x=1",
70 "y=2",
70 "y=2",
71 "if 1:",
71 "if 1:",
72 " x += 1",
72 " x += 1",
73 " y += 1",])
73 " y += 1",])
74 res = ip.run_cell(src)
74 res = ip.run_cell(src)
75 self.assertEqual(ip.user_ns['x'], 2)
75 self.assertEqual(ip.user_ns['x'], 2)
76 self.assertEqual(ip.user_ns['y'], 3)
76 self.assertEqual(ip.user_ns['y'], 3)
77 self.assertEqual(res.success, True)
77 self.assertEqual(res.success, True)
78 self.assertEqual(res.result, None)
78 self.assertEqual(res.result, None)
79
79
80 def test_multiline_string_cells(self):
80 def test_multiline_string_cells(self):
81 "Code sprinkled with multiline strings should execute (GH-306)"
81 "Code sprinkled with multiline strings should execute (GH-306)"
82 ip.run_cell('tmp=0')
82 ip.run_cell('tmp=0')
83 self.assertEqual(ip.user_ns['tmp'], 0)
83 self.assertEqual(ip.user_ns['tmp'], 0)
84 res = ip.run_cell('tmp=1;"""a\nb"""\n')
84 res = ip.run_cell('tmp=1;"""a\nb"""\n')
85 self.assertEqual(ip.user_ns['tmp'], 1)
85 self.assertEqual(ip.user_ns['tmp'], 1)
86 self.assertEqual(res.success, True)
86 self.assertEqual(res.success, True)
87 self.assertEqual(res.result, "a\nb")
87 self.assertEqual(res.result, "a\nb")
88
88
89 def test_dont_cache_with_semicolon(self):
89 def test_dont_cache_with_semicolon(self):
90 "Ending a line with semicolon should not cache the returned object (GH-307)"
90 "Ending a line with semicolon should not cache the returned object (GH-307)"
91 oldlen = len(ip.user_ns['Out'])
91 oldlen = len(ip.user_ns['Out'])
92 for cell in ['1;', '1;1;']:
92 for cell in ['1;', '1;1;']:
93 res = ip.run_cell(cell, store_history=True)
93 res = ip.run_cell(cell, store_history=True)
94 newlen = len(ip.user_ns['Out'])
94 newlen = len(ip.user_ns['Out'])
95 self.assertEqual(oldlen, newlen)
95 self.assertEqual(oldlen, newlen)
96 self.assertIsNone(res.result)
96 self.assertIsNone(res.result)
97 i = 0
97 i = 0
98 #also test the default caching behavior
98 #also test the default caching behavior
99 for cell in ['1', '1;1']:
99 for cell in ['1', '1;1']:
100 ip.run_cell(cell, store_history=True)
100 ip.run_cell(cell, store_history=True)
101 newlen = len(ip.user_ns['Out'])
101 newlen = len(ip.user_ns['Out'])
102 i += 1
102 i += 1
103 self.assertEqual(oldlen+i, newlen)
103 self.assertEqual(oldlen+i, newlen)
104
104
105 def test_syntax_error(self):
105 def test_syntax_error(self):
106 res = ip.run_cell("raise = 3")
106 res = ip.run_cell("raise = 3")
107 self.assertIsInstance(res.error_before_exec, SyntaxError)
107 self.assertIsInstance(res.error_before_exec, SyntaxError)
108
108
109 def test_In_variable(self):
109 def test_In_variable(self):
110 "Verify that In variable grows with user input (GH-284)"
110 "Verify that In variable grows with user input (GH-284)"
111 oldlen = len(ip.user_ns['In'])
111 oldlen = len(ip.user_ns['In'])
112 ip.run_cell('1;', store_history=True)
112 ip.run_cell('1;', store_history=True)
113 newlen = len(ip.user_ns['In'])
113 newlen = len(ip.user_ns['In'])
114 self.assertEqual(oldlen+1, newlen)
114 self.assertEqual(oldlen+1, newlen)
115 self.assertEqual(ip.user_ns['In'][-1],'1;')
115 self.assertEqual(ip.user_ns['In'][-1],'1;')
116
116
117 def test_magic_names_in_string(self):
117 def test_magic_names_in_string(self):
118 ip.run_cell('a = """\n%exit\n"""')
118 ip.run_cell('a = """\n%exit\n"""')
119 self.assertEqual(ip.user_ns['a'], '\n%exit\n')
119 self.assertEqual(ip.user_ns['a'], '\n%exit\n')
120
120
121 def test_trailing_newline(self):
121 def test_trailing_newline(self):
122 """test that running !(command) does not raise a SyntaxError"""
122 """test that running !(command) does not raise a SyntaxError"""
123 ip.run_cell('!(true)\n', False)
123 ip.run_cell('!(true)\n', False)
124 ip.run_cell('!(true)\n\n\n', False)
124 ip.run_cell('!(true)\n\n\n', False)
125
125
126 def test_gh_597(self):
126 def test_gh_597(self):
127 """Pretty-printing lists of objects with non-ascii reprs may cause
127 """Pretty-printing lists of objects with non-ascii reprs may cause
128 problems."""
128 problems."""
129 class Spam(object):
129 class Spam(object):
130 def __repr__(self):
130 def __repr__(self):
131 return "\xe9"*50
131 return "\xe9"*50
132 import IPython.core.formatters
132 import IPython.core.formatters
133 f = IPython.core.formatters.PlainTextFormatter()
133 f = IPython.core.formatters.PlainTextFormatter()
134 f([Spam(),Spam()])
134 f([Spam(),Spam()])
135
135
136
136
137 def test_future_flags(self):
137 def test_future_flags(self):
138 """Check that future flags are used for parsing code (gh-777)"""
138 """Check that future flags are used for parsing code (gh-777)"""
139 ip.run_cell('from __future__ import barry_as_FLUFL')
139 ip.run_cell('from __future__ import barry_as_FLUFL')
140 try:
140 try:
141 ip.run_cell('prfunc_return_val = 1 <> 2')
141 ip.run_cell('prfunc_return_val = 1 <> 2')
142 assert 'prfunc_return_val' in ip.user_ns
142 assert 'prfunc_return_val' in ip.user_ns
143 finally:
143 finally:
144 # Reset compiler flags so we don't mess up other tests.
144 # Reset compiler flags so we don't mess up other tests.
145 ip.compile.reset_compiler_flags()
145 ip.compile.reset_compiler_flags()
146
146
147 def test_can_pickle(self):
147 def test_can_pickle(self):
148 "Can we pickle objects defined interactively (GH-29)"
148 "Can we pickle objects defined interactively (GH-29)"
149 ip = get_ipython()
149 ip = get_ipython()
150 ip.reset()
150 ip.reset()
151 ip.run_cell(("class Mylist(list):\n"
151 ip.run_cell(("class Mylist(list):\n"
152 " def __init__(self,x=[]):\n"
152 " def __init__(self,x=[]):\n"
153 " list.__init__(self,x)"))
153 " list.__init__(self,x)"))
154 ip.run_cell("w=Mylist([1,2,3])")
154 ip.run_cell("w=Mylist([1,2,3])")
155
155
156 from pickle import dumps
156 from pickle import dumps
157
157
158 # We need to swap in our main module - this is only necessary
158 # We need to swap in our main module - this is only necessary
159 # inside the test framework, because IPython puts the interactive module
159 # inside the test framework, because IPython puts the interactive module
160 # in place (but the test framework undoes this).
160 # in place (but the test framework undoes this).
161 _main = sys.modules['__main__']
161 _main = sys.modules['__main__']
162 sys.modules['__main__'] = ip.user_module
162 sys.modules['__main__'] = ip.user_module
163 try:
163 try:
164 res = dumps(ip.user_ns["w"])
164 res = dumps(ip.user_ns["w"])
165 finally:
165 finally:
166 sys.modules['__main__'] = _main
166 sys.modules['__main__'] = _main
167 self.assertTrue(isinstance(res, bytes))
167 self.assertTrue(isinstance(res, bytes))
168
168
169 def test_global_ns(self):
169 def test_global_ns(self):
170 "Code in functions must be able to access variables outside them."
170 "Code in functions must be able to access variables outside them."
171 ip = get_ipython()
171 ip = get_ipython()
172 ip.run_cell("a = 10")
172 ip.run_cell("a = 10")
173 ip.run_cell(("def f(x):\n"
173 ip.run_cell(("def f(x):\n"
174 " return x + a"))
174 " return x + a"))
175 ip.run_cell("b = f(12)")
175 ip.run_cell("b = f(12)")
176 self.assertEqual(ip.user_ns["b"], 22)
176 self.assertEqual(ip.user_ns["b"], 22)
177
177
178 def test_bad_custom_tb(self):
178 def test_bad_custom_tb(self):
179 """Check that InteractiveShell is protected from bad custom exception handlers"""
179 """Check that InteractiveShell is protected from bad custom exception handlers"""
180 ip.set_custom_exc((IOError,), lambda etype,value,tb: 1/0)
180 ip.set_custom_exc((IOError,), lambda etype,value,tb: 1/0)
181 self.assertEqual(ip.custom_exceptions, (IOError,))
181 self.assertEqual(ip.custom_exceptions, (IOError,))
182 with tt.AssertPrints("Custom TB Handler failed", channel='stderr'):
182 with tt.AssertPrints("Custom TB Handler failed", channel='stderr'):
183 ip.run_cell(u'raise IOError("foo")')
183 ip.run_cell(u'raise IOError("foo")')
184 self.assertEqual(ip.custom_exceptions, ())
184 self.assertEqual(ip.custom_exceptions, ())
185
185
186 def test_bad_custom_tb_return(self):
186 def test_bad_custom_tb_return(self):
187 """Check that InteractiveShell is protected from bad return types in custom exception handlers"""
187 """Check that InteractiveShell is protected from bad return types in custom exception handlers"""
188 ip.set_custom_exc((NameError,),lambda etype,value,tb, tb_offset=None: 1)
188 ip.set_custom_exc((NameError,),lambda etype,value,tb, tb_offset=None: 1)
189 self.assertEqual(ip.custom_exceptions, (NameError,))
189 self.assertEqual(ip.custom_exceptions, (NameError,))
190 with tt.AssertPrints("Custom TB Handler failed", channel='stderr'):
190 with tt.AssertPrints("Custom TB Handler failed", channel='stderr'):
191 ip.run_cell(u'a=abracadabra')
191 ip.run_cell(u'a=abracadabra')
192 self.assertEqual(ip.custom_exceptions, ())
192 self.assertEqual(ip.custom_exceptions, ())
193
193
194 def test_drop_by_id(self):
194 def test_drop_by_id(self):
195 myvars = {"a":object(), "b":object(), "c": object()}
195 myvars = {"a":object(), "b":object(), "c": object()}
196 ip.push(myvars, interactive=False)
196 ip.push(myvars, interactive=False)
197 for name in myvars:
197 for name in myvars:
198 assert name in ip.user_ns, name
198 assert name in ip.user_ns, name
199 assert name in ip.user_ns_hidden, name
199 assert name in ip.user_ns_hidden, name
200 ip.user_ns['b'] = 12
200 ip.user_ns['b'] = 12
201 ip.drop_by_id(myvars)
201 ip.drop_by_id(myvars)
202 for name in ["a", "c"]:
202 for name in ["a", "c"]:
203 assert name not in ip.user_ns, name
203 assert name not in ip.user_ns, name
204 assert name not in ip.user_ns_hidden, name
204 assert name not in ip.user_ns_hidden, name
205 assert ip.user_ns['b'] == 12
205 assert ip.user_ns['b'] == 12
206 ip.reset()
206 ip.reset()
207
207
208 def test_var_expand(self):
208 def test_var_expand(self):
209 ip.user_ns['f'] = u'Ca\xf1o'
209 ip.user_ns['f'] = u'Ca\xf1o'
210 self.assertEqual(ip.var_expand(u'echo $f'), u'echo Ca\xf1o')
210 self.assertEqual(ip.var_expand(u'echo $f'), u'echo Ca\xf1o')
211 self.assertEqual(ip.var_expand(u'echo {f}'), u'echo Ca\xf1o')
211 self.assertEqual(ip.var_expand(u'echo {f}'), u'echo Ca\xf1o')
212 self.assertEqual(ip.var_expand(u'echo {f[:-1]}'), u'echo Ca\xf1')
212 self.assertEqual(ip.var_expand(u'echo {f[:-1]}'), u'echo Ca\xf1')
213 self.assertEqual(ip.var_expand(u'echo {1*2}'), u'echo 2')
213 self.assertEqual(ip.var_expand(u'echo {1*2}'), u'echo 2')
214
214
215 self.assertEqual(ip.var_expand(u"grep x | awk '{print $1}'"), u"grep x | awk '{print $1}'")
215 self.assertEqual(ip.var_expand(u"grep x | awk '{print $1}'"), u"grep x | awk '{print $1}'")
216
216
217 ip.user_ns['f'] = b'Ca\xc3\xb1o'
217 ip.user_ns['f'] = b'Ca\xc3\xb1o'
218 # This should not raise any exception:
218 # This should not raise any exception:
219 ip.var_expand(u'echo $f')
219 ip.var_expand(u'echo $f')
220
220
221 def test_var_expand_local(self):
221 def test_var_expand_local(self):
222 """Test local variable expansion in !system and %magic calls"""
222 """Test local variable expansion in !system and %magic calls"""
223 # !system
223 # !system
224 ip.run_cell('def test():\n'
224 ip.run_cell('def test():\n'
225 ' lvar = "ttt"\n'
225 ' lvar = "ttt"\n'
226 ' ret = !echo {lvar}\n'
226 ' ret = !echo {lvar}\n'
227 ' return ret[0]\n')
227 ' return ret[0]\n')
228 res = ip.user_ns['test']()
228 res = ip.user_ns['test']()
229 nt.assert_in('ttt', res)
229 nt.assert_in('ttt', res)
230
230
231 # %magic
231 # %magic
232 ip.run_cell('def makemacro():\n'
232 ip.run_cell('def makemacro():\n'
233 ' macroname = "macro_var_expand_locals"\n'
233 ' macroname = "macro_var_expand_locals"\n'
234 ' %macro {macroname} codestr\n')
234 ' %macro {macroname} codestr\n')
235 ip.user_ns['codestr'] = "str(12)"
235 ip.user_ns['codestr'] = "str(12)"
236 ip.run_cell('makemacro()')
236 ip.run_cell('makemacro()')
237 nt.assert_in('macro_var_expand_locals', ip.user_ns)
237 nt.assert_in('macro_var_expand_locals', ip.user_ns)
238
238
239 def test_var_expand_self(self):
239 def test_var_expand_self(self):
240 """Test variable expansion with the name 'self', which was failing.
240 """Test variable expansion with the name 'self', which was failing.
241
241
242 See https://github.com/ipython/ipython/issues/1878#issuecomment-7698218
242 See https://github.com/ipython/ipython/issues/1878#issuecomment-7698218
243 """
243 """
244 ip.run_cell('class cTest:\n'
244 ip.run_cell('class cTest:\n'
245 ' classvar="see me"\n'
245 ' classvar="see me"\n'
246 ' def test(self):\n'
246 ' def test(self):\n'
247 ' res = !echo Variable: {self.classvar}\n'
247 ' res = !echo Variable: {self.classvar}\n'
248 ' return res[0]\n')
248 ' return res[0]\n')
249 nt.assert_in('see me', ip.user_ns['cTest']().test())
249 nt.assert_in('see me', ip.user_ns['cTest']().test())
250
250
251 def test_bad_var_expand(self):
251 def test_bad_var_expand(self):
252 """var_expand on invalid formats shouldn't raise"""
252 """var_expand on invalid formats shouldn't raise"""
253 # SyntaxError
253 # SyntaxError
254 self.assertEqual(ip.var_expand(u"{'a':5}"), u"{'a':5}")
254 self.assertEqual(ip.var_expand(u"{'a':5}"), u"{'a':5}")
255 # NameError
255 # NameError
256 self.assertEqual(ip.var_expand(u"{asdf}"), u"{asdf}")
256 self.assertEqual(ip.var_expand(u"{asdf}"), u"{asdf}")
257 # ZeroDivisionError
257 # ZeroDivisionError
258 self.assertEqual(ip.var_expand(u"{1/0}"), u"{1/0}")
258 self.assertEqual(ip.var_expand(u"{1/0}"), u"{1/0}")
259
259
260 def test_silent_postexec(self):
260 def test_silent_postexec(self):
261 """run_cell(silent=True) doesn't invoke pre/post_run_cell callbacks"""
261 """run_cell(silent=True) doesn't invoke pre/post_run_cell callbacks"""
262 pre_explicit = mock.Mock()
262 pre_explicit = mock.Mock()
263 pre_always = mock.Mock()
263 pre_always = mock.Mock()
264 post_explicit = mock.Mock()
264 post_explicit = mock.Mock()
265 post_always = mock.Mock()
265 post_always = mock.Mock()
266 all_mocks = [pre_explicit, pre_always, post_explicit, post_always]
266 all_mocks = [pre_explicit, pre_always, post_explicit, post_always]
267
267
268 ip.events.register('pre_run_cell', pre_explicit)
268 ip.events.register('pre_run_cell', pre_explicit)
269 ip.events.register('pre_execute', pre_always)
269 ip.events.register('pre_execute', pre_always)
270 ip.events.register('post_run_cell', post_explicit)
270 ip.events.register('post_run_cell', post_explicit)
271 ip.events.register('post_execute', post_always)
271 ip.events.register('post_execute', post_always)
272
272
273 try:
273 try:
274 ip.run_cell("1", silent=True)
274 ip.run_cell("1", silent=True)
275 assert pre_always.called
275 assert pre_always.called
276 assert not pre_explicit.called
276 assert not pre_explicit.called
277 assert post_always.called
277 assert post_always.called
278 assert not post_explicit.called
278 assert not post_explicit.called
279 # double-check that non-silent exec did what we expected
279 # double-check that non-silent exec did what we expected
280 # silent to avoid
280 # silent to avoid
281 ip.run_cell("1")
281 ip.run_cell("1")
282 assert pre_explicit.called
282 assert pre_explicit.called
283 assert post_explicit.called
283 assert post_explicit.called
284 info, = pre_explicit.call_args[0]
284 info, = pre_explicit.call_args[0]
285 result, = post_explicit.call_args[0]
285 result, = post_explicit.call_args[0]
286 self.assertEqual(info, result.info)
286 self.assertEqual(info, result.info)
287 # check that post hooks are always called
287 # check that post hooks are always called
288 [m.reset_mock() for m in all_mocks]
288 [m.reset_mock() for m in all_mocks]
289 ip.run_cell("syntax error")
289 ip.run_cell("syntax error")
290 assert pre_always.called
290 assert pre_always.called
291 assert pre_explicit.called
291 assert pre_explicit.called
292 assert post_always.called
292 assert post_always.called
293 assert post_explicit.called
293 assert post_explicit.called
294 info, = pre_explicit.call_args[0]
294 info, = pre_explicit.call_args[0]
295 result, = post_explicit.call_args[0]
295 result, = post_explicit.call_args[0]
296 self.assertEqual(info, result.info)
296 self.assertEqual(info, result.info)
297 finally:
297 finally:
298 # remove post-exec
298 # remove post-exec
299 ip.events.unregister('pre_run_cell', pre_explicit)
299 ip.events.unregister('pre_run_cell', pre_explicit)
300 ip.events.unregister('pre_execute', pre_always)
300 ip.events.unregister('pre_execute', pre_always)
301 ip.events.unregister('post_run_cell', post_explicit)
301 ip.events.unregister('post_run_cell', post_explicit)
302 ip.events.unregister('post_execute', post_always)
302 ip.events.unregister('post_execute', post_always)
303
303
304 def test_silent_noadvance(self):
304 def test_silent_noadvance(self):
305 """run_cell(silent=True) doesn't advance execution_count"""
305 """run_cell(silent=True) doesn't advance execution_count"""
306 ec = ip.execution_count
306 ec = ip.execution_count
307 # silent should force store_history=False
307 # silent should force store_history=False
308 ip.run_cell("1", store_history=True, silent=True)
308 ip.run_cell("1", store_history=True, silent=True)
309
309
310 self.assertEqual(ec, ip.execution_count)
310 self.assertEqual(ec, ip.execution_count)
311 # double-check that non-silent exec did what we expected
311 # double-check that non-silent exec did what we expected
312 # silent to avoid
312 # silent to avoid
313 ip.run_cell("1", store_history=True)
313 ip.run_cell("1", store_history=True)
314 self.assertEqual(ec+1, ip.execution_count)
314 self.assertEqual(ec+1, ip.execution_count)
315
315
316 def test_silent_nodisplayhook(self):
316 def test_silent_nodisplayhook(self):
317 """run_cell(silent=True) doesn't trigger displayhook"""
317 """run_cell(silent=True) doesn't trigger displayhook"""
318 d = dict(called=False)
318 d = dict(called=False)
319
319
320 trap = ip.display_trap
320 trap = ip.display_trap
321 save_hook = trap.hook
321 save_hook = trap.hook
322
322
323 def failing_hook(*args, **kwargs):
323 def failing_hook(*args, **kwargs):
324 d['called'] = True
324 d['called'] = True
325
325
326 try:
326 try:
327 trap.hook = failing_hook
327 trap.hook = failing_hook
328 res = ip.run_cell("1", silent=True)
328 res = ip.run_cell("1", silent=True)
329 self.assertFalse(d['called'])
329 self.assertFalse(d['called'])
330 self.assertIsNone(res.result)
330 self.assertIsNone(res.result)
331 # double-check that non-silent exec did what we expected
331 # double-check that non-silent exec did what we expected
332 # silent to avoid
332 # silent to avoid
333 ip.run_cell("1")
333 ip.run_cell("1")
334 self.assertTrue(d['called'])
334 self.assertTrue(d['called'])
335 finally:
335 finally:
336 trap.hook = save_hook
336 trap.hook = save_hook
337
337
338 def test_ofind_line_magic(self):
338 def test_ofind_line_magic(self):
339 from IPython.core.magic import register_line_magic
339 from IPython.core.magic import register_line_magic
340
340
341 @register_line_magic
341 @register_line_magic
342 def lmagic(line):
342 def lmagic(line):
343 "A line magic"
343 "A line magic"
344
344
345 # Get info on line magic
345 # Get info on line magic
346 lfind = ip._ofind('lmagic')
346 lfind = ip._ofind('lmagic')
347 info = dict(found=True, isalias=False, ismagic=True,
347 info = dict(found=True, isalias=False, ismagic=True,
348 namespace = 'IPython internal', obj= lmagic.__wrapped__,
348 namespace = 'IPython internal', obj= lmagic.__wrapped__,
349 parent = None)
349 parent = None)
350 nt.assert_equal(lfind, info)
350 nt.assert_equal(lfind, info)
351
351
352 def test_ofind_cell_magic(self):
352 def test_ofind_cell_magic(self):
353 from IPython.core.magic import register_cell_magic
353 from IPython.core.magic import register_cell_magic
354
354
355 @register_cell_magic
355 @register_cell_magic
356 def cmagic(line, cell):
356 def cmagic(line, cell):
357 "A cell magic"
357 "A cell magic"
358
358
359 # Get info on cell magic
359 # Get info on cell magic
360 find = ip._ofind('cmagic')
360 find = ip._ofind('cmagic')
361 info = dict(found=True, isalias=False, ismagic=True,
361 info = dict(found=True, isalias=False, ismagic=True,
362 namespace = 'IPython internal', obj= cmagic.__wrapped__,
362 namespace = 'IPython internal', obj= cmagic.__wrapped__,
363 parent = None)
363 parent = None)
364 nt.assert_equal(find, info)
364 nt.assert_equal(find, info)
365
365
366 def test_ofind_property_with_error(self):
366 def test_ofind_property_with_error(self):
367 class A(object):
367 class A(object):
368 @property
368 @property
369 def foo(self):
369 def foo(self):
370 raise NotImplementedError()
370 raise NotImplementedError()
371 a = A()
371 a = A()
372
372
373 found = ip._ofind('a.foo', [('locals', locals())])
373 found = ip._ofind('a.foo', [('locals', locals())])
374 info = dict(found=True, isalias=False, ismagic=False,
374 info = dict(found=True, isalias=False, ismagic=False,
375 namespace='locals', obj=A.foo, parent=a)
375 namespace='locals', obj=A.foo, parent=a)
376 nt.assert_equal(found, info)
376 nt.assert_equal(found, info)
377
377
378 def test_ofind_multiple_attribute_lookups(self):
378 def test_ofind_multiple_attribute_lookups(self):
379 class A(object):
379 class A(object):
380 @property
380 @property
381 def foo(self):
381 def foo(self):
382 raise NotImplementedError()
382 raise NotImplementedError()
383
383
384 a = A()
384 a = A()
385 a.a = A()
385 a.a = A()
386 a.a.a = A()
386 a.a.a = A()
387
387
388 found = ip._ofind('a.a.a.foo', [('locals', locals())])
388 found = ip._ofind('a.a.a.foo', [('locals', locals())])
389 info = dict(found=True, isalias=False, ismagic=False,
389 info = dict(found=True, isalias=False, ismagic=False,
390 namespace='locals', obj=A.foo, parent=a.a.a)
390 namespace='locals', obj=A.foo, parent=a.a.a)
391 nt.assert_equal(found, info)
391 nt.assert_equal(found, info)
392
392
393 def test_ofind_slotted_attributes(self):
393 def test_ofind_slotted_attributes(self):
394 class A(object):
394 class A(object):
395 __slots__ = ['foo']
395 __slots__ = ['foo']
396 def __init__(self):
396 def __init__(self):
397 self.foo = 'bar'
397 self.foo = 'bar'
398
398
399 a = A()
399 a = A()
400 found = ip._ofind('a.foo', [('locals', locals())])
400 found = ip._ofind('a.foo', [('locals', locals())])
401 info = dict(found=True, isalias=False, ismagic=False,
401 info = dict(found=True, isalias=False, ismagic=False,
402 namespace='locals', obj=a.foo, parent=a)
402 namespace='locals', obj=a.foo, parent=a)
403 nt.assert_equal(found, info)
403 nt.assert_equal(found, info)
404
404
405 found = ip._ofind('a.bar', [('locals', locals())])
405 found = ip._ofind('a.bar', [('locals', locals())])
406 info = dict(found=False, isalias=False, ismagic=False,
406 info = dict(found=False, isalias=False, ismagic=False,
407 namespace=None, obj=None, parent=a)
407 namespace=None, obj=None, parent=a)
408 nt.assert_equal(found, info)
408 nt.assert_equal(found, info)
409
409
410 def test_ofind_prefers_property_to_instance_level_attribute(self):
410 def test_ofind_prefers_property_to_instance_level_attribute(self):
411 class A(object):
411 class A(object):
412 @property
412 @property
413 def foo(self):
413 def foo(self):
414 return 'bar'
414 return 'bar'
415 a = A()
415 a = A()
416 a.__dict__['foo'] = 'baz'
416 a.__dict__['foo'] = 'baz'
417 nt.assert_equal(a.foo, 'bar')
417 nt.assert_equal(a.foo, 'bar')
418 found = ip._ofind('a.foo', [('locals', locals())])
418 found = ip._ofind('a.foo', [('locals', locals())])
419 nt.assert_is(found['obj'], A.foo)
419 nt.assert_is(found['obj'], A.foo)
420
420
421 def test_custom_syntaxerror_exception(self):
421 def test_custom_syntaxerror_exception(self):
422 called = []
422 called = []
423 def my_handler(shell, etype, value, tb, tb_offset=None):
423 def my_handler(shell, etype, value, tb, tb_offset=None):
424 called.append(etype)
424 called.append(etype)
425 shell.showtraceback((etype, value, tb), tb_offset=tb_offset)
425 shell.showtraceback((etype, value, tb), tb_offset=tb_offset)
426
426
427 ip.set_custom_exc((SyntaxError,), my_handler)
427 ip.set_custom_exc((SyntaxError,), my_handler)
428 try:
428 try:
429 ip.run_cell("1f")
429 ip.run_cell("1f")
430 # Check that this was called, and only once.
430 # Check that this was called, and only once.
431 self.assertEqual(called, [SyntaxError])
431 self.assertEqual(called, [SyntaxError])
432 finally:
432 finally:
433 # Reset the custom exception hook
433 # Reset the custom exception hook
434 ip.set_custom_exc((), None)
434 ip.set_custom_exc((), None)
435
435
436 def test_custom_exception(self):
436 def test_custom_exception(self):
437 called = []
437 called = []
438 def my_handler(shell, etype, value, tb, tb_offset=None):
438 def my_handler(shell, etype, value, tb, tb_offset=None):
439 called.append(etype)
439 called.append(etype)
440 shell.showtraceback((etype, value, tb), tb_offset=tb_offset)
440 shell.showtraceback((etype, value, tb), tb_offset=tb_offset)
441
441
442 ip.set_custom_exc((ValueError,), my_handler)
442 ip.set_custom_exc((ValueError,), my_handler)
443 try:
443 try:
444 res = ip.run_cell("raise ValueError('test')")
444 res = ip.run_cell("raise ValueError('test')")
445 # Check that this was called, and only once.
445 # Check that this was called, and only once.
446 self.assertEqual(called, [ValueError])
446 self.assertEqual(called, [ValueError])
447 # Check that the error is on the result object
447 # Check that the error is on the result object
448 self.assertIsInstance(res.error_in_exec, ValueError)
448 self.assertIsInstance(res.error_in_exec, ValueError)
449 finally:
449 finally:
450 # Reset the custom exception hook
450 # Reset the custom exception hook
451 ip.set_custom_exc((), None)
451 ip.set_custom_exc((), None)
452
452
453 def test_mktempfile(self):
453 def test_mktempfile(self):
454 filename = ip.mktempfile()
454 filename = ip.mktempfile()
455 # Check that we can open the file again on Windows
455 # Check that we can open the file again on Windows
456 with open(filename, 'w') as f:
456 with open(filename, 'w') as f:
457 f.write('abc')
457 f.write('abc')
458
458
459 filename = ip.mktempfile(data='blah')
459 filename = ip.mktempfile(data='blah')
460 with open(filename, 'r') as f:
460 with open(filename, 'r') as f:
461 self.assertEqual(f.read(), 'blah')
461 self.assertEqual(f.read(), 'blah')
462
462
463 def test_new_main_mod(self):
463 def test_new_main_mod(self):
464 # Smoketest to check that this accepts a unicode module name
464 # Smoketest to check that this accepts a unicode module name
465 name = u'jiefmw'
465 name = u'jiefmw'
466 mod = ip.new_main_mod(u'%s.py' % name, name)
466 mod = ip.new_main_mod(u'%s.py' % name, name)
467 self.assertEqual(mod.__name__, name)
467 self.assertEqual(mod.__name__, name)
468
468
469 def test_get_exception_only(self):
469 def test_get_exception_only(self):
470 try:
470 try:
471 raise KeyboardInterrupt
471 raise KeyboardInterrupt
472 except KeyboardInterrupt:
472 except KeyboardInterrupt:
473 msg = ip.get_exception_only()
473 msg = ip.get_exception_only()
474 self.assertEqual(msg, 'KeyboardInterrupt\n')
474 self.assertEqual(msg, 'KeyboardInterrupt\n')
475
475
476 try:
476 try:
477 raise DerivedInterrupt("foo")
477 raise DerivedInterrupt("foo")
478 except KeyboardInterrupt:
478 except KeyboardInterrupt:
479 msg = ip.get_exception_only()
479 msg = ip.get_exception_only()
480 self.assertEqual(msg, 'IPython.core.tests.test_interactiveshell.DerivedInterrupt: foo\n')
480 self.assertEqual(msg, 'IPython.core.tests.test_interactiveshell.DerivedInterrupt: foo\n')
481
481
482 def test_inspect_text(self):
482 def test_inspect_text(self):
483 ip.run_cell('a = 5')
483 ip.run_cell('a = 5')
484 text = ip.object_inspect_text('a')
484 text = ip.object_inspect_text('a')
485 self.assertIsInstance(text, str)
485 self.assertIsInstance(text, str)
486
486
487 def test_last_execution_result(self):
487 def test_last_execution_result(self):
488 """ Check that last execution result gets set correctly (GH-10702) """
488 """ Check that last execution result gets set correctly (GH-10702) """
489 result = ip.run_cell('a = 5; a')
489 result = ip.run_cell('a = 5; a')
490 self.assertTrue(ip.last_execution_succeeded)
490 self.assertTrue(ip.last_execution_succeeded)
491 self.assertEqual(ip.last_execution_result.result, 5)
491 self.assertEqual(ip.last_execution_result.result, 5)
492
492
493 result = ip.run_cell('a = x_invalid_id_x')
493 result = ip.run_cell('a = x_invalid_id_x')
494 self.assertFalse(ip.last_execution_succeeded)
494 self.assertFalse(ip.last_execution_succeeded)
495 self.assertFalse(ip.last_execution_result.success)
495 self.assertFalse(ip.last_execution_result.success)
496 self.assertIsInstance(ip.last_execution_result.error_in_exec, NameError)
496 self.assertIsInstance(ip.last_execution_result.error_in_exec, NameError)
497
497
498 def test_reset_aliasing(self):
498 def test_reset_aliasing(self):
499 """ Check that standard posix aliases work after %reset. """
499 """ Check that standard posix aliases work after %reset. """
500 if os.name != 'posix':
500 if os.name != 'posix':
501 return
501 return
502
502
503 ip.reset()
503 ip.reset()
504 for cmd in ('clear', 'more', 'less', 'man'):
504 for cmd in ('clear', 'more', 'less', 'man'):
505 res = ip.run_cell('%' + cmd)
505 res = ip.run_cell('%' + cmd)
506 self.assertEqual(res.success, True)
506 self.assertEqual(res.success, True)
507
507
508
508
509 class TestSafeExecfileNonAsciiPath(unittest.TestCase):
509 class TestSafeExecfileNonAsciiPath(unittest.TestCase):
510
510
511 @onlyif_unicode_paths
511 @onlyif_unicode_paths
512 def setUp(self):
512 def setUp(self):
513 self.BASETESTDIR = tempfile.mkdtemp()
513 self.BASETESTDIR = tempfile.mkdtemp()
514 self.TESTDIR = join(self.BASETESTDIR, u"Γ₯Àâ")
514 self.TESTDIR = join(self.BASETESTDIR, u"Γ₯Àâ")
515 os.mkdir(self.TESTDIR)
515 os.mkdir(self.TESTDIR)
516 with open(join(self.TESTDIR, u"Γ₯Àâtestscript.py"), "w") as sfile:
516 with open(join(self.TESTDIR, u"Γ₯Àâtestscript.py"), "w") as sfile:
517 sfile.write("pass\n")
517 sfile.write("pass\n")
518 self.oldpath = os.getcwd()
518 self.oldpath = os.getcwd()
519 os.chdir(self.TESTDIR)
519 os.chdir(self.TESTDIR)
520 self.fname = u"Γ₯Àâtestscript.py"
520 self.fname = u"Γ₯Àâtestscript.py"
521
521
522 def tearDown(self):
522 def tearDown(self):
523 os.chdir(self.oldpath)
523 os.chdir(self.oldpath)
524 shutil.rmtree(self.BASETESTDIR)
524 shutil.rmtree(self.BASETESTDIR)
525
525
526 @onlyif_unicode_paths
526 @onlyif_unicode_paths
527 def test_1(self):
527 def test_1(self):
528 """Test safe_execfile with non-ascii path
528 """Test safe_execfile with non-ascii path
529 """
529 """
530 ip.safe_execfile(self.fname, {}, raise_exceptions=True)
530 ip.safe_execfile(self.fname, {}, raise_exceptions=True)
531
531
532 class ExitCodeChecks(tt.TempFileMixin):
532 class ExitCodeChecks(tt.TempFileMixin):
533 def test_exit_code_ok(self):
533 def test_exit_code_ok(self):
534 self.system('exit 0')
534 self.system('exit 0')
535 self.assertEqual(ip.user_ns['_exit_code'], 0)
535 self.assertEqual(ip.user_ns['_exit_code'], 0)
536
536
537 def test_exit_code_error(self):
537 def test_exit_code_error(self):
538 self.system('exit 1')
538 self.system('exit 1')
539 self.assertEqual(ip.user_ns['_exit_code'], 1)
539 self.assertEqual(ip.user_ns['_exit_code'], 1)
540
540
541 @skipif(not hasattr(signal, 'SIGALRM'))
541 @skipif(not hasattr(signal, 'SIGALRM'))
542 def test_exit_code_signal(self):
542 def test_exit_code_signal(self):
543 self.mktmp("import signal, time\n"
543 self.mktmp("import signal, time\n"
544 "signal.setitimer(signal.ITIMER_REAL, 0.1)\n"
544 "signal.setitimer(signal.ITIMER_REAL, 0.1)\n"
545 "time.sleep(1)\n")
545 "time.sleep(1)\n")
546 self.system("%s %s" % (sys.executable, self.fname))
546 self.system("%s %s" % (sys.executable, self.fname))
547 self.assertEqual(ip.user_ns['_exit_code'], -signal.SIGALRM)
547 self.assertEqual(ip.user_ns['_exit_code'], -signal.SIGALRM)
548
548
549 @onlyif_cmds_exist("csh")
549 @onlyif_cmds_exist("csh")
550 def test_exit_code_signal_csh(self):
550 def test_exit_code_signal_csh(self):
551 SHELL = os.environ.get('SHELL', None)
551 SHELL = os.environ.get('SHELL', None)
552 os.environ['SHELL'] = find_cmd("csh")
552 os.environ['SHELL'] = find_cmd("csh")
553 try:
553 try:
554 self.test_exit_code_signal()
554 self.test_exit_code_signal()
555 finally:
555 finally:
556 if SHELL is not None:
556 if SHELL is not None:
557 os.environ['SHELL'] = SHELL
557 os.environ['SHELL'] = SHELL
558 else:
558 else:
559 del os.environ['SHELL']
559 del os.environ['SHELL']
560
560
561
561
562 class TestSystemRaw(ExitCodeChecks, unittest.TestCase):
562 class TestSystemRaw(ExitCodeChecks, unittest.TestCase):
563 system = ip.system_raw
563 system = ip.system_raw
564
564
565 @onlyif_unicode_paths
565 @onlyif_unicode_paths
566 def test_1(self):
566 def test_1(self):
567 """Test system_raw with non-ascii cmd
567 """Test system_raw with non-ascii cmd
568 """
568 """
569 cmd = u'''python -c "'Γ₯Àâ'" '''
569 cmd = u'''python -c "'Γ₯Àâ'" '''
570 ip.system_raw(cmd)
570 ip.system_raw(cmd)
571
571
572 @mock.patch('subprocess.call', side_effect=KeyboardInterrupt)
572 @mock.patch('subprocess.call', side_effect=KeyboardInterrupt)
573 @mock.patch('os.system', side_effect=KeyboardInterrupt)
573 @mock.patch('os.system', side_effect=KeyboardInterrupt)
574 def test_control_c(self, *mocks):
574 def test_control_c(self, *mocks):
575 try:
575 try:
576 self.system("sleep 1 # wont happen")
576 self.system("sleep 1 # wont happen")
577 except KeyboardInterrupt:
577 except KeyboardInterrupt:
578 self.fail("system call should intercept "
578 self.fail("system call should intercept "
579 "keyboard interrupt from subprocess.call")
579 "keyboard interrupt from subprocess.call")
580 self.assertEqual(ip.user_ns['_exit_code'], -signal.SIGINT)
580 self.assertEqual(ip.user_ns['_exit_code'], -signal.SIGINT)
581
581
582 # TODO: Exit codes are currently ignored on Windows.
582 # TODO: Exit codes are currently ignored on Windows.
583 class TestSystemPipedExitCode(ExitCodeChecks, unittest.TestCase):
583 class TestSystemPipedExitCode(ExitCodeChecks, unittest.TestCase):
584 system = ip.system_piped
584 system = ip.system_piped
585
585
586 @skip_win32
586 @skip_win32
587 def test_exit_code_ok(self):
587 def test_exit_code_ok(self):
588 ExitCodeChecks.test_exit_code_ok(self)
588 ExitCodeChecks.test_exit_code_ok(self)
589
589
590 @skip_win32
590 @skip_win32
591 def test_exit_code_error(self):
591 def test_exit_code_error(self):
592 ExitCodeChecks.test_exit_code_error(self)
592 ExitCodeChecks.test_exit_code_error(self)
593
593
594 @skip_win32
594 @skip_win32
595 def test_exit_code_signal(self):
595 def test_exit_code_signal(self):
596 ExitCodeChecks.test_exit_code_signal(self)
596 ExitCodeChecks.test_exit_code_signal(self)
597
597
598 class TestModules(tt.TempFileMixin, unittest.TestCase):
598 class TestModules(tt.TempFileMixin, unittest.TestCase):
599 def test_extraneous_loads(self):
599 def test_extraneous_loads(self):
600 """Test we're not loading modules on startup that we shouldn't.
600 """Test we're not loading modules on startup that we shouldn't.
601 """
601 """
602 self.mktmp("import sys\n"
602 self.mktmp("import sys\n"
603 "print('numpy' in sys.modules)\n"
603 "print('numpy' in sys.modules)\n"
604 "print('ipyparallel' in sys.modules)\n"
604 "print('ipyparallel' in sys.modules)\n"
605 "print('ipykernel' in sys.modules)\n"
605 "print('ipykernel' in sys.modules)\n"
606 )
606 )
607 out = "False\nFalse\nFalse\n"
607 out = "False\nFalse\nFalse\n"
608 tt.ipexec_validate(self.fname, out)
608 tt.ipexec_validate(self.fname, out)
609
609
610 class Negator(ast.NodeTransformer):
610 class Negator(ast.NodeTransformer):
611 """Negates all number literals in an AST."""
611 """Negates all number literals in an AST."""
612
612
613 def visit_Num(self, node):
613 def visit_Num(self, node):
614 node.n = -node.n
614 node.n = -node.n
615 return node
615 return node
616
616
617 if sys.version_info > (3,8):
617 if sys.version_info > (3,8):
618 def visit_Constant(self, node):
618 def visit_Constant(self, node):
619 return self.visit_Num(node)
619 if isinstance(node.value, int):
620 return self.visit_Num(node)
620
621
621 class TestAstTransform(unittest.TestCase):
622 class TestAstTransform(unittest.TestCase):
622 def setUp(self):
623 def setUp(self):
623 self.negator = Negator()
624 self.negator = Negator()
624 ip.ast_transformers.append(self.negator)
625 ip.ast_transformers.append(self.negator)
625
626
626 def tearDown(self):
627 def tearDown(self):
627 ip.ast_transformers.remove(self.negator)
628 ip.ast_transformers.remove(self.negator)
628
629
629 def test_run_cell(self):
630 def test_run_cell(self):
630 with tt.AssertPrints('-34'):
631 with tt.AssertPrints('-34'):
631 ip.run_cell('print (12 + 22)')
632 ip.run_cell('print (12 + 22)')
632
633
633 # A named reference to a number shouldn't be transformed.
634 # A named reference to a number shouldn't be transformed.
634 ip.user_ns['n'] = 55
635 ip.user_ns['n'] = 55
635 with tt.AssertNotPrints('-55'):
636 with tt.AssertNotPrints('-55'):
636 ip.run_cell('print (n)')
637 ip.run_cell('print (n)')
637
638
638 def test_timeit(self):
639 def test_timeit(self):
639 called = set()
640 called = set()
640 def f(x):
641 def f(x):
641 called.add(x)
642 called.add(x)
642 ip.push({'f':f})
643 ip.push({'f':f})
643
644
644 with tt.AssertPrints("std. dev. of"):
645 with tt.AssertPrints("std. dev. of"):
645 ip.run_line_magic("timeit", "-n1 f(1)")
646 ip.run_line_magic("timeit", "-n1 f(1)")
646 self.assertEqual(called, {-1})
647 self.assertEqual(called, {-1})
647 called.clear()
648 called.clear()
648
649
649 with tt.AssertPrints("std. dev. of"):
650 with tt.AssertPrints("std. dev. of"):
650 ip.run_cell_magic("timeit", "-n1 f(2)", "f(3)")
651 ip.run_cell_magic("timeit", "-n1 f(2)", "f(3)")
651 self.assertEqual(called, {-2, -3})
652 self.assertEqual(called, {-2, -3})
652
653
653 def test_time(self):
654 def test_time(self):
654 called = []
655 called = []
655 def f(x):
656 def f(x):
656 called.append(x)
657 called.append(x)
657 ip.push({'f':f})
658 ip.push({'f':f})
658
659
659 # Test with an expression
660 # Test with an expression
660 with tt.AssertPrints("Wall time: "):
661 with tt.AssertPrints("Wall time: "):
661 ip.run_line_magic("time", "f(5+9)")
662 ip.run_line_magic("time", "f(5+9)")
662 self.assertEqual(called, [-14])
663 self.assertEqual(called, [-14])
663 called[:] = []
664 called[:] = []
664
665
665 # Test with a statement (different code path)
666 # Test with a statement (different code path)
666 with tt.AssertPrints("Wall time: "):
667 with tt.AssertPrints("Wall time: "):
667 ip.run_line_magic("time", "a = f(-3 + -2)")
668 ip.run_line_magic("time", "a = f(-3 + -2)")
668 self.assertEqual(called, [5])
669 self.assertEqual(called, [5])
669
670
670 def test_macro(self):
671 def test_macro(self):
671 ip.push({'a':10})
672 ip.push({'a':10})
672 # The AST transformation makes this do a+=-1
673 # The AST transformation makes this do a+=-1
673 ip.define_macro("amacro", "a+=1\nprint(a)")
674 ip.define_macro("amacro", "a+=1\nprint(a)")
674
675
675 with tt.AssertPrints("9"):
676 with tt.AssertPrints("9"):
676 ip.run_cell("amacro")
677 ip.run_cell("amacro")
677 with tt.AssertPrints("8"):
678 with tt.AssertPrints("8"):
678 ip.run_cell("amacro")
679 ip.run_cell("amacro")
679
680
680 class IntegerWrapper(ast.NodeTransformer):
681 class IntegerWrapper(ast.NodeTransformer):
681 """Wraps all integers in a call to Integer()"""
682 """Wraps all integers in a call to Integer()"""
682 def visit_Num(self, node):
683 def visit_Num(self, node):
683 if isinstance(node.n, int):
684 if isinstance(node.n, int):
684 return ast.Call(func=ast.Name(id='Integer', ctx=ast.Load()),
685 return ast.Call(func=ast.Name(id='Integer', ctx=ast.Load()),
685 args=[node], keywords=[])
686 args=[node], keywords=[])
686 return node
687 return node
687
688
688 if sys.version_info > (3,8):
689 if sys.version_info > (3,8):
689 def visit_Constant(self, node):
690 def visit_Constant(self, node):
690 return self.visit_Num(node)
691 if isinstance(node.value, int):
692 return self.visit_Num(node)
691
693
692
694
693 class TestAstTransform2(unittest.TestCase):
695 class TestAstTransform2(unittest.TestCase):
694 def setUp(self):
696 def setUp(self):
695 self.intwrapper = IntegerWrapper()
697 self.intwrapper = IntegerWrapper()
696 ip.ast_transformers.append(self.intwrapper)
698 ip.ast_transformers.append(self.intwrapper)
697
699
698 self.calls = []
700 self.calls = []
699 def Integer(*args):
701 def Integer(*args):
700 self.calls.append(args)
702 self.calls.append(args)
701 return args
703 return args
702 ip.push({"Integer": Integer})
704 ip.push({"Integer": Integer})
703
705
704 def tearDown(self):
706 def tearDown(self):
705 ip.ast_transformers.remove(self.intwrapper)
707 ip.ast_transformers.remove(self.intwrapper)
706 del ip.user_ns['Integer']
708 del ip.user_ns['Integer']
707
709
708 def test_run_cell(self):
710 def test_run_cell(self):
709 ip.run_cell("n = 2")
711 ip.run_cell("n = 2")
710 self.assertEqual(self.calls, [(2,)])
712 self.assertEqual(self.calls, [(2,)])
711
713
712 # This shouldn't throw an error
714 # This shouldn't throw an error
713 ip.run_cell("o = 2.0")
715 ip.run_cell("o = 2.0")
714 self.assertEqual(ip.user_ns['o'], 2.0)
716 self.assertEqual(ip.user_ns['o'], 2.0)
715
717
716 def test_timeit(self):
718 def test_timeit(self):
717 called = set()
719 called = set()
718 def f(x):
720 def f(x):
719 called.add(x)
721 called.add(x)
720 ip.push({'f':f})
722 ip.push({'f':f})
721
723
722 with tt.AssertPrints("std. dev. of"):
724 with tt.AssertPrints("std. dev. of"):
723 ip.run_line_magic("timeit", "-n1 f(1)")
725 ip.run_line_magic("timeit", "-n1 f(1)")
724 self.assertEqual(called, {(1,)})
726 self.assertEqual(called, {(1,)})
725 called.clear()
727 called.clear()
726
728
727 with tt.AssertPrints("std. dev. of"):
729 with tt.AssertPrints("std. dev. of"):
728 ip.run_cell_magic("timeit", "-n1 f(2)", "f(3)")
730 ip.run_cell_magic("timeit", "-n1 f(2)", "f(3)")
729 self.assertEqual(called, {(2,), (3,)})
731 self.assertEqual(called, {(2,), (3,)})
730
732
731 class ErrorTransformer(ast.NodeTransformer):
733 class ErrorTransformer(ast.NodeTransformer):
732 """Throws an error when it sees a number."""
734 """Throws an error when it sees a number."""
733 def visit_Num(self, node):
735 def visit_Num(self, node):
734 raise ValueError("test")
736 raise ValueError("test")
735
737
736 if sys.version_info > (3,8):
738 if sys.version_info > (3,8):
737 def visit_Constant(self, node):
739 def visit_Constant(self, node):
738 return self.visit_Num(node)
740 if isinstance(node.value, int):
741 return self.visit_Num(node)
739
742
740
743
741 class TestAstTransformError(unittest.TestCase):
744 class TestAstTransformError(unittest.TestCase):
742 def test_unregistering(self):
745 def test_unregistering(self):
743 err_transformer = ErrorTransformer()
746 err_transformer = ErrorTransformer()
744 ip.ast_transformers.append(err_transformer)
747 ip.ast_transformers.append(err_transformer)
745
748
746 with tt.AssertPrints("unregister", channel='stderr'):
749 with tt.AssertPrints("unregister", channel='stderr'):
747 ip.run_cell("1 + 2")
750 ip.run_cell("1 + 2")
748
751
749 # This should have been removed.
752 # This should have been removed.
750 nt.assert_not_in(err_transformer, ip.ast_transformers)
753 nt.assert_not_in(err_transformer, ip.ast_transformers)
751
754
752
755
753 class StringRejector(ast.NodeTransformer):
756 class StringRejector(ast.NodeTransformer):
754 """Throws an InputRejected when it sees a string literal.
757 """Throws an InputRejected when it sees a string literal.
755
758
756 Used to verify that NodeTransformers can signal that a piece of code should
759 Used to verify that NodeTransformers can signal that a piece of code should
757 not be executed by throwing an InputRejected.
760 not be executed by throwing an InputRejected.
758 """
761 """
759
762
760 def visit_Str(self, node):
763 def visit_Str(self, node):
761 raise InputRejected("test")
764 raise InputRejected("test")
762
765
766 # 3.8 only
767 def visit_Constant(self, node):
768 if isinstance(node.value, str):
769 raise InputRejected("test")
770
763
771
764 class TestAstTransformInputRejection(unittest.TestCase):
772 class TestAstTransformInputRejection(unittest.TestCase):
765
773
766 def setUp(self):
774 def setUp(self):
767 self.transformer = StringRejector()
775 self.transformer = StringRejector()
768 ip.ast_transformers.append(self.transformer)
776 ip.ast_transformers.append(self.transformer)
769
777
770 def tearDown(self):
778 def tearDown(self):
771 ip.ast_transformers.remove(self.transformer)
779 ip.ast_transformers.remove(self.transformer)
772
780
773 def test_input_rejection(self):
781 def test_input_rejection(self):
774 """Check that NodeTransformers can reject input."""
782 """Check that NodeTransformers can reject input."""
775
783
776 expect_exception_tb = tt.AssertPrints("InputRejected: test")
784 expect_exception_tb = tt.AssertPrints("InputRejected: test")
777 expect_no_cell_output = tt.AssertNotPrints("'unsafe'", suppress=False)
785 expect_no_cell_output = tt.AssertNotPrints("'unsafe'", suppress=False)
778
786
779 # Run the same check twice to verify that the transformer is not
787 # Run the same check twice to verify that the transformer is not
780 # disabled after raising.
788 # disabled after raising.
781 with expect_exception_tb, expect_no_cell_output:
789 with expect_exception_tb, expect_no_cell_output:
782 ip.run_cell("'unsafe'")
790 ip.run_cell("'unsafe'")
783
791
784 with expect_exception_tb, expect_no_cell_output:
792 with expect_exception_tb, expect_no_cell_output:
785 res = ip.run_cell("'unsafe'")
793 res = ip.run_cell("'unsafe'")
786
794
787 self.assertIsInstance(res.error_before_exec, InputRejected)
795 self.assertIsInstance(res.error_before_exec, InputRejected)
788
796
789 def test__IPYTHON__():
797 def test__IPYTHON__():
790 # This shouldn't raise a NameError, that's all
798 # This shouldn't raise a NameError, that's all
791 __IPYTHON__
799 __IPYTHON__
792
800
793
801
794 class DummyRepr(object):
802 class DummyRepr(object):
795 def __repr__(self):
803 def __repr__(self):
796 return "DummyRepr"
804 return "DummyRepr"
797
805
798 def _repr_html_(self):
806 def _repr_html_(self):
799 return "<b>dummy</b>"
807 return "<b>dummy</b>"
800
808
801 def _repr_javascript_(self):
809 def _repr_javascript_(self):
802 return "console.log('hi');", {'key': 'value'}
810 return "console.log('hi');", {'key': 'value'}
803
811
804
812
805 def test_user_variables():
813 def test_user_variables():
806 # enable all formatters
814 # enable all formatters
807 ip.display_formatter.active_types = ip.display_formatter.format_types
815 ip.display_formatter.active_types = ip.display_formatter.format_types
808
816
809 ip.user_ns['dummy'] = d = DummyRepr()
817 ip.user_ns['dummy'] = d = DummyRepr()
810 keys = {'dummy', 'doesnotexist'}
818 keys = {'dummy', 'doesnotexist'}
811 r = ip.user_expressions({ key:key for key in keys})
819 r = ip.user_expressions({ key:key for key in keys})
812
820
813 nt.assert_equal(keys, set(r.keys()))
821 nt.assert_equal(keys, set(r.keys()))
814 dummy = r['dummy']
822 dummy = r['dummy']
815 nt.assert_equal({'status', 'data', 'metadata'}, set(dummy.keys()))
823 nt.assert_equal({'status', 'data', 'metadata'}, set(dummy.keys()))
816 nt.assert_equal(dummy['status'], 'ok')
824 nt.assert_equal(dummy['status'], 'ok')
817 data = dummy['data']
825 data = dummy['data']
818 metadata = dummy['metadata']
826 metadata = dummy['metadata']
819 nt.assert_equal(data.get('text/html'), d._repr_html_())
827 nt.assert_equal(data.get('text/html'), d._repr_html_())
820 js, jsmd = d._repr_javascript_()
828 js, jsmd = d._repr_javascript_()
821 nt.assert_equal(data.get('application/javascript'), js)
829 nt.assert_equal(data.get('application/javascript'), js)
822 nt.assert_equal(metadata.get('application/javascript'), jsmd)
830 nt.assert_equal(metadata.get('application/javascript'), jsmd)
823
831
824 dne = r['doesnotexist']
832 dne = r['doesnotexist']
825 nt.assert_equal(dne['status'], 'error')
833 nt.assert_equal(dne['status'], 'error')
826 nt.assert_equal(dne['ename'], 'NameError')
834 nt.assert_equal(dne['ename'], 'NameError')
827
835
828 # back to text only
836 # back to text only
829 ip.display_formatter.active_types = ['text/plain']
837 ip.display_formatter.active_types = ['text/plain']
830
838
831 def test_user_expression():
839 def test_user_expression():
832 # enable all formatters
840 # enable all formatters
833 ip.display_formatter.active_types = ip.display_formatter.format_types
841 ip.display_formatter.active_types = ip.display_formatter.format_types
834 query = {
842 query = {
835 'a' : '1 + 2',
843 'a' : '1 + 2',
836 'b' : '1/0',
844 'b' : '1/0',
837 }
845 }
838 r = ip.user_expressions(query)
846 r = ip.user_expressions(query)
839 import pprint
847 import pprint
840 pprint.pprint(r)
848 pprint.pprint(r)
841 nt.assert_equal(set(r.keys()), set(query.keys()))
849 nt.assert_equal(set(r.keys()), set(query.keys()))
842 a = r['a']
850 a = r['a']
843 nt.assert_equal({'status', 'data', 'metadata'}, set(a.keys()))
851 nt.assert_equal({'status', 'data', 'metadata'}, set(a.keys()))
844 nt.assert_equal(a['status'], 'ok')
852 nt.assert_equal(a['status'], 'ok')
845 data = a['data']
853 data = a['data']
846 metadata = a['metadata']
854 metadata = a['metadata']
847 nt.assert_equal(data.get('text/plain'), '3')
855 nt.assert_equal(data.get('text/plain'), '3')
848
856
849 b = r['b']
857 b = r['b']
850 nt.assert_equal(b['status'], 'error')
858 nt.assert_equal(b['status'], 'error')
851 nt.assert_equal(b['ename'], 'ZeroDivisionError')
859 nt.assert_equal(b['ename'], 'ZeroDivisionError')
852
860
853 # back to text only
861 # back to text only
854 ip.display_formatter.active_types = ['text/plain']
862 ip.display_formatter.active_types = ['text/plain']
855
863
856
864
857
865
858
866
859
867
860 class TestSyntaxErrorTransformer(unittest.TestCase):
868 class TestSyntaxErrorTransformer(unittest.TestCase):
861 """Check that SyntaxError raised by an input transformer is handled by run_cell()"""
869 """Check that SyntaxError raised by an input transformer is handled by run_cell()"""
862
870
863 @staticmethod
871 @staticmethod
864 def transformer(lines):
872 def transformer(lines):
865 for line in lines:
873 for line in lines:
866 pos = line.find('syntaxerror')
874 pos = line.find('syntaxerror')
867 if pos >= 0:
875 if pos >= 0:
868 e = SyntaxError('input contains "syntaxerror"')
876 e = SyntaxError('input contains "syntaxerror"')
869 e.text = line
877 e.text = line
870 e.offset = pos + 1
878 e.offset = pos + 1
871 raise e
879 raise e
872 return lines
880 return lines
873
881
874 def setUp(self):
882 def setUp(self):
875 ip.input_transformers_post.append(self.transformer)
883 ip.input_transformers_post.append(self.transformer)
876
884
877 def tearDown(self):
885 def tearDown(self):
878 ip.input_transformers_post.remove(self.transformer)
886 ip.input_transformers_post.remove(self.transformer)
879
887
880 def test_syntaxerror_input_transformer(self):
888 def test_syntaxerror_input_transformer(self):
881 with tt.AssertPrints('1234'):
889 with tt.AssertPrints('1234'):
882 ip.run_cell('1234')
890 ip.run_cell('1234')
883 with tt.AssertPrints('SyntaxError: invalid syntax'):
891 with tt.AssertPrints('SyntaxError: invalid syntax'):
884 ip.run_cell('1 2 3') # plain python syntax error
892 ip.run_cell('1 2 3') # plain python syntax error
885 with tt.AssertPrints('SyntaxError: input contains "syntaxerror"'):
893 with tt.AssertPrints('SyntaxError: input contains "syntaxerror"'):
886 ip.run_cell('2345 # syntaxerror') # input transformer syntax error
894 ip.run_cell('2345 # syntaxerror') # input transformer syntax error
887 with tt.AssertPrints('3456'):
895 with tt.AssertPrints('3456'):
888 ip.run_cell('3456')
896 ip.run_cell('3456')
889
897
890
898
891
899
892 def test_warning_suppression():
900 def test_warning_suppression():
893 ip.run_cell("import warnings")
901 ip.run_cell("import warnings")
894 try:
902 try:
895 with tt.AssertPrints("UserWarning: asdf", channel="stderr"):
903 with tt.AssertPrints("UserWarning: asdf", channel="stderr"):
896 ip.run_cell("warnings.warn('asdf')")
904 ip.run_cell("warnings.warn('asdf')")
897 # Here's the real test -- if we run that again, we should get the
905 # Here's the real test -- if we run that again, we should get the
898 # warning again. Traditionally, each warning was only issued once per
906 # warning again. Traditionally, each warning was only issued once per
899 # IPython session (approximately), even if the user typed in new and
907 # IPython session (approximately), even if the user typed in new and
900 # different code that should have also triggered the warning, leading
908 # different code that should have also triggered the warning, leading
901 # to much confusion.
909 # to much confusion.
902 with tt.AssertPrints("UserWarning: asdf", channel="stderr"):
910 with tt.AssertPrints("UserWarning: asdf", channel="stderr"):
903 ip.run_cell("warnings.warn('asdf')")
911 ip.run_cell("warnings.warn('asdf')")
904 finally:
912 finally:
905 ip.run_cell("del warnings")
913 ip.run_cell("del warnings")
906
914
907
915
908 def test_deprecation_warning():
916 def test_deprecation_warning():
909 ip.run_cell("""
917 ip.run_cell("""
910 import warnings
918 import warnings
911 def wrn():
919 def wrn():
912 warnings.warn(
920 warnings.warn(
913 "I AM A WARNING",
921 "I AM A WARNING",
914 DeprecationWarning
922 DeprecationWarning
915 )
923 )
916 """)
924 """)
917 try:
925 try:
918 with tt.AssertPrints("I AM A WARNING", channel="stderr"):
926 with tt.AssertPrints("I AM A WARNING", channel="stderr"):
919 ip.run_cell("wrn()")
927 ip.run_cell("wrn()")
920 finally:
928 finally:
921 ip.run_cell("del warnings")
929 ip.run_cell("del warnings")
922 ip.run_cell("del wrn")
930 ip.run_cell("del wrn")
923
931
924
932
925 class TestImportNoDeprecate(tt.TempFileMixin):
933 class TestImportNoDeprecate(tt.TempFileMixin):
926
934
927 def setup(self):
935 def setup(self):
928 """Make a valid python temp file."""
936 """Make a valid python temp file."""
929 self.mktmp("""
937 self.mktmp("""
930 import warnings
938 import warnings
931 def wrn():
939 def wrn():
932 warnings.warn(
940 warnings.warn(
933 "I AM A WARNING",
941 "I AM A WARNING",
934 DeprecationWarning
942 DeprecationWarning
935 )
943 )
936 """)
944 """)
937
945
938 def test_no_dep(self):
946 def test_no_dep(self):
939 """
947 """
940 No deprecation warning should be raised from imported functions
948 No deprecation warning should be raised from imported functions
941 """
949 """
942 ip.run_cell("from {} import wrn".format(self.fname))
950 ip.run_cell("from {} import wrn".format(self.fname))
943
951
944 with tt.AssertNotPrints("I AM A WARNING"):
952 with tt.AssertNotPrints("I AM A WARNING"):
945 ip.run_cell("wrn()")
953 ip.run_cell("wrn()")
946 ip.run_cell("del wrn")
954 ip.run_cell("del wrn")
947
955
948
956
949 def test_custom_exc_count():
957 def test_custom_exc_count():
950 hook = mock.Mock(return_value=None)
958 hook = mock.Mock(return_value=None)
951 ip.set_custom_exc((SyntaxError,), hook)
959 ip.set_custom_exc((SyntaxError,), hook)
952 before = ip.execution_count
960 before = ip.execution_count
953 ip.run_cell("def foo()", store_history=True)
961 ip.run_cell("def foo()", store_history=True)
954 # restore default excepthook
962 # restore default excepthook
955 ip.set_custom_exc((), None)
963 ip.set_custom_exc((), None)
956 nt.assert_equal(hook.call_count, 1)
964 nt.assert_equal(hook.call_count, 1)
957 nt.assert_equal(ip.execution_count, before + 1)
965 nt.assert_equal(ip.execution_count, before + 1)
958
966
959
967
960 def test_run_cell_async():
968 def test_run_cell_async():
961 loop = asyncio.get_event_loop()
969 loop = asyncio.get_event_loop()
962 ip.run_cell("import asyncio")
970 ip.run_cell("import asyncio")
963 coro = ip.run_cell_async("await asyncio.sleep(0.01)\n5")
971 coro = ip.run_cell_async("await asyncio.sleep(0.01)\n5")
964 assert asyncio.iscoroutine(coro)
972 assert asyncio.iscoroutine(coro)
965 result = loop.run_until_complete(coro)
973 result = loop.run_until_complete(coro)
966 assert isinstance(result, interactiveshell.ExecutionResult)
974 assert isinstance(result, interactiveshell.ExecutionResult)
967 assert result.result == 5
975 assert result.result == 5
968
976
969
977
970 def test_should_run_async():
978 def test_should_run_async():
971 assert not ip.should_run_async("a = 5")
979 assert not ip.should_run_async("a = 5")
972 assert ip.should_run_async("await x")
980 assert ip.should_run_async("await x")
973 assert ip.should_run_async("import asyncio; await asyncio.sleep(1)")
981 assert ip.should_run_async("import asyncio; await asyncio.sleep(1)")
@@ -1,166 +1,166 b''
1 """prompt-toolkit utilities
1 """prompt-toolkit utilities
2
2
3 Everything in this module is a private API,
3 Everything in this module is a private API,
4 not to be used outside IPython.
4 not to be used outside IPython.
5 """
5 """
6
6
7 # Copyright (c) IPython Development Team.
7 # Copyright (c) IPython Development Team.
8 # Distributed under the terms of the Modified BSD License.
8 # Distributed under the terms of the Modified BSD License.
9
9
10 import unicodedata
10 import unicodedata
11 from wcwidth import wcwidth
11 from wcwidth import wcwidth
12
12
13 from IPython.core.completer import (
13 from IPython.core.completer import (
14 provisionalcompleter, cursor_to_position,
14 provisionalcompleter, cursor_to_position,
15 _deduplicate_completions)
15 _deduplicate_completions)
16 from prompt_toolkit.completion import Completer, Completion
16 from prompt_toolkit.completion import Completer, Completion
17 from prompt_toolkit.lexers import Lexer
17 from prompt_toolkit.lexers import Lexer
18 from prompt_toolkit.lexers import PygmentsLexer
18 from prompt_toolkit.lexers import PygmentsLexer
19 from prompt_toolkit.patch_stdout import patch_stdout
19 from prompt_toolkit.patch_stdout import patch_stdout
20
20
21 import pygments.lexers as pygments_lexers
21 import pygments.lexers as pygments_lexers
22 import os
22 import os
23
23
24 _completion_sentinel = object()
24 _completion_sentinel = object()
25
25
26 def _elide(string, *, min_elide=30):
26 def _elide(string, *, min_elide=30):
27 """
27 """
28 If a string is long enough, and has at least 3 dots,
28 If a string is long enough, and has at least 3 dots,
29 replace the middle part with ellipses.
29 replace the middle part with ellipses.
30
30
31 If a string naming a file is long enough, and has at least 3 slashes,
31 If a string naming a file is long enough, and has at least 3 slashes,
32 replace the middle part with ellipses.
32 replace the middle part with ellipses.
33
33
34 If three consecutive dots, or two consecutive dots are encountered these are
34 If three consecutive dots, or two consecutive dots are encountered these are
35 replaced by the equivalents HORIZONTAL ELLIPSIS or TWO DOT LEADER unicode
35 replaced by the equivalents HORIZONTAL ELLIPSIS or TWO DOT LEADER unicode
36 equivalents
36 equivalents
37 """
37 """
38 string = string.replace('...','\N{HORIZONTAL ELLIPSIS}')
38 string = string.replace('...','\N{HORIZONTAL ELLIPSIS}')
39 string = string.replace('..','\N{TWO DOT LEADER}')
39 string = string.replace('..','\N{TWO DOT LEADER}')
40 if len(string) < min_elide:
40 if len(string) < min_elide:
41 return string
41 return string
42
42
43 object_parts = string.split('.')
43 object_parts = string.split('.')
44 file_parts = string.split(os.sep)
44 file_parts = string.split(os.sep)
45
45
46 if len(object_parts) > 3:
46 if len(object_parts) > 3:
47 return '{}.{}\N{HORIZONTAL ELLIPSIS}{}.{}'.format(object_parts[0], object_parts[1][0], object_parts[-2][-1], object_parts[-1])
47 return '{}.{}\N{HORIZONTAL ELLIPSIS}{}.{}'.format(object_parts[0], object_parts[1][0], object_parts[-2][-1], object_parts[-1])
48
48
49 elif len(file_parts) > 3:
49 elif len(file_parts) > 3:
50 return ('{}' + os.sep + '{}\N{HORIZONTAL ELLIPSIS}{}' + os.sep + '{}').format(file_parts[0], file_parts[1][0], file_parts[-2][-1], file_parts[-1])
50 return ('{}' + os.sep + '{}\N{HORIZONTAL ELLIPSIS}{}' + os.sep + '{}').format(file_parts[0], file_parts[1][0], file_parts[-2][-1], file_parts[-1])
51
51
52 return string
52 return string
53
53
54
54
55 def _adjust_completion_text_based_on_context(text, body, offset):
55 def _adjust_completion_text_based_on_context(text, body, offset):
56 if text.endswith('=') and len(body) > offset and body[offset] is '=':
56 if text.endswith('=') and len(body) > offset and body[offset] == '=':
57 return text[:-1]
57 return text[:-1]
58 else:
58 else:
59 return text
59 return text
60
60
61
61
62 class IPythonPTCompleter(Completer):
62 class IPythonPTCompleter(Completer):
63 """Adaptor to provide IPython completions to prompt_toolkit"""
63 """Adaptor to provide IPython completions to prompt_toolkit"""
64 def __init__(self, ipy_completer=None, shell=None):
64 def __init__(self, ipy_completer=None, shell=None):
65 if shell is None and ipy_completer is None:
65 if shell is None and ipy_completer is None:
66 raise TypeError("Please pass shell=an InteractiveShell instance.")
66 raise TypeError("Please pass shell=an InteractiveShell instance.")
67 self._ipy_completer = ipy_completer
67 self._ipy_completer = ipy_completer
68 self.shell = shell
68 self.shell = shell
69
69
70 @property
70 @property
71 def ipy_completer(self):
71 def ipy_completer(self):
72 if self._ipy_completer:
72 if self._ipy_completer:
73 return self._ipy_completer
73 return self._ipy_completer
74 else:
74 else:
75 return self.shell.Completer
75 return self.shell.Completer
76
76
77 def get_completions(self, document, complete_event):
77 def get_completions(self, document, complete_event):
78 if not document.current_line.strip():
78 if not document.current_line.strip():
79 return
79 return
80 # Some bits of our completion system may print stuff (e.g. if a module
80 # Some bits of our completion system may print stuff (e.g. if a module
81 # is imported). This context manager ensures that doesn't interfere with
81 # is imported). This context manager ensures that doesn't interfere with
82 # the prompt.
82 # the prompt.
83
83
84 with patch_stdout(), provisionalcompleter():
84 with patch_stdout(), provisionalcompleter():
85 body = document.text
85 body = document.text
86 cursor_row = document.cursor_position_row
86 cursor_row = document.cursor_position_row
87 cursor_col = document.cursor_position_col
87 cursor_col = document.cursor_position_col
88 cursor_position = document.cursor_position
88 cursor_position = document.cursor_position
89 offset = cursor_to_position(body, cursor_row, cursor_col)
89 offset = cursor_to_position(body, cursor_row, cursor_col)
90 yield from self._get_completions(body, offset, cursor_position, self.ipy_completer)
90 yield from self._get_completions(body, offset, cursor_position, self.ipy_completer)
91
91
92 @staticmethod
92 @staticmethod
93 def _get_completions(body, offset, cursor_position, ipyc):
93 def _get_completions(body, offset, cursor_position, ipyc):
94 """
94 """
95 Private equivalent of get_completions() use only for unit_testing.
95 Private equivalent of get_completions() use only for unit_testing.
96 """
96 """
97 debug = getattr(ipyc, 'debug', False)
97 debug = getattr(ipyc, 'debug', False)
98 completions = _deduplicate_completions(
98 completions = _deduplicate_completions(
99 body, ipyc.completions(body, offset))
99 body, ipyc.completions(body, offset))
100 for c in completions:
100 for c in completions:
101 if not c.text:
101 if not c.text:
102 # Guard against completion machinery giving us an empty string.
102 # Guard against completion machinery giving us an empty string.
103 continue
103 continue
104 text = unicodedata.normalize('NFC', c.text)
104 text = unicodedata.normalize('NFC', c.text)
105 # When the first character of the completion has a zero length,
105 # When the first character of the completion has a zero length,
106 # then it's probably a decomposed unicode character. E.g. caused by
106 # then it's probably a decomposed unicode character. E.g. caused by
107 # the "\dot" completion. Try to compose again with the previous
107 # the "\dot" completion. Try to compose again with the previous
108 # character.
108 # character.
109 if wcwidth(text[0]) == 0:
109 if wcwidth(text[0]) == 0:
110 if cursor_position + c.start > 0:
110 if cursor_position + c.start > 0:
111 char_before = body[c.start - 1]
111 char_before = body[c.start - 1]
112 fixed_text = unicodedata.normalize(
112 fixed_text = unicodedata.normalize(
113 'NFC', char_before + text)
113 'NFC', char_before + text)
114
114
115 # Yield the modified completion instead, if this worked.
115 # Yield the modified completion instead, if this worked.
116 if wcwidth(text[0:1]) == 1:
116 if wcwidth(text[0:1]) == 1:
117 yield Completion(fixed_text, start_position=c.start - offset - 1)
117 yield Completion(fixed_text, start_position=c.start - offset - 1)
118 continue
118 continue
119
119
120 # TODO: Use Jedi to determine meta_text
120 # TODO: Use Jedi to determine meta_text
121 # (Jedi currently has a bug that results in incorrect information.)
121 # (Jedi currently has a bug that results in incorrect information.)
122 # meta_text = ''
122 # meta_text = ''
123 # yield Completion(m, start_position=start_pos,
123 # yield Completion(m, start_position=start_pos,
124 # display_meta=meta_text)
124 # display_meta=meta_text)
125 display_text = c.text
125 display_text = c.text
126
126
127 adjusted_text = _adjust_completion_text_based_on_context(c.text, body, offset)
127 adjusted_text = _adjust_completion_text_based_on_context(c.text, body, offset)
128 if c.type == 'function':
128 if c.type == 'function':
129 yield Completion(adjusted_text, start_position=c.start - offset, display=_elide(display_text+'()'), display_meta=c.type+c.signature)
129 yield Completion(adjusted_text, start_position=c.start - offset, display=_elide(display_text+'()'), display_meta=c.type+c.signature)
130 else:
130 else:
131 yield Completion(adjusted_text, start_position=c.start - offset, display=_elide(display_text), display_meta=c.type)
131 yield Completion(adjusted_text, start_position=c.start - offset, display=_elide(display_text), display_meta=c.type)
132
132
133 class IPythonPTLexer(Lexer):
133 class IPythonPTLexer(Lexer):
134 """
134 """
135 Wrapper around PythonLexer and BashLexer.
135 Wrapper around PythonLexer and BashLexer.
136 """
136 """
137 def __init__(self):
137 def __init__(self):
138 l = pygments_lexers
138 l = pygments_lexers
139 self.python_lexer = PygmentsLexer(l.Python3Lexer)
139 self.python_lexer = PygmentsLexer(l.Python3Lexer)
140 self.shell_lexer = PygmentsLexer(l.BashLexer)
140 self.shell_lexer = PygmentsLexer(l.BashLexer)
141
141
142 self.magic_lexers = {
142 self.magic_lexers = {
143 'HTML': PygmentsLexer(l.HtmlLexer),
143 'HTML': PygmentsLexer(l.HtmlLexer),
144 'html': PygmentsLexer(l.HtmlLexer),
144 'html': PygmentsLexer(l.HtmlLexer),
145 'javascript': PygmentsLexer(l.JavascriptLexer),
145 'javascript': PygmentsLexer(l.JavascriptLexer),
146 'js': PygmentsLexer(l.JavascriptLexer),
146 'js': PygmentsLexer(l.JavascriptLexer),
147 'perl': PygmentsLexer(l.PerlLexer),
147 'perl': PygmentsLexer(l.PerlLexer),
148 'ruby': PygmentsLexer(l.RubyLexer),
148 'ruby': PygmentsLexer(l.RubyLexer),
149 'latex': PygmentsLexer(l.TexLexer),
149 'latex': PygmentsLexer(l.TexLexer),
150 }
150 }
151
151
152 def lex_document(self, document):
152 def lex_document(self, document):
153 text = document.text.lstrip()
153 text = document.text.lstrip()
154
154
155 lexer = self.python_lexer
155 lexer = self.python_lexer
156
156
157 if text.startswith('!') or text.startswith('%%bash'):
157 if text.startswith('!') or text.startswith('%%bash'):
158 lexer = self.shell_lexer
158 lexer = self.shell_lexer
159
159
160 elif text.startswith('%%'):
160 elif text.startswith('%%'):
161 for magic, l in self.magic_lexers.items():
161 for magic, l in self.magic_lexers.items():
162 if text.startswith('%%' + magic):
162 if text.startswith('%%' + magic):
163 lexer = l
163 lexer = l
164 break
164 break
165
165
166 return lexer.lex_document(document)
166 return lexer.lex_document(document)
General Comments 0
You need to be logged in to leave comments. Login now