##// END OF EJS Templates
Update IPython/core/tests/test_async_helpers.py
Matthias Bussonnier -
Show More
@@ -1,322 +1,323 b''
1 """
1 """
2 Test for async helpers.
2 Test for async helpers.
3
3
4 Should only trigger on python 3.5+ or will have syntax errors.
4 Should only trigger on python 3.5+ or will have syntax errors.
5 """
5 """
6 from itertools import chain, repeat
6 from itertools import chain, repeat
7 import nose.tools as nt
7 import nose.tools as nt
8 from textwrap import dedent, indent
8 from textwrap import dedent, indent
9 from unittest import TestCase
9 from unittest import TestCase
10 from IPython.testing.decorators import skip_without
10 from IPython.testing.decorators import skip_without
11 import sys
11 import sys
12 from typing import TYPE_CHECKING
12 from typing import TYPE_CHECKING
13
13
14 if TYPE_CHECKING:
14 if TYPE_CHECKING:
15 from IPython import get_ipython
15 from IPython import get_ipython
16
16 ip = get_ipython()
17 ip = get_ipython()
17
18
18
19
19 iprc = lambda x: ip.run_cell(dedent(x)).raise_error()
20 iprc = lambda x: ip.run_cell(dedent(x)).raise_error()
20 iprc_nr = lambda x: ip.run_cell(dedent(x))
21 iprc_nr = lambda x: ip.run_cell(dedent(x))
21
22
22 from IPython.core.async_helpers import _should_be_async
23 from IPython.core.async_helpers import _should_be_async
23
24
24 class AsyncTest(TestCase):
25 class AsyncTest(TestCase):
25 def test_should_be_async(self):
26 def test_should_be_async(self):
26 nt.assert_false(_should_be_async("False"))
27 nt.assert_false(_should_be_async("False"))
27 nt.assert_true(_should_be_async("await bar()"))
28 nt.assert_true(_should_be_async("await bar()"))
28 nt.assert_true(_should_be_async("x = await bar()"))
29 nt.assert_true(_should_be_async("x = await bar()"))
29 nt.assert_false(
30 nt.assert_false(
30 _should_be_async(
31 _should_be_async(
31 dedent(
32 dedent(
32 """
33 """
33 async def awaitable():
34 async def awaitable():
34 pass
35 pass
35 """
36 """
36 )
37 )
37 )
38 )
38 )
39 )
39
40
40 def _get_top_level_cases(self):
41 def _get_top_level_cases(self):
41 # These are test cases that should be valid in a function
42 # These are test cases that should be valid in a function
42 # but invalid outside of a function.
43 # but invalid outside of a function.
43 test_cases = []
44 test_cases = []
44 test_cases.append(('basic', "{val}"))
45 test_cases.append(('basic', "{val}"))
45
46
46 # Note, in all conditional cases, I use True instead of
47 # Note, in all conditional cases, I use True instead of
47 # False so that the peephole optimizer won't optimize away
48 # False so that the peephole optimizer won't optimize away
48 # the return, so CPython will see this as a syntax error:
49 # the return, so CPython will see this as a syntax error:
49 #
50 #
50 # while True:
51 # while True:
51 # break
52 # break
52 # return
53 # return
53 #
54 #
54 # But not this:
55 # But not this:
55 #
56 #
56 # while False:
57 # while False:
57 # return
58 # return
58 #
59 #
59 # See https://bugs.python.org/issue1875
60 # See https://bugs.python.org/issue1875
60
61
61 test_cases.append(('if', dedent("""
62 test_cases.append(('if', dedent("""
62 if True:
63 if True:
63 {val}
64 {val}
64 """)))
65 """)))
65
66
66 test_cases.append(('while', dedent("""
67 test_cases.append(('while', dedent("""
67 while True:
68 while True:
68 {val}
69 {val}
69 break
70 break
70 """)))
71 """)))
71
72
72 test_cases.append(('try', dedent("""
73 test_cases.append(('try', dedent("""
73 try:
74 try:
74 {val}
75 {val}
75 except:
76 except:
76 pass
77 pass
77 """)))
78 """)))
78
79
79 test_cases.append(('except', dedent("""
80 test_cases.append(('except', dedent("""
80 try:
81 try:
81 pass
82 pass
82 except:
83 except:
83 {val}
84 {val}
84 """)))
85 """)))
85
86
86 test_cases.append(('finally', dedent("""
87 test_cases.append(('finally', dedent("""
87 try:
88 try:
88 pass
89 pass
89 except:
90 except:
90 pass
91 pass
91 finally:
92 finally:
92 {val}
93 {val}
93 """)))
94 """)))
94
95
95 test_cases.append(('for', dedent("""
96 test_cases.append(('for', dedent("""
96 for _ in range(4):
97 for _ in range(4):
97 {val}
98 {val}
98 """)))
99 """)))
99
100
100
101
101 test_cases.append(('nested', dedent("""
102 test_cases.append(('nested', dedent("""
102 if True:
103 if True:
103 while True:
104 while True:
104 {val}
105 {val}
105 break
106 break
106 """)))
107 """)))
107
108
108 test_cases.append(('deep-nested', dedent("""
109 test_cases.append(('deep-nested', dedent("""
109 if True:
110 if True:
110 while True:
111 while True:
111 break
112 break
112 for x in range(3):
113 for x in range(3):
113 if True:
114 if True:
114 while True:
115 while True:
115 for x in range(3):
116 for x in range(3):
116 {val}
117 {val}
117 """)))
118 """)))
118
119
119 return test_cases
120 return test_cases
120
121
121 def _get_ry_syntax_errors(self):
122 def _get_ry_syntax_errors(self):
122 # This is a mix of tests that should be a syntax error if
123 # This is a mix of tests that should be a syntax error if
123 # return or yield whether or not they are in a function
124 # return or yield whether or not they are in a function
124
125
125 test_cases = []
126 test_cases = []
126
127
127 test_cases.append(('class', dedent("""
128 test_cases.append(('class', dedent("""
128 class V:
129 class V:
129 {val}
130 {val}
130 """)))
131 """)))
131
132
132 test_cases.append(('nested-class', dedent("""
133 test_cases.append(('nested-class', dedent("""
133 class V:
134 class V:
134 class C:
135 class C:
135 {val}
136 {val}
136 """)))
137 """)))
137
138
138 return test_cases
139 return test_cases
139
140
140
141
141 def test_top_level_return_error(self):
142 def test_top_level_return_error(self):
142 tl_err_test_cases = self._get_top_level_cases()
143 tl_err_test_cases = self._get_top_level_cases()
143 tl_err_test_cases.extend(self._get_ry_syntax_errors())
144 tl_err_test_cases.extend(self._get_ry_syntax_errors())
144
145
145 vals = ('return', 'yield', 'yield from (_ for _ in range(3))',
146 vals = ('return', 'yield', 'yield from (_ for _ in range(3))',
146 dedent('''
147 dedent('''
147 def f():
148 def f():
148 pass
149 pass
149 return
150 return
150 '''),
151 '''),
151 )
152 )
152
153
153 for test_name, test_case in tl_err_test_cases:
154 for test_name, test_case in tl_err_test_cases:
154 # This example should work if 'pass' is used as the value
155 # This example should work if 'pass' is used as the value
155 with self.subTest((test_name, 'pass')):
156 with self.subTest((test_name, 'pass')):
156 iprc(test_case.format(val='pass'))
157 iprc(test_case.format(val='pass'))
157
158
158 # It should fail with all the values
159 # It should fail with all the values
159 for val in vals:
160 for val in vals:
160 with self.subTest((test_name, val)):
161 with self.subTest((test_name, val)):
161 msg = "Syntax error not raised for %s, %s" % (test_name, val)
162 msg = "Syntax error not raised for %s, %s" % (test_name, val)
162 with self.assertRaises(SyntaxError, msg=msg):
163 with self.assertRaises(SyntaxError, msg=msg):
163 iprc(test_case.format(val=val))
164 iprc(test_case.format(val=val))
164
165
165 def test_in_func_no_error(self):
166 def test_in_func_no_error(self):
166 # Test that the implementation of top-level return/yield
167 # Test that the implementation of top-level return/yield
167 # detection isn't *too* aggressive, and works inside a function
168 # detection isn't *too* aggressive, and works inside a function
168 func_contexts = []
169 func_contexts = []
169
170
170 func_contexts.append(('func', False, dedent("""
171 func_contexts.append(('func', False, dedent("""
171 def f():""")))
172 def f():""")))
172
173
173 func_contexts.append(('method', False, dedent("""
174 func_contexts.append(('method', False, dedent("""
174 class MyClass:
175 class MyClass:
175 def __init__(self):
176 def __init__(self):
176 """)))
177 """)))
177
178
178 func_contexts.append(('async-func', True, dedent("""
179 func_contexts.append(('async-func', True, dedent("""
179 async def f():""")))
180 async def f():""")))
180
181
181 func_contexts.append(('async-method', True, dedent("""
182 func_contexts.append(('async-method', True, dedent("""
182 class MyClass:
183 class MyClass:
183 async def f(self):""")))
184 async def f(self):""")))
184
185
185 func_contexts.append(('closure', False, dedent("""
186 func_contexts.append(('closure', False, dedent("""
186 def f():
187 def f():
187 def g():
188 def g():
188 """)))
189 """)))
189
190
190 def nest_case(context, case):
191 def nest_case(context, case):
191 # Detect indentation
192 # Detect indentation
192 lines = context.strip().splitlines()
193 lines = context.strip().splitlines()
193 prefix_len = 0
194 prefix_len = 0
194 for c in lines[-1]:
195 for c in lines[-1]:
195 if c != ' ':
196 if c != ' ':
196 break
197 break
197 prefix_len += 1
198 prefix_len += 1
198
199
199 indented_case = indent(case, ' ' * (prefix_len + 4))
200 indented_case = indent(case, ' ' * (prefix_len + 4))
200 return context + '\n' + indented_case
201 return context + '\n' + indented_case
201
202
202 # Gather and run the tests
203 # Gather and run the tests
203
204
204 # yield is allowed in async functions, starting in Python 3.6,
205 # yield is allowed in async functions, starting in Python 3.6,
205 # and yield from is not allowed in any version
206 # and yield from is not allowed in any version
206 vals = ('return', 'yield', 'yield from (_ for _ in range(3))')
207 vals = ('return', 'yield', 'yield from (_ for _ in range(3))')
207 async_safe = (True,
208 async_safe = (True,
208 True,
209 True,
209 False)
210 False)
210 vals = tuple(zip(vals, async_safe))
211 vals = tuple(zip(vals, async_safe))
211
212
212 success_tests = zip(self._get_top_level_cases(), repeat(False))
213 success_tests = zip(self._get_top_level_cases(), repeat(False))
213 failure_tests = zip(self._get_ry_syntax_errors(), repeat(True))
214 failure_tests = zip(self._get_ry_syntax_errors(), repeat(True))
214
215
215 tests = chain(success_tests, failure_tests)
216 tests = chain(success_tests, failure_tests)
216
217
217 for context_name, async_func, context in func_contexts:
218 for context_name, async_func, context in func_contexts:
218 for (test_name, test_case), should_fail in tests:
219 for (test_name, test_case), should_fail in tests:
219 nested_case = nest_case(context, test_case)
220 nested_case = nest_case(context, test_case)
220
221
221 for val, async_safe in vals:
222 for val, async_safe in vals:
222 val_should_fail = (should_fail or
223 val_should_fail = (should_fail or
223 (async_func and not async_safe))
224 (async_func and not async_safe))
224
225
225 test_id = (context_name, test_name, val)
226 test_id = (context_name, test_name, val)
226 cell = nested_case.format(val=val)
227 cell = nested_case.format(val=val)
227
228
228 with self.subTest(test_id):
229 with self.subTest(test_id):
229 if val_should_fail:
230 if val_should_fail:
230 msg = ("SyntaxError not raised for %s" %
231 msg = ("SyntaxError not raised for %s" %
231 str(test_id))
232 str(test_id))
232 with self.assertRaises(SyntaxError, msg=msg):
233 with self.assertRaises(SyntaxError, msg=msg):
233 iprc(cell)
234 iprc(cell)
234
235
235 print(cell)
236 print(cell)
236 else:
237 else:
237 iprc(cell)
238 iprc(cell)
238
239
239 def test_nonlocal(self):
240 def test_nonlocal(self):
240 # fails if outer scope is not a function scope or if var not defined
241 # fails if outer scope is not a function scope or if var not defined
241 with self.assertRaises(SyntaxError):
242 with self.assertRaises(SyntaxError):
242 iprc("nonlocal x")
243 iprc("nonlocal x")
243 iprc("""
244 iprc("""
244 x = 1
245 x = 1
245 def f():
246 def f():
246 nonlocal x
247 nonlocal x
247 x = 10000
248 x = 10000
248 yield x
249 yield x
249 """)
250 """)
250 iprc("""
251 iprc("""
251 def f():
252 def f():
252 def g():
253 def g():
253 nonlocal x
254 nonlocal x
254 x = 10000
255 x = 10000
255 yield x
256 yield x
256 """)
257 """)
257
258
258 # works if outer scope is a function scope and var exists
259 # works if outer scope is a function scope and var exists
259 iprc("""
260 iprc("""
260 def f():
261 def f():
261 x = 20
262 x = 20
262 def g():
263 def g():
263 nonlocal x
264 nonlocal x
264 x = 10000
265 x = 10000
265 yield x
266 yield x
266 """)
267 """)
267
268
268
269
269 def test_execute(self):
270 def test_execute(self):
270 iprc("""
271 iprc("""
271 import asyncio
272 import asyncio
272 await asyncio.sleep(0.001)
273 await asyncio.sleep(0.001)
273 """
274 """
274 )
275 )
275
276
276 def test_autoawait(self):
277 def test_autoawait(self):
277 iprc("%autoawait False")
278 iprc("%autoawait False")
278 iprc("%autoawait True")
279 iprc("%autoawait True")
279 iprc("""
280 iprc("""
280 from asyncio import sleep
281 from asyncio import sleep
281 await sleep(0.1)
282 await sleep(0.1)
282 """
283 """
283 )
284 )
284
285
285 if sys.version_info < (3,9):
286 if sys.version_info < (3,9):
286 # new pgen parser in 3.9 does not raise MemoryError on too many nested
287 # new pgen parser in 3.9 does not raise MemoryError on too many nested
287 # parens anymore
288 # parens anymore
288 def test_memory_error(self):
289 def test_memory_error(self):
289 with self.assertRaises(MemoryError):
290 with self.assertRaises(MemoryError):
290 iprc("(" * 200 + ")" * 200)
291 iprc("(" * 200 + ")" * 200)
291
292
292 @skip_without('curio')
293 @skip_without('curio')
293 def test_autoawait_curio(self):
294 def test_autoawait_curio(self):
294 iprc("%autoawait curio")
295 iprc("%autoawait curio")
295
296
296 @skip_without('trio')
297 @skip_without('trio')
297 def test_autoawait_trio(self):
298 def test_autoawait_trio(self):
298 iprc("%autoawait trio")
299 iprc("%autoawait trio")
299
300
300 @skip_without('trio')
301 @skip_without('trio')
301 def test_autoawait_trio_wrong_sleep(self):
302 def test_autoawait_trio_wrong_sleep(self):
302 iprc("%autoawait trio")
303 iprc("%autoawait trio")
303 res = iprc_nr("""
304 res = iprc_nr("""
304 import asyncio
305 import asyncio
305 await asyncio.sleep(0)
306 await asyncio.sleep(0)
306 """)
307 """)
307 with nt.assert_raises(TypeError):
308 with nt.assert_raises(TypeError):
308 res.raise_error()
309 res.raise_error()
309
310
310 @skip_without('trio')
311 @skip_without('trio')
311 def test_autoawait_asyncio_wrong_sleep(self):
312 def test_autoawait_asyncio_wrong_sleep(self):
312 iprc("%autoawait asyncio")
313 iprc("%autoawait asyncio")
313 res = iprc_nr("""
314 res = iprc_nr("""
314 import trio
315 import trio
315 await trio.sleep(0)
316 await trio.sleep(0)
316 """)
317 """)
317 with nt.assert_raises(RuntimeError):
318 with nt.assert_raises(RuntimeError):
318 res.raise_error()
319 res.raise_error()
319
320
320
321
321 def tearDown(self):
322 def tearDown(self):
322 ip.loop_runner = "asyncio"
323 ip.loop_runner = "asyncio"
General Comments 0
You need to be logged in to leave comments. Login now