##// END OF EJS Templates
Improve inference from return type annotations in completer (#14357)...
M Bussonnier -
r28682:7c22387c merge
parent child Browse files
Show More
@@ -1,16 +1,23 b''
1 from inspect import signature, Signature
1 from inspect import isclass, signature, Signature
2 from typing import (
2 from typing import (
3 Any,
3 Annotated,
4 AnyStr,
4 Callable,
5 Callable,
5 Dict,
6 Dict,
7 Literal,
8 NamedTuple,
9 NewType,
10 Optional,
11 Protocol,
6 Set,
12 Set,
7 Sequence,
13 Sequence,
8 Tuple,
14 Tuple,
9 NamedTuple,
10 Type,
15 Type,
11 Literal,
16 TypeGuard,
12 Union,
17 Union,
13 TYPE_CHECKING,
18 get_args,
19 get_origin,
20 is_typeddict,
14 )
21 )
15 import ast
22 import ast
16 import builtins
23 import builtins
@@ -21,15 +28,18 b' from functools import cached_property'
21 from dataclasses import dataclass, field
28 from dataclasses import dataclass, field
22 from types import MethodDescriptorType, ModuleType
29 from types import MethodDescriptorType, ModuleType
23
30
24 from IPython.utils.docs import GENERATING_DOCUMENTATION
25 from IPython.utils.decorators import undoc
31 from IPython.utils.decorators import undoc
26
32
27
33
28 if TYPE_CHECKING or GENERATING_DOCUMENTATION:
34 if sys.version_info < (3, 11):
29 from typing_extensions import Protocol
35 from typing_extensions import Self, LiteralString
36 else:
37 from typing import Self, LiteralString
38
39 if sys.version_info < (3, 12):
40 from typing_extensions import TypeAliasType
30 else:
41 else:
31 # do not require on runtime
42 from typing import TypeAliasType
32 Protocol = object # requires Python >=3.8
33
43
34
44
35 @undoc
45 @undoc
@@ -337,6 +347,7 b' class _IdentitySubscript:'
337 IDENTITY_SUBSCRIPT = _IdentitySubscript()
347 IDENTITY_SUBSCRIPT = _IdentitySubscript()
338 SUBSCRIPT_MARKER = "__SUBSCRIPT_SENTINEL__"
348 SUBSCRIPT_MARKER = "__SUBSCRIPT_SENTINEL__"
339 UNKNOWN_SIGNATURE = Signature()
349 UNKNOWN_SIGNATURE = Signature()
350 NOT_EVALUATED = object()
340
351
341
352
342 class GuardRejection(Exception):
353 class GuardRejection(Exception):
@@ -417,9 +428,37 b' UNARY_OP_DUNDERS: Dict[Type[ast.unaryop], Tuple[str, ...]] = {'
417 }
428 }
418
429
419
430
420 class Duck:
431 class ImpersonatingDuck:
421 """A dummy class used to create objects of other classes without calling their ``__init__``"""
432 """A dummy class used to create objects of other classes without calling their ``__init__``"""
422
433
434 # no-op: override __class__ to impersonate
435
436
437 class _Duck:
438 """A dummy class used to create objects pretending to have given attributes"""
439
440 def __init__(self, attributes: Optional[dict] = None, items: Optional[dict] = None):
441 self.attributes = attributes or {}
442 self.items = items or {}
443
444 def __getattr__(self, attr: str):
445 return self.attributes[attr]
446
447 def __hasattr__(self, attr: str):
448 return attr in self.attributes
449
450 def __dir__(self):
451 return [*dir(super), *self.attributes]
452
453 def __getitem__(self, key: str):
454 return self.items[key]
455
456 def __hasitem__(self, key: str):
457 return self.items[key]
458
459 def _ipython_key_completions_(self):
460 return self.items.keys()
461
423
462
424 def _find_dunder(node_op, dunders) -> Union[Tuple[str, ...], None]:
463 def _find_dunder(node_op, dunders) -> Union[Tuple[str, ...], None]:
425 dunder = None
464 dunder = None
@@ -557,19 +596,7 b' def eval_node(node: Union[ast.AST, None], context: EvaluationContext):'
557 f" not allowed in {context.evaluation} mode",
596 f" not allowed in {context.evaluation} mode",
558 )
597 )
559 if isinstance(node, ast.Name):
598 if isinstance(node, ast.Name):
560 if policy.allow_locals_access and node.id in context.locals:
599 return _eval_node_name(node.id, context)
561 return context.locals[node.id]
562 if policy.allow_globals_access and node.id in context.globals:
563 return context.globals[node.id]
564 if policy.allow_builtins_access and hasattr(builtins, node.id):
565 # note: do not use __builtins__, it is implementation detail of cPython
566 return getattr(builtins, node.id)
567 if not policy.allow_globals_access and not policy.allow_locals_access:
568 raise GuardRejection(
569 f"Namespace access not allowed in {context.evaluation} mode"
570 )
571 else:
572 raise NameError(f"{node.id} not found in locals, globals, nor builtins")
573 if isinstance(node, ast.Attribute):
600 if isinstance(node, ast.Attribute):
574 value = eval_node(node.value, context)
601 value = eval_node(node.value, context)
575 if policy.can_get_attr(value, node.attr):
602 if policy.can_get_attr(value, node.attr):
@@ -590,27 +617,19 b' def eval_node(node: Union[ast.AST, None], context: EvaluationContext):'
590 if policy.can_call(func) and not node.keywords:
617 if policy.can_call(func) and not node.keywords:
591 args = [eval_node(arg, context) for arg in node.args]
618 args = [eval_node(arg, context) for arg in node.args]
592 return func(*args)
619 return func(*args)
593 try:
620 if isclass(func):
594 sig = signature(func)
621 # this code path gets entered when calling class e.g. `MyClass()`
595 except ValueError:
622 # or `my_instance.__class__()` - in both cases `func` is `MyClass`.
596 sig = UNKNOWN_SIGNATURE
623 # Should return `MyClass` if `__new__` is not overridden,
597 # if annotation was not stringized, or it was stringized
624 # otherwise whatever `__new__` return type is.
598 # but resolved by signature call we know the return type
625 overridden_return_type = _eval_return_type(func.__new__, node, context)
599 not_empty = sig.return_annotation is not Signature.empty
626 if overridden_return_type is not NOT_EVALUATED:
600 not_stringized = not isinstance(sig.return_annotation, str)
627 return overridden_return_type
601 if not_empty and not_stringized:
628 return _create_duck_for_heap_type(func)
602 duck = Duck()
629 else:
603 # if allow-listed builtin is on type annotation, instantiate it
630 return_type = _eval_return_type(func, node, context)
604 if policy.can_call(sig.return_annotation) and not node.keywords:
631 if return_type is not NOT_EVALUATED:
605 args = [eval_node(arg, context) for arg in node.args]
632 return return_type
606 return sig.return_annotation(*args)
607 try:
608 # if custom class is in type annotation, mock it;
609 # this only works for heap types, not builtins
610 duck.__class__ = sig.return_annotation
611 return duck
612 except TypeError:
613 pass
614 raise GuardRejection(
633 raise GuardRejection(
615 "Call for",
634 "Call for",
616 func, # not joined to avoid calling `repr`
635 func, # not joined to avoid calling `repr`
@@ -619,6 +638,125 b' def eval_node(node: Union[ast.AST, None], context: EvaluationContext):'
619 raise ValueError("Unhandled node", ast.dump(node))
638 raise ValueError("Unhandled node", ast.dump(node))
620
639
621
640
641 def _eval_return_type(func: Callable, node: ast.Call, context: EvaluationContext):
642 """Evaluate return type of a given callable function.
643
644 Returns the built-in type, a duck or NOT_EVALUATED sentinel.
645 """
646 try:
647 sig = signature(func)
648 except ValueError:
649 sig = UNKNOWN_SIGNATURE
650 # if annotation was not stringized, or it was stringized
651 # but resolved by signature call we know the return type
652 not_empty = sig.return_annotation is not Signature.empty
653 if not_empty:
654 return _resolve_annotation(sig.return_annotation, sig, func, node, context)
655 return NOT_EVALUATED
656
657
658 def _resolve_annotation(
659 annotation,
660 sig: Signature,
661 func: Callable,
662 node: ast.Call,
663 context: EvaluationContext,
664 ):
665 """Resolve annotation created by user with `typing` module and custom objects."""
666 annotation = (
667 _eval_node_name(annotation, context)
668 if isinstance(annotation, str)
669 else annotation
670 )
671 origin = get_origin(annotation)
672 if annotation is Self and hasattr(func, "__self__"):
673 return func.__self__
674 elif origin is Literal:
675 type_args = get_args(annotation)
676 if len(type_args) == 1:
677 return type_args[0]
678 elif annotation is LiteralString:
679 return ""
680 elif annotation is AnyStr:
681 index = None
682 for i, (key, value) in enumerate(sig.parameters.items()):
683 if value.annotation is AnyStr:
684 index = i
685 break
686 if index is not None and index < len(node.args):
687 return eval_node(node.args[index], context)
688 elif origin is TypeGuard:
689 return bool()
690 elif origin is Union:
691 attributes = [
692 attr
693 for type_arg in get_args(annotation)
694 for attr in dir(_resolve_annotation(type_arg, sig, func, node, context))
695 ]
696 return _Duck(attributes=dict.fromkeys(attributes))
697 elif is_typeddict(annotation):
698 return _Duck(
699 attributes=dict.fromkeys(dir(dict())),
700 items={
701 k: _resolve_annotation(v, sig, func, node, context)
702 for k, v in annotation.__annotations__.items()
703 },
704 )
705 elif hasattr(annotation, "_is_protocol"):
706 return _Duck(attributes=dict.fromkeys(dir(annotation)))
707 elif origin is Annotated:
708 type_arg = get_args(annotation)[0]
709 return _resolve_annotation(type_arg, sig, func, node, context)
710 elif isinstance(annotation, NewType):
711 return _eval_or_create_duck(annotation.__supertype__, node, context)
712 elif isinstance(annotation, TypeAliasType):
713 return _eval_or_create_duck(annotation.__value__, node, context)
714 else:
715 return _eval_or_create_duck(annotation, node, context)
716
717
718 def _eval_node_name(node_id: str, context: EvaluationContext):
719 policy = EVALUATION_POLICIES[context.evaluation]
720 if policy.allow_locals_access and node_id in context.locals:
721 return context.locals[node_id]
722 if policy.allow_globals_access and node_id in context.globals:
723 return context.globals[node_id]
724 if policy.allow_builtins_access and hasattr(builtins, node_id):
725 # note: do not use __builtins__, it is implementation detail of cPython
726 return getattr(builtins, node_id)
727 if not policy.allow_globals_access and not policy.allow_locals_access:
728 raise GuardRejection(
729 f"Namespace access not allowed in {context.evaluation} mode"
730 )
731 else:
732 raise NameError(f"{node_id} not found in locals, globals, nor builtins")
733
734
735 def _eval_or_create_duck(duck_type, node: ast.Call, context: EvaluationContext):
736 policy = EVALUATION_POLICIES[context.evaluation]
737 # if allow-listed builtin is on type annotation, instantiate it
738 if policy.can_call(duck_type) and not node.keywords:
739 args = [eval_node(arg, context) for arg in node.args]
740 return duck_type(*args)
741 # if custom class is in type annotation, mock it
742 return _create_duck_for_heap_type(duck_type)
743
744
745 def _create_duck_for_heap_type(duck_type):
746 """Create an imitation of an object of a given type (a duck).
747
748 Returns the duck or NOT_EVALUATED sentinel if duck could not be created.
749 """
750 duck = ImpersonatingDuck()
751 try:
752 # this only works for heap types, not builtins
753 duck.__class__ = duck_type
754 return duck
755 except TypeError:
756 pass
757 return NOT_EVALUATED
758
759
622 SUPPORTED_EXTERNAL_GETITEM = {
760 SUPPORTED_EXTERNAL_GETITEM = {
623 ("pandas", "core", "indexing", "_iLocIndexer"),
761 ("pandas", "core", "indexing", "_iLocIndexer"),
624 ("pandas", "core", "indexing", "_LocIndexer"),
762 ("pandas", "core", "indexing", "_LocIndexer"),
@@ -1,5 +1,17 b''
1 import sys
1 from contextlib import contextmanager
2 from contextlib import contextmanager
2 from typing import NamedTuple
3 from typing import (
4 Annotated,
5 AnyStr,
6 NamedTuple,
7 Literal,
8 NewType,
9 Optional,
10 Protocol,
11 TypeGuard,
12 Union,
13 TypedDict,
14 )
3 from functools import partial
15 from functools import partial
4 from IPython.core.guarded_eval import (
16 from IPython.core.guarded_eval import (
5 EvaluationContext,
17 EvaluationContext,
@@ -11,6 +23,17 b' from IPython.testing import decorators as dec'
11 import pytest
23 import pytest
12
24
13
25
26 if sys.version_info < (3, 11):
27 from typing_extensions import Self, LiteralString
28 else:
29 from typing import Self, LiteralString
30
31 if sys.version_info < (3, 12):
32 from typing_extensions import TypeAliasType
33 else:
34 from typing import TypeAliasType
35
36
14 def create_context(evaluation: str, **kwargs):
37 def create_context(evaluation: str, **kwargs):
15 return EvaluationContext(locals=kwargs, globals={}, evaluation=evaluation)
38 return EvaluationContext(locals=kwargs, globals={}, evaluation=evaluation)
16
39
@@ -267,23 +290,183 b' class CallCreatesBuiltin:'
267 return frozenset()
290 return frozenset()
268
291
269
292
293 class HasStaticMethod:
294 @staticmethod
295 def static_method() -> HeapType:
296 return HeapType()
297
298
299 class InitReturnsFrozenset:
300 def __new__(self) -> frozenset: # type:ignore[misc]
301 return frozenset()
302
303
304 class StringAnnotation:
305 def heap(self) -> "HeapType":
306 return HeapType()
307
308 def copy(self) -> "StringAnnotation":
309 return StringAnnotation()
310
311
312 CustomIntType = NewType("CustomIntType", int)
313 CustomHeapType = NewType("CustomHeapType", HeapType)
314 IntTypeAlias = TypeAliasType("IntTypeAlias", int)
315 HeapTypeAlias = TypeAliasType("HeapTypeAlias", HeapType)
316
317
318 class TestProtocol(Protocol):
319 def test_method(self) -> bool:
320 pass
321
322
323 class TestProtocolImplementer(TestProtocol):
324 def test_method(self) -> bool:
325 return True
326
327
328 class Movie(TypedDict):
329 name: str
330 year: int
331
332
333 class SpecialTyping:
334 def custom_int_type(self) -> CustomIntType:
335 return CustomIntType(1)
336
337 def custom_heap_type(self) -> CustomHeapType:
338 return CustomHeapType(HeapType())
339
340 # TODO: remove type:ignore comment once mypy
341 # supports explicit calls to `TypeAliasType`, see:
342 # https://github.com/python/mypy/issues/16614
343 def int_type_alias(self) -> IntTypeAlias: # type:ignore[valid-type]
344 return 1
345
346 def heap_type_alias(self) -> HeapTypeAlias: # type:ignore[valid-type]
347 return 1
348
349 def literal(self) -> Literal[False]:
350 return False
351
352 def literal_string(self) -> LiteralString:
353 return "test"
354
355 def self(self) -> Self:
356 return self
357
358 def any_str(self, x: AnyStr) -> AnyStr:
359 return x
360
361 def annotated(self) -> Annotated[float, "positive number"]:
362 return 1
363
364 def annotated_self(self) -> Annotated[Self, "self with metadata"]:
365 self._metadata = "test"
366 return self
367
368 def int_type_guard(self, x) -> TypeGuard[int]:
369 return isinstance(x, int)
370
371 def optional_float(self) -> Optional[float]:
372 return 1.0
373
374 def union_str_and_int(self) -> Union[str, int]:
375 return ""
376
377 def protocol(self) -> TestProtocol:
378 return TestProtocolImplementer()
379
380 def typed_dict(self) -> Movie:
381 return {"name": "The Matrix", "year": 1999}
382
383
270 @pytest.mark.parametrize(
384 @pytest.mark.parametrize(
271 "data,good,bad,expected, equality",
385 "data,code,expected,equality",
272 [
386 [
273 [[1, 2, 3], "data.index(2)", "data.append(4)", 1, True],
387 [[1, 2, 3], "data.index(2)", 1, True],
274 [{"a": 1}, "data.keys().isdisjoint({})", "data.update()", True, True],
388 [{"a": 1}, "data.keys().isdisjoint({})", True, True],
275 [CallCreatesHeapType(), "data()", "data.__class__()", HeapType, False],
389 [StringAnnotation(), "data.heap()", HeapType, False],
276 [CallCreatesBuiltin(), "data()", "data.__class__()", frozenset, False],
390 [StringAnnotation(), "data.copy()", StringAnnotation, False],
391 # test cases for `__call__`
392 [CallCreatesHeapType(), "data()", HeapType, False],
393 [CallCreatesBuiltin(), "data()", frozenset, False],
394 # Test cases for `__init__`
395 [HeapType, "data()", HeapType, False],
396 [InitReturnsFrozenset, "data()", frozenset, False],
397 [HeapType(), "data.__class__()", HeapType, False],
398 # supported special cases for typing
399 [SpecialTyping(), "data.custom_int_type()", int, False],
400 [SpecialTyping(), "data.custom_heap_type()", HeapType, False],
401 [SpecialTyping(), "data.int_type_alias()", int, False],
402 [SpecialTyping(), "data.heap_type_alias()", HeapType, False],
403 [SpecialTyping(), "data.self()", SpecialTyping, False],
404 [SpecialTyping(), "data.literal()", False, True],
405 [SpecialTyping(), "data.literal_string()", str, False],
406 [SpecialTyping(), "data.any_str('a')", str, False],
407 [SpecialTyping(), "data.any_str(b'a')", bytes, False],
408 [SpecialTyping(), "data.annotated()", float, False],
409 [SpecialTyping(), "data.annotated_self()", SpecialTyping, False],
410 [SpecialTyping(), "data.int_type_guard()", int, False],
411 # test cases for static methods
412 [HasStaticMethod, "data.static_method()", HeapType, False],
277 ],
413 ],
278 )
414 )
279 def test_evaluates_calls(data, good, bad, expected, equality):
415 def test_evaluates_calls(data, code, expected, equality):
280 context = limited(data=data)
416 context = limited(data=data, HeapType=HeapType, StringAnnotation=StringAnnotation)
281 value = guarded_eval(good, context)
417 value = guarded_eval(code, context)
282 if equality:
418 if equality:
283 assert value == expected
419 assert value == expected
284 else:
420 else:
285 assert isinstance(value, expected)
421 assert isinstance(value, expected)
286
422
423
424 @pytest.mark.parametrize(
425 "data,code,expected_attributes",
426 [
427 [SpecialTyping(), "data.optional_float()", ["is_integer"]],
428 [
429 SpecialTyping(),
430 "data.union_str_and_int()",
431 ["capitalize", "as_integer_ratio"],
432 ],
433 [SpecialTyping(), "data.protocol()", ["test_method"]],
434 [SpecialTyping(), "data.typed_dict()", ["keys", "values", "items"]],
435 ],
436 )
437 def test_mocks_attributes_of_call_results(data, code, expected_attributes):
438 context = limited(data=data, HeapType=HeapType, StringAnnotation=StringAnnotation)
439 result = guarded_eval(code, context)
440 for attr in expected_attributes:
441 assert hasattr(result, attr)
442 assert attr in dir(result)
443
444
445 @pytest.mark.parametrize(
446 "data,code,expected_items",
447 [
448 [SpecialTyping(), "data.typed_dict()", {"year": int, "name": str}],
449 ],
450 )
451 def test_mocks_items_of_call_results(data, code, expected_items):
452 context = limited(data=data, HeapType=HeapType, StringAnnotation=StringAnnotation)
453 result = guarded_eval(code, context)
454 ipython_keys = result._ipython_key_completions_()
455 for key, value in expected_items.items():
456 assert isinstance(result[key], value)
457 assert key in ipython_keys
458
459
460 @pytest.mark.parametrize(
461 "data,bad",
462 [
463 [[1, 2, 3], "data.append(4)"],
464 [{"a": 1}, "data.update()"],
465 ],
466 )
467 def test_rejects_calls_with_side_effects(data, bad):
468 context = limited(data=data)
469
287 with pytest.raises(GuardRejection):
470 with pytest.raises(GuardRejection):
288 guarded_eval(bad, context)
471 guarded_eval(bad, context)
289
472
@@ -32,7 +32,7 b' dependencies = ['
32 "pygments>=2.4.0",
32 "pygments>=2.4.0",
33 "stack_data",
33 "stack_data",
34 "traitlets>=5.13.0",
34 "traitlets>=5.13.0",
35 "typing_extensions; python_version<'3.10'",
35 "typing_extensions; python_version<'3.12'",
36 ]
36 ]
37 dynamic = ["authors", "license", "version"]
37 dynamic = ["authors", "license", "version"]
38
38
General Comments 0
You need to be logged in to leave comments. Login now