##// END OF EJS Templates
Simplify assignment in test_async_helpers...
LeafyLi -
Show More
@@ -1,323 +1,322 b''
1 1 """
2 2 Test for async helpers.
3 3
4 4 Should only trigger on python 3.5+ or will have syntax errors.
5 5 """
6 6 from itertools import chain, repeat
7 7 import nose.tools as nt
8 8 from textwrap import dedent, indent
9 9 from unittest import TestCase
10 10 from IPython.testing.decorators import skip_without
11 11 import sys
12 12 from typing import TYPE_CHECKING
13 13
14 14 if TYPE_CHECKING:
15 15 from IPython import get_ipython
16 16
17 17 ip = get_ipython()
18 18
19 19
20 20 iprc = lambda x: ip.run_cell(dedent(x)).raise_error()
21 21 iprc_nr = lambda x: ip.run_cell(dedent(x))
22 22
23 23 from IPython.core.async_helpers import _should_be_async
24 24
25 25 class AsyncTest(TestCase):
26 26 def test_should_be_async(self):
27 27 nt.assert_false(_should_be_async("False"))
28 28 nt.assert_true(_should_be_async("await bar()"))
29 29 nt.assert_true(_should_be_async("x = await bar()"))
30 30 nt.assert_false(
31 31 _should_be_async(
32 32 dedent(
33 33 """
34 34 async def awaitable():
35 35 pass
36 36 """
37 37 )
38 38 )
39 39 )
40 40
41 41 def _get_top_level_cases(self):
42 42 # These are test cases that should be valid in a function
43 43 # but invalid outside of a function.
44 44 test_cases = []
45 45 test_cases.append(('basic', "{val}"))
46 46
47 47 # Note, in all conditional cases, I use True instead of
48 48 # False so that the peephole optimizer won't optimize away
49 49 # the return, so CPython will see this as a syntax error:
50 50 #
51 51 # while True:
52 52 # break
53 53 # return
54 54 #
55 55 # But not this:
56 56 #
57 57 # while False:
58 58 # return
59 59 #
60 60 # See https://bugs.python.org/issue1875
61 61
62 62 test_cases.append(('if', dedent("""
63 63 if True:
64 64 {val}
65 65 """)))
66 66
67 67 test_cases.append(('while', dedent("""
68 68 while True:
69 69 {val}
70 70 break
71 71 """)))
72 72
73 73 test_cases.append(('try', dedent("""
74 74 try:
75 75 {val}
76 76 except:
77 77 pass
78 78 """)))
79 79
80 80 test_cases.append(('except', dedent("""
81 81 try:
82 82 pass
83 83 except:
84 84 {val}
85 85 """)))
86 86
87 87 test_cases.append(('finally', dedent("""
88 88 try:
89 89 pass
90 90 except:
91 91 pass
92 92 finally:
93 93 {val}
94 94 """)))
95 95
96 96 test_cases.append(('for', dedent("""
97 97 for _ in range(4):
98 98 {val}
99 99 """)))
100 100
101 101
102 102 test_cases.append(('nested', dedent("""
103 103 if True:
104 104 while True:
105 105 {val}
106 106 break
107 107 """)))
108 108
109 109 test_cases.append(('deep-nested', dedent("""
110 110 if True:
111 111 while True:
112 112 break
113 113 for x in range(3):
114 114 if True:
115 115 while True:
116 116 for x in range(3):
117 117 {val}
118 118 """)))
119 119
120 120 return test_cases
121 121
122 122 def _get_ry_syntax_errors(self):
123 123 # This is a mix of tests that should be a syntax error if
124 124 # return or yield whether or not they are in a function
125 125
126 126 test_cases = []
127 127
128 128 test_cases.append(('class', dedent("""
129 129 class V:
130 130 {val}
131 131 """)))
132 132
133 133 test_cases.append(('nested-class', dedent("""
134 134 class V:
135 135 class C:
136 136 {val}
137 137 """)))
138 138
139 139 return test_cases
140 140
141 141
142 142 def test_top_level_return_error(self):
143 143 tl_err_test_cases = self._get_top_level_cases()
144 144 tl_err_test_cases.extend(self._get_ry_syntax_errors())
145 145
146 146 vals = ('return', 'yield', 'yield from (_ for _ in range(3))',
147 147 dedent('''
148 148 def f():
149 149 pass
150 150 return
151 151 '''),
152 152 )
153 153
154 154 for test_name, test_case in tl_err_test_cases:
155 155 # This example should work if 'pass' is used as the value
156 156 with self.subTest((test_name, 'pass')):
157 157 iprc(test_case.format(val='pass'))
158 158
159 159 # It should fail with all the values
160 160 for val in vals:
161 161 with self.subTest((test_name, val)):
162 162 msg = "Syntax error not raised for %s, %s" % (test_name, val)
163 163 with self.assertRaises(SyntaxError, msg=msg):
164 164 iprc(test_case.format(val=val))
165 165
166 166 def test_in_func_no_error(self):
167 167 # Test that the implementation of top-level return/yield
168 168 # detection isn't *too* aggressive, and works inside a function
169 169 func_contexts = []
170 170
171 171 func_contexts.append(('func', False, dedent("""
172 172 def f():""")))
173 173
174 174 func_contexts.append(('method', False, dedent("""
175 175 class MyClass:
176 176 def __init__(self):
177 177 """)))
178 178
179 179 func_contexts.append(('async-func', True, dedent("""
180 180 async def f():""")))
181 181
182 182 func_contexts.append(('async-method', True, dedent("""
183 183 class MyClass:
184 184 async def f(self):""")))
185 185
186 186 func_contexts.append(('closure', False, dedent("""
187 187 def f():
188 188 def g():
189 189 """)))
190 190
191 191 def nest_case(context, case):
192 192 # Detect indentation
193 193 lines = context.strip().splitlines()
194 194 prefix_len = 0
195 195 for c in lines[-1]:
196 196 if c != ' ':
197 197 break
198 198 prefix_len += 1
199 199
200 200 indented_case = indent(case, ' ' * (prefix_len + 4))
201 201 return context + '\n' + indented_case
202 202
203 203 # Gather and run the tests
204 204
205 205 # yield is allowed in async functions, starting in Python 3.6,
206 206 # and yield from is not allowed in any version
207 207 vals = ('return', 'yield', 'yield from (_ for _ in range(3))')
208 208 async_safe = (True,
209 209 True,
210 210 False)
211 211 vals = tuple(zip(vals, async_safe))
212 212
213 213 success_tests = zip(self._get_top_level_cases(), repeat(False))
214 214 failure_tests = zip(self._get_ry_syntax_errors(), repeat(True))
215 215
216 216 tests = chain(success_tests, failure_tests)
217 217
218 218 for context_name, async_func, context in func_contexts:
219 219 for (test_name, test_case), should_fail in tests:
220 220 nested_case = nest_case(context, test_case)
221 221
222 222 for val, async_safe in vals:
223 val_should_fail = (should_fail or
224 (async_func and not async_safe))
223 val_should_fail = should_fail
225 224
226 225 test_id = (context_name, test_name, val)
227 226 cell = nested_case.format(val=val)
228 227
229 228 with self.subTest(test_id):
230 229 if val_should_fail:
231 230 msg = ("SyntaxError not raised for %s" %
232 231 str(test_id))
233 232 with self.assertRaises(SyntaxError, msg=msg):
234 233 iprc(cell)
235 234
236 235 print(cell)
237 236 else:
238 237 iprc(cell)
239 238
240 239 def test_nonlocal(self):
241 240 # fails if outer scope is not a function scope or if var not defined
242 241 with self.assertRaises(SyntaxError):
243 242 iprc("nonlocal x")
244 243 iprc("""
245 244 x = 1
246 245 def f():
247 246 nonlocal x
248 247 x = 10000
249 248 yield x
250 249 """)
251 250 iprc("""
252 251 def f():
253 252 def g():
254 253 nonlocal x
255 254 x = 10000
256 255 yield x
257 256 """)
258 257
259 258 # works if outer scope is a function scope and var exists
260 259 iprc("""
261 260 def f():
262 261 x = 20
263 262 def g():
264 263 nonlocal x
265 264 x = 10000
266 265 yield x
267 266 """)
268 267
269 268
270 269 def test_execute(self):
271 270 iprc("""
272 271 import asyncio
273 272 await asyncio.sleep(0.001)
274 273 """
275 274 )
276 275
277 276 def test_autoawait(self):
278 277 iprc("%autoawait False")
279 278 iprc("%autoawait True")
280 279 iprc("""
281 280 from asyncio import sleep
282 281 await sleep(0.1)
283 282 """
284 283 )
285 284
286 285 if sys.version_info < (3,9):
287 286 # new pgen parser in 3.9 does not raise MemoryError on too many nested
288 287 # parens anymore
289 288 def test_memory_error(self):
290 289 with self.assertRaises(MemoryError):
291 290 iprc("(" * 200 + ")" * 200)
292 291
293 292 @skip_without('curio')
294 293 def test_autoawait_curio(self):
295 294 iprc("%autoawait curio")
296 295
297 296 @skip_without('trio')
298 297 def test_autoawait_trio(self):
299 298 iprc("%autoawait trio")
300 299
301 300 @skip_without('trio')
302 301 def test_autoawait_trio_wrong_sleep(self):
303 302 iprc("%autoawait trio")
304 303 res = iprc_nr("""
305 304 import asyncio
306 305 await asyncio.sleep(0)
307 306 """)
308 307 with nt.assert_raises(TypeError):
309 308 res.raise_error()
310 309
311 310 @skip_without('trio')
312 311 def test_autoawait_asyncio_wrong_sleep(self):
313 312 iprc("%autoawait asyncio")
314 313 res = iprc_nr("""
315 314 import trio
316 315 await trio.sleep(0)
317 316 """)
318 317 with nt.assert_raises(RuntimeError):
319 318 res.raise_error()
320 319
321 320
322 321 def tearDown(self):
323 322 ip.loop_runner = "asyncio"
General Comments 0
You need to be logged in to leave comments. Login now