##// END OF EJS Templates
Support stringized return type annotations
krassowski -
Show More
@@ -1,6 +1,5 b''
1 from inspect import isclass, signature, Signature
1 from inspect import isclass, signature, Signature
2 from typing import (
2 from typing import (
3 Any,
4 Callable,
3 Callable,
5 Dict,
4 Dict,
6 Set,
5 Set,
@@ -558,19 +557,7 b' def eval_node(node: Union[ast.AST, None], context: EvaluationContext):'
558 f" not allowed in {context.evaluation} mode",
557 f" not allowed in {context.evaluation} mode",
559 )
558 )
560 if isinstance(node, ast.Name):
559 if isinstance(node, ast.Name):
561 if policy.allow_locals_access and node.id in context.locals:
560 return _eval_node_name(node.id, policy, context)
562 return context.locals[node.id]
563 if policy.allow_globals_access and node.id in context.globals:
564 return context.globals[node.id]
565 if policy.allow_builtins_access and hasattr(builtins, node.id):
566 # note: do not use __builtins__, it is implementation detail of cPython
567 return getattr(builtins, node.id)
568 if not policy.allow_globals_access and not policy.allow_locals_access:
569 raise GuardRejection(
570 f"Namespace access not allowed in {context.evaluation} mode"
571 )
572 else:
573 raise NameError(f"{node.id} not found in locals, globals, nor builtins")
574 if isinstance(node, ast.Attribute):
561 if isinstance(node, ast.Attribute):
575 value = eval_node(node.value, context)
562 value = eval_node(node.value, context)
576 if policy.can_get_attr(value, node.attr):
563 if policy.can_get_attr(value, node.attr):
@@ -614,7 +601,9 b' def eval_node(node: Union[ast.AST, None], context: EvaluationContext):'
614 raise ValueError("Unhandled node", ast.dump(node))
601 raise ValueError("Unhandled node", ast.dump(node))
615
602
616
603
617 def _eval_return_type(func, policy, node, context):
604 def _eval_return_type(
605 func: Callable, policy: EvaluationPolicy, node: ast.Call, context: EvaluationContext
606 ):
618 """Evaluate return type of a given callable function.
607 """Evaluate return type of a given callable function.
619
608
620 Returns the built-in type, a duck or NOT_EVALUATED sentinel.
609 Returns the built-in type, a duck or NOT_EVALUATED sentinel.
@@ -626,17 +615,38 b' def _eval_return_type(func, policy, node, context):'
626 # if annotation was not stringized, or it was stringized
615 # if annotation was not stringized, or it was stringized
627 # but resolved by signature call we know the return type
616 # but resolved by signature call we know the return type
628 not_empty = sig.return_annotation is not Signature.empty
617 not_empty = sig.return_annotation is not Signature.empty
629 not_stringized = not isinstance(sig.return_annotation, str)
618 stringized = isinstance(sig.return_annotation, str)
630 if not_empty and not_stringized:
619 if not_empty:
620 return_type = (
621 _eval_node_name(sig.return_annotation, policy, context)
622 if stringized
623 else sig.return_annotation
624 )
631 # if allow-listed builtin is on type annotation, instantiate it
625 # if allow-listed builtin is on type annotation, instantiate it
632 if policy.can_call(sig.return_annotation) and not node.keywords:
626 if policy.can_call(return_type) and not node.keywords:
633 args = [eval_node(arg, context) for arg in node.args]
627 args = [eval_node(arg, context) for arg in node.args]
634 # if custom class is in type annotation, mock it;
628 # if custom class is in type annotation, mock it;
635 return sig.return_annotation(*args)
629 return return_type(*args)
636 return _create_duck_for_type(sig.return_annotation)
630 return _create_duck_for_type(return_type)
637 return NOT_EVALUATED
631 return NOT_EVALUATED
638
632
639
633
634 def _eval_node_name(node_id: str, policy: EvaluationPolicy, context: EvaluationContext):
635 if policy.allow_locals_access and node_id in context.locals:
636 return context.locals[node_id]
637 if policy.allow_globals_access and node_id in context.globals:
638 return context.globals[node_id]
639 if policy.allow_builtins_access and hasattr(builtins, node_id):
640 # note: do not use __builtins__, it is implementation detail of cPython
641 return getattr(builtins, node_id)
642 if not policy.allow_globals_access and not policy.allow_locals_access:
643 raise GuardRejection(
644 f"Namespace access not allowed in {context.evaluation} mode"
645 )
646 else:
647 raise NameError(f"{node_id} not found in locals, globals, nor builtins")
648
649
640 def _create_duck_for_type(duck_type):
650 def _create_duck_for_type(duck_type):
641 """Create an imitation of an object of a given type (a duck).
651 """Create an imitation of an object of a given type (a duck).
642
652
@@ -278,11 +278,21 b' class InitReturnsFrozenset:'
278 return frozenset()
278 return frozenset()
279
279
280
280
281 class StringAnnotation:
282 def heap(self) -> "HeapType":
283 return HeapType()
284
285 def copy(self) -> "StringAnnotation":
286 return StringAnnotation()
287
288
281 @pytest.mark.parametrize(
289 @pytest.mark.parametrize(
282 "data,good,expected,equality",
290 "data,good,expected,equality",
283 [
291 [
284 [[1, 2, 3], "data.index(2)", 1, True],
292 [[1, 2, 3], "data.index(2)", 1, True],
285 [{"a": 1}, "data.keys().isdisjoint({})", True, True],
293 [{"a": 1}, "data.keys().isdisjoint({})", True, True],
294 [StringAnnotation(), "data.heap()", HeapType, False],
295 [StringAnnotation(), "data.copy()", StringAnnotation, False],
286 # test cases for `__call__`
296 # test cases for `__call__`
287 [CallCreatesHeapType(), "data()", HeapType, False],
297 [CallCreatesHeapType(), "data()", HeapType, False],
288 [CallCreatesBuiltin(), "data()", frozenset, False],
298 [CallCreatesBuiltin(), "data()", frozenset, False],
@@ -290,12 +300,12 b' class InitReturnsFrozenset:'
290 [HeapType, "data()", HeapType, False],
300 [HeapType, "data()", HeapType, False],
291 [InitReturnsFrozenset, "data()", frozenset, False],
301 [InitReturnsFrozenset, "data()", frozenset, False],
292 [HeapType(), "data.__class__()", HeapType, False],
302 [HeapType(), "data.__class__()", HeapType, False],
293 # test cases for static and class methods
303 # test cases for static methods
294 [HasStaticMethod, "data.static_method()", HeapType, False],
304 [HasStaticMethod, "data.static_method()", HeapType, False],
295 ],
305 ],
296 )
306 )
297 def test_evaluates_calls(data, good, expected, equality):
307 def test_evaluates_calls(data, good, expected, equality):
298 context = limited(data=data)
308 context = limited(data=data, HeapType=HeapType, StringAnnotation=StringAnnotation)
299 value = guarded_eval(good, context)
309 value = guarded_eval(good, context)
300 if equality:
310 if equality:
301 assert value == expected
311 assert value == expected
General Comments 0
You need to be logged in to leave comments. Login now