From 7c22387ceef8de64cc1e687cfbbf3d8021c237c6 2024-02-27 10:07:08 From: M Bussonnier Date: 2024-02-27 10:07:08 Subject: [PATCH] Improve inference from return type annotations in completer (#14357) Addresses the issue of non-jedi completer not picking type annotations for `__init__()` brought up in https://github.com/ipython/ipython/issues/14336 ![completion_works](https://github.com/ipython/ipython/assets/5832902/73d44e26-123e-4691-87a6-e4d92c6f5061) Follow-up to https://github.com/ipython/ipython/pull/14185 Supports: - [x] `Annotated` - [x] `AnyStr` - [x] `Literal` - [x] `LiteralString` - [x] `NewType` - [x] `Optional` - [x] `Protocol` - [x] `Self` - [x] `TypeAliasType` (`type` keyword in Python 3.12+) - [x] `TypedDict` - [x] `TypeGuard` - [x] `Union` Limitations: - no type narrowing: ambiguous return types from `Union`, and `Optional` will always return all possible values - generics (`TypeVar` and `Generic`) are not support (except for `AnyStr`) - old style `TypeAlias` (deprecated in Python 3.12) is not supported --- diff --git a/IPython/core/guarded_eval.py b/IPython/core/guarded_eval.py index a304aff..d8ac992 100644 --- a/IPython/core/guarded_eval.py +++ b/IPython/core/guarded_eval.py @@ -1,16 +1,23 @@ -from inspect import signature, Signature +from inspect import isclass, signature, Signature from typing import ( - Any, + Annotated, + AnyStr, Callable, Dict, + Literal, + NamedTuple, + NewType, + Optional, + Protocol, Set, Sequence, Tuple, - NamedTuple, Type, - Literal, + TypeGuard, Union, - TYPE_CHECKING, + get_args, + get_origin, + is_typeddict, ) import ast import builtins @@ -21,15 +28,18 @@ from functools import cached_property from dataclasses import dataclass, field from types import MethodDescriptorType, ModuleType -from IPython.utils.docs import GENERATING_DOCUMENTATION from IPython.utils.decorators import undoc -if TYPE_CHECKING or GENERATING_DOCUMENTATION: - from typing_extensions import Protocol +if sys.version_info < (3, 11): + from typing_extensions import Self, LiteralString +else: + from typing import Self, LiteralString + +if sys.version_info < (3, 12): + from typing_extensions import TypeAliasType else: - # do not require on runtime - Protocol = object # requires Python >=3.8 + from typing import TypeAliasType @undoc @@ -337,6 +347,7 @@ class _IdentitySubscript: IDENTITY_SUBSCRIPT = _IdentitySubscript() SUBSCRIPT_MARKER = "__SUBSCRIPT_SENTINEL__" UNKNOWN_SIGNATURE = Signature() +NOT_EVALUATED = object() class GuardRejection(Exception): @@ -417,9 +428,37 @@ UNARY_OP_DUNDERS: Dict[Type[ast.unaryop], Tuple[str, ...]] = { } -class Duck: +class ImpersonatingDuck: """A dummy class used to create objects of other classes without calling their ``__init__``""" + # no-op: override __class__ to impersonate + + +class _Duck: + """A dummy class used to create objects pretending to have given attributes""" + + def __init__(self, attributes: Optional[dict] = None, items: Optional[dict] = None): + self.attributes = attributes or {} + self.items = items or {} + + def __getattr__(self, attr: str): + return self.attributes[attr] + + def __hasattr__(self, attr: str): + return attr in self.attributes + + def __dir__(self): + return [*dir(super), *self.attributes] + + def __getitem__(self, key: str): + return self.items[key] + + def __hasitem__(self, key: str): + return self.items[key] + + def _ipython_key_completions_(self): + return self.items.keys() + def _find_dunder(node_op, dunders) -> Union[Tuple[str, ...], None]: dunder = None @@ -557,19 +596,7 @@ def eval_node(node: Union[ast.AST, None], context: EvaluationContext): f" not allowed in {context.evaluation} mode", ) if isinstance(node, ast.Name): - if policy.allow_locals_access and node.id in context.locals: - return context.locals[node.id] - if policy.allow_globals_access and node.id in context.globals: - return context.globals[node.id] - if policy.allow_builtins_access and hasattr(builtins, node.id): - # note: do not use __builtins__, it is implementation detail of cPython - return getattr(builtins, node.id) - if not policy.allow_globals_access and not policy.allow_locals_access: - raise GuardRejection( - f"Namespace access not allowed in {context.evaluation} mode" - ) - else: - raise NameError(f"{node.id} not found in locals, globals, nor builtins") + return _eval_node_name(node.id, context) if isinstance(node, ast.Attribute): value = eval_node(node.value, context) if policy.can_get_attr(value, node.attr): @@ -590,27 +617,19 @@ def eval_node(node: Union[ast.AST, None], context: EvaluationContext): if policy.can_call(func) and not node.keywords: args = [eval_node(arg, context) for arg in node.args] return func(*args) - try: - sig = signature(func) - except ValueError: - sig = UNKNOWN_SIGNATURE - # if annotation was not stringized, or it was stringized - # but resolved by signature call we know the return type - not_empty = sig.return_annotation is not Signature.empty - not_stringized = not isinstance(sig.return_annotation, str) - if not_empty and not_stringized: - duck = Duck() - # if allow-listed builtin is on type annotation, instantiate it - if policy.can_call(sig.return_annotation) and not node.keywords: - args = [eval_node(arg, context) for arg in node.args] - return sig.return_annotation(*args) - try: - # if custom class is in type annotation, mock it; - # this only works for heap types, not builtins - duck.__class__ = sig.return_annotation - return duck - except TypeError: - pass + if isclass(func): + # this code path gets entered when calling class e.g. `MyClass()` + # or `my_instance.__class__()` - in both cases `func` is `MyClass`. + # Should return `MyClass` if `__new__` is not overridden, + # otherwise whatever `__new__` return type is. + overridden_return_type = _eval_return_type(func.__new__, node, context) + if overridden_return_type is not NOT_EVALUATED: + return overridden_return_type + return _create_duck_for_heap_type(func) + else: + return_type = _eval_return_type(func, node, context) + if return_type is not NOT_EVALUATED: + return return_type raise GuardRejection( "Call for", func, # not joined to avoid calling `repr` @@ -619,6 +638,125 @@ def eval_node(node: Union[ast.AST, None], context: EvaluationContext): raise ValueError("Unhandled node", ast.dump(node)) +def _eval_return_type(func: Callable, node: ast.Call, context: EvaluationContext): + """Evaluate return type of a given callable function. + + Returns the built-in type, a duck or NOT_EVALUATED sentinel. + """ + try: + sig = signature(func) + except ValueError: + sig = UNKNOWN_SIGNATURE + # if annotation was not stringized, or it was stringized + # but resolved by signature call we know the return type + not_empty = sig.return_annotation is not Signature.empty + if not_empty: + return _resolve_annotation(sig.return_annotation, sig, func, node, context) + return NOT_EVALUATED + + +def _resolve_annotation( + annotation, + sig: Signature, + func: Callable, + node: ast.Call, + context: EvaluationContext, +): + """Resolve annotation created by user with `typing` module and custom objects.""" + annotation = ( + _eval_node_name(annotation, context) + if isinstance(annotation, str) + else annotation + ) + origin = get_origin(annotation) + if annotation is Self and hasattr(func, "__self__"): + return func.__self__ + elif origin is Literal: + type_args = get_args(annotation) + if len(type_args) == 1: + return type_args[0] + elif annotation is LiteralString: + return "" + elif annotation is AnyStr: + index = None + for i, (key, value) in enumerate(sig.parameters.items()): + if value.annotation is AnyStr: + index = i + break + if index is not None and index < len(node.args): + return eval_node(node.args[index], context) + elif origin is TypeGuard: + return bool() + elif origin is Union: + attributes = [ + attr + for type_arg in get_args(annotation) + for attr in dir(_resolve_annotation(type_arg, sig, func, node, context)) + ] + return _Duck(attributes=dict.fromkeys(attributes)) + elif is_typeddict(annotation): + return _Duck( + attributes=dict.fromkeys(dir(dict())), + items={ + k: _resolve_annotation(v, sig, func, node, context) + for k, v in annotation.__annotations__.items() + }, + ) + elif hasattr(annotation, "_is_protocol"): + return _Duck(attributes=dict.fromkeys(dir(annotation))) + elif origin is Annotated: + type_arg = get_args(annotation)[0] + return _resolve_annotation(type_arg, sig, func, node, context) + elif isinstance(annotation, NewType): + return _eval_or_create_duck(annotation.__supertype__, node, context) + elif isinstance(annotation, TypeAliasType): + return _eval_or_create_duck(annotation.__value__, node, context) + else: + return _eval_or_create_duck(annotation, node, context) + + +def _eval_node_name(node_id: str, context: EvaluationContext): + policy = EVALUATION_POLICIES[context.evaluation] + if policy.allow_locals_access and node_id in context.locals: + return context.locals[node_id] + if policy.allow_globals_access and node_id in context.globals: + return context.globals[node_id] + if policy.allow_builtins_access and hasattr(builtins, node_id): + # note: do not use __builtins__, it is implementation detail of cPython + return getattr(builtins, node_id) + if not policy.allow_globals_access and not policy.allow_locals_access: + raise GuardRejection( + f"Namespace access not allowed in {context.evaluation} mode" + ) + else: + raise NameError(f"{node_id} not found in locals, globals, nor builtins") + + +def _eval_or_create_duck(duck_type, node: ast.Call, context: EvaluationContext): + policy = EVALUATION_POLICIES[context.evaluation] + # if allow-listed builtin is on type annotation, instantiate it + if policy.can_call(duck_type) and not node.keywords: + args = [eval_node(arg, context) for arg in node.args] + return duck_type(*args) + # if custom class is in type annotation, mock it + return _create_duck_for_heap_type(duck_type) + + +def _create_duck_for_heap_type(duck_type): + """Create an imitation of an object of a given type (a duck). + + Returns the duck or NOT_EVALUATED sentinel if duck could not be created. + """ + duck = ImpersonatingDuck() + try: + # this only works for heap types, not builtins + duck.__class__ = duck_type + return duck + except TypeError: + pass + return NOT_EVALUATED + + SUPPORTED_EXTERNAL_GETITEM = { ("pandas", "core", "indexing", "_iLocIndexer"), ("pandas", "core", "indexing", "_LocIndexer"), diff --git a/IPython/core/tests/test_guarded_eval.py b/IPython/core/tests/test_guarded_eval.py index 13f9091..f9057d8 100644 --- a/IPython/core/tests/test_guarded_eval.py +++ b/IPython/core/tests/test_guarded_eval.py @@ -1,5 +1,17 @@ +import sys from contextlib import contextmanager -from typing import NamedTuple +from typing import ( + Annotated, + AnyStr, + NamedTuple, + Literal, + NewType, + Optional, + Protocol, + TypeGuard, + Union, + TypedDict, +) from functools import partial from IPython.core.guarded_eval import ( EvaluationContext, @@ -11,6 +23,17 @@ from IPython.testing import decorators as dec import pytest +if sys.version_info < (3, 11): + from typing_extensions import Self, LiteralString +else: + from typing import Self, LiteralString + +if sys.version_info < (3, 12): + from typing_extensions import TypeAliasType +else: + from typing import TypeAliasType + + def create_context(evaluation: str, **kwargs): return EvaluationContext(locals=kwargs, globals={}, evaluation=evaluation) @@ -267,23 +290,183 @@ class CallCreatesBuiltin: return frozenset() +class HasStaticMethod: + @staticmethod + def static_method() -> HeapType: + return HeapType() + + +class InitReturnsFrozenset: + def __new__(self) -> frozenset: # type:ignore[misc] + return frozenset() + + +class StringAnnotation: + def heap(self) -> "HeapType": + return HeapType() + + def copy(self) -> "StringAnnotation": + return StringAnnotation() + + +CustomIntType = NewType("CustomIntType", int) +CustomHeapType = NewType("CustomHeapType", HeapType) +IntTypeAlias = TypeAliasType("IntTypeAlias", int) +HeapTypeAlias = TypeAliasType("HeapTypeAlias", HeapType) + + +class TestProtocol(Protocol): + def test_method(self) -> bool: + pass + + +class TestProtocolImplementer(TestProtocol): + def test_method(self) -> bool: + return True + + +class Movie(TypedDict): + name: str + year: int + + +class SpecialTyping: + def custom_int_type(self) -> CustomIntType: + return CustomIntType(1) + + def custom_heap_type(self) -> CustomHeapType: + return CustomHeapType(HeapType()) + + # TODO: remove type:ignore comment once mypy + # supports explicit calls to `TypeAliasType`, see: + # https://github.com/python/mypy/issues/16614 + def int_type_alias(self) -> IntTypeAlias: # type:ignore[valid-type] + return 1 + + def heap_type_alias(self) -> HeapTypeAlias: # type:ignore[valid-type] + return 1 + + def literal(self) -> Literal[False]: + return False + + def literal_string(self) -> LiteralString: + return "test" + + def self(self) -> Self: + return self + + def any_str(self, x: AnyStr) -> AnyStr: + return x + + def annotated(self) -> Annotated[float, "positive number"]: + return 1 + + def annotated_self(self) -> Annotated[Self, "self with metadata"]: + self._metadata = "test" + return self + + def int_type_guard(self, x) -> TypeGuard[int]: + return isinstance(x, int) + + def optional_float(self) -> Optional[float]: + return 1.0 + + def union_str_and_int(self) -> Union[str, int]: + return "" + + def protocol(self) -> TestProtocol: + return TestProtocolImplementer() + + def typed_dict(self) -> Movie: + return {"name": "The Matrix", "year": 1999} + + @pytest.mark.parametrize( - "data,good,bad,expected, equality", + "data,code,expected,equality", [ - [[1, 2, 3], "data.index(2)", "data.append(4)", 1, True], - [{"a": 1}, "data.keys().isdisjoint({})", "data.update()", True, True], - [CallCreatesHeapType(), "data()", "data.__class__()", HeapType, False], - [CallCreatesBuiltin(), "data()", "data.__class__()", frozenset, False], + [[1, 2, 3], "data.index(2)", 1, True], + [{"a": 1}, "data.keys().isdisjoint({})", True, True], + [StringAnnotation(), "data.heap()", HeapType, False], + [StringAnnotation(), "data.copy()", StringAnnotation, False], + # test cases for `__call__` + [CallCreatesHeapType(), "data()", HeapType, False], + [CallCreatesBuiltin(), "data()", frozenset, False], + # Test cases for `__init__` + [HeapType, "data()", HeapType, False], + [InitReturnsFrozenset, "data()", frozenset, False], + [HeapType(), "data.__class__()", HeapType, False], + # supported special cases for typing + [SpecialTyping(), "data.custom_int_type()", int, False], + [SpecialTyping(), "data.custom_heap_type()", HeapType, False], + [SpecialTyping(), "data.int_type_alias()", int, False], + [SpecialTyping(), "data.heap_type_alias()", HeapType, False], + [SpecialTyping(), "data.self()", SpecialTyping, False], + [SpecialTyping(), "data.literal()", False, True], + [SpecialTyping(), "data.literal_string()", str, False], + [SpecialTyping(), "data.any_str('a')", str, False], + [SpecialTyping(), "data.any_str(b'a')", bytes, False], + [SpecialTyping(), "data.annotated()", float, False], + [SpecialTyping(), "data.annotated_self()", SpecialTyping, False], + [SpecialTyping(), "data.int_type_guard()", int, False], + # test cases for static methods + [HasStaticMethod, "data.static_method()", HeapType, False], ], ) -def test_evaluates_calls(data, good, bad, expected, equality): - context = limited(data=data) - value = guarded_eval(good, context) +def test_evaluates_calls(data, code, expected, equality): + context = limited(data=data, HeapType=HeapType, StringAnnotation=StringAnnotation) + value = guarded_eval(code, context) if equality: assert value == expected else: assert isinstance(value, expected) + +@pytest.mark.parametrize( + "data,code,expected_attributes", + [ + [SpecialTyping(), "data.optional_float()", ["is_integer"]], + [ + SpecialTyping(), + "data.union_str_and_int()", + ["capitalize", "as_integer_ratio"], + ], + [SpecialTyping(), "data.protocol()", ["test_method"]], + [SpecialTyping(), "data.typed_dict()", ["keys", "values", "items"]], + ], +) +def test_mocks_attributes_of_call_results(data, code, expected_attributes): + context = limited(data=data, HeapType=HeapType, StringAnnotation=StringAnnotation) + result = guarded_eval(code, context) + for attr in expected_attributes: + assert hasattr(result, attr) + assert attr in dir(result) + + +@pytest.mark.parametrize( + "data,code,expected_items", + [ + [SpecialTyping(), "data.typed_dict()", {"year": int, "name": str}], + ], +) +def test_mocks_items_of_call_results(data, code, expected_items): + context = limited(data=data, HeapType=HeapType, StringAnnotation=StringAnnotation) + result = guarded_eval(code, context) + ipython_keys = result._ipython_key_completions_() + for key, value in expected_items.items(): + assert isinstance(result[key], value) + assert key in ipython_keys + + +@pytest.mark.parametrize( + "data,bad", + [ + [[1, 2, 3], "data.append(4)"], + [{"a": 1}, "data.update()"], + ], +) +def test_rejects_calls_with_side_effects(data, bad): + context = limited(data=data) + with pytest.raises(GuardRejection): guarded_eval(bad, context) diff --git a/pyproject.toml b/pyproject.toml index 6dc1389..151c61c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ dependencies = [ "pygments>=2.4.0", "stack_data", "traitlets>=5.13.0", - "typing_extensions; python_version<'3.10'", + "typing_extensions; python_version<'3.12'", ] dynamic = ["authors", "license", "version"]