##// END OF EJS Templates
Infer type for `__init__` calls (including `__new__` mods)
krassowski -
Show More
@@ -1,4 +1,4 b''
1 from inspect import signature, Signature
1 from inspect import isclass, signature, Signature
2 from typing import (
2 from typing import (
3 Any,
3 Any,
4 Callable,
4 Callable,
@@ -337,6 +337,7 b' class _IdentitySubscript:'
337 IDENTITY_SUBSCRIPT = _IdentitySubscript()
337 IDENTITY_SUBSCRIPT = _IdentitySubscript()
338 SUBSCRIPT_MARKER = "__SUBSCRIPT_SENTINEL__"
338 SUBSCRIPT_MARKER = "__SUBSCRIPT_SENTINEL__"
339 UNKNOWN_SIGNATURE = Signature()
339 UNKNOWN_SIGNATURE = Signature()
340 NOT_EVALUATED = object()
340
341
341
342
342 class GuardRejection(Exception):
343 class GuardRejection(Exception):
@@ -590,27 +591,21 b' def eval_node(node: Union[ast.AST, None], context: EvaluationContext):'
590 if policy.can_call(func) and not node.keywords:
591 if policy.can_call(func) and not node.keywords:
591 args = [eval_node(arg, context) for arg in node.args]
592 args = [eval_node(arg, context) for arg in node.args]
592 return func(*args)
593 return func(*args)
593 try:
594 if isclass(func):
594 sig = signature(func)
595 # this code path gets entered when calling class e.g. `MyClass()`
595 except ValueError:
596 # or `my_instance.__class__()` - in both cases `func` is `MyClass`.
596 sig = UNKNOWN_SIGNATURE
597 # Should return `MyClass` if `__new__` is not overridden,
597 # if annotation was not stringized, or it was stringized
598 # otherwise whatever `__new__` return type is.
598 # but resolved by signature call we know the return type
599 overridden_return_type = _eval_return_type(
599 not_empty = sig.return_annotation is not Signature.empty
600 func.__new__, policy, node, context
600 not_stringized = not isinstance(sig.return_annotation, str)
601 )
601 if not_empty and not_stringized:
602 if overridden_return_type is not NOT_EVALUATED:
602 duck = Duck()
603 return overridden_return_type
603 # if allow-listed builtin is on type annotation, instantiate it
604 return _create_duck_for_type(func)
604 if policy.can_call(sig.return_annotation) and not node.keywords:
605 else:
605 args = [eval_node(arg, context) for arg in node.args]
606 return_type = _eval_return_type(func, policy, node, context)
606 return sig.return_annotation(*args)
607 if return_type is not NOT_EVALUATED:
607 try:
608 return return_type
608 # if custom class is in type annotation, mock it;
609 # this only works for heap types, not builtins
610 duck.__class__ = sig.return_annotation
611 return duck
612 except TypeError:
613 pass
614 raise GuardRejection(
609 raise GuardRejection(
615 "Call for",
610 "Call for",
616 func, # not joined to avoid calling `repr`
611 func, # not joined to avoid calling `repr`
@@ -619,6 +614,44 b' def eval_node(node: Union[ast.AST, None], context: EvaluationContext):'
619 raise ValueError("Unhandled node", ast.dump(node))
614 raise ValueError("Unhandled node", ast.dump(node))
620
615
621
616
617 def _eval_return_type(func, policy, node, context):
618 """Evaluate return type of a given callable function.
619
620 Returns the built-in type, a duck or NOT_EVALUATED sentinel.
621 """
622 try:
623 sig = signature(func)
624 except ValueError:
625 sig = UNKNOWN_SIGNATURE
626 # if annotation was not stringized, or it was stringized
627 # but resolved by signature call we know the return type
628 not_empty = sig.return_annotation is not Signature.empty
629 not_stringized = not isinstance(sig.return_annotation, str)
630 if not_empty and not_stringized:
631 # if allow-listed builtin is on type annotation, instantiate it
632 if policy.can_call(sig.return_annotation) and not node.keywords:
633 args = [eval_node(arg, context) for arg in node.args]
634 # if custom class is in type annotation, mock it;
635 return sig.return_annotation(*args)
636 return _create_duck_for_type(sig.return_annotation)
637 return NOT_EVALUATED
638
639
640 def _create_duck_for_type(duck_type):
641 """Create an imitation of an object of a given type (a duck).
642
643 Returns the duck or NOT_EVALUATED sentinel if duck could not be created.
644 """
645 duck = Duck()
646 try:
647 # this only works for heap types, not builtins
648 duck.__class__ = duck_type
649 return duck
650 except TypeError:
651 pass
652 return NOT_EVALUATED
653
654
622 SUPPORTED_EXTERNAL_GETITEM = {
655 SUPPORTED_EXTERNAL_GETITEM = {
623 ("pandas", "core", "indexing", "_iLocIndexer"),
656 ("pandas", "core", "indexing", "_iLocIndexer"),
624 ("pandas", "core", "indexing", "_LocIndexer"),
657 ("pandas", "core", "indexing", "_LocIndexer"),
@@ -267,16 +267,34 b' class CallCreatesBuiltin:'
267 return frozenset()
267 return frozenset()
268
268
269
269
270 class HasStaticMethod:
271 @staticmethod
272 def static_method() -> HeapType:
273 return HeapType()
274
275
276 class InitReturnsFrozenset:
277 def __new__(self) -> frozenset: # type:ignore[misc]
278 return frozenset()
279
280
270 @pytest.mark.parametrize(
281 @pytest.mark.parametrize(
271 "data,good,bad,expected, equality",
282 "data,good,expected,equality",
272 [
283 [
273 [[1, 2, 3], "data.index(2)", "data.append(4)", 1, True],
284 [[1, 2, 3], "data.index(2)", 1, True],
274 [{"a": 1}, "data.keys().isdisjoint({})", "data.update()", True, True],
285 [{"a": 1}, "data.keys().isdisjoint({})", True, True],
275 [CallCreatesHeapType(), "data()", "data.__class__()", HeapType, False],
286 # test cases for `__call__`
276 [CallCreatesBuiltin(), "data()", "data.__class__()", frozenset, False],
287 [CallCreatesHeapType(), "data()", HeapType, False],
288 [CallCreatesBuiltin(), "data()", frozenset, False],
289 # Test cases for `__init__`
290 [HeapType, "data()", HeapType, False],
291 [InitReturnsFrozenset, "data()", frozenset, False],
292 [HeapType(), "data.__class__()", HeapType, False],
293 # test cases for static and class methods
294 [HasStaticMethod, "data.static_method()", HeapType, False],
277 ],
295 ],
278 )
296 )
279 def test_evaluates_calls(data, good, bad, expected, equality):
297 def test_evaluates_calls(data, good, expected, equality):
280 context = limited(data=data)
298 context = limited(data=data)
281 value = guarded_eval(good, context)
299 value = guarded_eval(good, context)
282 if equality:
300 if equality:
@@ -284,6 +302,17 b' def test_evaluates_calls(data, good, bad, expected, equality):'
284 else:
302 else:
285 assert isinstance(value, expected)
303 assert isinstance(value, expected)
286
304
305
306 @pytest.mark.parametrize(
307 "data,bad",
308 [
309 [[1, 2, 3], "data.append(4)"],
310 [{"a": 1}, "data.update()"],
311 ],
312 )
313 def test_rejects_calls_with_side_effects(data, bad):
314 context = limited(data=data)
315
287 with pytest.raises(GuardRejection):
316 with pytest.raises(GuardRejection):
288 guarded_eval(bad, context)
317 guarded_eval(bad, context)
289
318
General Comments 0
You need to be logged in to leave comments. Login now