diff --git a/IPython/core/guarded_eval.py b/IPython/core/guarded_eval.py index 0e7d4bf..d8ca8ce 100644 --- a/IPython/core/guarded_eval.py +++ b/IPython/core/guarded_eval.py @@ -2,14 +2,21 @@ from inspect import isclass, signature, Signature from typing import ( Callable, Dict, + Literal, + NamedTuple, + NewType, Set, Sequence, Tuple, - NamedTuple, Type, - Literal, + Protocol, Union, - TYPE_CHECKING, + get_args, + get_origin, +) +from typing_extensions import ( + Self, # Python >=3.10 + TypeAliasType, # Python >=3.12 ) import ast import builtins @@ -20,15 +27,8 @@ from functools import cached_property from dataclasses import dataclass, field from types import MethodDescriptorType, ModuleType -from IPython.utils.docs import GENERATING_DOCUMENTATION -from IPython.utils.decorators import undoc - -if TYPE_CHECKING or GENERATING_DOCUMENTATION: - from typing_extensions import Protocol -else: - # do not require on runtime - Protocol = object # requires Python >=3.8 +from IPython.utils.decorators import undoc @undoc @@ -557,7 +557,7 @@ def eval_node(node: Union[ast.AST, None], context: EvaluationContext): f" not allowed in {context.evaluation} mode", ) if isinstance(node, ast.Name): - return _eval_node_name(node.id, policy, context) + return _eval_node_name(node.id, context) if isinstance(node, ast.Attribute): value = eval_node(node.value, context) if policy.can_get_attr(value, node.attr): @@ -583,14 +583,12 @@ def eval_node(node: Union[ast.AST, None], context: EvaluationContext): # or `my_instance.__class__()` - in both cases `func` is `MyClass`. # Should return `MyClass` if `__new__` is not overridden, # otherwise whatever `__new__` return type is. - overridden_return_type = _eval_return_type( - func.__new__, policy, node, context - ) + overridden_return_type = _eval_return_type(func.__new__, node, context) if overridden_return_type is not NOT_EVALUATED: return overridden_return_type - return _create_duck_for_type(func) + return _create_duck_for_heap_type(func) else: - return_type = _eval_return_type(func, policy, node, context) + return_type = _eval_return_type(func, node, context) if return_type is not NOT_EVALUATED: return return_type raise GuardRejection( @@ -601,9 +599,7 @@ def eval_node(node: Union[ast.AST, None], context: EvaluationContext): raise ValueError("Unhandled node", ast.dump(node)) -def _eval_return_type( - func: Callable, policy: EvaluationPolicy, node: ast.Call, context: EvaluationContext -): +def _eval_return_type(func: Callable, node: ast.Call, context: EvaluationContext): """Evaluate return type of a given callable function. Returns the built-in type, a duck or NOT_EVALUATED sentinel. @@ -618,20 +614,27 @@ def _eval_return_type( stringized = isinstance(sig.return_annotation, str) if not_empty: return_type = ( - _eval_node_name(sig.return_annotation, policy, context) + _eval_node_name(sig.return_annotation, context) if stringized else sig.return_annotation ) - # if allow-listed builtin is on type annotation, instantiate it - if policy.can_call(return_type) and not node.keywords: - args = [eval_node(arg, context) for arg in node.args] - # if custom class is in type annotation, mock it; - return return_type(*args) - return _create_duck_for_type(return_type) + if return_type is Self and hasattr(func, "__self__"): + return func.__self__ + elif get_origin(return_type) is Literal: + type_args = get_args(return_type) + if len(type_args) == 1: + return type_args[0] + elif isinstance(return_type, NewType): + return _eval_or_create_duck(return_type.__supertype__, node, context) + elif isinstance(return_type, TypeAliasType): + return _eval_or_create_duck(return_type.__value__, node, context) + else: + return _eval_or_create_duck(return_type, node, context) return NOT_EVALUATED -def _eval_node_name(node_id: str, policy: EvaluationPolicy, context: EvaluationContext): +def _eval_node_name(node_id: str, context: EvaluationContext): + policy = EVALUATION_POLICIES[context.evaluation] if policy.allow_locals_access and node_id in context.locals: return context.locals[node_id] if policy.allow_globals_access and node_id in context.globals: @@ -647,7 +650,17 @@ def _eval_node_name(node_id: str, policy: EvaluationPolicy, context: EvaluationC raise NameError(f"{node_id} not found in locals, globals, nor builtins") -def _create_duck_for_type(duck_type): +def _eval_or_create_duck(duck_type, node: ast.Call, context: EvaluationContext): + policy = EVALUATION_POLICIES[context.evaluation] + # if allow-listed builtin is on type annotation, instantiate it + if policy.can_call(duck_type) and not node.keywords: + args = [eval_node(arg, context) for arg in node.args] + return duck_type(*args) + # if custom class is in type annotation, mock it + return _create_duck_for_heap_type(duck_type) + + +def _create_duck_for_heap_type(duck_type): """Create an imitation of an object of a given type (a duck). Returns the duck or NOT_EVALUATED sentinel if duck could not be created. diff --git a/IPython/core/tests/test_guarded_eval.py b/IPython/core/tests/test_guarded_eval.py index 68bde5d..9d8052f 100644 --- a/IPython/core/tests/test_guarded_eval.py +++ b/IPython/core/tests/test_guarded_eval.py @@ -1,5 +1,5 @@ from contextlib import contextmanager -from typing import NamedTuple +from typing import NamedTuple, Literal, NewType from functools import partial from IPython.core.guarded_eval import ( EvaluationContext, @@ -7,6 +7,10 @@ from IPython.core.guarded_eval import ( guarded_eval, _unbind_method, ) +from typing_extensions import ( + Self, # Python >=3.10 + TypeAliasType, # Python >=3.12 +) from IPython.testing import decorators as dec import pytest @@ -286,6 +290,35 @@ class StringAnnotation: return StringAnnotation() +CustomIntType = NewType("CustomIntType", int) +CustomHeapType = NewType("CustomHeapType", HeapType) +IntTypeAlias = TypeAliasType("IntTypeAlias", int) +HeapTypeAlias = TypeAliasType("HeapTypeAlias", HeapType) + + +class SpecialTyping: + def custom_int_type(self) -> CustomIntType: + return CustomIntType(1) + + def custom_heap_type(self) -> CustomHeapType: + return CustomHeapType(HeapType()) + + # TODO: remove type:ignore comment once mypy + # supports explicit calls to `TypeAliasType`, see: + # https://github.com/python/mypy/issues/16614 + def int_type_alias(self) -> IntTypeAlias: # type:ignore[valid-type] + return 1 + + def heap_type_alias(self) -> HeapTypeAlias: # type:ignore[valid-type] + return 1 + + def literal(self) -> Literal[False]: + return False + + def self(self) -> Self: + return self + + @pytest.mark.parametrize( "data,good,expected,equality", [ @@ -300,6 +333,13 @@ class StringAnnotation: [HeapType, "data()", HeapType, False], [InitReturnsFrozenset, "data()", frozenset, False], [HeapType(), "data.__class__()", HeapType, False], + # supported special cases for typing + [SpecialTyping(), "data.custom_int_type()", int, False], + [SpecialTyping(), "data.custom_heap_type()", HeapType, False], + [SpecialTyping(), "data.int_type_alias()", int, False], + [SpecialTyping(), "data.heap_type_alias()", HeapType, False], + [SpecialTyping(), "data.self()", SpecialTyping, False], + [SpecialTyping(), "data.literal()", False, True], # test cases for static methods [HasStaticMethod, "data.static_method()", HeapType, False], ], diff --git a/pyproject.toml b/pyproject.toml index 6dc1389..151c61c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ dependencies = [ "pygments>=2.4.0", "stack_data", "traitlets>=5.13.0", - "typing_extensions; python_version<'3.10'", + "typing_extensions; python_version<'3.12'", ] dynamic = ["authors", "license", "version"]