##// END OF EJS Templates
Update IPython/core/tests/test_async_helpers.py
Matthias Bussonnier -
Show More
@@ -1,322 +1,323 b''
1 1 """
2 2 Test for async helpers.
3 3
4 4 Should only trigger on python 3.5+ or will have syntax errors.
5 5 """
6 6 from itertools import chain, repeat
7 7 import nose.tools as nt
8 8 from textwrap import dedent, indent
9 9 from unittest import TestCase
10 10 from IPython.testing.decorators import skip_without
11 11 import sys
12 12 from typing import TYPE_CHECKING
13 13
14 14 if TYPE_CHECKING:
15 15 from IPython import get_ipython
16
16 17 ip = get_ipython()
17 18
18 19
19 20 iprc = lambda x: ip.run_cell(dedent(x)).raise_error()
20 21 iprc_nr = lambda x: ip.run_cell(dedent(x))
21 22
22 23 from IPython.core.async_helpers import _should_be_async
23 24
24 25 class AsyncTest(TestCase):
25 26 def test_should_be_async(self):
26 27 nt.assert_false(_should_be_async("False"))
27 28 nt.assert_true(_should_be_async("await bar()"))
28 29 nt.assert_true(_should_be_async("x = await bar()"))
29 30 nt.assert_false(
30 31 _should_be_async(
31 32 dedent(
32 33 """
33 34 async def awaitable():
34 35 pass
35 36 """
36 37 )
37 38 )
38 39 )
39 40
40 41 def _get_top_level_cases(self):
41 42 # These are test cases that should be valid in a function
42 43 # but invalid outside of a function.
43 44 test_cases = []
44 45 test_cases.append(('basic', "{val}"))
45 46
46 47 # Note, in all conditional cases, I use True instead of
47 48 # False so that the peephole optimizer won't optimize away
48 49 # the return, so CPython will see this as a syntax error:
49 50 #
50 51 # while True:
51 52 # break
52 53 # return
53 54 #
54 55 # But not this:
55 56 #
56 57 # while False:
57 58 # return
58 59 #
59 60 # See https://bugs.python.org/issue1875
60 61
61 62 test_cases.append(('if', dedent("""
62 63 if True:
63 64 {val}
64 65 """)))
65 66
66 67 test_cases.append(('while', dedent("""
67 68 while True:
68 69 {val}
69 70 break
70 71 """)))
71 72
72 73 test_cases.append(('try', dedent("""
73 74 try:
74 75 {val}
75 76 except:
76 77 pass
77 78 """)))
78 79
79 80 test_cases.append(('except', dedent("""
80 81 try:
81 82 pass
82 83 except:
83 84 {val}
84 85 """)))
85 86
86 87 test_cases.append(('finally', dedent("""
87 88 try:
88 89 pass
89 90 except:
90 91 pass
91 92 finally:
92 93 {val}
93 94 """)))
94 95
95 96 test_cases.append(('for', dedent("""
96 97 for _ in range(4):
97 98 {val}
98 99 """)))
99 100
100 101
101 102 test_cases.append(('nested', dedent("""
102 103 if True:
103 104 while True:
104 105 {val}
105 106 break
106 107 """)))
107 108
108 109 test_cases.append(('deep-nested', dedent("""
109 110 if True:
110 111 while True:
111 112 break
112 113 for x in range(3):
113 114 if True:
114 115 while True:
115 116 for x in range(3):
116 117 {val}
117 118 """)))
118 119
119 120 return test_cases
120 121
121 122 def _get_ry_syntax_errors(self):
122 123 # This is a mix of tests that should be a syntax error if
123 124 # return or yield whether or not they are in a function
124 125
125 126 test_cases = []
126 127
127 128 test_cases.append(('class', dedent("""
128 129 class V:
129 130 {val}
130 131 """)))
131 132
132 133 test_cases.append(('nested-class', dedent("""
133 134 class V:
134 135 class C:
135 136 {val}
136 137 """)))
137 138
138 139 return test_cases
139 140
140 141
141 142 def test_top_level_return_error(self):
142 143 tl_err_test_cases = self._get_top_level_cases()
143 144 tl_err_test_cases.extend(self._get_ry_syntax_errors())
144 145
145 146 vals = ('return', 'yield', 'yield from (_ for _ in range(3))',
146 147 dedent('''
147 148 def f():
148 149 pass
149 150 return
150 151 '''),
151 152 )
152 153
153 154 for test_name, test_case in tl_err_test_cases:
154 155 # This example should work if 'pass' is used as the value
155 156 with self.subTest((test_name, 'pass')):
156 157 iprc(test_case.format(val='pass'))
157 158
158 159 # It should fail with all the values
159 160 for val in vals:
160 161 with self.subTest((test_name, val)):
161 162 msg = "Syntax error not raised for %s, %s" % (test_name, val)
162 163 with self.assertRaises(SyntaxError, msg=msg):
163 164 iprc(test_case.format(val=val))
164 165
165 166 def test_in_func_no_error(self):
166 167 # Test that the implementation of top-level return/yield
167 168 # detection isn't *too* aggressive, and works inside a function
168 169 func_contexts = []
169 170
170 171 func_contexts.append(('func', False, dedent("""
171 172 def f():""")))
172 173
173 174 func_contexts.append(('method', False, dedent("""
174 175 class MyClass:
175 176 def __init__(self):
176 177 """)))
177 178
178 179 func_contexts.append(('async-func', True, dedent("""
179 180 async def f():""")))
180 181
181 182 func_contexts.append(('async-method', True, dedent("""
182 183 class MyClass:
183 184 async def f(self):""")))
184 185
185 186 func_contexts.append(('closure', False, dedent("""
186 187 def f():
187 188 def g():
188 189 """)))
189 190
190 191 def nest_case(context, case):
191 192 # Detect indentation
192 193 lines = context.strip().splitlines()
193 194 prefix_len = 0
194 195 for c in lines[-1]:
195 196 if c != ' ':
196 197 break
197 198 prefix_len += 1
198 199
199 200 indented_case = indent(case, ' ' * (prefix_len + 4))
200 201 return context + '\n' + indented_case
201 202
202 203 # Gather and run the tests
203 204
204 205 # yield is allowed in async functions, starting in Python 3.6,
205 206 # and yield from is not allowed in any version
206 207 vals = ('return', 'yield', 'yield from (_ for _ in range(3))')
207 208 async_safe = (True,
208 209 True,
209 210 False)
210 211 vals = tuple(zip(vals, async_safe))
211 212
212 213 success_tests = zip(self._get_top_level_cases(), repeat(False))
213 214 failure_tests = zip(self._get_ry_syntax_errors(), repeat(True))
214 215
215 216 tests = chain(success_tests, failure_tests)
216 217
217 218 for context_name, async_func, context in func_contexts:
218 219 for (test_name, test_case), should_fail in tests:
219 220 nested_case = nest_case(context, test_case)
220 221
221 222 for val, async_safe in vals:
222 223 val_should_fail = (should_fail or
223 224 (async_func and not async_safe))
224 225
225 226 test_id = (context_name, test_name, val)
226 227 cell = nested_case.format(val=val)
227 228
228 229 with self.subTest(test_id):
229 230 if val_should_fail:
230 231 msg = ("SyntaxError not raised for %s" %
231 232 str(test_id))
232 233 with self.assertRaises(SyntaxError, msg=msg):
233 234 iprc(cell)
234 235
235 236 print(cell)
236 237 else:
237 238 iprc(cell)
238 239
239 240 def test_nonlocal(self):
240 241 # fails if outer scope is not a function scope or if var not defined
241 242 with self.assertRaises(SyntaxError):
242 243 iprc("nonlocal x")
243 244 iprc("""
244 245 x = 1
245 246 def f():
246 247 nonlocal x
247 248 x = 10000
248 249 yield x
249 250 """)
250 251 iprc("""
251 252 def f():
252 253 def g():
253 254 nonlocal x
254 255 x = 10000
255 256 yield x
256 257 """)
257 258
258 259 # works if outer scope is a function scope and var exists
259 260 iprc("""
260 261 def f():
261 262 x = 20
262 263 def g():
263 264 nonlocal x
264 265 x = 10000
265 266 yield x
266 267 """)
267 268
268 269
269 270 def test_execute(self):
270 271 iprc("""
271 272 import asyncio
272 273 await asyncio.sleep(0.001)
273 274 """
274 275 )
275 276
276 277 def test_autoawait(self):
277 278 iprc("%autoawait False")
278 279 iprc("%autoawait True")
279 280 iprc("""
280 281 from asyncio import sleep
281 282 await sleep(0.1)
282 283 """
283 284 )
284 285
285 286 if sys.version_info < (3,9):
286 287 # new pgen parser in 3.9 does not raise MemoryError on too many nested
287 288 # parens anymore
288 289 def test_memory_error(self):
289 290 with self.assertRaises(MemoryError):
290 291 iprc("(" * 200 + ")" * 200)
291 292
292 293 @skip_without('curio')
293 294 def test_autoawait_curio(self):
294 295 iprc("%autoawait curio")
295 296
296 297 @skip_without('trio')
297 298 def test_autoawait_trio(self):
298 299 iprc("%autoawait trio")
299 300
300 301 @skip_without('trio')
301 302 def test_autoawait_trio_wrong_sleep(self):
302 303 iprc("%autoawait trio")
303 304 res = iprc_nr("""
304 305 import asyncio
305 306 await asyncio.sleep(0)
306 307 """)
307 308 with nt.assert_raises(TypeError):
308 309 res.raise_error()
309 310
310 311 @skip_without('trio')
311 312 def test_autoawait_asyncio_wrong_sleep(self):
312 313 iprc("%autoawait asyncio")
313 314 res = iprc_nr("""
314 315 import trio
315 316 await trio.sleep(0)
316 317 """)
317 318 with nt.assert_raises(RuntimeError):
318 319 res.raise_error()
319 320
320 321
321 322 def tearDown(self):
322 323 ip.loop_runner = "asyncio"
General Comments 0
You need to be logged in to leave comments. Login now