##// END OF EJS Templates
Merge pull request #11382 from felixzhuologist/master...
Min RK -
r24668:940dd607 merge
parent child Browse files
Show More
@@ -1,157 +1,165 b''
1 1 """
2 2 Async helper function that are invalid syntax on Python 3.5 and below.
3 3
4 4 This code is best effort, and may have edge cases not behaving as expected. In
5 5 particular it contain a number of heuristics to detect whether code is
6 6 effectively async and need to run in an event loop or not.
7 7
8 8 Some constructs (like top-level `return`, or `yield`) are taken care of
9 9 explicitly to actually raise a SyntaxError and stay as close as possible to
10 10 Python semantics.
11 11 """
12 12
13 13
14 14 import ast
15 15 import sys
16 16 from textwrap import dedent, indent
17 17
18 18
19 19 class _AsyncIORunner:
20 20
21 21 def __call__(self, coro):
22 22 """
23 23 Handler for asyncio autoawait
24 24 """
25 25 import asyncio
26 26
27 27 return asyncio.get_event_loop().run_until_complete(coro)
28 28
29 29 def __str__(self):
30 30 return 'asyncio'
31 31
32 32 _asyncio_runner = _AsyncIORunner()
33 33
34 34
35 35 def _curio_runner(coroutine):
36 36 """
37 37 handler for curio autoawait
38 38 """
39 39 import curio
40 40
41 41 return curio.run(coroutine)
42 42
43 43
44 44 def _trio_runner(async_fn):
45 45 import trio
46 46
47 47 async def loc(coro):
48 48 """
49 49 We need the dummy no-op async def to protect from
50 50 trio's internal. See https://github.com/python-trio/trio/issues/89
51 51 """
52 52 return await coro
53 53
54 54 return trio.run(loc, async_fn)
55 55
56 56
57 57 def _pseudo_sync_runner(coro):
58 58 """
59 59 A runner that does not really allow async execution, and just advance the coroutine.
60 60
61 61 See discussion in https://github.com/python-trio/trio/issues/608,
62 62
63 63 Credit to Nathaniel Smith
64 64
65 65 """
66 66 try:
67 67 coro.send(None)
68 68 except StopIteration as exc:
69 69 return exc.value
70 70 else:
71 71 # TODO: do not raise but return an execution result with the right info.
72 72 raise RuntimeError(
73 73 "{coro_name!r} needs a real async loop".format(coro_name=coro.__name__)
74 74 )
75 75
76 76
77 77 def _asyncify(code: str) -> str:
78 78 """wrap code in async def definition.
79 79
80 80 And setup a bit of context to run it later.
81 81 """
82 82 res = dedent(
83 83 """
84 84 async def __wrapper__():
85 85 try:
86 86 {usercode}
87 87 finally:
88 88 locals()
89 89 """
90 90 ).format(usercode=indent(code, " " * 8))
91 91 return res
92 92
93 93
94 94 class _AsyncSyntaxErrorVisitor(ast.NodeVisitor):
95 95 """
96 96 Find syntax errors that would be an error in an async repl, but because
97 97 the implementation involves wrapping the repl in an async function, it
98 98 is erroneously allowed (e.g. yield or return at the top level)
99 99 """
100 def __init__(self):
101 self.depth = 0
102 super().__init__()
100 103
101 104 def generic_visit(self, node):
102 105 func_types = (ast.FunctionDef, ast.AsyncFunctionDef)
103 invalid_types = (ast.Return, ast.Yield, ast.YieldFrom)
104
105 if isinstance(node, func_types):
106 return # Don't recurse into functions
107 elif isinstance(node, invalid_types):
106 invalid_types_by_depth = {
107 0: (ast.Return, ast.Yield, ast.YieldFrom),
108 1: (ast.Nonlocal,)
109 }
110
111 should_traverse = self.depth < max(invalid_types_by_depth.keys())
112 if isinstance(node, func_types) and should_traverse:
113 self.depth += 1
114 super().generic_visit(node)
115 elif isinstance(node, invalid_types_by_depth[self.depth]):
108 116 raise SyntaxError()
109 117 else:
110 118 super().generic_visit(node)
111 119
112 120
113 121 def _async_parse_cell(cell: str) -> ast.AST:
114 122 """
115 123 This is a compatibility shim for pre-3.7 when async outside of a function
116 124 is a syntax error at the parse stage.
117 125
118 126 It will return an abstract syntax tree parsed as if async and await outside
119 127 of a function were not a syntax error.
120 128 """
121 129 if sys.version_info < (3, 7):
122 130 # Prior to 3.7 you need to asyncify before parse
123 131 wrapped_parse_tree = ast.parse(_asyncify(cell))
124 132 return wrapped_parse_tree.body[0].body[0]
125 133 else:
126 134 return ast.parse(cell)
127 135
128 136
129 137 def _should_be_async(cell: str) -> bool:
130 138 """Detect if a block of code need to be wrapped in an `async def`
131 139
132 140 Attempt to parse the block of code, it it compile we're fine.
133 141 Otherwise we wrap if and try to compile.
134 142
135 143 If it works, assume it should be async. Otherwise Return False.
136 144
137 145 Not handled yet: If the block of code has a return statement as the top
138 146 level, it will be seen as async. This is a know limitation.
139 147 """
140 148
141 149 try:
142 150 # we can't limit ourself to ast.parse, as it __accepts__ to parse on
143 151 # 3.7+, but just does not _compile_
144 152 compile(cell, "<>", "exec")
145 153 return False
146 154 except SyntaxError:
147 155 try:
148 156 parse_tree = _async_parse_cell(cell)
149 157
150 158 # Raise a SyntaxError if there are top-level return or yields
151 159 v = _AsyncSyntaxErrorVisitor()
152 160 v.visit(parse_tree)
153 161
154 162 except SyntaxError:
155 163 return False
156 164 return True
157 165 return False
@@ -1,277 +1,306 b''
1 1 """
2 2 Test for async helpers.
3 3
4 4 Should only trigger on python 3.5+ or will have syntax errors.
5 5 """
6 6
7 7 import sys
8 8 from itertools import chain, repeat
9 9 import nose.tools as nt
10 10 from textwrap import dedent, indent
11 11 from unittest import TestCase
12 12 from IPython.testing.decorators import skip_without
13 13
14 14 ip = get_ipython()
15 15 iprc = lambda x: ip.run_cell(dedent(x)).raise_error()
16 16 iprc_nr = lambda x: ip.run_cell(dedent(x))
17 17
18 18 if sys.version_info > (3, 5):
19 19 from IPython.core.async_helpers import _should_be_async
20 20
21 21 class AsyncTest(TestCase):
22 22 def test_should_be_async(self):
23 23 nt.assert_false(_should_be_async("False"))
24 24 nt.assert_true(_should_be_async("await bar()"))
25 25 nt.assert_true(_should_be_async("x = await bar()"))
26 26 nt.assert_false(
27 27 _should_be_async(
28 28 dedent(
29 29 """
30 30 async def awaitable():
31 31 pass
32 32 """
33 33 )
34 34 )
35 35 )
36 36
37 37 def _get_top_level_cases(self):
38 38 # These are test cases that should be valid in a function
39 39 # but invalid outside of a function.
40 40 test_cases = []
41 41 test_cases.append(('basic', "{val}"))
42 42
43 43 # Note, in all conditional cases, I use True instead of
44 44 # False so that the peephole optimizer won't optimize away
45 45 # the return, so CPython will see this as a syntax error:
46 46 #
47 47 # while True:
48 48 # break
49 49 # return
50 50 #
51 51 # But not this:
52 52 #
53 53 # while False:
54 54 # return
55 55 #
56 56 # See https://bugs.python.org/issue1875
57 57
58 58 test_cases.append(('if', dedent("""
59 59 if True:
60 60 {val}
61 61 """)))
62 62
63 63 test_cases.append(('while', dedent("""
64 64 while True:
65 65 {val}
66 66 break
67 67 """)))
68 68
69 69 test_cases.append(('try', dedent("""
70 70 try:
71 71 {val}
72 72 except:
73 73 pass
74 74 """)))
75 75
76 76 test_cases.append(('except', dedent("""
77 77 try:
78 78 pass
79 79 except:
80 80 {val}
81 81 """)))
82 82
83 83 test_cases.append(('finally', dedent("""
84 84 try:
85 85 pass
86 86 except:
87 87 pass
88 88 finally:
89 89 {val}
90 90 """)))
91 91
92 92 test_cases.append(('for', dedent("""
93 93 for _ in range(4):
94 94 {val}
95 95 """)))
96 96
97 97
98 98 test_cases.append(('nested', dedent("""
99 99 if True:
100 100 while True:
101 101 {val}
102 102 break
103 103 """)))
104 104
105 105 test_cases.append(('deep-nested', dedent("""
106 106 if True:
107 107 while True:
108 108 break
109 109 for x in range(3):
110 110 if True:
111 111 while True:
112 112 for x in range(3):
113 113 {val}
114 114 """)))
115 115
116 116 return test_cases
117 117
118 118 def _get_ry_syntax_errors(self):
119 119 # This is a mix of tests that should be a syntax error if
120 120 # return or yield whether or not they are in a function
121 121
122 122 test_cases = []
123 123
124 124 test_cases.append(('class', dedent("""
125 125 class V:
126 126 {val}
127 127 """)))
128 128
129 129 test_cases.append(('nested-class', dedent("""
130 130 class V:
131 131 class C:
132 132 {val}
133 133 """)))
134 134
135 135 return test_cases
136 136
137 137
138 138 def test_top_level_return_error(self):
139 139 tl_err_test_cases = self._get_top_level_cases()
140 140 tl_err_test_cases.extend(self._get_ry_syntax_errors())
141 141
142 142 vals = ('return', 'yield', 'yield from (_ for _ in range(3))')
143 143
144 144 for test_name, test_case in tl_err_test_cases:
145 145 # This example should work if 'pass' is used as the value
146 146 with self.subTest((test_name, 'pass')):
147 147 iprc(test_case.format(val='pass'))
148 148
149 149 # It should fail with all the values
150 150 for val in vals:
151 151 with self.subTest((test_name, val)):
152 152 msg = "Syntax error not raised for %s, %s" % (test_name, val)
153 153 with self.assertRaises(SyntaxError, msg=msg):
154 154 iprc(test_case.format(val=val))
155 155
156 156 def test_in_func_no_error(self):
157 157 # Test that the implementation of top-level return/yield
158 158 # detection isn't *too* aggressive, and works inside a function
159 159 func_contexts = []
160 160
161 161 func_contexts.append(('func', False, dedent("""
162 162 def f():""")))
163 163
164 164 func_contexts.append(('method', False, dedent("""
165 165 class MyClass:
166 166 def __init__(self):
167 167 """)))
168 168
169 169 func_contexts.append(('async-func', True, dedent("""
170 170 async def f():""")))
171 171
172 172 func_contexts.append(('async-method', True, dedent("""
173 173 class MyClass:
174 174 async def f(self):""")))
175 175
176 176 func_contexts.append(('closure', False, dedent("""
177 177 def f():
178 178 def g():
179 179 """)))
180 180
181 181 def nest_case(context, case):
182 182 # Detect indentation
183 183 lines = context.strip().splitlines()
184 184 prefix_len = 0
185 185 for c in lines[-1]:
186 186 if c != ' ':
187 187 break
188 188 prefix_len += 1
189 189
190 190 indented_case = indent(case, ' ' * (prefix_len + 4))
191 191 return context + '\n' + indented_case
192 192
193 193 # Gather and run the tests
194 194
195 195 # yield is allowed in async functions, starting in Python 3.6,
196 196 # and yield from is not allowed in any version
197 197 vals = ('return', 'yield', 'yield from (_ for _ in range(3))')
198 198 async_safe = (True,
199 199 sys.version_info >= (3, 6),
200 200 False)
201 201 vals = tuple(zip(vals, async_safe))
202 202
203 203 success_tests = zip(self._get_top_level_cases(), repeat(False))
204 204 failure_tests = zip(self._get_ry_syntax_errors(), repeat(True))
205 205
206 206 tests = chain(success_tests, failure_tests)
207 207
208 208 for context_name, async_func, context in func_contexts:
209 209 for (test_name, test_case), should_fail in tests:
210 210 nested_case = nest_case(context, test_case)
211 211
212 212 for val, async_safe in vals:
213 213 val_should_fail = (should_fail or
214 214 (async_func and not async_safe))
215 215
216 216 test_id = (context_name, test_name, val)
217 217 cell = nested_case.format(val=val)
218 218
219 219 with self.subTest(test_id):
220 220 if val_should_fail:
221 221 msg = ("SyntaxError not raised for %s" %
222 222 str(test_id))
223 223 with self.assertRaises(SyntaxError, msg=msg):
224 224 iprc(cell)
225 225
226 226 print(cell)
227 227 else:
228 228 iprc(cell)
229 229
230 def test_nonlocal(self):
231 # fails if outer scope is not a function scope or if var not defined
232 with self.assertRaises(SyntaxError):
233 iprc("nonlocal x")
234 iprc("""
235 x = 1
236 def f():
237 nonlocal x
238 x = 10000
239 yield x
240 """)
241 iprc("""
242 def f():
243 def g():
244 nonlocal x
245 x = 10000
246 yield x
247 """)
248
249 # works if outer scope is a function scope and var exists
250 iprc("""
251 def f():
252 x = 20
253 def g():
254 nonlocal x
255 x = 10000
256 yield x
257 """)
258
230 259
231 260 def test_execute(self):
232 261 iprc("""
233 262 import asyncio
234 263 await asyncio.sleep(0.001)
235 264 """
236 265 )
237 266
238 267 def test_autoawait(self):
239 268 iprc("%autoawait False")
240 269 iprc("%autoawait True")
241 270 iprc("""
242 271 from asyncio import sleep
243 272 await sleep(0.1)
244 273 """
245 274 )
246 275
247 276 @skip_without('curio')
248 277 def test_autoawait_curio(self):
249 278 iprc("%autoawait curio")
250 279
251 280 @skip_without('trio')
252 281 def test_autoawait_trio(self):
253 282 iprc("%autoawait trio")
254 283
255 284 @skip_without('trio')
256 285 def test_autoawait_trio_wrong_sleep(self):
257 286 iprc("%autoawait trio")
258 287 res = iprc_nr("""
259 288 import asyncio
260 289 await asyncio.sleep(0)
261 290 """)
262 291 with nt.assert_raises(TypeError):
263 292 res.raise_error()
264 293
265 294 @skip_without('trio')
266 295 def test_autoawait_asyncio_wrong_sleep(self):
267 296 iprc("%autoawait asyncio")
268 297 res = iprc_nr("""
269 298 import trio
270 299 await trio.sleep(0)
271 300 """)
272 301 with nt.assert_raises(RuntimeError):
273 302 res.raise_error()
274 303
275 304
276 305 def tearDown(self):
277 306 ip.loop_runner = "asyncio"
General Comments 0
You need to be logged in to leave comments. Login now