From 55e79ef42b5f74b76617c96c788642208d59e26b 2024-02-26 21:58:26 From: krassowski <5832902+krassowski@users.noreply.github.com> Date: 2024-02-26 21:58:26 Subject: [PATCH] Implement remaining special `typing` wrappers --- diff --git a/IPython/core/guarded_eval.py b/IPython/core/guarded_eval.py index 4ae5a41..d8ac992 100644 --- a/IPython/core/guarded_eval.py +++ b/IPython/core/guarded_eval.py @@ -1,18 +1,23 @@ from inspect import isclass, signature, Signature from typing import ( + Annotated, + AnyStr, Callable, Dict, Literal, NamedTuple, NewType, + Optional, + Protocol, Set, Sequence, Tuple, Type, - Protocol, + TypeGuard, Union, get_args, get_origin, + is_typeddict, ) import ast import builtins @@ -27,9 +32,9 @@ from IPython.utils.decorators import undoc if sys.version_info < (3, 11): - from typing_extensions import Self + from typing_extensions import Self, LiteralString else: - from typing import Self + from typing import Self, LiteralString if sys.version_info < (3, 12): from typing_extensions import TypeAliasType @@ -423,9 +428,37 @@ UNARY_OP_DUNDERS: Dict[Type[ast.unaryop], Tuple[str, ...]] = { } -class Duck: +class ImpersonatingDuck: """A dummy class used to create objects of other classes without calling their ``__init__``""" + # no-op: override __class__ to impersonate + + +class _Duck: + """A dummy class used to create objects pretending to have given attributes""" + + def __init__(self, attributes: Optional[dict] = None, items: Optional[dict] = None): + self.attributes = attributes or {} + self.items = items or {} + + def __getattr__(self, attr: str): + return self.attributes[attr] + + def __hasattr__(self, attr: str): + return attr in self.attributes + + def __dir__(self): + return [*dir(super), *self.attributes] + + def __getitem__(self, key: str): + return self.items[key] + + def __hasitem__(self, key: str): + return self.items[key] + + def _ipython_key_completions_(self): + return self.items.keys() + def _find_dunder(node_op, dunders) -> Union[Tuple[str, ...], None]: dunder = None @@ -617,28 +650,71 @@ def _eval_return_type(func: Callable, node: ast.Call, context: EvaluationContext # if annotation was not stringized, or it was stringized # but resolved by signature call we know the return type not_empty = sig.return_annotation is not Signature.empty - stringized = isinstance(sig.return_annotation, str) if not_empty: - return_type = ( - _eval_node_name(sig.return_annotation, context) - if stringized - else sig.return_annotation - ) - if return_type is Self and hasattr(func, "__self__"): - return func.__self__ - elif get_origin(return_type) is Literal: - type_args = get_args(return_type) - if len(type_args) == 1: - return type_args[0] - elif isinstance(return_type, NewType): - return _eval_or_create_duck(return_type.__supertype__, node, context) - elif isinstance(return_type, TypeAliasType): - return _eval_or_create_duck(return_type.__value__, node, context) - else: - return _eval_or_create_duck(return_type, node, context) + return _resolve_annotation(sig.return_annotation, sig, func, node, context) return NOT_EVALUATED +def _resolve_annotation( + annotation, + sig: Signature, + func: Callable, + node: ast.Call, + context: EvaluationContext, +): + """Resolve annotation created by user with `typing` module and custom objects.""" + annotation = ( + _eval_node_name(annotation, context) + if isinstance(annotation, str) + else annotation + ) + origin = get_origin(annotation) + if annotation is Self and hasattr(func, "__self__"): + return func.__self__ + elif origin is Literal: + type_args = get_args(annotation) + if len(type_args) == 1: + return type_args[0] + elif annotation is LiteralString: + return "" + elif annotation is AnyStr: + index = None + for i, (key, value) in enumerate(sig.parameters.items()): + if value.annotation is AnyStr: + index = i + break + if index is not None and index < len(node.args): + return eval_node(node.args[index], context) + elif origin is TypeGuard: + return bool() + elif origin is Union: + attributes = [ + attr + for type_arg in get_args(annotation) + for attr in dir(_resolve_annotation(type_arg, sig, func, node, context)) + ] + return _Duck(attributes=dict.fromkeys(attributes)) + elif is_typeddict(annotation): + return _Duck( + attributes=dict.fromkeys(dir(dict())), + items={ + k: _resolve_annotation(v, sig, func, node, context) + for k, v in annotation.__annotations__.items() + }, + ) + elif hasattr(annotation, "_is_protocol"): + return _Duck(attributes=dict.fromkeys(dir(annotation))) + elif origin is Annotated: + type_arg = get_args(annotation)[0] + return _resolve_annotation(type_arg, sig, func, node, context) + elif isinstance(annotation, NewType): + return _eval_or_create_duck(annotation.__supertype__, node, context) + elif isinstance(annotation, TypeAliasType): + return _eval_or_create_duck(annotation.__value__, node, context) + else: + return _eval_or_create_duck(annotation, node, context) + + def _eval_node_name(node_id: str, context: EvaluationContext): policy = EVALUATION_POLICIES[context.evaluation] if policy.allow_locals_access and node_id in context.locals: @@ -671,7 +747,7 @@ def _create_duck_for_heap_type(duck_type): Returns the duck or NOT_EVALUATED sentinel if duck could not be created. """ - duck = Duck() + duck = ImpersonatingDuck() try: # this only works for heap types, not builtins duck.__class__ = duck_type diff --git a/IPython/core/tests/test_guarded_eval.py b/IPython/core/tests/test_guarded_eval.py index cb45db5..f9057d8 100644 --- a/IPython/core/tests/test_guarded_eval.py +++ b/IPython/core/tests/test_guarded_eval.py @@ -1,6 +1,17 @@ import sys from contextlib import contextmanager -from typing import NamedTuple, Literal, NewType +from typing import ( + Annotated, + AnyStr, + NamedTuple, + Literal, + NewType, + Optional, + Protocol, + TypeGuard, + Union, + TypedDict, +) from functools import partial from IPython.core.guarded_eval import ( EvaluationContext, @@ -13,9 +24,9 @@ import pytest if sys.version_info < (3, 11): - from typing_extensions import Self + from typing_extensions import Self, LiteralString else: - from typing import Self + from typing import Self, LiteralString if sys.version_info < (3, 12): from typing_extensions import TypeAliasType @@ -304,6 +315,21 @@ IntTypeAlias = TypeAliasType("IntTypeAlias", int) HeapTypeAlias = TypeAliasType("HeapTypeAlias", HeapType) +class TestProtocol(Protocol): + def test_method(self) -> bool: + pass + + +class TestProtocolImplementer(TestProtocol): + def test_method(self) -> bool: + return True + + +class Movie(TypedDict): + name: str + year: int + + class SpecialTyping: def custom_int_type(self) -> CustomIntType: return CustomIntType(1) @@ -323,12 +349,40 @@ class SpecialTyping: def literal(self) -> Literal[False]: return False + def literal_string(self) -> LiteralString: + return "test" + def self(self) -> Self: return self + def any_str(self, x: AnyStr) -> AnyStr: + return x + + def annotated(self) -> Annotated[float, "positive number"]: + return 1 + + def annotated_self(self) -> Annotated[Self, "self with metadata"]: + self._metadata = "test" + return self + + def int_type_guard(self, x) -> TypeGuard[int]: + return isinstance(x, int) + + def optional_float(self) -> Optional[float]: + return 1.0 + + def union_str_and_int(self) -> Union[str, int]: + return "" + + def protocol(self) -> TestProtocol: + return TestProtocolImplementer() + + def typed_dict(self) -> Movie: + return {"name": "The Matrix", "year": 1999} + @pytest.mark.parametrize( - "data,good,expected,equality", + "data,code,expected,equality", [ [[1, 2, 3], "data.index(2)", 1, True], [{"a": 1}, "data.keys().isdisjoint({})", True, True], @@ -348,13 +402,19 @@ class SpecialTyping: [SpecialTyping(), "data.heap_type_alias()", HeapType, False], [SpecialTyping(), "data.self()", SpecialTyping, False], [SpecialTyping(), "data.literal()", False, True], + [SpecialTyping(), "data.literal_string()", str, False], + [SpecialTyping(), "data.any_str('a')", str, False], + [SpecialTyping(), "data.any_str(b'a')", bytes, False], + [SpecialTyping(), "data.annotated()", float, False], + [SpecialTyping(), "data.annotated_self()", SpecialTyping, False], + [SpecialTyping(), "data.int_type_guard()", int, False], # test cases for static methods [HasStaticMethod, "data.static_method()", HeapType, False], ], ) -def test_evaluates_calls(data, good, expected, equality): +def test_evaluates_calls(data, code, expected, equality): context = limited(data=data, HeapType=HeapType, StringAnnotation=StringAnnotation) - value = guarded_eval(good, context) + value = guarded_eval(code, context) if equality: assert value == expected else: @@ -362,6 +422,42 @@ def test_evaluates_calls(data, good, expected, equality): @pytest.mark.parametrize( + "data,code,expected_attributes", + [ + [SpecialTyping(), "data.optional_float()", ["is_integer"]], + [ + SpecialTyping(), + "data.union_str_and_int()", + ["capitalize", "as_integer_ratio"], + ], + [SpecialTyping(), "data.protocol()", ["test_method"]], + [SpecialTyping(), "data.typed_dict()", ["keys", "values", "items"]], + ], +) +def test_mocks_attributes_of_call_results(data, code, expected_attributes): + context = limited(data=data, HeapType=HeapType, StringAnnotation=StringAnnotation) + result = guarded_eval(code, context) + for attr in expected_attributes: + assert hasattr(result, attr) + assert attr in dir(result) + + +@pytest.mark.parametrize( + "data,code,expected_items", + [ + [SpecialTyping(), "data.typed_dict()", {"year": int, "name": str}], + ], +) +def test_mocks_items_of_call_results(data, code, expected_items): + context = limited(data=data, HeapType=HeapType, StringAnnotation=StringAnnotation) + result = guarded_eval(code, context) + ipython_keys = result._ipython_key_completions_() + for key, value in expected_items.items(): + assert isinstance(result[key], value) + assert key in ipython_keys + + +@pytest.mark.parametrize( "data,bad", [ [[1, 2, 3], "data.append(4)"],