##// END OF EJS Templates
Infer type for `__init__` calls (including `__new__` mods)
krassowski -
Show More
@@ -1,4 +1,4 b''
1 from inspect import signature, Signature
1 from inspect import isclass, signature, Signature
2 2 from typing import (
3 3 Any,
4 4 Callable,
@@ -337,6 +337,7 b' class _IdentitySubscript:'
337 337 IDENTITY_SUBSCRIPT = _IdentitySubscript()
338 338 SUBSCRIPT_MARKER = "__SUBSCRIPT_SENTINEL__"
339 339 UNKNOWN_SIGNATURE = Signature()
340 NOT_EVALUATED = object()
340 341
341 342
342 343 class GuardRejection(Exception):
@@ -590,6 +591,34 b' def eval_node(node: Union[ast.AST, None], context: EvaluationContext):'
590 591 if policy.can_call(func) and not node.keywords:
591 592 args = [eval_node(arg, context) for arg in node.args]
592 593 return func(*args)
594 if isclass(func):
595 # this code path gets entered when calling class e.g. `MyClass()`
596 # or `my_instance.__class__()` - in both cases `func` is `MyClass`.
597 # Should return `MyClass` if `__new__` is not overridden,
598 # otherwise whatever `__new__` return type is.
599 overridden_return_type = _eval_return_type(
600 func.__new__, policy, node, context
601 )
602 if overridden_return_type is not NOT_EVALUATED:
603 return overridden_return_type
604 return _create_duck_for_type(func)
605 else:
606 return_type = _eval_return_type(func, policy, node, context)
607 if return_type is not NOT_EVALUATED:
608 return return_type
609 raise GuardRejection(
610 "Call for",
611 func, # not joined to avoid calling `repr`
612 f"not allowed in {context.evaluation} mode",
613 )
614 raise ValueError("Unhandled node", ast.dump(node))
615
616
617 def _eval_return_type(func, policy, node, context):
618 """Evaluate return type of a given callable function.
619
620 Returns the built-in type, a duck or NOT_EVALUATED sentinel.
621 """
593 622 try:
594 623 sig = signature(func)
595 624 except ValueError:
@@ -599,24 +628,28 b' def eval_node(node: Union[ast.AST, None], context: EvaluationContext):'
599 628 not_empty = sig.return_annotation is not Signature.empty
600 629 not_stringized = not isinstance(sig.return_annotation, str)
601 630 if not_empty and not_stringized:
602 duck = Duck()
603 631 # if allow-listed builtin is on type annotation, instantiate it
604 632 if policy.can_call(sig.return_annotation) and not node.keywords:
605 633 args = [eval_node(arg, context) for arg in node.args]
634 # if custom class is in type annotation, mock it;
606 635 return sig.return_annotation(*args)
636 return _create_duck_for_type(sig.return_annotation)
637 return NOT_EVALUATED
638
639
640 def _create_duck_for_type(duck_type):
641 """Create an imitation of an object of a given type (a duck).
642
643 Returns the duck or NOT_EVALUATED sentinel if duck could not be created.
644 """
645 duck = Duck()
607 646 try:
608 # if custom class is in type annotation, mock it;
609 647 # this only works for heap types, not builtins
610 duck.__class__ = sig.return_annotation
648 duck.__class__ = duck_type
611 649 return duck
612 650 except TypeError:
613 651 pass
614 raise GuardRejection(
615 "Call for",
616 func, # not joined to avoid calling `repr`
617 f"not allowed in {context.evaluation} mode",
618 )
619 raise ValueError("Unhandled node", ast.dump(node))
652 return NOT_EVALUATED
620 653
621 654
622 655 SUPPORTED_EXTERNAL_GETITEM = {
@@ -267,16 +267,34 b' class CallCreatesBuiltin:'
267 267 return frozenset()
268 268
269 269
270 class HasStaticMethod:
271 @staticmethod
272 def static_method() -> HeapType:
273 return HeapType()
274
275
276 class InitReturnsFrozenset:
277 def __new__(self) -> frozenset: # type:ignore[misc]
278 return frozenset()
279
280
270 281 @pytest.mark.parametrize(
271 "data,good,bad,expected, equality",
282 "data,good,expected,equality",
272 283 [
273 [[1, 2, 3], "data.index(2)", "data.append(4)", 1, True],
274 [{"a": 1}, "data.keys().isdisjoint({})", "data.update()", True, True],
275 [CallCreatesHeapType(), "data()", "data.__class__()", HeapType, False],
276 [CallCreatesBuiltin(), "data()", "data.__class__()", frozenset, False],
284 [[1, 2, 3], "data.index(2)", 1, True],
285 [{"a": 1}, "data.keys().isdisjoint({})", True, True],
286 # test cases for `__call__`
287 [CallCreatesHeapType(), "data()", HeapType, False],
288 [CallCreatesBuiltin(), "data()", frozenset, False],
289 # Test cases for `__init__`
290 [HeapType, "data()", HeapType, False],
291 [InitReturnsFrozenset, "data()", frozenset, False],
292 [HeapType(), "data.__class__()", HeapType, False],
293 # test cases for static and class methods
294 [HasStaticMethod, "data.static_method()", HeapType, False],
277 295 ],
278 296 )
279 def test_evaluates_calls(data, good, bad, expected, equality):
297 def test_evaluates_calls(data, good, expected, equality):
280 298 context = limited(data=data)
281 299 value = guarded_eval(good, context)
282 300 if equality:
@@ -284,6 +302,17 b' def test_evaluates_calls(data, good, bad, expected, equality):'
284 302 else:
285 303 assert isinstance(value, expected)
286 304
305
306 @pytest.mark.parametrize(
307 "data,bad",
308 [
309 [[1, 2, 3], "data.append(4)"],
310 [{"a": 1}, "data.update()"],
311 ],
312 )
313 def test_rejects_calls_with_side_effects(data, bad):
314 context = limited(data=data)
315
287 316 with pytest.raises(GuardRejection):
288 317 guarded_eval(bad, context)
289 318
General Comments 0
You need to be logged in to leave comments. Login now