##// END OF EJS Templates
Implement remaining special `typing` wrappers
krassowski -
Show More
@@ -1,18 +1,23 b''
1 1 from inspect import isclass, signature, Signature
2 2 from typing import (
3 Annotated,
4 AnyStr,
3 5 Callable,
4 6 Dict,
5 7 Literal,
6 8 NamedTuple,
7 9 NewType,
10 Optional,
11 Protocol,
8 12 Set,
9 13 Sequence,
10 14 Tuple,
11 15 Type,
12 Protocol,
16 TypeGuard,
13 17 Union,
14 18 get_args,
15 19 get_origin,
20 is_typeddict,
16 21 )
17 22 import ast
18 23 import builtins
@@ -27,9 +32,9 b' from IPython.utils.decorators import undoc'
27 32
28 33
29 34 if sys.version_info < (3, 11):
30 from typing_extensions import Self
35 from typing_extensions import Self, LiteralString
31 36 else:
32 from typing import Self
37 from typing import Self, LiteralString
33 38
34 39 if sys.version_info < (3, 12):
35 40 from typing_extensions import TypeAliasType
@@ -423,9 +428,37 b' UNARY_OP_DUNDERS: Dict[Type[ast.unaryop], Tuple[str, ...]] = {'
423 428 }
424 429
425 430
426 class Duck:
431 class ImpersonatingDuck:
427 432 """A dummy class used to create objects of other classes without calling their ``__init__``"""
428 433
434 # no-op: override __class__ to impersonate
435
436
437 class _Duck:
438 """A dummy class used to create objects pretending to have given attributes"""
439
440 def __init__(self, attributes: Optional[dict] = None, items: Optional[dict] = None):
441 self.attributes = attributes or {}
442 self.items = items or {}
443
444 def __getattr__(self, attr: str):
445 return self.attributes[attr]
446
447 def __hasattr__(self, attr: str):
448 return attr in self.attributes
449
450 def __dir__(self):
451 return [*dir(super), *self.attributes]
452
453 def __getitem__(self, key: str):
454 return self.items[key]
455
456 def __hasitem__(self, key: str):
457 return self.items[key]
458
459 def _ipython_key_completions_(self):
460 return self.items.keys()
461
429 462
430 463 def _find_dunder(node_op, dunders) -> Union[Tuple[str, ...], None]:
431 464 dunder = None
@@ -617,28 +650,71 b' def _eval_return_type(func: Callable, node: ast.Call, context: EvaluationContext'
617 650 # if annotation was not stringized, or it was stringized
618 651 # but resolved by signature call we know the return type
619 652 not_empty = sig.return_annotation is not Signature.empty
620 stringized = isinstance(sig.return_annotation, str)
621 653 if not_empty:
622 return_type = (
623 _eval_node_name(sig.return_annotation, context)
624 if stringized
625 else sig.return_annotation
626 )
627 if return_type is Self and hasattr(func, "__self__"):
628 return func.__self__
629 elif get_origin(return_type) is Literal:
630 type_args = get_args(return_type)
631 if len(type_args) == 1:
632 return type_args[0]
633 elif isinstance(return_type, NewType):
634 return _eval_or_create_duck(return_type.__supertype__, node, context)
635 elif isinstance(return_type, TypeAliasType):
636 return _eval_or_create_duck(return_type.__value__, node, context)
637 else:
638 return _eval_or_create_duck(return_type, node, context)
654 return _resolve_annotation(sig.return_annotation, sig, func, node, context)
639 655 return NOT_EVALUATED
640 656
641 657
658 def _resolve_annotation(
659 annotation,
660 sig: Signature,
661 func: Callable,
662 node: ast.Call,
663 context: EvaluationContext,
664 ):
665 """Resolve annotation created by user with `typing` module and custom objects."""
666 annotation = (
667 _eval_node_name(annotation, context)
668 if isinstance(annotation, str)
669 else annotation
670 )
671 origin = get_origin(annotation)
672 if annotation is Self and hasattr(func, "__self__"):
673 return func.__self__
674 elif origin is Literal:
675 type_args = get_args(annotation)
676 if len(type_args) == 1:
677 return type_args[0]
678 elif annotation is LiteralString:
679 return ""
680 elif annotation is AnyStr:
681 index = None
682 for i, (key, value) in enumerate(sig.parameters.items()):
683 if value.annotation is AnyStr:
684 index = i
685 break
686 if index is not None and index < len(node.args):
687 return eval_node(node.args[index], context)
688 elif origin is TypeGuard:
689 return bool()
690 elif origin is Union:
691 attributes = [
692 attr
693 for type_arg in get_args(annotation)
694 for attr in dir(_resolve_annotation(type_arg, sig, func, node, context))
695 ]
696 return _Duck(attributes=dict.fromkeys(attributes))
697 elif is_typeddict(annotation):
698 return _Duck(
699 attributes=dict.fromkeys(dir(dict())),
700 items={
701 k: _resolve_annotation(v, sig, func, node, context)
702 for k, v in annotation.__annotations__.items()
703 },
704 )
705 elif hasattr(annotation, "_is_protocol"):
706 return _Duck(attributes=dict.fromkeys(dir(annotation)))
707 elif origin is Annotated:
708 type_arg = get_args(annotation)[0]
709 return _resolve_annotation(type_arg, sig, func, node, context)
710 elif isinstance(annotation, NewType):
711 return _eval_or_create_duck(annotation.__supertype__, node, context)
712 elif isinstance(annotation, TypeAliasType):
713 return _eval_or_create_duck(annotation.__value__, node, context)
714 else:
715 return _eval_or_create_duck(annotation, node, context)
716
717
642 718 def _eval_node_name(node_id: str, context: EvaluationContext):
643 719 policy = EVALUATION_POLICIES[context.evaluation]
644 720 if policy.allow_locals_access and node_id in context.locals:
@@ -671,7 +747,7 b' def _create_duck_for_heap_type(duck_type):'
671 747
672 748 Returns the duck or NOT_EVALUATED sentinel if duck could not be created.
673 749 """
674 duck = Duck()
750 duck = ImpersonatingDuck()
675 751 try:
676 752 # this only works for heap types, not builtins
677 753 duck.__class__ = duck_type
@@ -1,6 +1,17 b''
1 1 import sys
2 2 from contextlib import contextmanager
3 from typing import NamedTuple, Literal, NewType
3 from typing import (
4 Annotated,
5 AnyStr,
6 NamedTuple,
7 Literal,
8 NewType,
9 Optional,
10 Protocol,
11 TypeGuard,
12 Union,
13 TypedDict,
14 )
4 15 from functools import partial
5 16 from IPython.core.guarded_eval import (
6 17 EvaluationContext,
@@ -13,9 +24,9 b' import pytest'
13 24
14 25
15 26 if sys.version_info < (3, 11):
16 from typing_extensions import Self
27 from typing_extensions import Self, LiteralString
17 28 else:
18 from typing import Self
29 from typing import Self, LiteralString
19 30
20 31 if sys.version_info < (3, 12):
21 32 from typing_extensions import TypeAliasType
@@ -304,6 +315,21 b' IntTypeAlias = TypeAliasType("IntTypeAlias", int)'
304 315 HeapTypeAlias = TypeAliasType("HeapTypeAlias", HeapType)
305 316
306 317
318 class TestProtocol(Protocol):
319 def test_method(self) -> bool:
320 pass
321
322
323 class TestProtocolImplementer(TestProtocol):
324 def test_method(self) -> bool:
325 return True
326
327
328 class Movie(TypedDict):
329 name: str
330 year: int
331
332
307 333 class SpecialTyping:
308 334 def custom_int_type(self) -> CustomIntType:
309 335 return CustomIntType(1)
@@ -323,12 +349,40 b' class SpecialTyping:'
323 349 def literal(self) -> Literal[False]:
324 350 return False
325 351
352 def literal_string(self) -> LiteralString:
353 return "test"
354
326 355 def self(self) -> Self:
327 356 return self
328 357
358 def any_str(self, x: AnyStr) -> AnyStr:
359 return x
360
361 def annotated(self) -> Annotated[float, "positive number"]:
362 return 1
363
364 def annotated_self(self) -> Annotated[Self, "self with metadata"]:
365 self._metadata = "test"
366 return self
367
368 def int_type_guard(self, x) -> TypeGuard[int]:
369 return isinstance(x, int)
370
371 def optional_float(self) -> Optional[float]:
372 return 1.0
373
374 def union_str_and_int(self) -> Union[str, int]:
375 return ""
376
377 def protocol(self) -> TestProtocol:
378 return TestProtocolImplementer()
379
380 def typed_dict(self) -> Movie:
381 return {"name": "The Matrix", "year": 1999}
382
329 383
330 384 @pytest.mark.parametrize(
331 "data,good,expected,equality",
385 "data,code,expected,equality",
332 386 [
333 387 [[1, 2, 3], "data.index(2)", 1, True],
334 388 [{"a": 1}, "data.keys().isdisjoint({})", True, True],
@@ -348,13 +402,19 b' class SpecialTyping:'
348 402 [SpecialTyping(), "data.heap_type_alias()", HeapType, False],
349 403 [SpecialTyping(), "data.self()", SpecialTyping, False],
350 404 [SpecialTyping(), "data.literal()", False, True],
405 [SpecialTyping(), "data.literal_string()", str, False],
406 [SpecialTyping(), "data.any_str('a')", str, False],
407 [SpecialTyping(), "data.any_str(b'a')", bytes, False],
408 [SpecialTyping(), "data.annotated()", float, False],
409 [SpecialTyping(), "data.annotated_self()", SpecialTyping, False],
410 [SpecialTyping(), "data.int_type_guard()", int, False],
351 411 # test cases for static methods
352 412 [HasStaticMethod, "data.static_method()", HeapType, False],
353 413 ],
354 414 )
355 def test_evaluates_calls(data, good, expected, equality):
415 def test_evaluates_calls(data, code, expected, equality):
356 416 context = limited(data=data, HeapType=HeapType, StringAnnotation=StringAnnotation)
357 value = guarded_eval(good, context)
417 value = guarded_eval(code, context)
358 418 if equality:
359 419 assert value == expected
360 420 else:
@@ -362,6 +422,42 b' def test_evaluates_calls(data, good, expected, equality):'
362 422
363 423
364 424 @pytest.mark.parametrize(
425 "data,code,expected_attributes",
426 [
427 [SpecialTyping(), "data.optional_float()", ["is_integer"]],
428 [
429 SpecialTyping(),
430 "data.union_str_and_int()",
431 ["capitalize", "as_integer_ratio"],
432 ],
433 [SpecialTyping(), "data.protocol()", ["test_method"]],
434 [SpecialTyping(), "data.typed_dict()", ["keys", "values", "items"]],
435 ],
436 )
437 def test_mocks_attributes_of_call_results(data, code, expected_attributes):
438 context = limited(data=data, HeapType=HeapType, StringAnnotation=StringAnnotation)
439 result = guarded_eval(code, context)
440 for attr in expected_attributes:
441 assert hasattr(result, attr)
442 assert attr in dir(result)
443
444
445 @pytest.mark.parametrize(
446 "data,code,expected_items",
447 [
448 [SpecialTyping(), "data.typed_dict()", {"year": int, "name": str}],
449 ],
450 )
451 def test_mocks_items_of_call_results(data, code, expected_items):
452 context = limited(data=data, HeapType=HeapType, StringAnnotation=StringAnnotation)
453 result = guarded_eval(code, context)
454 ipython_keys = result._ipython_key_completions_()
455 for key, value in expected_items.items():
456 assert isinstance(result[key], value)
457 assert key in ipython_keys
458
459
460 @pytest.mark.parametrize(
365 461 "data,bad",
366 462 [
367 463 [[1, 2, 3], "data.append(4)"],
General Comments 0
You need to be logged in to leave comments. Login now