##// END OF EJS Templates
Support stringized return type annotations
krassowski -
Show More
@@ -1,6 +1,5 b''
1 1 from inspect import isclass, signature, Signature
2 2 from typing import (
3 Any,
4 3 Callable,
5 4 Dict,
6 5 Set,
@@ -558,19 +557,7 b' def eval_node(node: Union[ast.AST, None], context: EvaluationContext):'
558 557 f" not allowed in {context.evaluation} mode",
559 558 )
560 559 if isinstance(node, ast.Name):
561 if policy.allow_locals_access and node.id in context.locals:
562 return context.locals[node.id]
563 if policy.allow_globals_access and node.id in context.globals:
564 return context.globals[node.id]
565 if policy.allow_builtins_access and hasattr(builtins, node.id):
566 # note: do not use __builtins__, it is implementation detail of cPython
567 return getattr(builtins, node.id)
568 if not policy.allow_globals_access and not policy.allow_locals_access:
569 raise GuardRejection(
570 f"Namespace access not allowed in {context.evaluation} mode"
571 )
572 else:
573 raise NameError(f"{node.id} not found in locals, globals, nor builtins")
560 return _eval_node_name(node.id, policy, context)
574 561 if isinstance(node, ast.Attribute):
575 562 value = eval_node(node.value, context)
576 563 if policy.can_get_attr(value, node.attr):
@@ -614,7 +601,9 b' def eval_node(node: Union[ast.AST, None], context: EvaluationContext):'
614 601 raise ValueError("Unhandled node", ast.dump(node))
615 602
616 603
617 def _eval_return_type(func, policy, node, context):
604 def _eval_return_type(
605 func: Callable, policy: EvaluationPolicy, node: ast.Call, context: EvaluationContext
606 ):
618 607 """Evaluate return type of a given callable function.
619 608
620 609 Returns the built-in type, a duck or NOT_EVALUATED sentinel.
@@ -626,17 +615,38 b' def _eval_return_type(func, policy, node, context):'
626 615 # if annotation was not stringized, or it was stringized
627 616 # but resolved by signature call we know the return type
628 617 not_empty = sig.return_annotation is not Signature.empty
629 not_stringized = not isinstance(sig.return_annotation, str)
630 if not_empty and not_stringized:
618 stringized = isinstance(sig.return_annotation, str)
619 if not_empty:
620 return_type = (
621 _eval_node_name(sig.return_annotation, policy, context)
622 if stringized
623 else sig.return_annotation
624 )
631 625 # if allow-listed builtin is on type annotation, instantiate it
632 if policy.can_call(sig.return_annotation) and not node.keywords:
626 if policy.can_call(return_type) and not node.keywords:
633 627 args = [eval_node(arg, context) for arg in node.args]
634 628 # if custom class is in type annotation, mock it;
635 return sig.return_annotation(*args)
636 return _create_duck_for_type(sig.return_annotation)
629 return return_type(*args)
630 return _create_duck_for_type(return_type)
637 631 return NOT_EVALUATED
638 632
639 633
634 def _eval_node_name(node_id: str, policy: EvaluationPolicy, context: EvaluationContext):
635 if policy.allow_locals_access and node_id in context.locals:
636 return context.locals[node_id]
637 if policy.allow_globals_access and node_id in context.globals:
638 return context.globals[node_id]
639 if policy.allow_builtins_access and hasattr(builtins, node_id):
640 # note: do not use __builtins__, it is implementation detail of cPython
641 return getattr(builtins, node_id)
642 if not policy.allow_globals_access and not policy.allow_locals_access:
643 raise GuardRejection(
644 f"Namespace access not allowed in {context.evaluation} mode"
645 )
646 else:
647 raise NameError(f"{node_id} not found in locals, globals, nor builtins")
648
649
640 650 def _create_duck_for_type(duck_type):
641 651 """Create an imitation of an object of a given type (a duck).
642 652
@@ -278,11 +278,21 b' class InitReturnsFrozenset:'
278 278 return frozenset()
279 279
280 280
281 class StringAnnotation:
282 def heap(self) -> "HeapType":
283 return HeapType()
284
285 def copy(self) -> "StringAnnotation":
286 return StringAnnotation()
287
288
281 289 @pytest.mark.parametrize(
282 290 "data,good,expected,equality",
283 291 [
284 292 [[1, 2, 3], "data.index(2)", 1, True],
285 293 [{"a": 1}, "data.keys().isdisjoint({})", True, True],
294 [StringAnnotation(), "data.heap()", HeapType, False],
295 [StringAnnotation(), "data.copy()", StringAnnotation, False],
286 296 # test cases for `__call__`
287 297 [CallCreatesHeapType(), "data()", HeapType, False],
288 298 [CallCreatesBuiltin(), "data()", frozenset, False],
@@ -290,12 +300,12 b' class InitReturnsFrozenset:'
290 300 [HeapType, "data()", HeapType, False],
291 301 [InitReturnsFrozenset, "data()", frozenset, False],
292 302 [HeapType(), "data.__class__()", HeapType, False],
293 # test cases for static and class methods
303 # test cases for static methods
294 304 [HasStaticMethod, "data.static_method()", HeapType, False],
295 305 ],
296 306 )
297 307 def test_evaluates_calls(data, good, expected, equality):
298 context = limited(data=data)
308 context = limited(data=data, HeapType=HeapType, StringAnnotation=StringAnnotation)
299 309 value = guarded_eval(good, context)
300 310 if equality:
301 311 assert value == expected
General Comments 0
You need to be logged in to leave comments. Login now