##// END OF EJS Templates
Improve inference from return type annotations in completer (#14357)...
M Bussonnier -
r28682:7c22387c merge
parent child Browse files
Show More
@@ -1,760 +1,898 b''
1 from inspect import signature, Signature
1 from inspect import isclass, signature, Signature
2 2 from typing import (
3 Any,
3 Annotated,
4 AnyStr,
4 5 Callable,
5 6 Dict,
7 Literal,
8 NamedTuple,
9 NewType,
10 Optional,
11 Protocol,
6 12 Set,
7 13 Sequence,
8 14 Tuple,
9 NamedTuple,
10 15 Type,
11 Literal,
16 TypeGuard,
12 17 Union,
13 TYPE_CHECKING,
18 get_args,
19 get_origin,
20 is_typeddict,
14 21 )
15 22 import ast
16 23 import builtins
17 24 import collections
18 25 import operator
19 26 import sys
20 27 from functools import cached_property
21 28 from dataclasses import dataclass, field
22 29 from types import MethodDescriptorType, ModuleType
23 30
24 from IPython.utils.docs import GENERATING_DOCUMENTATION
25 31 from IPython.utils.decorators import undoc
26 32
27 33
28 if TYPE_CHECKING or GENERATING_DOCUMENTATION:
29 from typing_extensions import Protocol
34 if sys.version_info < (3, 11):
35 from typing_extensions import Self, LiteralString
36 else:
37 from typing import Self, LiteralString
38
39 if sys.version_info < (3, 12):
40 from typing_extensions import TypeAliasType
30 41 else:
31 # do not require on runtime
32 Protocol = object # requires Python >=3.8
42 from typing import TypeAliasType
33 43
34 44
35 45 @undoc
36 46 class HasGetItem(Protocol):
37 47 def __getitem__(self, key) -> None:
38 48 ...
39 49
40 50
41 51 @undoc
42 52 class InstancesHaveGetItem(Protocol):
43 53 def __call__(self, *args, **kwargs) -> HasGetItem:
44 54 ...
45 55
46 56
47 57 @undoc
48 58 class HasGetAttr(Protocol):
49 59 def __getattr__(self, key) -> None:
50 60 ...
51 61
52 62
53 63 @undoc
54 64 class DoesNotHaveGetAttr(Protocol):
55 65 pass
56 66
57 67
58 68 # By default `__getattr__` is not explicitly implemented on most objects
59 69 MayHaveGetattr = Union[HasGetAttr, DoesNotHaveGetAttr]
60 70
61 71
62 72 def _unbind_method(func: Callable) -> Union[Callable, None]:
63 73 """Get unbound method for given bound method.
64 74
65 75 Returns None if cannot get unbound method, or method is already unbound.
66 76 """
67 77 owner = getattr(func, "__self__", None)
68 78 owner_class = type(owner)
69 79 name = getattr(func, "__name__", None)
70 80 instance_dict_overrides = getattr(owner, "__dict__", None)
71 81 if (
72 82 owner is not None
73 83 and name
74 84 and (
75 85 not instance_dict_overrides
76 86 or (instance_dict_overrides and name not in instance_dict_overrides)
77 87 )
78 88 ):
79 89 return getattr(owner_class, name)
80 90 return None
81 91
82 92
83 93 @undoc
84 94 @dataclass
85 95 class EvaluationPolicy:
86 96 """Definition of evaluation policy."""
87 97
88 98 allow_locals_access: bool = False
89 99 allow_globals_access: bool = False
90 100 allow_item_access: bool = False
91 101 allow_attr_access: bool = False
92 102 allow_builtins_access: bool = False
93 103 allow_all_operations: bool = False
94 104 allow_any_calls: bool = False
95 105 allowed_calls: Set[Callable] = field(default_factory=set)
96 106
97 107 def can_get_item(self, value, item):
98 108 return self.allow_item_access
99 109
100 110 def can_get_attr(self, value, attr):
101 111 return self.allow_attr_access
102 112
103 113 def can_operate(self, dunders: Tuple[str, ...], a, b=None):
104 114 if self.allow_all_operations:
105 115 return True
106 116
107 117 def can_call(self, func):
108 118 if self.allow_any_calls:
109 119 return True
110 120
111 121 if func in self.allowed_calls:
112 122 return True
113 123
114 124 owner_method = _unbind_method(func)
115 125
116 126 if owner_method and owner_method in self.allowed_calls:
117 127 return True
118 128
119 129
120 130 def _get_external(module_name: str, access_path: Sequence[str]):
121 131 """Get value from external module given a dotted access path.
122 132
123 133 Raises:
124 134 * `KeyError` if module is removed not found, and
125 135 * `AttributeError` if acess path does not match an exported object
126 136 """
127 137 member_type = sys.modules[module_name]
128 138 for attr in access_path:
129 139 member_type = getattr(member_type, attr)
130 140 return member_type
131 141
132 142
133 143 def _has_original_dunder_external(
134 144 value,
135 145 module_name: str,
136 146 access_path: Sequence[str],
137 147 method_name: str,
138 148 ):
139 149 if module_name not in sys.modules:
140 150 # LBYLB as it is faster
141 151 return False
142 152 try:
143 153 member_type = _get_external(module_name, access_path)
144 154 value_type = type(value)
145 155 if type(value) == member_type:
146 156 return True
147 157 if method_name == "__getattribute__":
148 158 # we have to short-circuit here due to an unresolved issue in
149 159 # `isinstance` implementation: https://bugs.python.org/issue32683
150 160 return False
151 161 if isinstance(value, member_type):
152 162 method = getattr(value_type, method_name, None)
153 163 member_method = getattr(member_type, method_name, None)
154 164 if member_method == method:
155 165 return True
156 166 except (AttributeError, KeyError):
157 167 return False
158 168
159 169
160 170 def _has_original_dunder(
161 171 value, allowed_types, allowed_methods, allowed_external, method_name
162 172 ):
163 173 # note: Python ignores `__getattr__`/`__getitem__` on instances,
164 174 # we only need to check at class level
165 175 value_type = type(value)
166 176
167 177 # strict type check passes β†’ no need to check method
168 178 if value_type in allowed_types:
169 179 return True
170 180
171 181 method = getattr(value_type, method_name, None)
172 182
173 183 if method is None:
174 184 return None
175 185
176 186 if method in allowed_methods:
177 187 return True
178 188
179 189 for module_name, *access_path in allowed_external:
180 190 if _has_original_dunder_external(value, module_name, access_path, method_name):
181 191 return True
182 192
183 193 return False
184 194
185 195
186 196 @undoc
187 197 @dataclass
188 198 class SelectivePolicy(EvaluationPolicy):
189 199 allowed_getitem: Set[InstancesHaveGetItem] = field(default_factory=set)
190 200 allowed_getitem_external: Set[Tuple[str, ...]] = field(default_factory=set)
191 201
192 202 allowed_getattr: Set[MayHaveGetattr] = field(default_factory=set)
193 203 allowed_getattr_external: Set[Tuple[str, ...]] = field(default_factory=set)
194 204
195 205 allowed_operations: Set = field(default_factory=set)
196 206 allowed_operations_external: Set[Tuple[str, ...]] = field(default_factory=set)
197 207
198 208 _operation_methods_cache: Dict[str, Set[Callable]] = field(
199 209 default_factory=dict, init=False
200 210 )
201 211
202 212 def can_get_attr(self, value, attr):
203 213 has_original_attribute = _has_original_dunder(
204 214 value,
205 215 allowed_types=self.allowed_getattr,
206 216 allowed_methods=self._getattribute_methods,
207 217 allowed_external=self.allowed_getattr_external,
208 218 method_name="__getattribute__",
209 219 )
210 220 has_original_attr = _has_original_dunder(
211 221 value,
212 222 allowed_types=self.allowed_getattr,
213 223 allowed_methods=self._getattr_methods,
214 224 allowed_external=self.allowed_getattr_external,
215 225 method_name="__getattr__",
216 226 )
217 227
218 228 accept = False
219 229
220 230 # Many objects do not have `__getattr__`, this is fine.
221 231 if has_original_attr is None and has_original_attribute:
222 232 accept = True
223 233 else:
224 234 # Accept objects without modifications to `__getattr__` and `__getattribute__`
225 235 accept = has_original_attr and has_original_attribute
226 236
227 237 if accept:
228 238 # We still need to check for overriden properties.
229 239
230 240 value_class = type(value)
231 241 if not hasattr(value_class, attr):
232 242 return True
233 243
234 244 class_attr_val = getattr(value_class, attr)
235 245 is_property = isinstance(class_attr_val, property)
236 246
237 247 if not is_property:
238 248 return True
239 249
240 250 # Properties in allowed types are ok (although we do not include any
241 251 # properties in our default allow list currently).
242 252 if type(value) in self.allowed_getattr:
243 253 return True # pragma: no cover
244 254
245 255 # Properties in subclasses of allowed types may be ok if not changed
246 256 for module_name, *access_path in self.allowed_getattr_external:
247 257 try:
248 258 external_class = _get_external(module_name, access_path)
249 259 external_class_attr_val = getattr(external_class, attr)
250 260 except (KeyError, AttributeError):
251 261 return False # pragma: no cover
252 262 return class_attr_val == external_class_attr_val
253 263
254 264 return False
255 265
256 266 def can_get_item(self, value, item):
257 267 """Allow accessing `__getiitem__` of allow-listed instances unless it was not modified."""
258 268 return _has_original_dunder(
259 269 value,
260 270 allowed_types=self.allowed_getitem,
261 271 allowed_methods=self._getitem_methods,
262 272 allowed_external=self.allowed_getitem_external,
263 273 method_name="__getitem__",
264 274 )
265 275
266 276 def can_operate(self, dunders: Tuple[str, ...], a, b=None):
267 277 objects = [a]
268 278 if b is not None:
269 279 objects.append(b)
270 280 return all(
271 281 [
272 282 _has_original_dunder(
273 283 obj,
274 284 allowed_types=self.allowed_operations,
275 285 allowed_methods=self._operator_dunder_methods(dunder),
276 286 allowed_external=self.allowed_operations_external,
277 287 method_name=dunder,
278 288 )
279 289 for dunder in dunders
280 290 for obj in objects
281 291 ]
282 292 )
283 293
284 294 def _operator_dunder_methods(self, dunder: str) -> Set[Callable]:
285 295 if dunder not in self._operation_methods_cache:
286 296 self._operation_methods_cache[dunder] = self._safe_get_methods(
287 297 self.allowed_operations, dunder
288 298 )
289 299 return self._operation_methods_cache[dunder]
290 300
291 301 @cached_property
292 302 def _getitem_methods(self) -> Set[Callable]:
293 303 return self._safe_get_methods(self.allowed_getitem, "__getitem__")
294 304
295 305 @cached_property
296 306 def _getattr_methods(self) -> Set[Callable]:
297 307 return self._safe_get_methods(self.allowed_getattr, "__getattr__")
298 308
299 309 @cached_property
300 310 def _getattribute_methods(self) -> Set[Callable]:
301 311 return self._safe_get_methods(self.allowed_getattr, "__getattribute__")
302 312
303 313 def _safe_get_methods(self, classes, name) -> Set[Callable]:
304 314 return {
305 315 method
306 316 for class_ in classes
307 317 for method in [getattr(class_, name, None)]
308 318 if method
309 319 }
310 320
311 321
312 322 class _DummyNamedTuple(NamedTuple):
313 323 """Used internally to retrieve methods of named tuple instance."""
314 324
315 325
316 326 class EvaluationContext(NamedTuple):
317 327 #: Local namespace
318 328 locals: dict
319 329 #: Global namespace
320 330 globals: dict
321 331 #: Evaluation policy identifier
322 332 evaluation: Literal[
323 333 "forbidden", "minimal", "limited", "unsafe", "dangerous"
324 334 ] = "forbidden"
325 335 #: Whether the evalution of code takes place inside of a subscript.
326 336 #: Useful for evaluating ``:-1, 'col'`` in ``df[:-1, 'col']``.
327 337 in_subscript: bool = False
328 338
329 339
330 340 class _IdentitySubscript:
331 341 """Returns the key itself when item is requested via subscript."""
332 342
333 343 def __getitem__(self, key):
334 344 return key
335 345
336 346
337 347 IDENTITY_SUBSCRIPT = _IdentitySubscript()
338 348 SUBSCRIPT_MARKER = "__SUBSCRIPT_SENTINEL__"
339 349 UNKNOWN_SIGNATURE = Signature()
350 NOT_EVALUATED = object()
340 351
341 352
342 353 class GuardRejection(Exception):
343 354 """Exception raised when guard rejects evaluation attempt."""
344 355
345 356 pass
346 357
347 358
348 359 def guarded_eval(code: str, context: EvaluationContext):
349 360 """Evaluate provided code in the evaluation context.
350 361
351 362 If evaluation policy given by context is set to ``forbidden``
352 363 no evaluation will be performed; if it is set to ``dangerous``
353 364 standard :func:`eval` will be used; finally, for any other,
354 365 policy :func:`eval_node` will be called on parsed AST.
355 366 """
356 367 locals_ = context.locals
357 368
358 369 if context.evaluation == "forbidden":
359 370 raise GuardRejection("Forbidden mode")
360 371
361 372 # note: not using `ast.literal_eval` as it does not implement
362 373 # getitem at all, for example it fails on simple `[0][1]`
363 374
364 375 if context.in_subscript:
365 376 # syntatic sugar for ellipsis (:) is only available in susbcripts
366 377 # so we need to trick the ast parser into thinking that we have
367 378 # a subscript, but we need to be able to later recognise that we did
368 379 # it so we can ignore the actual __getitem__ operation
369 380 if not code:
370 381 return tuple()
371 382 locals_ = locals_.copy()
372 383 locals_[SUBSCRIPT_MARKER] = IDENTITY_SUBSCRIPT
373 384 code = SUBSCRIPT_MARKER + "[" + code + "]"
374 385 context = EvaluationContext(**{**context._asdict(), **{"locals": locals_}})
375 386
376 387 if context.evaluation == "dangerous":
377 388 return eval(code, context.globals, context.locals)
378 389
379 390 expression = ast.parse(code, mode="eval")
380 391
381 392 return eval_node(expression, context)
382 393
383 394
384 395 BINARY_OP_DUNDERS: Dict[Type[ast.operator], Tuple[str]] = {
385 396 ast.Add: ("__add__",),
386 397 ast.Sub: ("__sub__",),
387 398 ast.Mult: ("__mul__",),
388 399 ast.Div: ("__truediv__",),
389 400 ast.FloorDiv: ("__floordiv__",),
390 401 ast.Mod: ("__mod__",),
391 402 ast.Pow: ("__pow__",),
392 403 ast.LShift: ("__lshift__",),
393 404 ast.RShift: ("__rshift__",),
394 405 ast.BitOr: ("__or__",),
395 406 ast.BitXor: ("__xor__",),
396 407 ast.BitAnd: ("__and__",),
397 408 ast.MatMult: ("__matmul__",),
398 409 }
399 410
400 411 COMP_OP_DUNDERS: Dict[Type[ast.cmpop], Tuple[str, ...]] = {
401 412 ast.Eq: ("__eq__",),
402 413 ast.NotEq: ("__ne__", "__eq__"),
403 414 ast.Lt: ("__lt__", "__gt__"),
404 415 ast.LtE: ("__le__", "__ge__"),
405 416 ast.Gt: ("__gt__", "__lt__"),
406 417 ast.GtE: ("__ge__", "__le__"),
407 418 ast.In: ("__contains__",),
408 419 # Note: ast.Is, ast.IsNot, ast.NotIn are handled specially
409 420 }
410 421
411 422 UNARY_OP_DUNDERS: Dict[Type[ast.unaryop], Tuple[str, ...]] = {
412 423 ast.USub: ("__neg__",),
413 424 ast.UAdd: ("__pos__",),
414 425 # we have to check both __inv__ and __invert__!
415 426 ast.Invert: ("__invert__", "__inv__"),
416 427 ast.Not: ("__not__",),
417 428 }
418 429
419 430
420 class Duck:
431 class ImpersonatingDuck:
421 432 """A dummy class used to create objects of other classes without calling their ``__init__``"""
422 433
434 # no-op: override __class__ to impersonate
435
436
437 class _Duck:
438 """A dummy class used to create objects pretending to have given attributes"""
439
440 def __init__(self, attributes: Optional[dict] = None, items: Optional[dict] = None):
441 self.attributes = attributes or {}
442 self.items = items or {}
443
444 def __getattr__(self, attr: str):
445 return self.attributes[attr]
446
447 def __hasattr__(self, attr: str):
448 return attr in self.attributes
449
450 def __dir__(self):
451 return [*dir(super), *self.attributes]
452
453 def __getitem__(self, key: str):
454 return self.items[key]
455
456 def __hasitem__(self, key: str):
457 return self.items[key]
458
459 def _ipython_key_completions_(self):
460 return self.items.keys()
461
423 462
424 463 def _find_dunder(node_op, dunders) -> Union[Tuple[str, ...], None]:
425 464 dunder = None
426 465 for op, candidate_dunder in dunders.items():
427 466 if isinstance(node_op, op):
428 467 dunder = candidate_dunder
429 468 return dunder
430 469
431 470
432 471 def eval_node(node: Union[ast.AST, None], context: EvaluationContext):
433 472 """Evaluate AST node in provided context.
434 473
435 474 Applies evaluation restrictions defined in the context. Currently does not support evaluation of functions with keyword arguments.
436 475
437 476 Does not evaluate actions that always have side effects:
438 477
439 478 - class definitions (``class sth: ...``)
440 479 - function definitions (``def sth: ...``)
441 480 - variable assignments (``x = 1``)
442 481 - augmented assignments (``x += 1``)
443 482 - deletions (``del x``)
444 483
445 484 Does not evaluate operations which do not return values:
446 485
447 486 - assertions (``assert x``)
448 487 - pass (``pass``)
449 488 - imports (``import x``)
450 489 - control flow:
451 490
452 491 - conditionals (``if x:``) except for ternary IfExp (``a if x else b``)
453 492 - loops (``for`` and ``while``)
454 493 - exception handling
455 494
456 495 The purpose of this function is to guard against unwanted side-effects;
457 496 it does not give guarantees on protection from malicious code execution.
458 497 """
459 498 policy = EVALUATION_POLICIES[context.evaluation]
460 499 if node is None:
461 500 return None
462 501 if isinstance(node, ast.Expression):
463 502 return eval_node(node.body, context)
464 503 if isinstance(node, ast.BinOp):
465 504 left = eval_node(node.left, context)
466 505 right = eval_node(node.right, context)
467 506 dunders = _find_dunder(node.op, BINARY_OP_DUNDERS)
468 507 if dunders:
469 508 if policy.can_operate(dunders, left, right):
470 509 return getattr(left, dunders[0])(right)
471 510 else:
472 511 raise GuardRejection(
473 512 f"Operation (`{dunders}`) for",
474 513 type(left),
475 514 f"not allowed in {context.evaluation} mode",
476 515 )
477 516 if isinstance(node, ast.Compare):
478 517 left = eval_node(node.left, context)
479 518 all_true = True
480 519 negate = False
481 520 for op, right in zip(node.ops, node.comparators):
482 521 right = eval_node(right, context)
483 522 dunder = None
484 523 dunders = _find_dunder(op, COMP_OP_DUNDERS)
485 524 if not dunders:
486 525 if isinstance(op, ast.NotIn):
487 526 dunders = COMP_OP_DUNDERS[ast.In]
488 527 negate = True
489 528 if isinstance(op, ast.Is):
490 529 dunder = "is_"
491 530 if isinstance(op, ast.IsNot):
492 531 dunder = "is_"
493 532 negate = True
494 533 if not dunder and dunders:
495 534 dunder = dunders[0]
496 535 if dunder:
497 536 a, b = (right, left) if dunder == "__contains__" else (left, right)
498 537 if dunder == "is_" or dunders and policy.can_operate(dunders, a, b):
499 538 result = getattr(operator, dunder)(a, b)
500 539 if negate:
501 540 result = not result
502 541 if not result:
503 542 all_true = False
504 543 left = right
505 544 else:
506 545 raise GuardRejection(
507 546 f"Comparison (`{dunder}`) for",
508 547 type(left),
509 548 f"not allowed in {context.evaluation} mode",
510 549 )
511 550 else:
512 551 raise ValueError(
513 552 f"Comparison `{dunder}` not supported"
514 553 ) # pragma: no cover
515 554 return all_true
516 555 if isinstance(node, ast.Constant):
517 556 return node.value
518 557 if isinstance(node, ast.Tuple):
519 558 return tuple(eval_node(e, context) for e in node.elts)
520 559 if isinstance(node, ast.List):
521 560 return [eval_node(e, context) for e in node.elts]
522 561 if isinstance(node, ast.Set):
523 562 return {eval_node(e, context) for e in node.elts}
524 563 if isinstance(node, ast.Dict):
525 564 return dict(
526 565 zip(
527 566 [eval_node(k, context) for k in node.keys],
528 567 [eval_node(v, context) for v in node.values],
529 568 )
530 569 )
531 570 if isinstance(node, ast.Slice):
532 571 return slice(
533 572 eval_node(node.lower, context),
534 573 eval_node(node.upper, context),
535 574 eval_node(node.step, context),
536 575 )
537 576 if isinstance(node, ast.UnaryOp):
538 577 value = eval_node(node.operand, context)
539 578 dunders = _find_dunder(node.op, UNARY_OP_DUNDERS)
540 579 if dunders:
541 580 if policy.can_operate(dunders, value):
542 581 return getattr(value, dunders[0])()
543 582 else:
544 583 raise GuardRejection(
545 584 f"Operation (`{dunders}`) for",
546 585 type(value),
547 586 f"not allowed in {context.evaluation} mode",
548 587 )
549 588 if isinstance(node, ast.Subscript):
550 589 value = eval_node(node.value, context)
551 590 slice_ = eval_node(node.slice, context)
552 591 if policy.can_get_item(value, slice_):
553 592 return value[slice_]
554 593 raise GuardRejection(
555 594 "Subscript access (`__getitem__`) for",
556 595 type(value), # not joined to avoid calling `repr`
557 596 f" not allowed in {context.evaluation} mode",
558 597 )
559 598 if isinstance(node, ast.Name):
560 if policy.allow_locals_access and node.id in context.locals:
561 return context.locals[node.id]
562 if policy.allow_globals_access and node.id in context.globals:
563 return context.globals[node.id]
564 if policy.allow_builtins_access and hasattr(builtins, node.id):
565 # note: do not use __builtins__, it is implementation detail of cPython
566 return getattr(builtins, node.id)
567 if not policy.allow_globals_access and not policy.allow_locals_access:
568 raise GuardRejection(
569 f"Namespace access not allowed in {context.evaluation} mode"
570 )
571 else:
572 raise NameError(f"{node.id} not found in locals, globals, nor builtins")
599 return _eval_node_name(node.id, context)
573 600 if isinstance(node, ast.Attribute):
574 601 value = eval_node(node.value, context)
575 602 if policy.can_get_attr(value, node.attr):
576 603 return getattr(value, node.attr)
577 604 raise GuardRejection(
578 605 "Attribute access (`__getattr__`) for",
579 606 type(value), # not joined to avoid calling `repr`
580 607 f"not allowed in {context.evaluation} mode",
581 608 )
582 609 if isinstance(node, ast.IfExp):
583 610 test = eval_node(node.test, context)
584 611 if test:
585 612 return eval_node(node.body, context)
586 613 else:
587 614 return eval_node(node.orelse, context)
588 615 if isinstance(node, ast.Call):
589 616 func = eval_node(node.func, context)
590 617 if policy.can_call(func) and not node.keywords:
591 618 args = [eval_node(arg, context) for arg in node.args]
592 619 return func(*args)
593 try:
594 sig = signature(func)
595 except ValueError:
596 sig = UNKNOWN_SIGNATURE
597 # if annotation was not stringized, or it was stringized
598 # but resolved by signature call we know the return type
599 not_empty = sig.return_annotation is not Signature.empty
600 not_stringized = not isinstance(sig.return_annotation, str)
601 if not_empty and not_stringized:
602 duck = Duck()
603 # if allow-listed builtin is on type annotation, instantiate it
604 if policy.can_call(sig.return_annotation) and not node.keywords:
605 args = [eval_node(arg, context) for arg in node.args]
606 return sig.return_annotation(*args)
607 try:
608 # if custom class is in type annotation, mock it;
609 # this only works for heap types, not builtins
610 duck.__class__ = sig.return_annotation
611 return duck
612 except TypeError:
613 pass
620 if isclass(func):
621 # this code path gets entered when calling class e.g. `MyClass()`
622 # or `my_instance.__class__()` - in both cases `func` is `MyClass`.
623 # Should return `MyClass` if `__new__` is not overridden,
624 # otherwise whatever `__new__` return type is.
625 overridden_return_type = _eval_return_type(func.__new__, node, context)
626 if overridden_return_type is not NOT_EVALUATED:
627 return overridden_return_type
628 return _create_duck_for_heap_type(func)
629 else:
630 return_type = _eval_return_type(func, node, context)
631 if return_type is not NOT_EVALUATED:
632 return return_type
614 633 raise GuardRejection(
615 634 "Call for",
616 635 func, # not joined to avoid calling `repr`
617 636 f"not allowed in {context.evaluation} mode",
618 637 )
619 638 raise ValueError("Unhandled node", ast.dump(node))
620 639
621 640
641 def _eval_return_type(func: Callable, node: ast.Call, context: EvaluationContext):
642 """Evaluate return type of a given callable function.
643
644 Returns the built-in type, a duck or NOT_EVALUATED sentinel.
645 """
646 try:
647 sig = signature(func)
648 except ValueError:
649 sig = UNKNOWN_SIGNATURE
650 # if annotation was not stringized, or it was stringized
651 # but resolved by signature call we know the return type
652 not_empty = sig.return_annotation is not Signature.empty
653 if not_empty:
654 return _resolve_annotation(sig.return_annotation, sig, func, node, context)
655 return NOT_EVALUATED
656
657
658 def _resolve_annotation(
659 annotation,
660 sig: Signature,
661 func: Callable,
662 node: ast.Call,
663 context: EvaluationContext,
664 ):
665 """Resolve annotation created by user with `typing` module and custom objects."""
666 annotation = (
667 _eval_node_name(annotation, context)
668 if isinstance(annotation, str)
669 else annotation
670 )
671 origin = get_origin(annotation)
672 if annotation is Self and hasattr(func, "__self__"):
673 return func.__self__
674 elif origin is Literal:
675 type_args = get_args(annotation)
676 if len(type_args) == 1:
677 return type_args[0]
678 elif annotation is LiteralString:
679 return ""
680 elif annotation is AnyStr:
681 index = None
682 for i, (key, value) in enumerate(sig.parameters.items()):
683 if value.annotation is AnyStr:
684 index = i
685 break
686 if index is not None and index < len(node.args):
687 return eval_node(node.args[index], context)
688 elif origin is TypeGuard:
689 return bool()
690 elif origin is Union:
691 attributes = [
692 attr
693 for type_arg in get_args(annotation)
694 for attr in dir(_resolve_annotation(type_arg, sig, func, node, context))
695 ]
696 return _Duck(attributes=dict.fromkeys(attributes))
697 elif is_typeddict(annotation):
698 return _Duck(
699 attributes=dict.fromkeys(dir(dict())),
700 items={
701 k: _resolve_annotation(v, sig, func, node, context)
702 for k, v in annotation.__annotations__.items()
703 },
704 )
705 elif hasattr(annotation, "_is_protocol"):
706 return _Duck(attributes=dict.fromkeys(dir(annotation)))
707 elif origin is Annotated:
708 type_arg = get_args(annotation)[0]
709 return _resolve_annotation(type_arg, sig, func, node, context)
710 elif isinstance(annotation, NewType):
711 return _eval_or_create_duck(annotation.__supertype__, node, context)
712 elif isinstance(annotation, TypeAliasType):
713 return _eval_or_create_duck(annotation.__value__, node, context)
714 else:
715 return _eval_or_create_duck(annotation, node, context)
716
717
718 def _eval_node_name(node_id: str, context: EvaluationContext):
719 policy = EVALUATION_POLICIES[context.evaluation]
720 if policy.allow_locals_access and node_id in context.locals:
721 return context.locals[node_id]
722 if policy.allow_globals_access and node_id in context.globals:
723 return context.globals[node_id]
724 if policy.allow_builtins_access and hasattr(builtins, node_id):
725 # note: do not use __builtins__, it is implementation detail of cPython
726 return getattr(builtins, node_id)
727 if not policy.allow_globals_access and not policy.allow_locals_access:
728 raise GuardRejection(
729 f"Namespace access not allowed in {context.evaluation} mode"
730 )
731 else:
732 raise NameError(f"{node_id} not found in locals, globals, nor builtins")
733
734
735 def _eval_or_create_duck(duck_type, node: ast.Call, context: EvaluationContext):
736 policy = EVALUATION_POLICIES[context.evaluation]
737 # if allow-listed builtin is on type annotation, instantiate it
738 if policy.can_call(duck_type) and not node.keywords:
739 args = [eval_node(arg, context) for arg in node.args]
740 return duck_type(*args)
741 # if custom class is in type annotation, mock it
742 return _create_duck_for_heap_type(duck_type)
743
744
745 def _create_duck_for_heap_type(duck_type):
746 """Create an imitation of an object of a given type (a duck).
747
748 Returns the duck or NOT_EVALUATED sentinel if duck could not be created.
749 """
750 duck = ImpersonatingDuck()
751 try:
752 # this only works for heap types, not builtins
753 duck.__class__ = duck_type
754 return duck
755 except TypeError:
756 pass
757 return NOT_EVALUATED
758
759
622 760 SUPPORTED_EXTERNAL_GETITEM = {
623 761 ("pandas", "core", "indexing", "_iLocIndexer"),
624 762 ("pandas", "core", "indexing", "_LocIndexer"),
625 763 ("pandas", "DataFrame"),
626 764 ("pandas", "Series"),
627 765 ("numpy", "ndarray"),
628 766 ("numpy", "void"),
629 767 }
630 768
631 769
632 770 BUILTIN_GETITEM: Set[InstancesHaveGetItem] = {
633 771 dict,
634 772 str, # type: ignore[arg-type]
635 773 bytes, # type: ignore[arg-type]
636 774 list,
637 775 tuple,
638 776 collections.defaultdict,
639 777 collections.deque,
640 778 collections.OrderedDict,
641 779 collections.ChainMap,
642 780 collections.UserDict,
643 781 collections.UserList,
644 782 collections.UserString, # type: ignore[arg-type]
645 783 _DummyNamedTuple,
646 784 _IdentitySubscript,
647 785 }
648 786
649 787
650 788 def _list_methods(cls, source=None):
651 789 """For use on immutable objects or with methods returning a copy"""
652 790 return [getattr(cls, k) for k in (source if source else dir(cls))]
653 791
654 792
655 793 dict_non_mutating_methods = ("copy", "keys", "values", "items")
656 794 list_non_mutating_methods = ("copy", "index", "count")
657 795 set_non_mutating_methods = set(dir(set)) & set(dir(frozenset))
658 796
659 797
660 798 dict_keys: Type[collections.abc.KeysView] = type({}.keys())
661 799
662 800 NUMERICS = {int, float, complex}
663 801
664 802 ALLOWED_CALLS = {
665 803 bytes,
666 804 *_list_methods(bytes),
667 805 dict,
668 806 *_list_methods(dict, dict_non_mutating_methods),
669 807 dict_keys.isdisjoint,
670 808 list,
671 809 *_list_methods(list, list_non_mutating_methods),
672 810 set,
673 811 *_list_methods(set, set_non_mutating_methods),
674 812 frozenset,
675 813 *_list_methods(frozenset),
676 814 range,
677 815 str,
678 816 *_list_methods(str),
679 817 tuple,
680 818 *_list_methods(tuple),
681 819 *NUMERICS,
682 820 *[method for numeric_cls in NUMERICS for method in _list_methods(numeric_cls)],
683 821 collections.deque,
684 822 *_list_methods(collections.deque, list_non_mutating_methods),
685 823 collections.defaultdict,
686 824 *_list_methods(collections.defaultdict, dict_non_mutating_methods),
687 825 collections.OrderedDict,
688 826 *_list_methods(collections.OrderedDict, dict_non_mutating_methods),
689 827 collections.UserDict,
690 828 *_list_methods(collections.UserDict, dict_non_mutating_methods),
691 829 collections.UserList,
692 830 *_list_methods(collections.UserList, list_non_mutating_methods),
693 831 collections.UserString,
694 832 *_list_methods(collections.UserString, dir(str)),
695 833 collections.Counter,
696 834 *_list_methods(collections.Counter, dict_non_mutating_methods),
697 835 collections.Counter.elements,
698 836 collections.Counter.most_common,
699 837 }
700 838
701 839 BUILTIN_GETATTR: Set[MayHaveGetattr] = {
702 840 *BUILTIN_GETITEM,
703 841 set,
704 842 frozenset,
705 843 object,
706 844 type, # `type` handles a lot of generic cases, e.g. numbers as in `int.real`.
707 845 *NUMERICS,
708 846 dict_keys,
709 847 MethodDescriptorType,
710 848 ModuleType,
711 849 }
712 850
713 851
714 852 BUILTIN_OPERATIONS = {*BUILTIN_GETATTR}
715 853
716 854 EVALUATION_POLICIES = {
717 855 "minimal": EvaluationPolicy(
718 856 allow_builtins_access=True,
719 857 allow_locals_access=False,
720 858 allow_globals_access=False,
721 859 allow_item_access=False,
722 860 allow_attr_access=False,
723 861 allowed_calls=set(),
724 862 allow_any_calls=False,
725 863 allow_all_operations=False,
726 864 ),
727 865 "limited": SelectivePolicy(
728 866 allowed_getitem=BUILTIN_GETITEM,
729 867 allowed_getitem_external=SUPPORTED_EXTERNAL_GETITEM,
730 868 allowed_getattr=BUILTIN_GETATTR,
731 869 allowed_getattr_external={
732 870 # pandas Series/Frame implements custom `__getattr__`
733 871 ("pandas", "DataFrame"),
734 872 ("pandas", "Series"),
735 873 },
736 874 allowed_operations=BUILTIN_OPERATIONS,
737 875 allow_builtins_access=True,
738 876 allow_locals_access=True,
739 877 allow_globals_access=True,
740 878 allowed_calls=ALLOWED_CALLS,
741 879 ),
742 880 "unsafe": EvaluationPolicy(
743 881 allow_builtins_access=True,
744 882 allow_locals_access=True,
745 883 allow_globals_access=True,
746 884 allow_attr_access=True,
747 885 allow_item_access=True,
748 886 allow_any_calls=True,
749 887 allow_all_operations=True,
750 888 ),
751 889 }
752 890
753 891
754 892 __all__ = [
755 893 "guarded_eval",
756 894 "eval_node",
757 895 "GuardRejection",
758 896 "EvaluationContext",
759 897 "_unbind_method",
760 898 ]
@@ -1,602 +1,785 b''
1 import sys
1 2 from contextlib import contextmanager
2 from typing import NamedTuple
3 from typing import (
4 Annotated,
5 AnyStr,
6 NamedTuple,
7 Literal,
8 NewType,
9 Optional,
10 Protocol,
11 TypeGuard,
12 Union,
13 TypedDict,
14 )
3 15 from functools import partial
4 16 from IPython.core.guarded_eval import (
5 17 EvaluationContext,
6 18 GuardRejection,
7 19 guarded_eval,
8 20 _unbind_method,
9 21 )
10 22 from IPython.testing import decorators as dec
11 23 import pytest
12 24
13 25
26 if sys.version_info < (3, 11):
27 from typing_extensions import Self, LiteralString
28 else:
29 from typing import Self, LiteralString
30
31 if sys.version_info < (3, 12):
32 from typing_extensions import TypeAliasType
33 else:
34 from typing import TypeAliasType
35
36
14 37 def create_context(evaluation: str, **kwargs):
15 38 return EvaluationContext(locals=kwargs, globals={}, evaluation=evaluation)
16 39
17 40
18 41 forbidden = partial(create_context, "forbidden")
19 42 minimal = partial(create_context, "minimal")
20 43 limited = partial(create_context, "limited")
21 44 unsafe = partial(create_context, "unsafe")
22 45 dangerous = partial(create_context, "dangerous")
23 46
24 47 LIMITED_OR_HIGHER = [limited, unsafe, dangerous]
25 48 MINIMAL_OR_HIGHER = [minimal, *LIMITED_OR_HIGHER]
26 49
27 50
28 51 @contextmanager
29 52 def module_not_installed(module: str):
30 53 import sys
31 54
32 55 try:
33 56 to_restore = sys.modules[module]
34 57 del sys.modules[module]
35 58 except KeyError:
36 59 to_restore = None
37 60 try:
38 61 yield
39 62 finally:
40 63 sys.modules[module] = to_restore
41 64
42 65
43 66 def test_external_not_installed():
44 67 """
45 68 Because attribute check requires checking if object is not of allowed
46 69 external type, this tests logic for absence of external module.
47 70 """
48 71
49 72 class Custom:
50 73 def __init__(self):
51 74 self.test = 1
52 75
53 76 def __getattr__(self, key):
54 77 return key
55 78
56 79 with module_not_installed("pandas"):
57 80 context = limited(x=Custom())
58 81 with pytest.raises(GuardRejection):
59 82 guarded_eval("x.test", context)
60 83
61 84
62 85 @dec.skip_without("pandas")
63 86 def test_external_changed_api(monkeypatch):
64 87 """Check that the execution rejects if external API changed paths"""
65 88 import pandas as pd
66 89
67 90 series = pd.Series([1], index=["a"])
68 91
69 92 with monkeypatch.context() as m:
70 93 m.delattr(pd, "Series")
71 94 context = limited(data=series)
72 95 with pytest.raises(GuardRejection):
73 96 guarded_eval("data.iloc[0]", context)
74 97
75 98
76 99 @dec.skip_without("pandas")
77 100 def test_pandas_series_iloc():
78 101 import pandas as pd
79 102
80 103 series = pd.Series([1], index=["a"])
81 104 context = limited(data=series)
82 105 assert guarded_eval("data.iloc[0]", context) == 1
83 106
84 107
85 108 def test_rejects_custom_properties():
86 109 class BadProperty:
87 110 @property
88 111 def iloc(self):
89 112 return [None]
90 113
91 114 series = BadProperty()
92 115 context = limited(data=series)
93 116
94 117 with pytest.raises(GuardRejection):
95 118 guarded_eval("data.iloc[0]", context)
96 119
97 120
98 121 @dec.skip_without("pandas")
99 122 def test_accepts_non_overriden_properties():
100 123 import pandas as pd
101 124
102 125 class GoodProperty(pd.Series):
103 126 pass
104 127
105 128 series = GoodProperty([1], index=["a"])
106 129 context = limited(data=series)
107 130
108 131 assert guarded_eval("data.iloc[0]", context) == 1
109 132
110 133
111 134 @dec.skip_without("pandas")
112 135 def test_pandas_series():
113 136 import pandas as pd
114 137
115 138 context = limited(data=pd.Series([1], index=["a"]))
116 139 assert guarded_eval('data["a"]', context) == 1
117 140 with pytest.raises(KeyError):
118 141 guarded_eval('data["c"]', context)
119 142
120 143
121 144 @dec.skip_without("pandas")
122 145 def test_pandas_bad_series():
123 146 import pandas as pd
124 147
125 148 class BadItemSeries(pd.Series):
126 149 def __getitem__(self, key):
127 150 return "CUSTOM_ITEM"
128 151
129 152 class BadAttrSeries(pd.Series):
130 153 def __getattr__(self, key):
131 154 return "CUSTOM_ATTR"
132 155
133 156 bad_series = BadItemSeries([1], index=["a"])
134 157 context = limited(data=bad_series)
135 158
136 159 with pytest.raises(GuardRejection):
137 160 guarded_eval('data["a"]', context)
138 161 with pytest.raises(GuardRejection):
139 162 guarded_eval('data["c"]', context)
140 163
141 164 # note: here result is a bit unexpected because
142 165 # pandas `__getattr__` calls `__getitem__`;
143 166 # FIXME - special case to handle it?
144 167 assert guarded_eval("data.a", context) == "CUSTOM_ITEM"
145 168
146 169 context = unsafe(data=bad_series)
147 170 assert guarded_eval('data["a"]', context) == "CUSTOM_ITEM"
148 171
149 172 bad_attr_series = BadAttrSeries([1], index=["a"])
150 173 context = limited(data=bad_attr_series)
151 174 assert guarded_eval('data["a"]', context) == 1
152 175 with pytest.raises(GuardRejection):
153 176 guarded_eval("data.a", context)
154 177
155 178
156 179 @dec.skip_without("pandas")
157 180 def test_pandas_dataframe_loc():
158 181 import pandas as pd
159 182 from pandas.testing import assert_series_equal
160 183
161 184 data = pd.DataFrame([{"a": 1}])
162 185 context = limited(data=data)
163 186 assert_series_equal(guarded_eval('data.loc[:, "a"]', context), data["a"])
164 187
165 188
166 189 def test_named_tuple():
167 190 class GoodNamedTuple(NamedTuple):
168 191 a: str
169 192 pass
170 193
171 194 class BadNamedTuple(NamedTuple):
172 195 a: str
173 196
174 197 def __getitem__(self, key):
175 198 return None
176 199
177 200 good = GoodNamedTuple(a="x")
178 201 bad = BadNamedTuple(a="x")
179 202
180 203 context = limited(data=good)
181 204 assert guarded_eval("data[0]", context) == "x"
182 205
183 206 context = limited(data=bad)
184 207 with pytest.raises(GuardRejection):
185 208 guarded_eval("data[0]", context)
186 209
187 210
188 211 def test_dict():
189 212 context = limited(data={"a": 1, "b": {"x": 2}, ("x", "y"): 3})
190 213 assert guarded_eval('data["a"]', context) == 1
191 214 assert guarded_eval('data["b"]', context) == {"x": 2}
192 215 assert guarded_eval('data["b"]["x"]', context) == 2
193 216 assert guarded_eval('data["x", "y"]', context) == 3
194 217
195 218 assert guarded_eval("data.keys", context)
196 219
197 220
198 221 def test_set():
199 222 context = limited(data={"a", "b"})
200 223 assert guarded_eval("data.difference", context)
201 224
202 225
203 226 def test_list():
204 227 context = limited(data=[1, 2, 3])
205 228 assert guarded_eval("data[1]", context) == 2
206 229 assert guarded_eval("data.copy", context)
207 230
208 231
209 232 def test_dict_literal():
210 233 context = limited()
211 234 assert guarded_eval("{}", context) == {}
212 235 assert guarded_eval('{"a": 1}', context) == {"a": 1}
213 236
214 237
215 238 def test_list_literal():
216 239 context = limited()
217 240 assert guarded_eval("[]", context) == []
218 241 assert guarded_eval('[1, "a"]', context) == [1, "a"]
219 242
220 243
221 244 def test_set_literal():
222 245 context = limited()
223 246 assert guarded_eval("set()", context) == set()
224 247 assert guarded_eval('{"a"}', context) == {"a"}
225 248
226 249
227 250 def test_evaluates_if_expression():
228 251 context = limited()
229 252 assert guarded_eval("2 if True else 3", context) == 2
230 253 assert guarded_eval("4 if False else 5", context) == 5
231 254
232 255
233 256 def test_object():
234 257 obj = object()
235 258 context = limited(obj=obj)
236 259 assert guarded_eval("obj.__dir__", context) == obj.__dir__
237 260
238 261
239 262 @pytest.mark.parametrize(
240 263 "code,expected",
241 264 [
242 265 ["int.numerator", int.numerator],
243 266 ["float.is_integer", float.is_integer],
244 267 ["complex.real", complex.real],
245 268 ],
246 269 )
247 270 def test_number_attributes(code, expected):
248 271 assert guarded_eval(code, limited()) == expected
249 272
250 273
251 274 def test_method_descriptor():
252 275 context = limited()
253 276 assert guarded_eval("list.copy.__name__", context) == "copy"
254 277
255 278
256 279 class HeapType:
257 280 pass
258 281
259 282
260 283 class CallCreatesHeapType:
261 284 def __call__(self) -> HeapType:
262 285 return HeapType()
263 286
264 287
265 288 class CallCreatesBuiltin:
266 289 def __call__(self) -> frozenset:
267 290 return frozenset()
268 291
269 292
293 class HasStaticMethod:
294 @staticmethod
295 def static_method() -> HeapType:
296 return HeapType()
297
298
299 class InitReturnsFrozenset:
300 def __new__(self) -> frozenset: # type:ignore[misc]
301 return frozenset()
302
303
304 class StringAnnotation:
305 def heap(self) -> "HeapType":
306 return HeapType()
307
308 def copy(self) -> "StringAnnotation":
309 return StringAnnotation()
310
311
312 CustomIntType = NewType("CustomIntType", int)
313 CustomHeapType = NewType("CustomHeapType", HeapType)
314 IntTypeAlias = TypeAliasType("IntTypeAlias", int)
315 HeapTypeAlias = TypeAliasType("HeapTypeAlias", HeapType)
316
317
318 class TestProtocol(Protocol):
319 def test_method(self) -> bool:
320 pass
321
322
323 class TestProtocolImplementer(TestProtocol):
324 def test_method(self) -> bool:
325 return True
326
327
328 class Movie(TypedDict):
329 name: str
330 year: int
331
332
333 class SpecialTyping:
334 def custom_int_type(self) -> CustomIntType:
335 return CustomIntType(1)
336
337 def custom_heap_type(self) -> CustomHeapType:
338 return CustomHeapType(HeapType())
339
340 # TODO: remove type:ignore comment once mypy
341 # supports explicit calls to `TypeAliasType`, see:
342 # https://github.com/python/mypy/issues/16614
343 def int_type_alias(self) -> IntTypeAlias: # type:ignore[valid-type]
344 return 1
345
346 def heap_type_alias(self) -> HeapTypeAlias: # type:ignore[valid-type]
347 return 1
348
349 def literal(self) -> Literal[False]:
350 return False
351
352 def literal_string(self) -> LiteralString:
353 return "test"
354
355 def self(self) -> Self:
356 return self
357
358 def any_str(self, x: AnyStr) -> AnyStr:
359 return x
360
361 def annotated(self) -> Annotated[float, "positive number"]:
362 return 1
363
364 def annotated_self(self) -> Annotated[Self, "self with metadata"]:
365 self._metadata = "test"
366 return self
367
368 def int_type_guard(self, x) -> TypeGuard[int]:
369 return isinstance(x, int)
370
371 def optional_float(self) -> Optional[float]:
372 return 1.0
373
374 def union_str_and_int(self) -> Union[str, int]:
375 return ""
376
377 def protocol(self) -> TestProtocol:
378 return TestProtocolImplementer()
379
380 def typed_dict(self) -> Movie:
381 return {"name": "The Matrix", "year": 1999}
382
383
270 384 @pytest.mark.parametrize(
271 "data,good,bad,expected, equality",
385 "data,code,expected,equality",
272 386 [
273 [[1, 2, 3], "data.index(2)", "data.append(4)", 1, True],
274 [{"a": 1}, "data.keys().isdisjoint({})", "data.update()", True, True],
275 [CallCreatesHeapType(), "data()", "data.__class__()", HeapType, False],
276 [CallCreatesBuiltin(), "data()", "data.__class__()", frozenset, False],
387 [[1, 2, 3], "data.index(2)", 1, True],
388 [{"a": 1}, "data.keys().isdisjoint({})", True, True],
389 [StringAnnotation(), "data.heap()", HeapType, False],
390 [StringAnnotation(), "data.copy()", StringAnnotation, False],
391 # test cases for `__call__`
392 [CallCreatesHeapType(), "data()", HeapType, False],
393 [CallCreatesBuiltin(), "data()", frozenset, False],
394 # Test cases for `__init__`
395 [HeapType, "data()", HeapType, False],
396 [InitReturnsFrozenset, "data()", frozenset, False],
397 [HeapType(), "data.__class__()", HeapType, False],
398 # supported special cases for typing
399 [SpecialTyping(), "data.custom_int_type()", int, False],
400 [SpecialTyping(), "data.custom_heap_type()", HeapType, False],
401 [SpecialTyping(), "data.int_type_alias()", int, False],
402 [SpecialTyping(), "data.heap_type_alias()", HeapType, False],
403 [SpecialTyping(), "data.self()", SpecialTyping, False],
404 [SpecialTyping(), "data.literal()", False, True],
405 [SpecialTyping(), "data.literal_string()", str, False],
406 [SpecialTyping(), "data.any_str('a')", str, False],
407 [SpecialTyping(), "data.any_str(b'a')", bytes, False],
408 [SpecialTyping(), "data.annotated()", float, False],
409 [SpecialTyping(), "data.annotated_self()", SpecialTyping, False],
410 [SpecialTyping(), "data.int_type_guard()", int, False],
411 # test cases for static methods
412 [HasStaticMethod, "data.static_method()", HeapType, False],
277 413 ],
278 414 )
279 def test_evaluates_calls(data, good, bad, expected, equality):
280 context = limited(data=data)
281 value = guarded_eval(good, context)
415 def test_evaluates_calls(data, code, expected, equality):
416 context = limited(data=data, HeapType=HeapType, StringAnnotation=StringAnnotation)
417 value = guarded_eval(code, context)
282 418 if equality:
283 419 assert value == expected
284 420 else:
285 421 assert isinstance(value, expected)
286 422
423
424 @pytest.mark.parametrize(
425 "data,code,expected_attributes",
426 [
427 [SpecialTyping(), "data.optional_float()", ["is_integer"]],
428 [
429 SpecialTyping(),
430 "data.union_str_and_int()",
431 ["capitalize", "as_integer_ratio"],
432 ],
433 [SpecialTyping(), "data.protocol()", ["test_method"]],
434 [SpecialTyping(), "data.typed_dict()", ["keys", "values", "items"]],
435 ],
436 )
437 def test_mocks_attributes_of_call_results(data, code, expected_attributes):
438 context = limited(data=data, HeapType=HeapType, StringAnnotation=StringAnnotation)
439 result = guarded_eval(code, context)
440 for attr in expected_attributes:
441 assert hasattr(result, attr)
442 assert attr in dir(result)
443
444
445 @pytest.mark.parametrize(
446 "data,code,expected_items",
447 [
448 [SpecialTyping(), "data.typed_dict()", {"year": int, "name": str}],
449 ],
450 )
451 def test_mocks_items_of_call_results(data, code, expected_items):
452 context = limited(data=data, HeapType=HeapType, StringAnnotation=StringAnnotation)
453 result = guarded_eval(code, context)
454 ipython_keys = result._ipython_key_completions_()
455 for key, value in expected_items.items():
456 assert isinstance(result[key], value)
457 assert key in ipython_keys
458
459
460 @pytest.mark.parametrize(
461 "data,bad",
462 [
463 [[1, 2, 3], "data.append(4)"],
464 [{"a": 1}, "data.update()"],
465 ],
466 )
467 def test_rejects_calls_with_side_effects(data, bad):
468 context = limited(data=data)
469
287 470 with pytest.raises(GuardRejection):
288 471 guarded_eval(bad, context)
289 472
290 473
291 474 @pytest.mark.parametrize(
292 475 "code,expected",
293 476 [
294 477 ["(1\n+\n1)", 2],
295 478 ["list(range(10))[-1:]", [9]],
296 479 ["list(range(20))[3:-2:3]", [3, 6, 9, 12, 15]],
297 480 ],
298 481 )
299 482 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
300 483 def test_evaluates_complex_cases(code, expected, context):
301 484 assert guarded_eval(code, context()) == expected
302 485
303 486
304 487 @pytest.mark.parametrize(
305 488 "code,expected",
306 489 [
307 490 ["1", 1],
308 491 ["1.0", 1.0],
309 492 ["0xdeedbeef", 0xDEEDBEEF],
310 493 ["True", True],
311 494 ["None", None],
312 495 ["{}", {}],
313 496 ["[]", []],
314 497 ],
315 498 )
316 499 @pytest.mark.parametrize("context", MINIMAL_OR_HIGHER)
317 500 def test_evaluates_literals(code, expected, context):
318 501 assert guarded_eval(code, context()) == expected
319 502
320 503
321 504 @pytest.mark.parametrize(
322 505 "code,expected",
323 506 [
324 507 ["-5", -5],
325 508 ["+5", +5],
326 509 ["~5", -6],
327 510 ],
328 511 )
329 512 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
330 513 def test_evaluates_unary_operations(code, expected, context):
331 514 assert guarded_eval(code, context()) == expected
332 515
333 516
334 517 @pytest.mark.parametrize(
335 518 "code,expected",
336 519 [
337 520 ["1 + 1", 2],
338 521 ["3 - 1", 2],
339 522 ["2 * 3", 6],
340 523 ["5 // 2", 2],
341 524 ["5 / 2", 2.5],
342 525 ["5**2", 25],
343 526 ["2 >> 1", 1],
344 527 ["2 << 1", 4],
345 528 ["1 | 2", 3],
346 529 ["1 & 1", 1],
347 530 ["1 & 2", 0],
348 531 ],
349 532 )
350 533 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
351 534 def test_evaluates_binary_operations(code, expected, context):
352 535 assert guarded_eval(code, context()) == expected
353 536
354 537
355 538 @pytest.mark.parametrize(
356 539 "code,expected",
357 540 [
358 541 ["2 > 1", True],
359 542 ["2 < 1", False],
360 543 ["2 <= 1", False],
361 544 ["2 <= 2", True],
362 545 ["1 >= 2", False],
363 546 ["2 >= 2", True],
364 547 ["2 == 2", True],
365 548 ["1 == 2", False],
366 549 ["1 != 2", True],
367 550 ["1 != 1", False],
368 551 ["1 < 4 < 3", False],
369 552 ["(1 < 4) < 3", True],
370 553 ["4 > 3 > 2 > 1", True],
371 554 ["4 > 3 > 2 > 9", False],
372 555 ["1 < 2 < 3 < 4", True],
373 556 ["9 < 2 < 3 < 4", False],
374 557 ["1 < 2 > 1 > 0 > -1 < 1", True],
375 558 ["1 in [1] in [[1]]", True],
376 559 ["1 in [1] in [[2]]", False],
377 560 ["1 in [1]", True],
378 561 ["0 in [1]", False],
379 562 ["1 not in [1]", False],
380 563 ["0 not in [1]", True],
381 564 ["True is True", True],
382 565 ["False is False", True],
383 566 ["True is False", False],
384 567 ["True is not True", False],
385 568 ["False is not True", True],
386 569 ],
387 570 )
388 571 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
389 572 def test_evaluates_comparisons(code, expected, context):
390 573 assert guarded_eval(code, context()) == expected
391 574
392 575
393 576 def test_guards_comparisons():
394 577 class GoodEq(int):
395 578 pass
396 579
397 580 class BadEq(int):
398 581 def __eq__(self, other):
399 582 assert False
400 583
401 584 context = limited(bad=BadEq(1), good=GoodEq(1))
402 585
403 586 with pytest.raises(GuardRejection):
404 587 guarded_eval("bad == 1", context)
405 588
406 589 with pytest.raises(GuardRejection):
407 590 guarded_eval("bad != 1", context)
408 591
409 592 with pytest.raises(GuardRejection):
410 593 guarded_eval("1 == bad", context)
411 594
412 595 with pytest.raises(GuardRejection):
413 596 guarded_eval("1 != bad", context)
414 597
415 598 assert guarded_eval("good == 1", context) is True
416 599 assert guarded_eval("good != 1", context) is False
417 600 assert guarded_eval("1 == good", context) is True
418 601 assert guarded_eval("1 != good", context) is False
419 602
420 603
421 604 def test_guards_unary_operations():
422 605 class GoodOp(int):
423 606 pass
424 607
425 608 class BadOpInv(int):
426 609 def __inv__(self, other):
427 610 assert False
428 611
429 612 class BadOpInverse(int):
430 613 def __inv__(self, other):
431 614 assert False
432 615
433 616 context = limited(good=GoodOp(1), bad1=BadOpInv(1), bad2=BadOpInverse(1))
434 617
435 618 with pytest.raises(GuardRejection):
436 619 guarded_eval("~bad1", context)
437 620
438 621 with pytest.raises(GuardRejection):
439 622 guarded_eval("~bad2", context)
440 623
441 624
442 625 def test_guards_binary_operations():
443 626 class GoodOp(int):
444 627 pass
445 628
446 629 class BadOp(int):
447 630 def __add__(self, other):
448 631 assert False
449 632
450 633 context = limited(good=GoodOp(1), bad=BadOp(1))
451 634
452 635 with pytest.raises(GuardRejection):
453 636 guarded_eval("1 + bad", context)
454 637
455 638 with pytest.raises(GuardRejection):
456 639 guarded_eval("bad + 1", context)
457 640
458 641 assert guarded_eval("good + 1", context) == 2
459 642 assert guarded_eval("1 + good", context) == 2
460 643
461 644
462 645 def test_guards_attributes():
463 646 class GoodAttr(float):
464 647 pass
465 648
466 649 class BadAttr1(float):
467 650 def __getattr__(self, key):
468 651 assert False
469 652
470 653 class BadAttr2(float):
471 654 def __getattribute__(self, key):
472 655 assert False
473 656
474 657 context = limited(good=GoodAttr(0.5), bad1=BadAttr1(0.5), bad2=BadAttr2(0.5))
475 658
476 659 with pytest.raises(GuardRejection):
477 660 guarded_eval("bad1.as_integer_ratio", context)
478 661
479 662 with pytest.raises(GuardRejection):
480 663 guarded_eval("bad2.as_integer_ratio", context)
481 664
482 665 assert guarded_eval("good.as_integer_ratio()", context) == (1, 2)
483 666
484 667
485 668 @pytest.mark.parametrize("context", MINIMAL_OR_HIGHER)
486 669 def test_access_builtins(context):
487 670 assert guarded_eval("round", context()) == round
488 671
489 672
490 673 def test_access_builtins_fails():
491 674 context = limited()
492 675 with pytest.raises(NameError):
493 676 guarded_eval("this_is_not_builtin", context)
494 677
495 678
496 679 def test_rejects_forbidden():
497 680 context = forbidden()
498 681 with pytest.raises(GuardRejection):
499 682 guarded_eval("1", context)
500 683
501 684
502 685 def test_guards_locals_and_globals():
503 686 context = EvaluationContext(
504 687 locals={"local_a": "a"}, globals={"global_b": "b"}, evaluation="minimal"
505 688 )
506 689
507 690 with pytest.raises(GuardRejection):
508 691 guarded_eval("local_a", context)
509 692
510 693 with pytest.raises(GuardRejection):
511 694 guarded_eval("global_b", context)
512 695
513 696
514 697 def test_access_locals_and_globals():
515 698 context = EvaluationContext(
516 699 locals={"local_a": "a"}, globals={"global_b": "b"}, evaluation="limited"
517 700 )
518 701 assert guarded_eval("local_a", context) == "a"
519 702 assert guarded_eval("global_b", context) == "b"
520 703
521 704
522 705 @pytest.mark.parametrize(
523 706 "code",
524 707 ["def func(): pass", "class C: pass", "x = 1", "x += 1", "del x", "import ast"],
525 708 )
526 709 @pytest.mark.parametrize("context", [minimal(), limited(), unsafe()])
527 710 def test_rejects_side_effect_syntax(code, context):
528 711 with pytest.raises(SyntaxError):
529 712 guarded_eval(code, context)
530 713
531 714
532 715 def test_subscript():
533 716 context = EvaluationContext(
534 717 locals={}, globals={}, evaluation="limited", in_subscript=True
535 718 )
536 719 empty_slice = slice(None, None, None)
537 720 assert guarded_eval("", context) == tuple()
538 721 assert guarded_eval(":", context) == empty_slice
539 722 assert guarded_eval("1:2:3", context) == slice(1, 2, 3)
540 723 assert guarded_eval(':, "a"', context) == (empty_slice, "a")
541 724
542 725
543 726 def test_unbind_method():
544 727 class X(list):
545 728 def index(self, k):
546 729 return "CUSTOM"
547 730
548 731 x = X()
549 732 assert _unbind_method(x.index) is X.index
550 733 assert _unbind_method([].index) is list.index
551 734 assert _unbind_method(list.index) is None
552 735
553 736
554 737 def test_assumption_instance_attr_do_not_matter():
555 738 """This is semi-specified in Python documentation.
556 739
557 740 However, since the specification says 'not guaranteed
558 741 to work' rather than 'is forbidden to work', future
559 742 versions could invalidate this assumptions. This test
560 743 is meant to catch such a change if it ever comes true.
561 744 """
562 745
563 746 class T:
564 747 def __getitem__(self, k):
565 748 return "a"
566 749
567 750 def __getattr__(self, k):
568 751 return "a"
569 752
570 753 def f(self):
571 754 return "b"
572 755
573 756 t = T()
574 757 t.__getitem__ = f
575 758 t.__getattr__ = f
576 759 assert t[1] == "a"
577 760 assert t[1] == "a"
578 761
579 762
580 763 def test_assumption_named_tuples_share_getitem():
581 764 """Check assumption on named tuples sharing __getitem__"""
582 765 from typing import NamedTuple
583 766
584 767 class A(NamedTuple):
585 768 pass
586 769
587 770 class B(NamedTuple):
588 771 pass
589 772
590 773 assert A.__getitem__ == B.__getitem__
591 774
592 775
593 776 @dec.skip_without("numpy")
594 777 def test_module_access():
595 778 import numpy
596 779
597 780 context = limited(numpy=numpy)
598 781 assert guarded_eval("numpy.linalg.norm", context) == numpy.linalg.norm
599 782
600 783 context = minimal(numpy=numpy)
601 784 with pytest.raises(GuardRejection):
602 785 guarded_eval("np.linalg.norm", context)
@@ -1,206 +1,206 b''
1 1 [build-system]
2 2 requires = ["setuptools>=61.2"]
3 3 # We need access to the 'setupbase' module at build time.
4 4 # Hence we declare a custom build backend.
5 5 build-backend = "_build_meta" # just re-exports setuptools.build_meta definitions
6 6 backend-path = ["."]
7 7
8 8 [project]
9 9 name = "ipython"
10 10 description = "IPython: Productive Interactive Computing"
11 11 keywords = ["Interactive", "Interpreter", "Shell", "Embedding"]
12 12 classifiers = [
13 13 "Framework :: IPython",
14 14 "Framework :: Jupyter",
15 15 "Intended Audience :: Developers",
16 16 "Intended Audience :: Science/Research",
17 17 "License :: OSI Approved :: BSD License",
18 18 "Programming Language :: Python",
19 19 "Programming Language :: Python :: 3",
20 20 "Programming Language :: Python :: 3 :: Only",
21 21 "Topic :: System :: Shells",
22 22 ]
23 23 requires-python = ">=3.10"
24 24 dependencies = [
25 25 'colorama; sys_platform == "win32"',
26 26 "decorator",
27 27 "exceptiongroup; python_version<'3.11'",
28 28 "jedi>=0.16",
29 29 "matplotlib-inline",
30 30 'pexpect>4.3; sys_platform != "win32" and sys_platform != "emscripten"',
31 31 "prompt_toolkit>=3.0.41,<3.1.0",
32 32 "pygments>=2.4.0",
33 33 "stack_data",
34 34 "traitlets>=5.13.0",
35 "typing_extensions; python_version<'3.10'",
35 "typing_extensions; python_version<'3.12'",
36 36 ]
37 37 dynamic = ["authors", "license", "version"]
38 38
39 39 [project.entry-points."pygments.lexers"]
40 40 ipythonconsole = "IPython.lib.lexers:IPythonConsoleLexer"
41 41 ipython = "IPython.lib.lexers:IPythonLexer"
42 42 ipython3 = "IPython.lib.lexers:IPython3Lexer"
43 43
44 44 [project.scripts]
45 45 ipython = "IPython:start_ipython"
46 46 ipython3 = "IPython:start_ipython"
47 47
48 48 [project.readme]
49 49 file = "long_description.rst"
50 50 content-type = "text/x-rst"
51 51
52 52 [project.urls]
53 53 Homepage = "https://ipython.org"
54 54 Documentation = "https://ipython.readthedocs.io/"
55 55 Funding = "https://numfocus.org/"
56 56 Source = "https://github.com/ipython/ipython"
57 57 Tracker = "https://github.com/ipython/ipython/issues"
58 58
59 59 [project.optional-dependencies]
60 60 black = [
61 61 "black",
62 62 ]
63 63 doc = [
64 64 "ipykernel",
65 65 "setuptools>=18.5",
66 66 "sphinx>=1.3",
67 67 "sphinx-rtd-theme",
68 68 "sphinxcontrib-jquery",
69 69 "docrepr",
70 70 "matplotlib",
71 71 "stack_data",
72 72 "typing_extensions",
73 73 "exceptiongroup",
74 74 "ipython[test]",
75 75 ]
76 76 kernel = [
77 77 "ipykernel",
78 78 ]
79 79 nbconvert = [
80 80 "nbconvert",
81 81 ]
82 82 nbformat = [
83 83 "nbformat",
84 84 ]
85 85 notebook = [
86 86 "ipywidgets",
87 87 "notebook",
88 88 ]
89 89 parallel = [
90 90 "ipyparallel",
91 91 ]
92 92 qtconsole = [
93 93 "qtconsole",
94 94 ]
95 95 terminal = []
96 96 test = [
97 97 "pytest<8",
98 98 "pytest-asyncio<0.22",
99 99 "testpath",
100 100 "pickleshare",
101 101 ]
102 102 test_extra = [
103 103 "ipython[test]",
104 104 "curio",
105 105 "matplotlib!=3.2.0",
106 106 "nbformat",
107 107 "numpy>=1.23",
108 108 "pandas",
109 109 "trio",
110 110 ]
111 111 all = [
112 112 "ipython[black,doc,kernel,nbconvert,nbformat,notebook,parallel,qtconsole,terminal]",
113 113 "ipython[test,test_extra]",
114 114 ]
115 115
116 116 [tool.mypy]
117 117 python_version = "3.10"
118 118 ignore_missing_imports = true
119 119 follow_imports = 'silent'
120 120 exclude = [
121 121 'test_\.+\.py',
122 122 'IPython.utils.tests.test_wildcard',
123 123 'testing',
124 124 'tests',
125 125 'PyColorize.py',
126 126 '_process_win32_controller.py',
127 127 'IPython/core/application.py',
128 128 'IPython/core/profileapp.py',
129 129 'IPython/lib/deepreload.py',
130 130 'IPython/sphinxext/ipython_directive.py',
131 131 'IPython/terminal/ipapp.py',
132 132 'IPython/utils/_process_win32.py',
133 133 'IPython/utils/path.py',
134 134 ]
135 135
136 136 [tool.pytest.ini_options]
137 137 addopts = [
138 138 "--durations=10",
139 139 "-pIPython.testing.plugin.pytest_ipdoctest",
140 140 "--ipdoctest-modules",
141 141 "--ignore=docs",
142 142 "--ignore=examples",
143 143 "--ignore=htmlcov",
144 144 "--ignore=ipython_kernel",
145 145 "--ignore=ipython_parallel",
146 146 "--ignore=results",
147 147 "--ignore=tmp",
148 148 "--ignore=tools",
149 149 "--ignore=traitlets",
150 150 "--ignore=IPython/core/tests/daft_extension",
151 151 "--ignore=IPython/sphinxext",
152 152 "--ignore=IPython/terminal/pt_inputhooks",
153 153 "--ignore=IPython/__main__.py",
154 154 "--ignore=IPython/external/qt_for_kernel.py",
155 155 "--ignore=IPython/html/widgets/widget_link.py",
156 156 "--ignore=IPython/html/widgets/widget_output.py",
157 157 "--ignore=IPython/terminal/console.py",
158 158 "--ignore=IPython/utils/_process_cli.py",
159 159 "--ignore=IPython/utils/_process_posix.py",
160 160 "--ignore=IPython/utils/_process_win32.py",
161 161 "--ignore=IPython/utils/_process_win32_controller.py",
162 162 "--ignore=IPython/utils/daemonize.py",
163 163 "--ignore=IPython/utils/eventful.py",
164 164 "--ignore=IPython/kernel",
165 165 "--ignore=IPython/consoleapp.py",
166 166 "--ignore=IPython/core/inputsplitter.py",
167 167 "--ignore=IPython/lib/kernel.py",
168 168 "--ignore=IPython/utils/jsonutil.py",
169 169 "--ignore=IPython/utils/localinterfaces.py",
170 170 "--ignore=IPython/utils/log.py",
171 171 "--ignore=IPython/utils/signatures.py",
172 172 "--ignore=IPython/utils/traitlets.py",
173 173 "--ignore=IPython/utils/version.py"
174 174 ]
175 175 doctest_optionflags = [
176 176 "NORMALIZE_WHITESPACE",
177 177 "ELLIPSIS"
178 178 ]
179 179 ipdoctest_optionflags = [
180 180 "NORMALIZE_WHITESPACE",
181 181 "ELLIPSIS"
182 182 ]
183 183 asyncio_mode = "strict"
184 184
185 185 [tool.pyright]
186 186 pythonPlatform="All"
187 187
188 188 [tool.setuptools]
189 189 zip-safe = false
190 190 platforms = ["Linux", "Mac OSX", "Windows"]
191 191 license-files = ["LICENSE"]
192 192 include-package-data = false
193 193
194 194 [tool.setuptools.packages.find]
195 195 exclude = ["setupext"]
196 196 namespaces = false
197 197
198 198 [tool.setuptools.package-data]
199 199 "IPython" = ["py.typed"]
200 200 "IPython.core" = ["profile/README*"]
201 201 "IPython.core.tests" = ["*.png", "*.jpg", "daft_extension/*.py"]
202 202 "IPython.lib.tests" = ["*.wav"]
203 203 "IPython.testing.plugin" = ["*.txt"]
204 204
205 205 [tool.setuptools.dynamic]
206 206 version = {attr = "IPython.core.release.__version__"}
General Comments 0
You need to be logged in to leave comments. Login now