##// END OF EJS Templates
`async_safe` cleanup
BaoGiang HoangVu -
Show More
@@ -1,322 +1,316 b''
1 1 """
2 2 Test for async helpers.
3 3
4 4 Should only trigger on python 3.5+ or will have syntax errors.
5 5 """
6 6 from itertools import chain, repeat
7 7 import nose.tools as nt
8 8 from textwrap import dedent, indent
9 9 from unittest import TestCase
10 10 from IPython.testing.decorators import skip_without
11 11 import sys
12 12 from typing import TYPE_CHECKING
13 13
14 14 if TYPE_CHECKING:
15 15 from IPython import get_ipython
16 16
17 17 ip = get_ipython()
18 18
19 19
20 20 iprc = lambda x: ip.run_cell(dedent(x)).raise_error()
21 21 iprc_nr = lambda x: ip.run_cell(dedent(x))
22 22
23 23 from IPython.core.async_helpers import _should_be_async
24 24
25 25 class AsyncTest(TestCase):
26 26 def test_should_be_async(self):
27 27 nt.assert_false(_should_be_async("False"))
28 28 nt.assert_true(_should_be_async("await bar()"))
29 29 nt.assert_true(_should_be_async("x = await bar()"))
30 30 nt.assert_false(
31 31 _should_be_async(
32 32 dedent(
33 33 """
34 34 async def awaitable():
35 35 pass
36 36 """
37 37 )
38 38 )
39 39 )
40 40
41 41 def _get_top_level_cases(self):
42 42 # These are test cases that should be valid in a function
43 43 # but invalid outside of a function.
44 44 test_cases = []
45 45 test_cases.append(('basic', "{val}"))
46 46
47 47 # Note, in all conditional cases, I use True instead of
48 48 # False so that the peephole optimizer won't optimize away
49 49 # the return, so CPython will see this as a syntax error:
50 50 #
51 51 # while True:
52 52 # break
53 53 # return
54 54 #
55 55 # But not this:
56 56 #
57 57 # while False:
58 58 # return
59 59 #
60 60 # See https://bugs.python.org/issue1875
61 61
62 62 test_cases.append(('if', dedent("""
63 63 if True:
64 64 {val}
65 65 """)))
66 66
67 67 test_cases.append(('while', dedent("""
68 68 while True:
69 69 {val}
70 70 break
71 71 """)))
72 72
73 73 test_cases.append(('try', dedent("""
74 74 try:
75 75 {val}
76 76 except:
77 77 pass
78 78 """)))
79 79
80 80 test_cases.append(('except', dedent("""
81 81 try:
82 82 pass
83 83 except:
84 84 {val}
85 85 """)))
86 86
87 87 test_cases.append(('finally', dedent("""
88 88 try:
89 89 pass
90 90 except:
91 91 pass
92 92 finally:
93 93 {val}
94 94 """)))
95 95
96 96 test_cases.append(('for', dedent("""
97 97 for _ in range(4):
98 98 {val}
99 99 """)))
100 100
101 101
102 102 test_cases.append(('nested', dedent("""
103 103 if True:
104 104 while True:
105 105 {val}
106 106 break
107 107 """)))
108 108
109 109 test_cases.append(('deep-nested', dedent("""
110 110 if True:
111 111 while True:
112 112 break
113 113 for x in range(3):
114 114 if True:
115 115 while True:
116 116 for x in range(3):
117 117 {val}
118 118 """)))
119 119
120 120 return test_cases
121 121
122 122 def _get_ry_syntax_errors(self):
123 123 # This is a mix of tests that should be a syntax error if
124 124 # return or yield whether or not they are in a function
125 125
126 126 test_cases = []
127 127
128 128 test_cases.append(('class', dedent("""
129 129 class V:
130 130 {val}
131 131 """)))
132 132
133 133 test_cases.append(('nested-class', dedent("""
134 134 class V:
135 135 class C:
136 136 {val}
137 137 """)))
138 138
139 139 return test_cases
140 140
141 141
142 142 def test_top_level_return_error(self):
143 143 tl_err_test_cases = self._get_top_level_cases()
144 144 tl_err_test_cases.extend(self._get_ry_syntax_errors())
145 145
146 146 vals = ('return', 'yield', 'yield from (_ for _ in range(3))',
147 147 dedent('''
148 148 def f():
149 149 pass
150 150 return
151 151 '''),
152 152 )
153 153
154 154 for test_name, test_case in tl_err_test_cases:
155 155 # This example should work if 'pass' is used as the value
156 156 with self.subTest((test_name, 'pass')):
157 157 iprc(test_case.format(val='pass'))
158 158
159 159 # It should fail with all the values
160 160 for val in vals:
161 161 with self.subTest((test_name, val)):
162 162 msg = "Syntax error not raised for %s, %s" % (test_name, val)
163 163 with self.assertRaises(SyntaxError, msg=msg):
164 164 iprc(test_case.format(val=val))
165 165
166 166 def test_in_func_no_error(self):
167 167 # Test that the implementation of top-level return/yield
168 168 # detection isn't *too* aggressive, and works inside a function
169 169 func_contexts = []
170 170
171 171 func_contexts.append(('func', False, dedent("""
172 172 def f():""")))
173 173
174 174 func_contexts.append(('method', False, dedent("""
175 175 class MyClass:
176 176 def __init__(self):
177 177 """)))
178 178
179 179 func_contexts.append(('async-func', True, dedent("""
180 180 async def f():""")))
181 181
182 182 func_contexts.append(('async-method', True, dedent("""
183 183 class MyClass:
184 184 async def f(self):""")))
185 185
186 186 func_contexts.append(('closure', False, dedent("""
187 187 def f():
188 188 def g():
189 189 """)))
190 190
191 191 def nest_case(context, case):
192 192 # Detect indentation
193 193 lines = context.strip().splitlines()
194 194 prefix_len = 0
195 195 for c in lines[-1]:
196 196 if c != ' ':
197 197 break
198 198 prefix_len += 1
199 199
200 200 indented_case = indent(case, ' ' * (prefix_len + 4))
201 201 return context + '\n' + indented_case
202 202
203 203 # Gather and run the tests
204 204
205 205 # yield is allowed in async functions, starting in Python 3.6,
206 206 # and yield from is not allowed in any version
207 207 vals = ('return', 'yield', 'yield from (_ for _ in range(3))')
208 async_safe = (True,
209 True,
210 False)
211 vals = tuple(zip(vals, async_safe))
212 208
213 209 success_tests = zip(self._get_top_level_cases(), repeat(False))
214 210 failure_tests = zip(self._get_ry_syntax_errors(), repeat(True))
215 211
216 212 tests = chain(success_tests, failure_tests)
217 213
218 214 for context_name, async_func, context in func_contexts:
219 215 for (test_name, test_case), should_fail in tests:
220 216 nested_case = nest_case(context, test_case)
221 217
222 for val, async_safe in vals:
223 val_should_fail = should_fail
224
218 for val in vals:
225 219 test_id = (context_name, test_name, val)
226 220 cell = nested_case.format(val=val)
227 221
228 222 with self.subTest(test_id):
229 if val_should_fail:
223 if should_fail:
230 224 msg = ("SyntaxError not raised for %s" %
231 225 str(test_id))
232 226 with self.assertRaises(SyntaxError, msg=msg):
233 227 iprc(cell)
234 228
235 229 print(cell)
236 230 else:
237 231 iprc(cell)
238 232
239 233 def test_nonlocal(self):
240 234 # fails if outer scope is not a function scope or if var not defined
241 235 with self.assertRaises(SyntaxError):
242 236 iprc("nonlocal x")
243 237 iprc("""
244 238 x = 1
245 239 def f():
246 240 nonlocal x
247 241 x = 10000
248 242 yield x
249 243 """)
250 244 iprc("""
251 245 def f():
252 246 def g():
253 247 nonlocal x
254 248 x = 10000
255 249 yield x
256 250 """)
257 251
258 252 # works if outer scope is a function scope and var exists
259 253 iprc("""
260 254 def f():
261 255 x = 20
262 256 def g():
263 257 nonlocal x
264 258 x = 10000
265 259 yield x
266 260 """)
267 261
268 262
269 263 def test_execute(self):
270 264 iprc("""
271 265 import asyncio
272 266 await asyncio.sleep(0.001)
273 267 """
274 268 )
275 269
276 270 def test_autoawait(self):
277 271 iprc("%autoawait False")
278 272 iprc("%autoawait True")
279 273 iprc("""
280 274 from asyncio import sleep
281 275 await sleep(0.1)
282 276 """
283 277 )
284 278
285 279 if sys.version_info < (3,9):
286 280 # new pgen parser in 3.9 does not raise MemoryError on too many nested
287 281 # parens anymore
288 282 def test_memory_error(self):
289 283 with self.assertRaises(MemoryError):
290 284 iprc("(" * 200 + ")" * 200)
291 285
292 286 @skip_without('curio')
293 287 def test_autoawait_curio(self):
294 288 iprc("%autoawait curio")
295 289
296 290 @skip_without('trio')
297 291 def test_autoawait_trio(self):
298 292 iprc("%autoawait trio")
299 293
300 294 @skip_without('trio')
301 295 def test_autoawait_trio_wrong_sleep(self):
302 296 iprc("%autoawait trio")
303 297 res = iprc_nr("""
304 298 import asyncio
305 299 await asyncio.sleep(0)
306 300 """)
307 301 with nt.assert_raises(TypeError):
308 302 res.raise_error()
309 303
310 304 @skip_without('trio')
311 305 def test_autoawait_asyncio_wrong_sleep(self):
312 306 iprc("%autoawait asyncio")
313 307 res = iprc_nr("""
314 308 import trio
315 309 await trio.sleep(0)
316 310 """)
317 311 with nt.assert_raises(RuntimeError):
318 312 res.raise_error()
319 313
320 314
321 315 def tearDown(self):
322 316 ip.loop_runner = "asyncio"
General Comments 0
You need to be logged in to leave comments. Login now