##// END OF EJS Templates
Merge pull request #11382 from felixzhuologist/master...
Min RK -
r24668:940dd607 merge
parent child Browse files
Show More
@@ -1,157 +1,165 b''
1 """
1 """
2 Async helper function that are invalid syntax on Python 3.5 and below.
2 Async helper function that are invalid syntax on Python 3.5 and below.
3
3
4 This code is best effort, and may have edge cases not behaving as expected. In
4 This code is best effort, and may have edge cases not behaving as expected. In
5 particular it contain a number of heuristics to detect whether code is
5 particular it contain a number of heuristics to detect whether code is
6 effectively async and need to run in an event loop or not.
6 effectively async and need to run in an event loop or not.
7
7
8 Some constructs (like top-level `return`, or `yield`) are taken care of
8 Some constructs (like top-level `return`, or `yield`) are taken care of
9 explicitly to actually raise a SyntaxError and stay as close as possible to
9 explicitly to actually raise a SyntaxError and stay as close as possible to
10 Python semantics.
10 Python semantics.
11 """
11 """
12
12
13
13
14 import ast
14 import ast
15 import sys
15 import sys
16 from textwrap import dedent, indent
16 from textwrap import dedent, indent
17
17
18
18
19 class _AsyncIORunner:
19 class _AsyncIORunner:
20
20
21 def __call__(self, coro):
21 def __call__(self, coro):
22 """
22 """
23 Handler for asyncio autoawait
23 Handler for asyncio autoawait
24 """
24 """
25 import asyncio
25 import asyncio
26
26
27 return asyncio.get_event_loop().run_until_complete(coro)
27 return asyncio.get_event_loop().run_until_complete(coro)
28
28
29 def __str__(self):
29 def __str__(self):
30 return 'asyncio'
30 return 'asyncio'
31
31
32 _asyncio_runner = _AsyncIORunner()
32 _asyncio_runner = _AsyncIORunner()
33
33
34
34
35 def _curio_runner(coroutine):
35 def _curio_runner(coroutine):
36 """
36 """
37 handler for curio autoawait
37 handler for curio autoawait
38 """
38 """
39 import curio
39 import curio
40
40
41 return curio.run(coroutine)
41 return curio.run(coroutine)
42
42
43
43
44 def _trio_runner(async_fn):
44 def _trio_runner(async_fn):
45 import trio
45 import trio
46
46
47 async def loc(coro):
47 async def loc(coro):
48 """
48 """
49 We need the dummy no-op async def to protect from
49 We need the dummy no-op async def to protect from
50 trio's internal. See https://github.com/python-trio/trio/issues/89
50 trio's internal. See https://github.com/python-trio/trio/issues/89
51 """
51 """
52 return await coro
52 return await coro
53
53
54 return trio.run(loc, async_fn)
54 return trio.run(loc, async_fn)
55
55
56
56
57 def _pseudo_sync_runner(coro):
57 def _pseudo_sync_runner(coro):
58 """
58 """
59 A runner that does not really allow async execution, and just advance the coroutine.
59 A runner that does not really allow async execution, and just advance the coroutine.
60
60
61 See discussion in https://github.com/python-trio/trio/issues/608,
61 See discussion in https://github.com/python-trio/trio/issues/608,
62
62
63 Credit to Nathaniel Smith
63 Credit to Nathaniel Smith
64
64
65 """
65 """
66 try:
66 try:
67 coro.send(None)
67 coro.send(None)
68 except StopIteration as exc:
68 except StopIteration as exc:
69 return exc.value
69 return exc.value
70 else:
70 else:
71 # TODO: do not raise but return an execution result with the right info.
71 # TODO: do not raise but return an execution result with the right info.
72 raise RuntimeError(
72 raise RuntimeError(
73 "{coro_name!r} needs a real async loop".format(coro_name=coro.__name__)
73 "{coro_name!r} needs a real async loop".format(coro_name=coro.__name__)
74 )
74 )
75
75
76
76
77 def _asyncify(code: str) -> str:
77 def _asyncify(code: str) -> str:
78 """wrap code in async def definition.
78 """wrap code in async def definition.
79
79
80 And setup a bit of context to run it later.
80 And setup a bit of context to run it later.
81 """
81 """
82 res = dedent(
82 res = dedent(
83 """
83 """
84 async def __wrapper__():
84 async def __wrapper__():
85 try:
85 try:
86 {usercode}
86 {usercode}
87 finally:
87 finally:
88 locals()
88 locals()
89 """
89 """
90 ).format(usercode=indent(code, " " * 8))
90 ).format(usercode=indent(code, " " * 8))
91 return res
91 return res
92
92
93
93
94 class _AsyncSyntaxErrorVisitor(ast.NodeVisitor):
94 class _AsyncSyntaxErrorVisitor(ast.NodeVisitor):
95 """
95 """
96 Find syntax errors that would be an error in an async repl, but because
96 Find syntax errors that would be an error in an async repl, but because
97 the implementation involves wrapping the repl in an async function, it
97 the implementation involves wrapping the repl in an async function, it
98 is erroneously allowed (e.g. yield or return at the top level)
98 is erroneously allowed (e.g. yield or return at the top level)
99 """
99 """
100 def __init__(self):
101 self.depth = 0
102 super().__init__()
100
103
101 def generic_visit(self, node):
104 def generic_visit(self, node):
102 func_types = (ast.FunctionDef, ast.AsyncFunctionDef)
105 func_types = (ast.FunctionDef, ast.AsyncFunctionDef)
103 invalid_types = (ast.Return, ast.Yield, ast.YieldFrom)
106 invalid_types_by_depth = {
104
107 0: (ast.Return, ast.Yield, ast.YieldFrom),
105 if isinstance(node, func_types):
108 1: (ast.Nonlocal,)
106 return # Don't recurse into functions
109 }
107 elif isinstance(node, invalid_types):
110
111 should_traverse = self.depth < max(invalid_types_by_depth.keys())
112 if isinstance(node, func_types) and should_traverse:
113 self.depth += 1
114 super().generic_visit(node)
115 elif isinstance(node, invalid_types_by_depth[self.depth]):
108 raise SyntaxError()
116 raise SyntaxError()
109 else:
117 else:
110 super().generic_visit(node)
118 super().generic_visit(node)
111
119
112
120
113 def _async_parse_cell(cell: str) -> ast.AST:
121 def _async_parse_cell(cell: str) -> ast.AST:
114 """
122 """
115 This is a compatibility shim for pre-3.7 when async outside of a function
123 This is a compatibility shim for pre-3.7 when async outside of a function
116 is a syntax error at the parse stage.
124 is a syntax error at the parse stage.
117
125
118 It will return an abstract syntax tree parsed as if async and await outside
126 It will return an abstract syntax tree parsed as if async and await outside
119 of a function were not a syntax error.
127 of a function were not a syntax error.
120 """
128 """
121 if sys.version_info < (3, 7):
129 if sys.version_info < (3, 7):
122 # Prior to 3.7 you need to asyncify before parse
130 # Prior to 3.7 you need to asyncify before parse
123 wrapped_parse_tree = ast.parse(_asyncify(cell))
131 wrapped_parse_tree = ast.parse(_asyncify(cell))
124 return wrapped_parse_tree.body[0].body[0]
132 return wrapped_parse_tree.body[0].body[0]
125 else:
133 else:
126 return ast.parse(cell)
134 return ast.parse(cell)
127
135
128
136
129 def _should_be_async(cell: str) -> bool:
137 def _should_be_async(cell: str) -> bool:
130 """Detect if a block of code need to be wrapped in an `async def`
138 """Detect if a block of code need to be wrapped in an `async def`
131
139
132 Attempt to parse the block of code, it it compile we're fine.
140 Attempt to parse the block of code, it it compile we're fine.
133 Otherwise we wrap if and try to compile.
141 Otherwise we wrap if and try to compile.
134
142
135 If it works, assume it should be async. Otherwise Return False.
143 If it works, assume it should be async. Otherwise Return False.
136
144
137 Not handled yet: If the block of code has a return statement as the top
145 Not handled yet: If the block of code has a return statement as the top
138 level, it will be seen as async. This is a know limitation.
146 level, it will be seen as async. This is a know limitation.
139 """
147 """
140
148
141 try:
149 try:
142 # we can't limit ourself to ast.parse, as it __accepts__ to parse on
150 # we can't limit ourself to ast.parse, as it __accepts__ to parse on
143 # 3.7+, but just does not _compile_
151 # 3.7+, but just does not _compile_
144 compile(cell, "<>", "exec")
152 compile(cell, "<>", "exec")
145 return False
153 return False
146 except SyntaxError:
154 except SyntaxError:
147 try:
155 try:
148 parse_tree = _async_parse_cell(cell)
156 parse_tree = _async_parse_cell(cell)
149
157
150 # Raise a SyntaxError if there are top-level return or yields
158 # Raise a SyntaxError if there are top-level return or yields
151 v = _AsyncSyntaxErrorVisitor()
159 v = _AsyncSyntaxErrorVisitor()
152 v.visit(parse_tree)
160 v.visit(parse_tree)
153
161
154 except SyntaxError:
162 except SyntaxError:
155 return False
163 return False
156 return True
164 return True
157 return False
165 return False
@@ -1,277 +1,306 b''
1 """
1 """
2 Test for async helpers.
2 Test for async helpers.
3
3
4 Should only trigger on python 3.5+ or will have syntax errors.
4 Should only trigger on python 3.5+ or will have syntax errors.
5 """
5 """
6
6
7 import sys
7 import sys
8 from itertools import chain, repeat
8 from itertools import chain, repeat
9 import nose.tools as nt
9 import nose.tools as nt
10 from textwrap import dedent, indent
10 from textwrap import dedent, indent
11 from unittest import TestCase
11 from unittest import TestCase
12 from IPython.testing.decorators import skip_without
12 from IPython.testing.decorators import skip_without
13
13
14 ip = get_ipython()
14 ip = get_ipython()
15 iprc = lambda x: ip.run_cell(dedent(x)).raise_error()
15 iprc = lambda x: ip.run_cell(dedent(x)).raise_error()
16 iprc_nr = lambda x: ip.run_cell(dedent(x))
16 iprc_nr = lambda x: ip.run_cell(dedent(x))
17
17
18 if sys.version_info > (3, 5):
18 if sys.version_info > (3, 5):
19 from IPython.core.async_helpers import _should_be_async
19 from IPython.core.async_helpers import _should_be_async
20
20
21 class AsyncTest(TestCase):
21 class AsyncTest(TestCase):
22 def test_should_be_async(self):
22 def test_should_be_async(self):
23 nt.assert_false(_should_be_async("False"))
23 nt.assert_false(_should_be_async("False"))
24 nt.assert_true(_should_be_async("await bar()"))
24 nt.assert_true(_should_be_async("await bar()"))
25 nt.assert_true(_should_be_async("x = await bar()"))
25 nt.assert_true(_should_be_async("x = await bar()"))
26 nt.assert_false(
26 nt.assert_false(
27 _should_be_async(
27 _should_be_async(
28 dedent(
28 dedent(
29 """
29 """
30 async def awaitable():
30 async def awaitable():
31 pass
31 pass
32 """
32 """
33 )
33 )
34 )
34 )
35 )
35 )
36
36
37 def _get_top_level_cases(self):
37 def _get_top_level_cases(self):
38 # These are test cases that should be valid in a function
38 # These are test cases that should be valid in a function
39 # but invalid outside of a function.
39 # but invalid outside of a function.
40 test_cases = []
40 test_cases = []
41 test_cases.append(('basic', "{val}"))
41 test_cases.append(('basic', "{val}"))
42
42
43 # Note, in all conditional cases, I use True instead of
43 # Note, in all conditional cases, I use True instead of
44 # False so that the peephole optimizer won't optimize away
44 # False so that the peephole optimizer won't optimize away
45 # the return, so CPython will see this as a syntax error:
45 # the return, so CPython will see this as a syntax error:
46 #
46 #
47 # while True:
47 # while True:
48 # break
48 # break
49 # return
49 # return
50 #
50 #
51 # But not this:
51 # But not this:
52 #
52 #
53 # while False:
53 # while False:
54 # return
54 # return
55 #
55 #
56 # See https://bugs.python.org/issue1875
56 # See https://bugs.python.org/issue1875
57
57
58 test_cases.append(('if', dedent("""
58 test_cases.append(('if', dedent("""
59 if True:
59 if True:
60 {val}
60 {val}
61 """)))
61 """)))
62
62
63 test_cases.append(('while', dedent("""
63 test_cases.append(('while', dedent("""
64 while True:
64 while True:
65 {val}
65 {val}
66 break
66 break
67 """)))
67 """)))
68
68
69 test_cases.append(('try', dedent("""
69 test_cases.append(('try', dedent("""
70 try:
70 try:
71 {val}
71 {val}
72 except:
72 except:
73 pass
73 pass
74 """)))
74 """)))
75
75
76 test_cases.append(('except', dedent("""
76 test_cases.append(('except', dedent("""
77 try:
77 try:
78 pass
78 pass
79 except:
79 except:
80 {val}
80 {val}
81 """)))
81 """)))
82
82
83 test_cases.append(('finally', dedent("""
83 test_cases.append(('finally', dedent("""
84 try:
84 try:
85 pass
85 pass
86 except:
86 except:
87 pass
87 pass
88 finally:
88 finally:
89 {val}
89 {val}
90 """)))
90 """)))
91
91
92 test_cases.append(('for', dedent("""
92 test_cases.append(('for', dedent("""
93 for _ in range(4):
93 for _ in range(4):
94 {val}
94 {val}
95 """)))
95 """)))
96
96
97
97
98 test_cases.append(('nested', dedent("""
98 test_cases.append(('nested', dedent("""
99 if True:
99 if True:
100 while True:
100 while True:
101 {val}
101 {val}
102 break
102 break
103 """)))
103 """)))
104
104
105 test_cases.append(('deep-nested', dedent("""
105 test_cases.append(('deep-nested', dedent("""
106 if True:
106 if True:
107 while True:
107 while True:
108 break
108 break
109 for x in range(3):
109 for x in range(3):
110 if True:
110 if True:
111 while True:
111 while True:
112 for x in range(3):
112 for x in range(3):
113 {val}
113 {val}
114 """)))
114 """)))
115
115
116 return test_cases
116 return test_cases
117
117
118 def _get_ry_syntax_errors(self):
118 def _get_ry_syntax_errors(self):
119 # This is a mix of tests that should be a syntax error if
119 # This is a mix of tests that should be a syntax error if
120 # return or yield whether or not they are in a function
120 # return or yield whether or not they are in a function
121
121
122 test_cases = []
122 test_cases = []
123
123
124 test_cases.append(('class', dedent("""
124 test_cases.append(('class', dedent("""
125 class V:
125 class V:
126 {val}
126 {val}
127 """)))
127 """)))
128
128
129 test_cases.append(('nested-class', dedent("""
129 test_cases.append(('nested-class', dedent("""
130 class V:
130 class V:
131 class C:
131 class C:
132 {val}
132 {val}
133 """)))
133 """)))
134
134
135 return test_cases
135 return test_cases
136
136
137
137
138 def test_top_level_return_error(self):
138 def test_top_level_return_error(self):
139 tl_err_test_cases = self._get_top_level_cases()
139 tl_err_test_cases = self._get_top_level_cases()
140 tl_err_test_cases.extend(self._get_ry_syntax_errors())
140 tl_err_test_cases.extend(self._get_ry_syntax_errors())
141
141
142 vals = ('return', 'yield', 'yield from (_ for _ in range(3))')
142 vals = ('return', 'yield', 'yield from (_ for _ in range(3))')
143
143
144 for test_name, test_case in tl_err_test_cases:
144 for test_name, test_case in tl_err_test_cases:
145 # This example should work if 'pass' is used as the value
145 # This example should work if 'pass' is used as the value
146 with self.subTest((test_name, 'pass')):
146 with self.subTest((test_name, 'pass')):
147 iprc(test_case.format(val='pass'))
147 iprc(test_case.format(val='pass'))
148
148
149 # It should fail with all the values
149 # It should fail with all the values
150 for val in vals:
150 for val in vals:
151 with self.subTest((test_name, val)):
151 with self.subTest((test_name, val)):
152 msg = "Syntax error not raised for %s, %s" % (test_name, val)
152 msg = "Syntax error not raised for %s, %s" % (test_name, val)
153 with self.assertRaises(SyntaxError, msg=msg):
153 with self.assertRaises(SyntaxError, msg=msg):
154 iprc(test_case.format(val=val))
154 iprc(test_case.format(val=val))
155
155
156 def test_in_func_no_error(self):
156 def test_in_func_no_error(self):
157 # Test that the implementation of top-level return/yield
157 # Test that the implementation of top-level return/yield
158 # detection isn't *too* aggressive, and works inside a function
158 # detection isn't *too* aggressive, and works inside a function
159 func_contexts = []
159 func_contexts = []
160
160
161 func_contexts.append(('func', False, dedent("""
161 func_contexts.append(('func', False, dedent("""
162 def f():""")))
162 def f():""")))
163
163
164 func_contexts.append(('method', False, dedent("""
164 func_contexts.append(('method', False, dedent("""
165 class MyClass:
165 class MyClass:
166 def __init__(self):
166 def __init__(self):
167 """)))
167 """)))
168
168
169 func_contexts.append(('async-func', True, dedent("""
169 func_contexts.append(('async-func', True, dedent("""
170 async def f():""")))
170 async def f():""")))
171
171
172 func_contexts.append(('async-method', True, dedent("""
172 func_contexts.append(('async-method', True, dedent("""
173 class MyClass:
173 class MyClass:
174 async def f(self):""")))
174 async def f(self):""")))
175
175
176 func_contexts.append(('closure', False, dedent("""
176 func_contexts.append(('closure', False, dedent("""
177 def f():
177 def f():
178 def g():
178 def g():
179 """)))
179 """)))
180
180
181 def nest_case(context, case):
181 def nest_case(context, case):
182 # Detect indentation
182 # Detect indentation
183 lines = context.strip().splitlines()
183 lines = context.strip().splitlines()
184 prefix_len = 0
184 prefix_len = 0
185 for c in lines[-1]:
185 for c in lines[-1]:
186 if c != ' ':
186 if c != ' ':
187 break
187 break
188 prefix_len += 1
188 prefix_len += 1
189
189
190 indented_case = indent(case, ' ' * (prefix_len + 4))
190 indented_case = indent(case, ' ' * (prefix_len + 4))
191 return context + '\n' + indented_case
191 return context + '\n' + indented_case
192
192
193 # Gather and run the tests
193 # Gather and run the tests
194
194
195 # yield is allowed in async functions, starting in Python 3.6,
195 # yield is allowed in async functions, starting in Python 3.6,
196 # and yield from is not allowed in any version
196 # and yield from is not allowed in any version
197 vals = ('return', 'yield', 'yield from (_ for _ in range(3))')
197 vals = ('return', 'yield', 'yield from (_ for _ in range(3))')
198 async_safe = (True,
198 async_safe = (True,
199 sys.version_info >= (3, 6),
199 sys.version_info >= (3, 6),
200 False)
200 False)
201 vals = tuple(zip(vals, async_safe))
201 vals = tuple(zip(vals, async_safe))
202
202
203 success_tests = zip(self._get_top_level_cases(), repeat(False))
203 success_tests = zip(self._get_top_level_cases(), repeat(False))
204 failure_tests = zip(self._get_ry_syntax_errors(), repeat(True))
204 failure_tests = zip(self._get_ry_syntax_errors(), repeat(True))
205
205
206 tests = chain(success_tests, failure_tests)
206 tests = chain(success_tests, failure_tests)
207
207
208 for context_name, async_func, context in func_contexts:
208 for context_name, async_func, context in func_contexts:
209 for (test_name, test_case), should_fail in tests:
209 for (test_name, test_case), should_fail in tests:
210 nested_case = nest_case(context, test_case)
210 nested_case = nest_case(context, test_case)
211
211
212 for val, async_safe in vals:
212 for val, async_safe in vals:
213 val_should_fail = (should_fail or
213 val_should_fail = (should_fail or
214 (async_func and not async_safe))
214 (async_func and not async_safe))
215
215
216 test_id = (context_name, test_name, val)
216 test_id = (context_name, test_name, val)
217 cell = nested_case.format(val=val)
217 cell = nested_case.format(val=val)
218
218
219 with self.subTest(test_id):
219 with self.subTest(test_id):
220 if val_should_fail:
220 if val_should_fail:
221 msg = ("SyntaxError not raised for %s" %
221 msg = ("SyntaxError not raised for %s" %
222 str(test_id))
222 str(test_id))
223 with self.assertRaises(SyntaxError, msg=msg):
223 with self.assertRaises(SyntaxError, msg=msg):
224 iprc(cell)
224 iprc(cell)
225
225
226 print(cell)
226 print(cell)
227 else:
227 else:
228 iprc(cell)
228 iprc(cell)
229
229
230 def test_nonlocal(self):
231 # fails if outer scope is not a function scope or if var not defined
232 with self.assertRaises(SyntaxError):
233 iprc("nonlocal x")
234 iprc("""
235 x = 1
236 def f():
237 nonlocal x
238 x = 10000
239 yield x
240 """)
241 iprc("""
242 def f():
243 def g():
244 nonlocal x
245 x = 10000
246 yield x
247 """)
248
249 # works if outer scope is a function scope and var exists
250 iprc("""
251 def f():
252 x = 20
253 def g():
254 nonlocal x
255 x = 10000
256 yield x
257 """)
258
230
259
231 def test_execute(self):
260 def test_execute(self):
232 iprc("""
261 iprc("""
233 import asyncio
262 import asyncio
234 await asyncio.sleep(0.001)
263 await asyncio.sleep(0.001)
235 """
264 """
236 )
265 )
237
266
238 def test_autoawait(self):
267 def test_autoawait(self):
239 iprc("%autoawait False")
268 iprc("%autoawait False")
240 iprc("%autoawait True")
269 iprc("%autoawait True")
241 iprc("""
270 iprc("""
242 from asyncio import sleep
271 from asyncio import sleep
243 await sleep(0.1)
272 await sleep(0.1)
244 """
273 """
245 )
274 )
246
275
247 @skip_without('curio')
276 @skip_without('curio')
248 def test_autoawait_curio(self):
277 def test_autoawait_curio(self):
249 iprc("%autoawait curio")
278 iprc("%autoawait curio")
250
279
251 @skip_without('trio')
280 @skip_without('trio')
252 def test_autoawait_trio(self):
281 def test_autoawait_trio(self):
253 iprc("%autoawait trio")
282 iprc("%autoawait trio")
254
283
255 @skip_without('trio')
284 @skip_without('trio')
256 def test_autoawait_trio_wrong_sleep(self):
285 def test_autoawait_trio_wrong_sleep(self):
257 iprc("%autoawait trio")
286 iprc("%autoawait trio")
258 res = iprc_nr("""
287 res = iprc_nr("""
259 import asyncio
288 import asyncio
260 await asyncio.sleep(0)
289 await asyncio.sleep(0)
261 """)
290 """)
262 with nt.assert_raises(TypeError):
291 with nt.assert_raises(TypeError):
263 res.raise_error()
292 res.raise_error()
264
293
265 @skip_without('trio')
294 @skip_without('trio')
266 def test_autoawait_asyncio_wrong_sleep(self):
295 def test_autoawait_asyncio_wrong_sleep(self):
267 iprc("%autoawait asyncio")
296 iprc("%autoawait asyncio")
268 res = iprc_nr("""
297 res = iprc_nr("""
269 import trio
298 import trio
270 await trio.sleep(0)
299 await trio.sleep(0)
271 """)
300 """)
272 with nt.assert_raises(RuntimeError):
301 with nt.assert_raises(RuntimeError):
273 res.raise_error()
302 res.raise_error()
274
303
275
304
276 def tearDown(self):
305 def tearDown(self):
277 ip.loop_runner = "asyncio"
306 ip.loop_runner = "asyncio"
General Comments 0
You need to be logged in to leave comments. Login now