From 51647c4f269c7285be60ffb8173713258ed46af2 2024-02-26 11:15:55 From: krassowski <5832902+krassowski@users.noreply.github.com> Date: 2024-02-26 11:15:55 Subject: [PATCH] Support stringized return type annotations --- diff --git a/IPython/core/guarded_eval.py b/IPython/core/guarded_eval.py index 24cc549..0e7d4bf 100644 --- a/IPython/core/guarded_eval.py +++ b/IPython/core/guarded_eval.py @@ -1,6 +1,5 @@ from inspect import isclass, signature, Signature from typing import ( - Any, Callable, Dict, Set, @@ -558,19 +557,7 @@ def eval_node(node: Union[ast.AST, None], context: EvaluationContext): f" not allowed in {context.evaluation} mode", ) if isinstance(node, ast.Name): - if policy.allow_locals_access and node.id in context.locals: - return context.locals[node.id] - if policy.allow_globals_access and node.id in context.globals: - return context.globals[node.id] - if policy.allow_builtins_access and hasattr(builtins, node.id): - # note: do not use __builtins__, it is implementation detail of cPython - return getattr(builtins, node.id) - if not policy.allow_globals_access and not policy.allow_locals_access: - raise GuardRejection( - f"Namespace access not allowed in {context.evaluation} mode" - ) - else: - raise NameError(f"{node.id} not found in locals, globals, nor builtins") + return _eval_node_name(node.id, policy, context) if isinstance(node, ast.Attribute): value = eval_node(node.value, context) if policy.can_get_attr(value, node.attr): @@ -614,7 +601,9 @@ def eval_node(node: Union[ast.AST, None], context: EvaluationContext): raise ValueError("Unhandled node", ast.dump(node)) -def _eval_return_type(func, policy, node, context): +def _eval_return_type( + func: Callable, policy: EvaluationPolicy, node: ast.Call, context: EvaluationContext +): """Evaluate return type of a given callable function. Returns the built-in type, a duck or NOT_EVALUATED sentinel. @@ -626,17 +615,38 @@ def _eval_return_type(func, policy, node, context): # if annotation was not stringized, or it was stringized # but resolved by signature call we know the return type not_empty = sig.return_annotation is not Signature.empty - not_stringized = not isinstance(sig.return_annotation, str) - if not_empty and not_stringized: + stringized = isinstance(sig.return_annotation, str) + if not_empty: + return_type = ( + _eval_node_name(sig.return_annotation, policy, context) + if stringized + else sig.return_annotation + ) # if allow-listed builtin is on type annotation, instantiate it - if policy.can_call(sig.return_annotation) and not node.keywords: + if policy.can_call(return_type) and not node.keywords: args = [eval_node(arg, context) for arg in node.args] # if custom class is in type annotation, mock it; - return sig.return_annotation(*args) - return _create_duck_for_type(sig.return_annotation) + return return_type(*args) + return _create_duck_for_type(return_type) return NOT_EVALUATED +def _eval_node_name(node_id: str, policy: EvaluationPolicy, context: EvaluationContext): + if policy.allow_locals_access and node_id in context.locals: + return context.locals[node_id] + if policy.allow_globals_access and node_id in context.globals: + return context.globals[node_id] + if policy.allow_builtins_access and hasattr(builtins, node_id): + # note: do not use __builtins__, it is implementation detail of cPython + return getattr(builtins, node_id) + if not policy.allow_globals_access and not policy.allow_locals_access: + raise GuardRejection( + f"Namespace access not allowed in {context.evaluation} mode" + ) + else: + raise NameError(f"{node_id} not found in locals, globals, nor builtins") + + def _create_duck_for_type(duck_type): """Create an imitation of an object of a given type (a duck). diff --git a/IPython/core/tests/test_guarded_eval.py b/IPython/core/tests/test_guarded_eval.py index 393b12d..68bde5d 100644 --- a/IPython/core/tests/test_guarded_eval.py +++ b/IPython/core/tests/test_guarded_eval.py @@ -278,11 +278,21 @@ class InitReturnsFrozenset: return frozenset() +class StringAnnotation: + def heap(self) -> "HeapType": + return HeapType() + + def copy(self) -> "StringAnnotation": + return StringAnnotation() + + @pytest.mark.parametrize( "data,good,expected,equality", [ [[1, 2, 3], "data.index(2)", 1, True], [{"a": 1}, "data.keys().isdisjoint({})", True, True], + [StringAnnotation(), "data.heap()", HeapType, False], + [StringAnnotation(), "data.copy()", StringAnnotation, False], # test cases for `__call__` [CallCreatesHeapType(), "data()", HeapType, False], [CallCreatesBuiltin(), "data()", frozenset, False], @@ -290,12 +300,12 @@ class InitReturnsFrozenset: [HeapType, "data()", HeapType, False], [InitReturnsFrozenset, "data()", frozenset, False], [HeapType(), "data.__class__()", HeapType, False], - # test cases for static and class methods + # test cases for static methods [HasStaticMethod, "data.static_method()", HeapType, False], ], ) def test_evaluates_calls(data, good, expected, equality): - context = limited(data=data) + context = limited(data=data, HeapType=HeapType, StringAnnotation=StringAnnotation) value = guarded_eval(good, context) if equality: assert value == expected