##// END OF EJS Templates
BUG: Fix mypy error...
farisachugthai -
Show More
@@ -1,316 +1,322 b''
1 1 """
2 2 Test for async helpers.
3 3
4 4 Should only trigger on python 3.5+ or will have syntax errors.
5 5 """
6 6 from itertools import chain, repeat
7 7 import nose.tools as nt
8 8 from textwrap import dedent, indent
9 9 from unittest import TestCase
10 10 from IPython.testing.decorators import skip_without
11 11 import sys
12 from typing import TYPE_CHECKING
13
14 if TYPE_CHECKING:
15 from IPython import get_ipython
16 ip = get_ipython()
17
12 18
13 19 iprc = lambda x: ip.run_cell(dedent(x)).raise_error()
14 20 iprc_nr = lambda x: ip.run_cell(dedent(x))
15 21
16 22 from IPython.core.async_helpers import _should_be_async
17 23
18 24 class AsyncTest(TestCase):
19 25 def test_should_be_async(self):
20 26 nt.assert_false(_should_be_async("False"))
21 27 nt.assert_true(_should_be_async("await bar()"))
22 28 nt.assert_true(_should_be_async("x = await bar()"))
23 29 nt.assert_false(
24 30 _should_be_async(
25 31 dedent(
26 32 """
27 33 async def awaitable():
28 34 pass
29 35 """
30 36 )
31 37 )
32 38 )
33 39
34 40 def _get_top_level_cases(self):
35 41 # These are test cases that should be valid in a function
36 42 # but invalid outside of a function.
37 43 test_cases = []
38 44 test_cases.append(('basic', "{val}"))
39 45
40 46 # Note, in all conditional cases, I use True instead of
41 47 # False so that the peephole optimizer won't optimize away
42 48 # the return, so CPython will see this as a syntax error:
43 49 #
44 50 # while True:
45 51 # break
46 52 # return
47 53 #
48 54 # But not this:
49 55 #
50 56 # while False:
51 57 # return
52 58 #
53 59 # See https://bugs.python.org/issue1875
54 60
55 61 test_cases.append(('if', dedent("""
56 62 if True:
57 63 {val}
58 64 """)))
59 65
60 66 test_cases.append(('while', dedent("""
61 67 while True:
62 68 {val}
63 69 break
64 70 """)))
65 71
66 72 test_cases.append(('try', dedent("""
67 73 try:
68 74 {val}
69 75 except:
70 76 pass
71 77 """)))
72 78
73 79 test_cases.append(('except', dedent("""
74 80 try:
75 81 pass
76 82 except:
77 83 {val}
78 84 """)))
79 85
80 86 test_cases.append(('finally', dedent("""
81 87 try:
82 88 pass
83 89 except:
84 90 pass
85 91 finally:
86 92 {val}
87 93 """)))
88 94
89 95 test_cases.append(('for', dedent("""
90 96 for _ in range(4):
91 97 {val}
92 98 """)))
93 99
94 100
95 101 test_cases.append(('nested', dedent("""
96 102 if True:
97 103 while True:
98 104 {val}
99 105 break
100 106 """)))
101 107
102 108 test_cases.append(('deep-nested', dedent("""
103 109 if True:
104 110 while True:
105 111 break
106 112 for x in range(3):
107 113 if True:
108 114 while True:
109 115 for x in range(3):
110 116 {val}
111 117 """)))
112 118
113 119 return test_cases
114 120
115 121 def _get_ry_syntax_errors(self):
116 122 # This is a mix of tests that should be a syntax error if
117 123 # return or yield whether or not they are in a function
118 124
119 125 test_cases = []
120 126
121 127 test_cases.append(('class', dedent("""
122 128 class V:
123 129 {val}
124 130 """)))
125 131
126 132 test_cases.append(('nested-class', dedent("""
127 133 class V:
128 134 class C:
129 135 {val}
130 136 """)))
131 137
132 138 return test_cases
133 139
134 140
135 141 def test_top_level_return_error(self):
136 142 tl_err_test_cases = self._get_top_level_cases()
137 143 tl_err_test_cases.extend(self._get_ry_syntax_errors())
138 144
139 145 vals = ('return', 'yield', 'yield from (_ for _ in range(3))',
140 146 dedent('''
141 147 def f():
142 148 pass
143 149 return
144 150 '''),
145 151 )
146 152
147 153 for test_name, test_case in tl_err_test_cases:
148 154 # This example should work if 'pass' is used as the value
149 155 with self.subTest((test_name, 'pass')):
150 156 iprc(test_case.format(val='pass'))
151 157
152 158 # It should fail with all the values
153 159 for val in vals:
154 160 with self.subTest((test_name, val)):
155 161 msg = "Syntax error not raised for %s, %s" % (test_name, val)
156 162 with self.assertRaises(SyntaxError, msg=msg):
157 163 iprc(test_case.format(val=val))
158 164
159 165 def test_in_func_no_error(self):
160 166 # Test that the implementation of top-level return/yield
161 167 # detection isn't *too* aggressive, and works inside a function
162 168 func_contexts = []
163 169
164 170 func_contexts.append(('func', False, dedent("""
165 171 def f():""")))
166 172
167 173 func_contexts.append(('method', False, dedent("""
168 174 class MyClass:
169 175 def __init__(self):
170 176 """)))
171 177
172 178 func_contexts.append(('async-func', True, dedent("""
173 179 async def f():""")))
174 180
175 181 func_contexts.append(('async-method', True, dedent("""
176 182 class MyClass:
177 183 async def f(self):""")))
178 184
179 185 func_contexts.append(('closure', False, dedent("""
180 186 def f():
181 187 def g():
182 188 """)))
183 189
184 190 def nest_case(context, case):
185 191 # Detect indentation
186 192 lines = context.strip().splitlines()
187 193 prefix_len = 0
188 194 for c in lines[-1]:
189 195 if c != ' ':
190 196 break
191 197 prefix_len += 1
192 198
193 199 indented_case = indent(case, ' ' * (prefix_len + 4))
194 200 return context + '\n' + indented_case
195 201
196 202 # Gather and run the tests
197 203
198 204 # yield is allowed in async functions, starting in Python 3.6,
199 205 # and yield from is not allowed in any version
200 206 vals = ('return', 'yield', 'yield from (_ for _ in range(3))')
201 207 async_safe = (True,
202 208 True,
203 209 False)
204 210 vals = tuple(zip(vals, async_safe))
205 211
206 212 success_tests = zip(self._get_top_level_cases(), repeat(False))
207 213 failure_tests = zip(self._get_ry_syntax_errors(), repeat(True))
208 214
209 215 tests = chain(success_tests, failure_tests)
210 216
211 217 for context_name, async_func, context in func_contexts:
212 218 for (test_name, test_case), should_fail in tests:
213 219 nested_case = nest_case(context, test_case)
214 220
215 221 for val, async_safe in vals:
216 222 val_should_fail = (should_fail or
217 223 (async_func and not async_safe))
218 224
219 225 test_id = (context_name, test_name, val)
220 226 cell = nested_case.format(val=val)
221 227
222 228 with self.subTest(test_id):
223 229 if val_should_fail:
224 230 msg = ("SyntaxError not raised for %s" %
225 231 str(test_id))
226 232 with self.assertRaises(SyntaxError, msg=msg):
227 233 iprc(cell)
228 234
229 235 print(cell)
230 236 else:
231 237 iprc(cell)
232 238
233 239 def test_nonlocal(self):
234 240 # fails if outer scope is not a function scope or if var not defined
235 241 with self.assertRaises(SyntaxError):
236 242 iprc("nonlocal x")
237 243 iprc("""
238 244 x = 1
239 245 def f():
240 246 nonlocal x
241 247 x = 10000
242 248 yield x
243 249 """)
244 250 iprc("""
245 251 def f():
246 252 def g():
247 253 nonlocal x
248 254 x = 10000
249 255 yield x
250 256 """)
251 257
252 258 # works if outer scope is a function scope and var exists
253 259 iprc("""
254 260 def f():
255 261 x = 20
256 262 def g():
257 263 nonlocal x
258 264 x = 10000
259 265 yield x
260 266 """)
261 267
262 268
263 269 def test_execute(self):
264 270 iprc("""
265 271 import asyncio
266 272 await asyncio.sleep(0.001)
267 273 """
268 274 )
269 275
270 276 def test_autoawait(self):
271 277 iprc("%autoawait False")
272 278 iprc("%autoawait True")
273 279 iprc("""
274 280 from asyncio import sleep
275 281 await sleep(0.1)
276 282 """
277 283 )
278 284
279 285 if sys.version_info < (3,9):
280 286 # new pgen parser in 3.9 does not raise MemoryError on too many nested
281 287 # parens anymore
282 288 def test_memory_error(self):
283 289 with self.assertRaises(MemoryError):
284 290 iprc("(" * 200 + ")" * 200)
285 291
286 292 @skip_without('curio')
287 293 def test_autoawait_curio(self):
288 294 iprc("%autoawait curio")
289 295
290 296 @skip_without('trio')
291 297 def test_autoawait_trio(self):
292 298 iprc("%autoawait trio")
293 299
294 300 @skip_without('trio')
295 301 def test_autoawait_trio_wrong_sleep(self):
296 302 iprc("%autoawait trio")
297 303 res = iprc_nr("""
298 304 import asyncio
299 305 await asyncio.sleep(0)
300 306 """)
301 307 with nt.assert_raises(TypeError):
302 308 res.raise_error()
303 309
304 310 @skip_without('trio')
305 311 def test_autoawait_asyncio_wrong_sleep(self):
306 312 iprc("%autoawait asyncio")
307 313 res = iprc_nr("""
308 314 import trio
309 315 await trio.sleep(0)
310 316 """)
311 317 with nt.assert_raises(RuntimeError):
312 318 res.raise_error()
313 319
314 320
315 321 def tearDown(self):
316 322 ip.loop_runner = "asyncio"
General Comments 0
You need to be logged in to leave comments. Login now