diff --git a/IPython/core/async_helpers.py b/IPython/core/async_helpers.py index d334dba..fa3d5cf 100644 --- a/IPython/core/async_helpers.py +++ b/IPython/core/async_helpers.py @@ -80,6 +80,40 @@ def _asyncify(code: str) -> str: return res +class _AsyncSyntaxErrorVisitor(ast.NodeVisitor): + """ + Find syntax errors that would be an error in an async repl, but because + the implementation involves wrapping the repl in an async function, it + is erroneously allowed (e.g. yield or return at the top level) + """ + def generic_visit(self, node): + func_types = (ast.FunctionDef, ast.AsyncFunctionDef) + invalid_types = (ast.Return, ast.Yield, ast.YieldFrom) + + if isinstance(node, func_types): + return # Don't recurse into functions + elif isinstance(node, invalid_types): + raise SyntaxError() + else: + super().generic_visit(node) + + +def _async_parse_cell(cell: str) -> ast.AST: + """ + This is a compatibility shim for pre-3.7 when async outside of a function + is a syntax error at the parse stage. + + It will return an abstract syntax tree parsed as if async and await outside + of a function were not a syntax error. + """ + if sys.version_info < (3, 7): + # Prior to 3.7 you need to asyncify before parse + wrapped_parse_tree = ast.parse(_asyncify(cell)) + return wrapped_parse_tree.body[0].body[0] + else: + return ast.parse(cell) + + def _should_be_async(cell: str) -> bool: """Detect if a block of code need to be wrapped in an `async def` @@ -99,8 +133,12 @@ def _should_be_async(cell: str) -> bool: return False except SyntaxError: try: - ast.parse(_asyncify(cell)) - # TODO verify ast has not "top level" return or yield. + parse_tree = _async_parse_cell(cell) + + # Raise a SyntaxError if there are top-level return or yields + v = _AsyncSyntaxErrorVisitor() + v.visit(parse_tree) + except SyntaxError: return False return True