##// END OF EJS Templates
Merge pull request #11641 from Carreau/fix-return-syntax...
Matthias Bussonnier -
r24966:e3f9954e merge
parent child Browse files
Show More
@@ -1,165 +1,166
1 1 """
2 2 Async helper function that are invalid syntax on Python 3.5 and below.
3 3
4 4 This code is best effort, and may have edge cases not behaving as expected. In
5 5 particular it contain a number of heuristics to detect whether code is
6 6 effectively async and need to run in an event loop or not.
7 7
8 8 Some constructs (like top-level `return`, or `yield`) are taken care of
9 9 explicitly to actually raise a SyntaxError and stay as close as possible to
10 10 Python semantics.
11 11 """
12 12
13 13
14 14 import ast
15 15 import sys
16 16 from textwrap import dedent, indent
17 17
18 18
19 19 class _AsyncIORunner:
20 20
21 21 def __call__(self, coro):
22 22 """
23 23 Handler for asyncio autoawait
24 24 """
25 25 import asyncio
26 26
27 27 return asyncio.get_event_loop().run_until_complete(coro)
28 28
29 29 def __str__(self):
30 30 return 'asyncio'
31 31
32 32 _asyncio_runner = _AsyncIORunner()
33 33
34 34
35 35 def _curio_runner(coroutine):
36 36 """
37 37 handler for curio autoawait
38 38 """
39 39 import curio
40 40
41 41 return curio.run(coroutine)
42 42
43 43
44 44 def _trio_runner(async_fn):
45 45 import trio
46 46
47 47 async def loc(coro):
48 48 """
49 49 We need the dummy no-op async def to protect from
50 50 trio's internal. See https://github.com/python-trio/trio/issues/89
51 51 """
52 52 return await coro
53 53
54 54 return trio.run(loc, async_fn)
55 55
56 56
57 57 def _pseudo_sync_runner(coro):
58 58 """
59 59 A runner that does not really allow async execution, and just advance the coroutine.
60 60
61 61 See discussion in https://github.com/python-trio/trio/issues/608,
62 62
63 63 Credit to Nathaniel Smith
64 64
65 65 """
66 66 try:
67 67 coro.send(None)
68 68 except StopIteration as exc:
69 69 return exc.value
70 70 else:
71 71 # TODO: do not raise but return an execution result with the right info.
72 72 raise RuntimeError(
73 73 "{coro_name!r} needs a real async loop".format(coro_name=coro.__name__)
74 74 )
75 75
76 76
77 77 def _asyncify(code: str) -> str:
78 78 """wrap code in async def definition.
79 79
80 80 And setup a bit of context to run it later.
81 81 """
82 82 res = dedent(
83 83 """
84 84 async def __wrapper__():
85 85 try:
86 86 {usercode}
87 87 finally:
88 88 locals()
89 89 """
90 90 ).format(usercode=indent(code, " " * 8))
91 91 return res
92 92
93 93
94 94 class _AsyncSyntaxErrorVisitor(ast.NodeVisitor):
95 95 """
96 96 Find syntax errors that would be an error in an async repl, but because
97 97 the implementation involves wrapping the repl in an async function, it
98 98 is erroneously allowed (e.g. yield or return at the top level)
99 99 """
100 100 def __init__(self):
101 101 self.depth = 0
102 102 super().__init__()
103 103
104 104 def generic_visit(self, node):
105 105 func_types = (ast.FunctionDef, ast.AsyncFunctionDef)
106 106 invalid_types_by_depth = {
107 107 0: (ast.Return, ast.Yield, ast.YieldFrom),
108 108 1: (ast.Nonlocal,)
109 109 }
110 110
111 111 should_traverse = self.depth < max(invalid_types_by_depth.keys())
112 112 if isinstance(node, func_types) and should_traverse:
113 113 self.depth += 1
114 114 super().generic_visit(node)
115 self.depth -= 1
115 116 elif isinstance(node, invalid_types_by_depth[self.depth]):
116 117 raise SyntaxError()
117 118 else:
118 119 super().generic_visit(node)
119 120
120 121
121 122 def _async_parse_cell(cell: str) -> ast.AST:
122 123 """
123 124 This is a compatibility shim for pre-3.7 when async outside of a function
124 125 is a syntax error at the parse stage.
125 126
126 127 It will return an abstract syntax tree parsed as if async and await outside
127 128 of a function were not a syntax error.
128 129 """
129 130 if sys.version_info < (3, 7):
130 131 # Prior to 3.7 you need to asyncify before parse
131 132 wrapped_parse_tree = ast.parse(_asyncify(cell))
132 133 return wrapped_parse_tree.body[0].body[0]
133 134 else:
134 135 return ast.parse(cell)
135 136
136 137
137 138 def _should_be_async(cell: str) -> bool:
138 139 """Detect if a block of code need to be wrapped in an `async def`
139 140
140 141 Attempt to parse the block of code, it it compile we're fine.
141 142 Otherwise we wrap if and try to compile.
142 143
143 144 If it works, assume it should be async. Otherwise Return False.
144 145
145 146 Not handled yet: If the block of code has a return statement as the top
146 147 level, it will be seen as async. This is a know limitation.
147 148 """
148 149
149 150 try:
150 151 # we can't limit ourself to ast.parse, as it __accepts__ to parse on
151 152 # 3.7+, but just does not _compile_
152 153 compile(cell, "<>", "exec")
153 154 return False
154 155 except SyntaxError:
155 156 try:
156 157 parse_tree = _async_parse_cell(cell)
157 158
158 159 # Raise a SyntaxError if there are top-level return or yields
159 160 v = _AsyncSyntaxErrorVisitor()
160 161 v.visit(parse_tree)
161 162
162 163 except SyntaxError:
163 164 return False
164 165 return True
165 166 return False
@@ -1,306 +1,312
1 1 """
2 2 Test for async helpers.
3 3
4 4 Should only trigger on python 3.5+ or will have syntax errors.
5 5 """
6 6
7 7 import sys
8 8 from itertools import chain, repeat
9 9 import nose.tools as nt
10 10 from textwrap import dedent, indent
11 11 from unittest import TestCase
12 12 from IPython.testing.decorators import skip_without
13 13
14 14 ip = get_ipython()
15 15 iprc = lambda x: ip.run_cell(dedent(x)).raise_error()
16 16 iprc_nr = lambda x: ip.run_cell(dedent(x))
17 17
18 18 if sys.version_info > (3, 5):
19 19 from IPython.core.async_helpers import _should_be_async
20 20
21 21 class AsyncTest(TestCase):
22 22 def test_should_be_async(self):
23 23 nt.assert_false(_should_be_async("False"))
24 24 nt.assert_true(_should_be_async("await bar()"))
25 25 nt.assert_true(_should_be_async("x = await bar()"))
26 26 nt.assert_false(
27 27 _should_be_async(
28 28 dedent(
29 29 """
30 30 async def awaitable():
31 31 pass
32 32 """
33 33 )
34 34 )
35 35 )
36 36
37 37 def _get_top_level_cases(self):
38 38 # These are test cases that should be valid in a function
39 39 # but invalid outside of a function.
40 40 test_cases = []
41 41 test_cases.append(('basic', "{val}"))
42 42
43 43 # Note, in all conditional cases, I use True instead of
44 44 # False so that the peephole optimizer won't optimize away
45 45 # the return, so CPython will see this as a syntax error:
46 46 #
47 47 # while True:
48 48 # break
49 49 # return
50 50 #
51 51 # But not this:
52 52 #
53 53 # while False:
54 54 # return
55 55 #
56 56 # See https://bugs.python.org/issue1875
57 57
58 58 test_cases.append(('if', dedent("""
59 59 if True:
60 60 {val}
61 61 """)))
62 62
63 63 test_cases.append(('while', dedent("""
64 64 while True:
65 65 {val}
66 66 break
67 67 """)))
68 68
69 69 test_cases.append(('try', dedent("""
70 70 try:
71 71 {val}
72 72 except:
73 73 pass
74 74 """)))
75 75
76 76 test_cases.append(('except', dedent("""
77 77 try:
78 78 pass
79 79 except:
80 80 {val}
81 81 """)))
82 82
83 83 test_cases.append(('finally', dedent("""
84 84 try:
85 85 pass
86 86 except:
87 87 pass
88 88 finally:
89 89 {val}
90 90 """)))
91 91
92 92 test_cases.append(('for', dedent("""
93 93 for _ in range(4):
94 94 {val}
95 95 """)))
96 96
97 97
98 98 test_cases.append(('nested', dedent("""
99 99 if True:
100 100 while True:
101 101 {val}
102 102 break
103 103 """)))
104 104
105 105 test_cases.append(('deep-nested', dedent("""
106 106 if True:
107 107 while True:
108 108 break
109 109 for x in range(3):
110 110 if True:
111 111 while True:
112 112 for x in range(3):
113 113 {val}
114 114 """)))
115 115
116 116 return test_cases
117 117
118 118 def _get_ry_syntax_errors(self):
119 119 # This is a mix of tests that should be a syntax error if
120 120 # return or yield whether or not they are in a function
121 121
122 122 test_cases = []
123 123
124 124 test_cases.append(('class', dedent("""
125 125 class V:
126 126 {val}
127 127 """)))
128 128
129 129 test_cases.append(('nested-class', dedent("""
130 130 class V:
131 131 class C:
132 132 {val}
133 133 """)))
134 134
135 135 return test_cases
136 136
137 137
138 138 def test_top_level_return_error(self):
139 139 tl_err_test_cases = self._get_top_level_cases()
140 140 tl_err_test_cases.extend(self._get_ry_syntax_errors())
141 141
142 vals = ('return', 'yield', 'yield from (_ for _ in range(3))')
142 vals = ('return', 'yield', 'yield from (_ for _ in range(3))',
143 dedent('''
144 def f():
145 pass
146 return
147 '''),
148 )
143 149
144 150 for test_name, test_case in tl_err_test_cases:
145 151 # This example should work if 'pass' is used as the value
146 152 with self.subTest((test_name, 'pass')):
147 153 iprc(test_case.format(val='pass'))
148 154
149 155 # It should fail with all the values
150 156 for val in vals:
151 157 with self.subTest((test_name, val)):
152 158 msg = "Syntax error not raised for %s, %s" % (test_name, val)
153 159 with self.assertRaises(SyntaxError, msg=msg):
154 160 iprc(test_case.format(val=val))
155 161
156 162 def test_in_func_no_error(self):
157 163 # Test that the implementation of top-level return/yield
158 164 # detection isn't *too* aggressive, and works inside a function
159 165 func_contexts = []
160 166
161 167 func_contexts.append(('func', False, dedent("""
162 168 def f():""")))
163 169
164 170 func_contexts.append(('method', False, dedent("""
165 171 class MyClass:
166 172 def __init__(self):
167 173 """)))
168 174
169 175 func_contexts.append(('async-func', True, dedent("""
170 176 async def f():""")))
171 177
172 178 func_contexts.append(('async-method', True, dedent("""
173 179 class MyClass:
174 180 async def f(self):""")))
175 181
176 182 func_contexts.append(('closure', False, dedent("""
177 183 def f():
178 184 def g():
179 185 """)))
180 186
181 187 def nest_case(context, case):
182 188 # Detect indentation
183 189 lines = context.strip().splitlines()
184 190 prefix_len = 0
185 191 for c in lines[-1]:
186 192 if c != ' ':
187 193 break
188 194 prefix_len += 1
189 195
190 196 indented_case = indent(case, ' ' * (prefix_len + 4))
191 197 return context + '\n' + indented_case
192 198
193 199 # Gather and run the tests
194 200
195 201 # yield is allowed in async functions, starting in Python 3.6,
196 202 # and yield from is not allowed in any version
197 203 vals = ('return', 'yield', 'yield from (_ for _ in range(3))')
198 204 async_safe = (True,
199 205 sys.version_info >= (3, 6),
200 206 False)
201 207 vals = tuple(zip(vals, async_safe))
202 208
203 209 success_tests = zip(self._get_top_level_cases(), repeat(False))
204 210 failure_tests = zip(self._get_ry_syntax_errors(), repeat(True))
205 211
206 212 tests = chain(success_tests, failure_tests)
207 213
208 214 for context_name, async_func, context in func_contexts:
209 215 for (test_name, test_case), should_fail in tests:
210 216 nested_case = nest_case(context, test_case)
211 217
212 218 for val, async_safe in vals:
213 219 val_should_fail = (should_fail or
214 220 (async_func and not async_safe))
215 221
216 222 test_id = (context_name, test_name, val)
217 223 cell = nested_case.format(val=val)
218 224
219 225 with self.subTest(test_id):
220 226 if val_should_fail:
221 227 msg = ("SyntaxError not raised for %s" %
222 228 str(test_id))
223 229 with self.assertRaises(SyntaxError, msg=msg):
224 230 iprc(cell)
225 231
226 232 print(cell)
227 233 else:
228 234 iprc(cell)
229 235
230 236 def test_nonlocal(self):
231 237 # fails if outer scope is not a function scope or if var not defined
232 238 with self.assertRaises(SyntaxError):
233 239 iprc("nonlocal x")
234 240 iprc("""
235 241 x = 1
236 242 def f():
237 243 nonlocal x
238 244 x = 10000
239 245 yield x
240 246 """)
241 247 iprc("""
242 248 def f():
243 249 def g():
244 250 nonlocal x
245 251 x = 10000
246 252 yield x
247 253 """)
248 254
249 255 # works if outer scope is a function scope and var exists
250 256 iprc("""
251 257 def f():
252 258 x = 20
253 259 def g():
254 260 nonlocal x
255 261 x = 10000
256 262 yield x
257 263 """)
258 264
259 265
260 266 def test_execute(self):
261 267 iprc("""
262 268 import asyncio
263 269 await asyncio.sleep(0.001)
264 270 """
265 271 )
266 272
267 273 def test_autoawait(self):
268 274 iprc("%autoawait False")
269 275 iprc("%autoawait True")
270 276 iprc("""
271 277 from asyncio import sleep
272 278 await sleep(0.1)
273 279 """
274 280 )
275 281
276 282 @skip_without('curio')
277 283 def test_autoawait_curio(self):
278 284 iprc("%autoawait curio")
279 285
280 286 @skip_without('trio')
281 287 def test_autoawait_trio(self):
282 288 iprc("%autoawait trio")
283 289
284 290 @skip_without('trio')
285 291 def test_autoawait_trio_wrong_sleep(self):
286 292 iprc("%autoawait trio")
287 293 res = iprc_nr("""
288 294 import asyncio
289 295 await asyncio.sleep(0)
290 296 """)
291 297 with nt.assert_raises(TypeError):
292 298 res.raise_error()
293 299
294 300 @skip_without('trio')
295 301 def test_autoawait_asyncio_wrong_sleep(self):
296 302 iprc("%autoawait asyncio")
297 303 res = iprc_nr("""
298 304 import trio
299 305 await trio.sleep(0)
300 306 """)
301 307 with nt.assert_raises(RuntimeError):
302 308 res.raise_error()
303 309
304 310
305 311 def tearDown(self):
306 312 ip.loop_runner = "asyncio"
General Comments 0
You need to be logged in to leave comments. Login now