Show More
@@ -1,4 +1,4 b'' | |||||
1 | from inspect import signature, Signature |
|
1 | from inspect import isclass, signature, Signature | |
2 | from typing import ( |
|
2 | from typing import ( | |
3 | Any, |
|
3 | Any, | |
4 | Callable, |
|
4 | Callable, | |
@@ -337,6 +337,7 b' class _IdentitySubscript:' | |||||
337 | IDENTITY_SUBSCRIPT = _IdentitySubscript() |
|
337 | IDENTITY_SUBSCRIPT = _IdentitySubscript() | |
338 | SUBSCRIPT_MARKER = "__SUBSCRIPT_SENTINEL__" |
|
338 | SUBSCRIPT_MARKER = "__SUBSCRIPT_SENTINEL__" | |
339 | UNKNOWN_SIGNATURE = Signature() |
|
339 | UNKNOWN_SIGNATURE = Signature() | |
|
340 | NOT_EVALUATED = object() | |||
340 |
|
341 | |||
341 |
|
342 | |||
342 | class GuardRejection(Exception): |
|
343 | class GuardRejection(Exception): | |
@@ -590,6 +591,34 b' def eval_node(node: Union[ast.AST, None], context: EvaluationContext):' | |||||
590 | if policy.can_call(func) and not node.keywords: |
|
591 | if policy.can_call(func) and not node.keywords: | |
591 | args = [eval_node(arg, context) for arg in node.args] |
|
592 | args = [eval_node(arg, context) for arg in node.args] | |
592 | return func(*args) |
|
593 | return func(*args) | |
|
594 | if isclass(func): | |||
|
595 | # this code path gets entered when calling class e.g. `MyClass()` | |||
|
596 | # or `my_instance.__class__()` - in both cases `func` is `MyClass`. | |||
|
597 | # Should return `MyClass` if `__new__` is not overridden, | |||
|
598 | # otherwise whatever `__new__` return type is. | |||
|
599 | overridden_return_type = _eval_return_type( | |||
|
600 | func.__new__, policy, node, context | |||
|
601 | ) | |||
|
602 | if overridden_return_type is not NOT_EVALUATED: | |||
|
603 | return overridden_return_type | |||
|
604 | return _create_duck_for_type(func) | |||
|
605 | else: | |||
|
606 | return_type = _eval_return_type(func, policy, node, context) | |||
|
607 | if return_type is not NOT_EVALUATED: | |||
|
608 | return return_type | |||
|
609 | raise GuardRejection( | |||
|
610 | "Call for", | |||
|
611 | func, # not joined to avoid calling `repr` | |||
|
612 | f"not allowed in {context.evaluation} mode", | |||
|
613 | ) | |||
|
614 | raise ValueError("Unhandled node", ast.dump(node)) | |||
|
615 | ||||
|
616 | ||||
|
617 | def _eval_return_type(func, policy, node, context): | |||
|
618 | """Evaluate return type of a given callable function. | |||
|
619 | ||||
|
620 | Returns the built-in type, a duck or NOT_EVALUATED sentinel. | |||
|
621 | """ | |||
593 |
|
|
622 | try: | |
594 |
|
|
623 | sig = signature(func) | |
595 |
|
|
624 | except ValueError: | |
@@ -599,24 +628,28 b' def eval_node(node: Union[ast.AST, None], context: EvaluationContext):' | |||||
599 |
|
|
628 | not_empty = sig.return_annotation is not Signature.empty | |
600 |
|
|
629 | not_stringized = not isinstance(sig.return_annotation, str) | |
601 |
|
|
630 | if not_empty and not_stringized: | |
602 | duck = Duck() |
|
|||
603 |
|
|
631 | # if allow-listed builtin is on type annotation, instantiate it | |
604 |
|
|
632 | if policy.can_call(sig.return_annotation) and not node.keywords: | |
605 |
|
|
633 | args = [eval_node(arg, context) for arg in node.args] | |
|
634 | # if custom class is in type annotation, mock it; | |||
606 |
|
|
635 | return sig.return_annotation(*args) | |
|
636 | return _create_duck_for_type(sig.return_annotation) | |||
|
637 | return NOT_EVALUATED | |||
|
638 | ||||
|
639 | ||||
|
640 | def _create_duck_for_type(duck_type): | |||
|
641 | """Create an imitation of an object of a given type (a duck). | |||
|
642 | ||||
|
643 | Returns the duck or NOT_EVALUATED sentinel if duck could not be created. | |||
|
644 | """ | |||
|
645 | duck = Duck() | |||
607 |
|
|
646 | try: | |
608 | # if custom class is in type annotation, mock it; |
|
|||
609 |
|
|
647 | # this only works for heap types, not builtins | |
610 | duck.__class__ = sig.return_annotation |
|
648 | duck.__class__ = duck_type | |
611 |
|
|
649 | return duck | |
612 |
|
|
650 | except TypeError: | |
613 |
|
|
651 | pass | |
614 | raise GuardRejection( |
|
652 | return NOT_EVALUATED | |
615 | "Call for", |
|
|||
616 | func, # not joined to avoid calling `repr` |
|
|||
617 | f"not allowed in {context.evaluation} mode", |
|
|||
618 | ) |
|
|||
619 | raise ValueError("Unhandled node", ast.dump(node)) |
|
|||
620 |
|
653 | |||
621 |
|
654 | |||
622 | SUPPORTED_EXTERNAL_GETITEM = { |
|
655 | SUPPORTED_EXTERNAL_GETITEM = { |
@@ -267,16 +267,34 b' class CallCreatesBuiltin:' | |||||
267 | return frozenset() |
|
267 | return frozenset() | |
268 |
|
268 | |||
269 |
|
269 | |||
|
270 | class HasStaticMethod: | |||
|
271 | @staticmethod | |||
|
272 | def static_method() -> HeapType: | |||
|
273 | return HeapType() | |||
|
274 | ||||
|
275 | ||||
|
276 | class InitReturnsFrozenset: | |||
|
277 | def __new__(self) -> frozenset: # type:ignore[misc] | |||
|
278 | return frozenset() | |||
|
279 | ||||
|
280 | ||||
270 | @pytest.mark.parametrize( |
|
281 | @pytest.mark.parametrize( | |
271 |
"data,good, |
|
282 | "data,good,expected,equality", | |
272 | [ |
|
283 | [ | |
273 |
[[1, 2, 3], "data.index(2)", |
|
284 | [[1, 2, 3], "data.index(2)", 1, True], | |
274 |
[{"a": 1}, "data.keys().isdisjoint({})", |
|
285 | [{"a": 1}, "data.keys().isdisjoint({})", True, True], | |
275 | [CallCreatesHeapType(), "data()", "data.__class__()", HeapType, False], |
|
286 | # test cases for `__call__` | |
276 |
[CallCreates |
|
287 | [CallCreatesHeapType(), "data()", HeapType, False], | |
|
288 | [CallCreatesBuiltin(), "data()", frozenset, False], | |||
|
289 | # Test cases for `__init__` | |||
|
290 | [HeapType, "data()", HeapType, False], | |||
|
291 | [InitReturnsFrozenset, "data()", frozenset, False], | |||
|
292 | [HeapType(), "data.__class__()", HeapType, False], | |||
|
293 | # test cases for static and class methods | |||
|
294 | [HasStaticMethod, "data.static_method()", HeapType, False], | |||
277 | ], |
|
295 | ], | |
278 | ) |
|
296 | ) | |
279 |
def test_evaluates_calls(data, good, |
|
297 | def test_evaluates_calls(data, good, expected, equality): | |
280 | context = limited(data=data) |
|
298 | context = limited(data=data) | |
281 | value = guarded_eval(good, context) |
|
299 | value = guarded_eval(good, context) | |
282 | if equality: |
|
300 | if equality: | |
@@ -284,6 +302,17 b' def test_evaluates_calls(data, good, bad, expected, equality):' | |||||
284 | else: |
|
302 | else: | |
285 | assert isinstance(value, expected) |
|
303 | assert isinstance(value, expected) | |
286 |
|
304 | |||
|
305 | ||||
|
306 | @pytest.mark.parametrize( | |||
|
307 | "data,bad", | |||
|
308 | [ | |||
|
309 | [[1, 2, 3], "data.append(4)"], | |||
|
310 | [{"a": 1}, "data.update()"], | |||
|
311 | ], | |||
|
312 | ) | |||
|
313 | def test_rejects_calls_with_side_effects(data, bad): | |||
|
314 | context = limited(data=data) | |||
|
315 | ||||
287 | with pytest.raises(GuardRejection): |
|
316 | with pytest.raises(GuardRejection): | |
288 | guarded_eval(bad, context) |
|
317 | guarded_eval(bad, context) | |
289 |
|
318 |
General Comments 0
You need to be logged in to leave comments.
Login now