From c4bf6326fe0726afb7d36bc8b67ed3a30ac09809 2018-08-27 13:15:00 From: Paul Ganssle Date: 2018-08-27 13:15:00 Subject: [PATCH] Improve async detection mechanism with blacklist Because the async repl works by wrapping any code that raises SyntaxError in an async function and trying to execute it again, cell bodies that are invalid at the top level but valid in functions and methods (e.g. return and yield statements) currently allow executing invalid code. This patch blacklists return and yield statements outside of a function or method to restore the proper SyntaxError behavior. --- diff --git a/IPython/core/async_helpers.py b/IPython/core/async_helpers.py index d334dba..fa3d5cf 100644 --- a/IPython/core/async_helpers.py +++ b/IPython/core/async_helpers.py @@ -80,6 +80,40 @@ def _asyncify(code: str) -> str: return res +class _AsyncSyntaxErrorVisitor(ast.NodeVisitor): + """ + Find syntax errors that would be an error in an async repl, but because + the implementation involves wrapping the repl in an async function, it + is erroneously allowed (e.g. yield or return at the top level) + """ + def generic_visit(self, node): + func_types = (ast.FunctionDef, ast.AsyncFunctionDef) + invalid_types = (ast.Return, ast.Yield, ast.YieldFrom) + + if isinstance(node, func_types): + return # Don't recurse into functions + elif isinstance(node, invalid_types): + raise SyntaxError() + else: + super().generic_visit(node) + + +def _async_parse_cell(cell: str) -> ast.AST: + """ + This is a compatibility shim for pre-3.7 when async outside of a function + is a syntax error at the parse stage. + + It will return an abstract syntax tree parsed as if async and await outside + of a function were not a syntax error. + """ + if sys.version_info < (3, 7): + # Prior to 3.7 you need to asyncify before parse + wrapped_parse_tree = ast.parse(_asyncify(cell)) + return wrapped_parse_tree.body[0].body[0] + else: + return ast.parse(cell) + + def _should_be_async(cell: str) -> bool: """Detect if a block of code need to be wrapped in an `async def` @@ -99,8 +133,12 @@ def _should_be_async(cell: str) -> bool: return False except SyntaxError: try: - ast.parse(_asyncify(cell)) - # TODO verify ast has not "top level" return or yield. + parse_tree = _async_parse_cell(cell) + + # Raise a SyntaxError if there are top-level return or yields + v = _AsyncSyntaxErrorVisitor() + v.visit(parse_tree) + except SyntaxError: return False return True