##// END OF EJS Templates
Implement remaining special `typing` wrappers
krassowski -
Show More
@@ -1,822 +1,898 b''
1 1 from inspect import isclass, signature, Signature
2 2 from typing import (
3 Annotated,
4 AnyStr,
3 5 Callable,
4 6 Dict,
5 7 Literal,
6 8 NamedTuple,
7 9 NewType,
10 Optional,
11 Protocol,
8 12 Set,
9 13 Sequence,
10 14 Tuple,
11 15 Type,
12 Protocol,
16 TypeGuard,
13 17 Union,
14 18 get_args,
15 19 get_origin,
20 is_typeddict,
16 21 )
17 22 import ast
18 23 import builtins
19 24 import collections
20 25 import operator
21 26 import sys
22 27 from functools import cached_property
23 28 from dataclasses import dataclass, field
24 29 from types import MethodDescriptorType, ModuleType
25 30
26 31 from IPython.utils.decorators import undoc
27 32
28 33
29 34 if sys.version_info < (3, 11):
30 from typing_extensions import Self
35 from typing_extensions import Self, LiteralString
31 36 else:
32 from typing import Self
37 from typing import Self, LiteralString
33 38
34 39 if sys.version_info < (3, 12):
35 40 from typing_extensions import TypeAliasType
36 41 else:
37 42 from typing import TypeAliasType
38 43
39 44
40 45 @undoc
41 46 class HasGetItem(Protocol):
42 47 def __getitem__(self, key) -> None:
43 48 ...
44 49
45 50
46 51 @undoc
47 52 class InstancesHaveGetItem(Protocol):
48 53 def __call__(self, *args, **kwargs) -> HasGetItem:
49 54 ...
50 55
51 56
52 57 @undoc
53 58 class HasGetAttr(Protocol):
54 59 def __getattr__(self, key) -> None:
55 60 ...
56 61
57 62
58 63 @undoc
59 64 class DoesNotHaveGetAttr(Protocol):
60 65 pass
61 66
62 67
63 68 # By default `__getattr__` is not explicitly implemented on most objects
64 69 MayHaveGetattr = Union[HasGetAttr, DoesNotHaveGetAttr]
65 70
66 71
67 72 def _unbind_method(func: Callable) -> Union[Callable, None]:
68 73 """Get unbound method for given bound method.
69 74
70 75 Returns None if cannot get unbound method, or method is already unbound.
71 76 """
72 77 owner = getattr(func, "__self__", None)
73 78 owner_class = type(owner)
74 79 name = getattr(func, "__name__", None)
75 80 instance_dict_overrides = getattr(owner, "__dict__", None)
76 81 if (
77 82 owner is not None
78 83 and name
79 84 and (
80 85 not instance_dict_overrides
81 86 or (instance_dict_overrides and name not in instance_dict_overrides)
82 87 )
83 88 ):
84 89 return getattr(owner_class, name)
85 90 return None
86 91
87 92
88 93 @undoc
89 94 @dataclass
90 95 class EvaluationPolicy:
91 96 """Definition of evaluation policy."""
92 97
93 98 allow_locals_access: bool = False
94 99 allow_globals_access: bool = False
95 100 allow_item_access: bool = False
96 101 allow_attr_access: bool = False
97 102 allow_builtins_access: bool = False
98 103 allow_all_operations: bool = False
99 104 allow_any_calls: bool = False
100 105 allowed_calls: Set[Callable] = field(default_factory=set)
101 106
102 107 def can_get_item(self, value, item):
103 108 return self.allow_item_access
104 109
105 110 def can_get_attr(self, value, attr):
106 111 return self.allow_attr_access
107 112
108 113 def can_operate(self, dunders: Tuple[str, ...], a, b=None):
109 114 if self.allow_all_operations:
110 115 return True
111 116
112 117 def can_call(self, func):
113 118 if self.allow_any_calls:
114 119 return True
115 120
116 121 if func in self.allowed_calls:
117 122 return True
118 123
119 124 owner_method = _unbind_method(func)
120 125
121 126 if owner_method and owner_method in self.allowed_calls:
122 127 return True
123 128
124 129
125 130 def _get_external(module_name: str, access_path: Sequence[str]):
126 131 """Get value from external module given a dotted access path.
127 132
128 133 Raises:
129 134 * `KeyError` if module is removed not found, and
130 135 * `AttributeError` if acess path does not match an exported object
131 136 """
132 137 member_type = sys.modules[module_name]
133 138 for attr in access_path:
134 139 member_type = getattr(member_type, attr)
135 140 return member_type
136 141
137 142
138 143 def _has_original_dunder_external(
139 144 value,
140 145 module_name: str,
141 146 access_path: Sequence[str],
142 147 method_name: str,
143 148 ):
144 149 if module_name not in sys.modules:
145 150 # LBYLB as it is faster
146 151 return False
147 152 try:
148 153 member_type = _get_external(module_name, access_path)
149 154 value_type = type(value)
150 155 if type(value) == member_type:
151 156 return True
152 157 if method_name == "__getattribute__":
153 158 # we have to short-circuit here due to an unresolved issue in
154 159 # `isinstance` implementation: https://bugs.python.org/issue32683
155 160 return False
156 161 if isinstance(value, member_type):
157 162 method = getattr(value_type, method_name, None)
158 163 member_method = getattr(member_type, method_name, None)
159 164 if member_method == method:
160 165 return True
161 166 except (AttributeError, KeyError):
162 167 return False
163 168
164 169
165 170 def _has_original_dunder(
166 171 value, allowed_types, allowed_methods, allowed_external, method_name
167 172 ):
168 173 # note: Python ignores `__getattr__`/`__getitem__` on instances,
169 174 # we only need to check at class level
170 175 value_type = type(value)
171 176
172 177 # strict type check passes β†’ no need to check method
173 178 if value_type in allowed_types:
174 179 return True
175 180
176 181 method = getattr(value_type, method_name, None)
177 182
178 183 if method is None:
179 184 return None
180 185
181 186 if method in allowed_methods:
182 187 return True
183 188
184 189 for module_name, *access_path in allowed_external:
185 190 if _has_original_dunder_external(value, module_name, access_path, method_name):
186 191 return True
187 192
188 193 return False
189 194
190 195
191 196 @undoc
192 197 @dataclass
193 198 class SelectivePolicy(EvaluationPolicy):
194 199 allowed_getitem: Set[InstancesHaveGetItem] = field(default_factory=set)
195 200 allowed_getitem_external: Set[Tuple[str, ...]] = field(default_factory=set)
196 201
197 202 allowed_getattr: Set[MayHaveGetattr] = field(default_factory=set)
198 203 allowed_getattr_external: Set[Tuple[str, ...]] = field(default_factory=set)
199 204
200 205 allowed_operations: Set = field(default_factory=set)
201 206 allowed_operations_external: Set[Tuple[str, ...]] = field(default_factory=set)
202 207
203 208 _operation_methods_cache: Dict[str, Set[Callable]] = field(
204 209 default_factory=dict, init=False
205 210 )
206 211
207 212 def can_get_attr(self, value, attr):
208 213 has_original_attribute = _has_original_dunder(
209 214 value,
210 215 allowed_types=self.allowed_getattr,
211 216 allowed_methods=self._getattribute_methods,
212 217 allowed_external=self.allowed_getattr_external,
213 218 method_name="__getattribute__",
214 219 )
215 220 has_original_attr = _has_original_dunder(
216 221 value,
217 222 allowed_types=self.allowed_getattr,
218 223 allowed_methods=self._getattr_methods,
219 224 allowed_external=self.allowed_getattr_external,
220 225 method_name="__getattr__",
221 226 )
222 227
223 228 accept = False
224 229
225 230 # Many objects do not have `__getattr__`, this is fine.
226 231 if has_original_attr is None and has_original_attribute:
227 232 accept = True
228 233 else:
229 234 # Accept objects without modifications to `__getattr__` and `__getattribute__`
230 235 accept = has_original_attr and has_original_attribute
231 236
232 237 if accept:
233 238 # We still need to check for overriden properties.
234 239
235 240 value_class = type(value)
236 241 if not hasattr(value_class, attr):
237 242 return True
238 243
239 244 class_attr_val = getattr(value_class, attr)
240 245 is_property = isinstance(class_attr_val, property)
241 246
242 247 if not is_property:
243 248 return True
244 249
245 250 # Properties in allowed types are ok (although we do not include any
246 251 # properties in our default allow list currently).
247 252 if type(value) in self.allowed_getattr:
248 253 return True # pragma: no cover
249 254
250 255 # Properties in subclasses of allowed types may be ok if not changed
251 256 for module_name, *access_path in self.allowed_getattr_external:
252 257 try:
253 258 external_class = _get_external(module_name, access_path)
254 259 external_class_attr_val = getattr(external_class, attr)
255 260 except (KeyError, AttributeError):
256 261 return False # pragma: no cover
257 262 return class_attr_val == external_class_attr_val
258 263
259 264 return False
260 265
261 266 def can_get_item(self, value, item):
262 267 """Allow accessing `__getiitem__` of allow-listed instances unless it was not modified."""
263 268 return _has_original_dunder(
264 269 value,
265 270 allowed_types=self.allowed_getitem,
266 271 allowed_methods=self._getitem_methods,
267 272 allowed_external=self.allowed_getitem_external,
268 273 method_name="__getitem__",
269 274 )
270 275
271 276 def can_operate(self, dunders: Tuple[str, ...], a, b=None):
272 277 objects = [a]
273 278 if b is not None:
274 279 objects.append(b)
275 280 return all(
276 281 [
277 282 _has_original_dunder(
278 283 obj,
279 284 allowed_types=self.allowed_operations,
280 285 allowed_methods=self._operator_dunder_methods(dunder),
281 286 allowed_external=self.allowed_operations_external,
282 287 method_name=dunder,
283 288 )
284 289 for dunder in dunders
285 290 for obj in objects
286 291 ]
287 292 )
288 293
289 294 def _operator_dunder_methods(self, dunder: str) -> Set[Callable]:
290 295 if dunder not in self._operation_methods_cache:
291 296 self._operation_methods_cache[dunder] = self._safe_get_methods(
292 297 self.allowed_operations, dunder
293 298 )
294 299 return self._operation_methods_cache[dunder]
295 300
296 301 @cached_property
297 302 def _getitem_methods(self) -> Set[Callable]:
298 303 return self._safe_get_methods(self.allowed_getitem, "__getitem__")
299 304
300 305 @cached_property
301 306 def _getattr_methods(self) -> Set[Callable]:
302 307 return self._safe_get_methods(self.allowed_getattr, "__getattr__")
303 308
304 309 @cached_property
305 310 def _getattribute_methods(self) -> Set[Callable]:
306 311 return self._safe_get_methods(self.allowed_getattr, "__getattribute__")
307 312
308 313 def _safe_get_methods(self, classes, name) -> Set[Callable]:
309 314 return {
310 315 method
311 316 for class_ in classes
312 317 for method in [getattr(class_, name, None)]
313 318 if method
314 319 }
315 320
316 321
317 322 class _DummyNamedTuple(NamedTuple):
318 323 """Used internally to retrieve methods of named tuple instance."""
319 324
320 325
321 326 class EvaluationContext(NamedTuple):
322 327 #: Local namespace
323 328 locals: dict
324 329 #: Global namespace
325 330 globals: dict
326 331 #: Evaluation policy identifier
327 332 evaluation: Literal[
328 333 "forbidden", "minimal", "limited", "unsafe", "dangerous"
329 334 ] = "forbidden"
330 335 #: Whether the evalution of code takes place inside of a subscript.
331 336 #: Useful for evaluating ``:-1, 'col'`` in ``df[:-1, 'col']``.
332 337 in_subscript: bool = False
333 338
334 339
335 340 class _IdentitySubscript:
336 341 """Returns the key itself when item is requested via subscript."""
337 342
338 343 def __getitem__(self, key):
339 344 return key
340 345
341 346
342 347 IDENTITY_SUBSCRIPT = _IdentitySubscript()
343 348 SUBSCRIPT_MARKER = "__SUBSCRIPT_SENTINEL__"
344 349 UNKNOWN_SIGNATURE = Signature()
345 350 NOT_EVALUATED = object()
346 351
347 352
348 353 class GuardRejection(Exception):
349 354 """Exception raised when guard rejects evaluation attempt."""
350 355
351 356 pass
352 357
353 358
354 359 def guarded_eval(code: str, context: EvaluationContext):
355 360 """Evaluate provided code in the evaluation context.
356 361
357 362 If evaluation policy given by context is set to ``forbidden``
358 363 no evaluation will be performed; if it is set to ``dangerous``
359 364 standard :func:`eval` will be used; finally, for any other,
360 365 policy :func:`eval_node` will be called on parsed AST.
361 366 """
362 367 locals_ = context.locals
363 368
364 369 if context.evaluation == "forbidden":
365 370 raise GuardRejection("Forbidden mode")
366 371
367 372 # note: not using `ast.literal_eval` as it does not implement
368 373 # getitem at all, for example it fails on simple `[0][1]`
369 374
370 375 if context.in_subscript:
371 376 # syntatic sugar for ellipsis (:) is only available in susbcripts
372 377 # so we need to trick the ast parser into thinking that we have
373 378 # a subscript, but we need to be able to later recognise that we did
374 379 # it so we can ignore the actual __getitem__ operation
375 380 if not code:
376 381 return tuple()
377 382 locals_ = locals_.copy()
378 383 locals_[SUBSCRIPT_MARKER] = IDENTITY_SUBSCRIPT
379 384 code = SUBSCRIPT_MARKER + "[" + code + "]"
380 385 context = EvaluationContext(**{**context._asdict(), **{"locals": locals_}})
381 386
382 387 if context.evaluation == "dangerous":
383 388 return eval(code, context.globals, context.locals)
384 389
385 390 expression = ast.parse(code, mode="eval")
386 391
387 392 return eval_node(expression, context)
388 393
389 394
390 395 BINARY_OP_DUNDERS: Dict[Type[ast.operator], Tuple[str]] = {
391 396 ast.Add: ("__add__",),
392 397 ast.Sub: ("__sub__",),
393 398 ast.Mult: ("__mul__",),
394 399 ast.Div: ("__truediv__",),
395 400 ast.FloorDiv: ("__floordiv__",),
396 401 ast.Mod: ("__mod__",),
397 402 ast.Pow: ("__pow__",),
398 403 ast.LShift: ("__lshift__",),
399 404 ast.RShift: ("__rshift__",),
400 405 ast.BitOr: ("__or__",),
401 406 ast.BitXor: ("__xor__",),
402 407 ast.BitAnd: ("__and__",),
403 408 ast.MatMult: ("__matmul__",),
404 409 }
405 410
406 411 COMP_OP_DUNDERS: Dict[Type[ast.cmpop], Tuple[str, ...]] = {
407 412 ast.Eq: ("__eq__",),
408 413 ast.NotEq: ("__ne__", "__eq__"),
409 414 ast.Lt: ("__lt__", "__gt__"),
410 415 ast.LtE: ("__le__", "__ge__"),
411 416 ast.Gt: ("__gt__", "__lt__"),
412 417 ast.GtE: ("__ge__", "__le__"),
413 418 ast.In: ("__contains__",),
414 419 # Note: ast.Is, ast.IsNot, ast.NotIn are handled specially
415 420 }
416 421
417 422 UNARY_OP_DUNDERS: Dict[Type[ast.unaryop], Tuple[str, ...]] = {
418 423 ast.USub: ("__neg__",),
419 424 ast.UAdd: ("__pos__",),
420 425 # we have to check both __inv__ and __invert__!
421 426 ast.Invert: ("__invert__", "__inv__"),
422 427 ast.Not: ("__not__",),
423 428 }
424 429
425 430
426 class Duck:
431 class ImpersonatingDuck:
427 432 """A dummy class used to create objects of other classes without calling their ``__init__``"""
428 433
434 # no-op: override __class__ to impersonate
435
436
437 class _Duck:
438 """A dummy class used to create objects pretending to have given attributes"""
439
440 def __init__(self, attributes: Optional[dict] = None, items: Optional[dict] = None):
441 self.attributes = attributes or {}
442 self.items = items or {}
443
444 def __getattr__(self, attr: str):
445 return self.attributes[attr]
446
447 def __hasattr__(self, attr: str):
448 return attr in self.attributes
449
450 def __dir__(self):
451 return [*dir(super), *self.attributes]
452
453 def __getitem__(self, key: str):
454 return self.items[key]
455
456 def __hasitem__(self, key: str):
457 return self.items[key]
458
459 def _ipython_key_completions_(self):
460 return self.items.keys()
461
429 462
430 463 def _find_dunder(node_op, dunders) -> Union[Tuple[str, ...], None]:
431 464 dunder = None
432 465 for op, candidate_dunder in dunders.items():
433 466 if isinstance(node_op, op):
434 467 dunder = candidate_dunder
435 468 return dunder
436 469
437 470
438 471 def eval_node(node: Union[ast.AST, None], context: EvaluationContext):
439 472 """Evaluate AST node in provided context.
440 473
441 474 Applies evaluation restrictions defined in the context. Currently does not support evaluation of functions with keyword arguments.
442 475
443 476 Does not evaluate actions that always have side effects:
444 477
445 478 - class definitions (``class sth: ...``)
446 479 - function definitions (``def sth: ...``)
447 480 - variable assignments (``x = 1``)
448 481 - augmented assignments (``x += 1``)
449 482 - deletions (``del x``)
450 483
451 484 Does not evaluate operations which do not return values:
452 485
453 486 - assertions (``assert x``)
454 487 - pass (``pass``)
455 488 - imports (``import x``)
456 489 - control flow:
457 490
458 491 - conditionals (``if x:``) except for ternary IfExp (``a if x else b``)
459 492 - loops (``for`` and ``while``)
460 493 - exception handling
461 494
462 495 The purpose of this function is to guard against unwanted side-effects;
463 496 it does not give guarantees on protection from malicious code execution.
464 497 """
465 498 policy = EVALUATION_POLICIES[context.evaluation]
466 499 if node is None:
467 500 return None
468 501 if isinstance(node, ast.Expression):
469 502 return eval_node(node.body, context)
470 503 if isinstance(node, ast.BinOp):
471 504 left = eval_node(node.left, context)
472 505 right = eval_node(node.right, context)
473 506 dunders = _find_dunder(node.op, BINARY_OP_DUNDERS)
474 507 if dunders:
475 508 if policy.can_operate(dunders, left, right):
476 509 return getattr(left, dunders[0])(right)
477 510 else:
478 511 raise GuardRejection(
479 512 f"Operation (`{dunders}`) for",
480 513 type(left),
481 514 f"not allowed in {context.evaluation} mode",
482 515 )
483 516 if isinstance(node, ast.Compare):
484 517 left = eval_node(node.left, context)
485 518 all_true = True
486 519 negate = False
487 520 for op, right in zip(node.ops, node.comparators):
488 521 right = eval_node(right, context)
489 522 dunder = None
490 523 dunders = _find_dunder(op, COMP_OP_DUNDERS)
491 524 if not dunders:
492 525 if isinstance(op, ast.NotIn):
493 526 dunders = COMP_OP_DUNDERS[ast.In]
494 527 negate = True
495 528 if isinstance(op, ast.Is):
496 529 dunder = "is_"
497 530 if isinstance(op, ast.IsNot):
498 531 dunder = "is_"
499 532 negate = True
500 533 if not dunder and dunders:
501 534 dunder = dunders[0]
502 535 if dunder:
503 536 a, b = (right, left) if dunder == "__contains__" else (left, right)
504 537 if dunder == "is_" or dunders and policy.can_operate(dunders, a, b):
505 538 result = getattr(operator, dunder)(a, b)
506 539 if negate:
507 540 result = not result
508 541 if not result:
509 542 all_true = False
510 543 left = right
511 544 else:
512 545 raise GuardRejection(
513 546 f"Comparison (`{dunder}`) for",
514 547 type(left),
515 548 f"not allowed in {context.evaluation} mode",
516 549 )
517 550 else:
518 551 raise ValueError(
519 552 f"Comparison `{dunder}` not supported"
520 553 ) # pragma: no cover
521 554 return all_true
522 555 if isinstance(node, ast.Constant):
523 556 return node.value
524 557 if isinstance(node, ast.Tuple):
525 558 return tuple(eval_node(e, context) for e in node.elts)
526 559 if isinstance(node, ast.List):
527 560 return [eval_node(e, context) for e in node.elts]
528 561 if isinstance(node, ast.Set):
529 562 return {eval_node(e, context) for e in node.elts}
530 563 if isinstance(node, ast.Dict):
531 564 return dict(
532 565 zip(
533 566 [eval_node(k, context) for k in node.keys],
534 567 [eval_node(v, context) for v in node.values],
535 568 )
536 569 )
537 570 if isinstance(node, ast.Slice):
538 571 return slice(
539 572 eval_node(node.lower, context),
540 573 eval_node(node.upper, context),
541 574 eval_node(node.step, context),
542 575 )
543 576 if isinstance(node, ast.UnaryOp):
544 577 value = eval_node(node.operand, context)
545 578 dunders = _find_dunder(node.op, UNARY_OP_DUNDERS)
546 579 if dunders:
547 580 if policy.can_operate(dunders, value):
548 581 return getattr(value, dunders[0])()
549 582 else:
550 583 raise GuardRejection(
551 584 f"Operation (`{dunders}`) for",
552 585 type(value),
553 586 f"not allowed in {context.evaluation} mode",
554 587 )
555 588 if isinstance(node, ast.Subscript):
556 589 value = eval_node(node.value, context)
557 590 slice_ = eval_node(node.slice, context)
558 591 if policy.can_get_item(value, slice_):
559 592 return value[slice_]
560 593 raise GuardRejection(
561 594 "Subscript access (`__getitem__`) for",
562 595 type(value), # not joined to avoid calling `repr`
563 596 f" not allowed in {context.evaluation} mode",
564 597 )
565 598 if isinstance(node, ast.Name):
566 599 return _eval_node_name(node.id, context)
567 600 if isinstance(node, ast.Attribute):
568 601 value = eval_node(node.value, context)
569 602 if policy.can_get_attr(value, node.attr):
570 603 return getattr(value, node.attr)
571 604 raise GuardRejection(
572 605 "Attribute access (`__getattr__`) for",
573 606 type(value), # not joined to avoid calling `repr`
574 607 f"not allowed in {context.evaluation} mode",
575 608 )
576 609 if isinstance(node, ast.IfExp):
577 610 test = eval_node(node.test, context)
578 611 if test:
579 612 return eval_node(node.body, context)
580 613 else:
581 614 return eval_node(node.orelse, context)
582 615 if isinstance(node, ast.Call):
583 616 func = eval_node(node.func, context)
584 617 if policy.can_call(func) and not node.keywords:
585 618 args = [eval_node(arg, context) for arg in node.args]
586 619 return func(*args)
587 620 if isclass(func):
588 621 # this code path gets entered when calling class e.g. `MyClass()`
589 622 # or `my_instance.__class__()` - in both cases `func` is `MyClass`.
590 623 # Should return `MyClass` if `__new__` is not overridden,
591 624 # otherwise whatever `__new__` return type is.
592 625 overridden_return_type = _eval_return_type(func.__new__, node, context)
593 626 if overridden_return_type is not NOT_EVALUATED:
594 627 return overridden_return_type
595 628 return _create_duck_for_heap_type(func)
596 629 else:
597 630 return_type = _eval_return_type(func, node, context)
598 631 if return_type is not NOT_EVALUATED:
599 632 return return_type
600 633 raise GuardRejection(
601 634 "Call for",
602 635 func, # not joined to avoid calling `repr`
603 636 f"not allowed in {context.evaluation} mode",
604 637 )
605 638 raise ValueError("Unhandled node", ast.dump(node))
606 639
607 640
608 641 def _eval_return_type(func: Callable, node: ast.Call, context: EvaluationContext):
609 642 """Evaluate return type of a given callable function.
610 643
611 644 Returns the built-in type, a duck or NOT_EVALUATED sentinel.
612 645 """
613 646 try:
614 647 sig = signature(func)
615 648 except ValueError:
616 649 sig = UNKNOWN_SIGNATURE
617 650 # if annotation was not stringized, or it was stringized
618 651 # but resolved by signature call we know the return type
619 652 not_empty = sig.return_annotation is not Signature.empty
620 stringized = isinstance(sig.return_annotation, str)
621 653 if not_empty:
622 return_type = (
623 _eval_node_name(sig.return_annotation, context)
624 if stringized
625 else sig.return_annotation
626 )
627 if return_type is Self and hasattr(func, "__self__"):
628 return func.__self__
629 elif get_origin(return_type) is Literal:
630 type_args = get_args(return_type)
631 if len(type_args) == 1:
632 return type_args[0]
633 elif isinstance(return_type, NewType):
634 return _eval_or_create_duck(return_type.__supertype__, node, context)
635 elif isinstance(return_type, TypeAliasType):
636 return _eval_or_create_duck(return_type.__value__, node, context)
637 else:
638 return _eval_or_create_duck(return_type, node, context)
654 return _resolve_annotation(sig.return_annotation, sig, func, node, context)
639 655 return NOT_EVALUATED
640 656
641 657
658 def _resolve_annotation(
659 annotation,
660 sig: Signature,
661 func: Callable,
662 node: ast.Call,
663 context: EvaluationContext,
664 ):
665 """Resolve annotation created by user with `typing` module and custom objects."""
666 annotation = (
667 _eval_node_name(annotation, context)
668 if isinstance(annotation, str)
669 else annotation
670 )
671 origin = get_origin(annotation)
672 if annotation is Self and hasattr(func, "__self__"):
673 return func.__self__
674 elif origin is Literal:
675 type_args = get_args(annotation)
676 if len(type_args) == 1:
677 return type_args[0]
678 elif annotation is LiteralString:
679 return ""
680 elif annotation is AnyStr:
681 index = None
682 for i, (key, value) in enumerate(sig.parameters.items()):
683 if value.annotation is AnyStr:
684 index = i
685 break
686 if index is not None and index < len(node.args):
687 return eval_node(node.args[index], context)
688 elif origin is TypeGuard:
689 return bool()
690 elif origin is Union:
691 attributes = [
692 attr
693 for type_arg in get_args(annotation)
694 for attr in dir(_resolve_annotation(type_arg, sig, func, node, context))
695 ]
696 return _Duck(attributes=dict.fromkeys(attributes))
697 elif is_typeddict(annotation):
698 return _Duck(
699 attributes=dict.fromkeys(dir(dict())),
700 items={
701 k: _resolve_annotation(v, sig, func, node, context)
702 for k, v in annotation.__annotations__.items()
703 },
704 )
705 elif hasattr(annotation, "_is_protocol"):
706 return _Duck(attributes=dict.fromkeys(dir(annotation)))
707 elif origin is Annotated:
708 type_arg = get_args(annotation)[0]
709 return _resolve_annotation(type_arg, sig, func, node, context)
710 elif isinstance(annotation, NewType):
711 return _eval_or_create_duck(annotation.__supertype__, node, context)
712 elif isinstance(annotation, TypeAliasType):
713 return _eval_or_create_duck(annotation.__value__, node, context)
714 else:
715 return _eval_or_create_duck(annotation, node, context)
716
717
642 718 def _eval_node_name(node_id: str, context: EvaluationContext):
643 719 policy = EVALUATION_POLICIES[context.evaluation]
644 720 if policy.allow_locals_access and node_id in context.locals:
645 721 return context.locals[node_id]
646 722 if policy.allow_globals_access and node_id in context.globals:
647 723 return context.globals[node_id]
648 724 if policy.allow_builtins_access and hasattr(builtins, node_id):
649 725 # note: do not use __builtins__, it is implementation detail of cPython
650 726 return getattr(builtins, node_id)
651 727 if not policy.allow_globals_access and not policy.allow_locals_access:
652 728 raise GuardRejection(
653 729 f"Namespace access not allowed in {context.evaluation} mode"
654 730 )
655 731 else:
656 732 raise NameError(f"{node_id} not found in locals, globals, nor builtins")
657 733
658 734
659 735 def _eval_or_create_duck(duck_type, node: ast.Call, context: EvaluationContext):
660 736 policy = EVALUATION_POLICIES[context.evaluation]
661 737 # if allow-listed builtin is on type annotation, instantiate it
662 738 if policy.can_call(duck_type) and not node.keywords:
663 739 args = [eval_node(arg, context) for arg in node.args]
664 740 return duck_type(*args)
665 741 # if custom class is in type annotation, mock it
666 742 return _create_duck_for_heap_type(duck_type)
667 743
668 744
669 745 def _create_duck_for_heap_type(duck_type):
670 746 """Create an imitation of an object of a given type (a duck).
671 747
672 748 Returns the duck or NOT_EVALUATED sentinel if duck could not be created.
673 749 """
674 duck = Duck()
750 duck = ImpersonatingDuck()
675 751 try:
676 752 # this only works for heap types, not builtins
677 753 duck.__class__ = duck_type
678 754 return duck
679 755 except TypeError:
680 756 pass
681 757 return NOT_EVALUATED
682 758
683 759
684 760 SUPPORTED_EXTERNAL_GETITEM = {
685 761 ("pandas", "core", "indexing", "_iLocIndexer"),
686 762 ("pandas", "core", "indexing", "_LocIndexer"),
687 763 ("pandas", "DataFrame"),
688 764 ("pandas", "Series"),
689 765 ("numpy", "ndarray"),
690 766 ("numpy", "void"),
691 767 }
692 768
693 769
694 770 BUILTIN_GETITEM: Set[InstancesHaveGetItem] = {
695 771 dict,
696 772 str, # type: ignore[arg-type]
697 773 bytes, # type: ignore[arg-type]
698 774 list,
699 775 tuple,
700 776 collections.defaultdict,
701 777 collections.deque,
702 778 collections.OrderedDict,
703 779 collections.ChainMap,
704 780 collections.UserDict,
705 781 collections.UserList,
706 782 collections.UserString, # type: ignore[arg-type]
707 783 _DummyNamedTuple,
708 784 _IdentitySubscript,
709 785 }
710 786
711 787
712 788 def _list_methods(cls, source=None):
713 789 """For use on immutable objects or with methods returning a copy"""
714 790 return [getattr(cls, k) for k in (source if source else dir(cls))]
715 791
716 792
717 793 dict_non_mutating_methods = ("copy", "keys", "values", "items")
718 794 list_non_mutating_methods = ("copy", "index", "count")
719 795 set_non_mutating_methods = set(dir(set)) & set(dir(frozenset))
720 796
721 797
722 798 dict_keys: Type[collections.abc.KeysView] = type({}.keys())
723 799
724 800 NUMERICS = {int, float, complex}
725 801
726 802 ALLOWED_CALLS = {
727 803 bytes,
728 804 *_list_methods(bytes),
729 805 dict,
730 806 *_list_methods(dict, dict_non_mutating_methods),
731 807 dict_keys.isdisjoint,
732 808 list,
733 809 *_list_methods(list, list_non_mutating_methods),
734 810 set,
735 811 *_list_methods(set, set_non_mutating_methods),
736 812 frozenset,
737 813 *_list_methods(frozenset),
738 814 range,
739 815 str,
740 816 *_list_methods(str),
741 817 tuple,
742 818 *_list_methods(tuple),
743 819 *NUMERICS,
744 820 *[method for numeric_cls in NUMERICS for method in _list_methods(numeric_cls)],
745 821 collections.deque,
746 822 *_list_methods(collections.deque, list_non_mutating_methods),
747 823 collections.defaultdict,
748 824 *_list_methods(collections.defaultdict, dict_non_mutating_methods),
749 825 collections.OrderedDict,
750 826 *_list_methods(collections.OrderedDict, dict_non_mutating_methods),
751 827 collections.UserDict,
752 828 *_list_methods(collections.UserDict, dict_non_mutating_methods),
753 829 collections.UserList,
754 830 *_list_methods(collections.UserList, list_non_mutating_methods),
755 831 collections.UserString,
756 832 *_list_methods(collections.UserString, dir(str)),
757 833 collections.Counter,
758 834 *_list_methods(collections.Counter, dict_non_mutating_methods),
759 835 collections.Counter.elements,
760 836 collections.Counter.most_common,
761 837 }
762 838
763 839 BUILTIN_GETATTR: Set[MayHaveGetattr] = {
764 840 *BUILTIN_GETITEM,
765 841 set,
766 842 frozenset,
767 843 object,
768 844 type, # `type` handles a lot of generic cases, e.g. numbers as in `int.real`.
769 845 *NUMERICS,
770 846 dict_keys,
771 847 MethodDescriptorType,
772 848 ModuleType,
773 849 }
774 850
775 851
776 852 BUILTIN_OPERATIONS = {*BUILTIN_GETATTR}
777 853
778 854 EVALUATION_POLICIES = {
779 855 "minimal": EvaluationPolicy(
780 856 allow_builtins_access=True,
781 857 allow_locals_access=False,
782 858 allow_globals_access=False,
783 859 allow_item_access=False,
784 860 allow_attr_access=False,
785 861 allowed_calls=set(),
786 862 allow_any_calls=False,
787 863 allow_all_operations=False,
788 864 ),
789 865 "limited": SelectivePolicy(
790 866 allowed_getitem=BUILTIN_GETITEM,
791 867 allowed_getitem_external=SUPPORTED_EXTERNAL_GETITEM,
792 868 allowed_getattr=BUILTIN_GETATTR,
793 869 allowed_getattr_external={
794 870 # pandas Series/Frame implements custom `__getattr__`
795 871 ("pandas", "DataFrame"),
796 872 ("pandas", "Series"),
797 873 },
798 874 allowed_operations=BUILTIN_OPERATIONS,
799 875 allow_builtins_access=True,
800 876 allow_locals_access=True,
801 877 allow_globals_access=True,
802 878 allowed_calls=ALLOWED_CALLS,
803 879 ),
804 880 "unsafe": EvaluationPolicy(
805 881 allow_builtins_access=True,
806 882 allow_locals_access=True,
807 883 allow_globals_access=True,
808 884 allow_attr_access=True,
809 885 allow_item_access=True,
810 886 allow_any_calls=True,
811 887 allow_all_operations=True,
812 888 ),
813 889 }
814 890
815 891
816 892 __all__ = [
817 893 "guarded_eval",
818 894 "eval_node",
819 895 "GuardRejection",
820 896 "EvaluationContext",
821 897 "_unbind_method",
822 898 ]
@@ -1,689 +1,785 b''
1 1 import sys
2 2 from contextlib import contextmanager
3 from typing import NamedTuple, Literal, NewType
3 from typing import (
4 Annotated,
5 AnyStr,
6 NamedTuple,
7 Literal,
8 NewType,
9 Optional,
10 Protocol,
11 TypeGuard,
12 Union,
13 TypedDict,
14 )
4 15 from functools import partial
5 16 from IPython.core.guarded_eval import (
6 17 EvaluationContext,
7 18 GuardRejection,
8 19 guarded_eval,
9 20 _unbind_method,
10 21 )
11 22 from IPython.testing import decorators as dec
12 23 import pytest
13 24
14 25
15 26 if sys.version_info < (3, 11):
16 from typing_extensions import Self
27 from typing_extensions import Self, LiteralString
17 28 else:
18 from typing import Self
29 from typing import Self, LiteralString
19 30
20 31 if sys.version_info < (3, 12):
21 32 from typing_extensions import TypeAliasType
22 33 else:
23 34 from typing import TypeAliasType
24 35
25 36
26 37 def create_context(evaluation: str, **kwargs):
27 38 return EvaluationContext(locals=kwargs, globals={}, evaluation=evaluation)
28 39
29 40
30 41 forbidden = partial(create_context, "forbidden")
31 42 minimal = partial(create_context, "minimal")
32 43 limited = partial(create_context, "limited")
33 44 unsafe = partial(create_context, "unsafe")
34 45 dangerous = partial(create_context, "dangerous")
35 46
36 47 LIMITED_OR_HIGHER = [limited, unsafe, dangerous]
37 48 MINIMAL_OR_HIGHER = [minimal, *LIMITED_OR_HIGHER]
38 49
39 50
40 51 @contextmanager
41 52 def module_not_installed(module: str):
42 53 import sys
43 54
44 55 try:
45 56 to_restore = sys.modules[module]
46 57 del sys.modules[module]
47 58 except KeyError:
48 59 to_restore = None
49 60 try:
50 61 yield
51 62 finally:
52 63 sys.modules[module] = to_restore
53 64
54 65
55 66 def test_external_not_installed():
56 67 """
57 68 Because attribute check requires checking if object is not of allowed
58 69 external type, this tests logic for absence of external module.
59 70 """
60 71
61 72 class Custom:
62 73 def __init__(self):
63 74 self.test = 1
64 75
65 76 def __getattr__(self, key):
66 77 return key
67 78
68 79 with module_not_installed("pandas"):
69 80 context = limited(x=Custom())
70 81 with pytest.raises(GuardRejection):
71 82 guarded_eval("x.test", context)
72 83
73 84
74 85 @dec.skip_without("pandas")
75 86 def test_external_changed_api(monkeypatch):
76 87 """Check that the execution rejects if external API changed paths"""
77 88 import pandas as pd
78 89
79 90 series = pd.Series([1], index=["a"])
80 91
81 92 with monkeypatch.context() as m:
82 93 m.delattr(pd, "Series")
83 94 context = limited(data=series)
84 95 with pytest.raises(GuardRejection):
85 96 guarded_eval("data.iloc[0]", context)
86 97
87 98
88 99 @dec.skip_without("pandas")
89 100 def test_pandas_series_iloc():
90 101 import pandas as pd
91 102
92 103 series = pd.Series([1], index=["a"])
93 104 context = limited(data=series)
94 105 assert guarded_eval("data.iloc[0]", context) == 1
95 106
96 107
97 108 def test_rejects_custom_properties():
98 109 class BadProperty:
99 110 @property
100 111 def iloc(self):
101 112 return [None]
102 113
103 114 series = BadProperty()
104 115 context = limited(data=series)
105 116
106 117 with pytest.raises(GuardRejection):
107 118 guarded_eval("data.iloc[0]", context)
108 119
109 120
110 121 @dec.skip_without("pandas")
111 122 def test_accepts_non_overriden_properties():
112 123 import pandas as pd
113 124
114 125 class GoodProperty(pd.Series):
115 126 pass
116 127
117 128 series = GoodProperty([1], index=["a"])
118 129 context = limited(data=series)
119 130
120 131 assert guarded_eval("data.iloc[0]", context) == 1
121 132
122 133
123 134 @dec.skip_without("pandas")
124 135 def test_pandas_series():
125 136 import pandas as pd
126 137
127 138 context = limited(data=pd.Series([1], index=["a"]))
128 139 assert guarded_eval('data["a"]', context) == 1
129 140 with pytest.raises(KeyError):
130 141 guarded_eval('data["c"]', context)
131 142
132 143
133 144 @dec.skip_without("pandas")
134 145 def test_pandas_bad_series():
135 146 import pandas as pd
136 147
137 148 class BadItemSeries(pd.Series):
138 149 def __getitem__(self, key):
139 150 return "CUSTOM_ITEM"
140 151
141 152 class BadAttrSeries(pd.Series):
142 153 def __getattr__(self, key):
143 154 return "CUSTOM_ATTR"
144 155
145 156 bad_series = BadItemSeries([1], index=["a"])
146 157 context = limited(data=bad_series)
147 158
148 159 with pytest.raises(GuardRejection):
149 160 guarded_eval('data["a"]', context)
150 161 with pytest.raises(GuardRejection):
151 162 guarded_eval('data["c"]', context)
152 163
153 164 # note: here result is a bit unexpected because
154 165 # pandas `__getattr__` calls `__getitem__`;
155 166 # FIXME - special case to handle it?
156 167 assert guarded_eval("data.a", context) == "CUSTOM_ITEM"
157 168
158 169 context = unsafe(data=bad_series)
159 170 assert guarded_eval('data["a"]', context) == "CUSTOM_ITEM"
160 171
161 172 bad_attr_series = BadAttrSeries([1], index=["a"])
162 173 context = limited(data=bad_attr_series)
163 174 assert guarded_eval('data["a"]', context) == 1
164 175 with pytest.raises(GuardRejection):
165 176 guarded_eval("data.a", context)
166 177
167 178
168 179 @dec.skip_without("pandas")
169 180 def test_pandas_dataframe_loc():
170 181 import pandas as pd
171 182 from pandas.testing import assert_series_equal
172 183
173 184 data = pd.DataFrame([{"a": 1}])
174 185 context = limited(data=data)
175 186 assert_series_equal(guarded_eval('data.loc[:, "a"]', context), data["a"])
176 187
177 188
178 189 def test_named_tuple():
179 190 class GoodNamedTuple(NamedTuple):
180 191 a: str
181 192 pass
182 193
183 194 class BadNamedTuple(NamedTuple):
184 195 a: str
185 196
186 197 def __getitem__(self, key):
187 198 return None
188 199
189 200 good = GoodNamedTuple(a="x")
190 201 bad = BadNamedTuple(a="x")
191 202
192 203 context = limited(data=good)
193 204 assert guarded_eval("data[0]", context) == "x"
194 205
195 206 context = limited(data=bad)
196 207 with pytest.raises(GuardRejection):
197 208 guarded_eval("data[0]", context)
198 209
199 210
200 211 def test_dict():
201 212 context = limited(data={"a": 1, "b": {"x": 2}, ("x", "y"): 3})
202 213 assert guarded_eval('data["a"]', context) == 1
203 214 assert guarded_eval('data["b"]', context) == {"x": 2}
204 215 assert guarded_eval('data["b"]["x"]', context) == 2
205 216 assert guarded_eval('data["x", "y"]', context) == 3
206 217
207 218 assert guarded_eval("data.keys", context)
208 219
209 220
210 221 def test_set():
211 222 context = limited(data={"a", "b"})
212 223 assert guarded_eval("data.difference", context)
213 224
214 225
215 226 def test_list():
216 227 context = limited(data=[1, 2, 3])
217 228 assert guarded_eval("data[1]", context) == 2
218 229 assert guarded_eval("data.copy", context)
219 230
220 231
221 232 def test_dict_literal():
222 233 context = limited()
223 234 assert guarded_eval("{}", context) == {}
224 235 assert guarded_eval('{"a": 1}', context) == {"a": 1}
225 236
226 237
227 238 def test_list_literal():
228 239 context = limited()
229 240 assert guarded_eval("[]", context) == []
230 241 assert guarded_eval('[1, "a"]', context) == [1, "a"]
231 242
232 243
233 244 def test_set_literal():
234 245 context = limited()
235 246 assert guarded_eval("set()", context) == set()
236 247 assert guarded_eval('{"a"}', context) == {"a"}
237 248
238 249
239 250 def test_evaluates_if_expression():
240 251 context = limited()
241 252 assert guarded_eval("2 if True else 3", context) == 2
242 253 assert guarded_eval("4 if False else 5", context) == 5
243 254
244 255
245 256 def test_object():
246 257 obj = object()
247 258 context = limited(obj=obj)
248 259 assert guarded_eval("obj.__dir__", context) == obj.__dir__
249 260
250 261
251 262 @pytest.mark.parametrize(
252 263 "code,expected",
253 264 [
254 265 ["int.numerator", int.numerator],
255 266 ["float.is_integer", float.is_integer],
256 267 ["complex.real", complex.real],
257 268 ],
258 269 )
259 270 def test_number_attributes(code, expected):
260 271 assert guarded_eval(code, limited()) == expected
261 272
262 273
263 274 def test_method_descriptor():
264 275 context = limited()
265 276 assert guarded_eval("list.copy.__name__", context) == "copy"
266 277
267 278
268 279 class HeapType:
269 280 pass
270 281
271 282
272 283 class CallCreatesHeapType:
273 284 def __call__(self) -> HeapType:
274 285 return HeapType()
275 286
276 287
277 288 class CallCreatesBuiltin:
278 289 def __call__(self) -> frozenset:
279 290 return frozenset()
280 291
281 292
282 293 class HasStaticMethod:
283 294 @staticmethod
284 295 def static_method() -> HeapType:
285 296 return HeapType()
286 297
287 298
288 299 class InitReturnsFrozenset:
289 300 def __new__(self) -> frozenset: # type:ignore[misc]
290 301 return frozenset()
291 302
292 303
293 304 class StringAnnotation:
294 305 def heap(self) -> "HeapType":
295 306 return HeapType()
296 307
297 308 def copy(self) -> "StringAnnotation":
298 309 return StringAnnotation()
299 310
300 311
301 312 CustomIntType = NewType("CustomIntType", int)
302 313 CustomHeapType = NewType("CustomHeapType", HeapType)
303 314 IntTypeAlias = TypeAliasType("IntTypeAlias", int)
304 315 HeapTypeAlias = TypeAliasType("HeapTypeAlias", HeapType)
305 316
306 317
318 class TestProtocol(Protocol):
319 def test_method(self) -> bool:
320 pass
321
322
323 class TestProtocolImplementer(TestProtocol):
324 def test_method(self) -> bool:
325 return True
326
327
328 class Movie(TypedDict):
329 name: str
330 year: int
331
332
307 333 class SpecialTyping:
308 334 def custom_int_type(self) -> CustomIntType:
309 335 return CustomIntType(1)
310 336
311 337 def custom_heap_type(self) -> CustomHeapType:
312 338 return CustomHeapType(HeapType())
313 339
314 340 # TODO: remove type:ignore comment once mypy
315 341 # supports explicit calls to `TypeAliasType`, see:
316 342 # https://github.com/python/mypy/issues/16614
317 343 def int_type_alias(self) -> IntTypeAlias: # type:ignore[valid-type]
318 344 return 1
319 345
320 346 def heap_type_alias(self) -> HeapTypeAlias: # type:ignore[valid-type]
321 347 return 1
322 348
323 349 def literal(self) -> Literal[False]:
324 350 return False
325 351
352 def literal_string(self) -> LiteralString:
353 return "test"
354
326 355 def self(self) -> Self:
327 356 return self
328 357
358 def any_str(self, x: AnyStr) -> AnyStr:
359 return x
360
361 def annotated(self) -> Annotated[float, "positive number"]:
362 return 1
363
364 def annotated_self(self) -> Annotated[Self, "self with metadata"]:
365 self._metadata = "test"
366 return self
367
368 def int_type_guard(self, x) -> TypeGuard[int]:
369 return isinstance(x, int)
370
371 def optional_float(self) -> Optional[float]:
372 return 1.0
373
374 def union_str_and_int(self) -> Union[str, int]:
375 return ""
376
377 def protocol(self) -> TestProtocol:
378 return TestProtocolImplementer()
379
380 def typed_dict(self) -> Movie:
381 return {"name": "The Matrix", "year": 1999}
382
329 383
330 384 @pytest.mark.parametrize(
331 "data,good,expected,equality",
385 "data,code,expected,equality",
332 386 [
333 387 [[1, 2, 3], "data.index(2)", 1, True],
334 388 [{"a": 1}, "data.keys().isdisjoint({})", True, True],
335 389 [StringAnnotation(), "data.heap()", HeapType, False],
336 390 [StringAnnotation(), "data.copy()", StringAnnotation, False],
337 391 # test cases for `__call__`
338 392 [CallCreatesHeapType(), "data()", HeapType, False],
339 393 [CallCreatesBuiltin(), "data()", frozenset, False],
340 394 # Test cases for `__init__`
341 395 [HeapType, "data()", HeapType, False],
342 396 [InitReturnsFrozenset, "data()", frozenset, False],
343 397 [HeapType(), "data.__class__()", HeapType, False],
344 398 # supported special cases for typing
345 399 [SpecialTyping(), "data.custom_int_type()", int, False],
346 400 [SpecialTyping(), "data.custom_heap_type()", HeapType, False],
347 401 [SpecialTyping(), "data.int_type_alias()", int, False],
348 402 [SpecialTyping(), "data.heap_type_alias()", HeapType, False],
349 403 [SpecialTyping(), "data.self()", SpecialTyping, False],
350 404 [SpecialTyping(), "data.literal()", False, True],
405 [SpecialTyping(), "data.literal_string()", str, False],
406 [SpecialTyping(), "data.any_str('a')", str, False],
407 [SpecialTyping(), "data.any_str(b'a')", bytes, False],
408 [SpecialTyping(), "data.annotated()", float, False],
409 [SpecialTyping(), "data.annotated_self()", SpecialTyping, False],
410 [SpecialTyping(), "data.int_type_guard()", int, False],
351 411 # test cases for static methods
352 412 [HasStaticMethod, "data.static_method()", HeapType, False],
353 413 ],
354 414 )
355 def test_evaluates_calls(data, good, expected, equality):
415 def test_evaluates_calls(data, code, expected, equality):
356 416 context = limited(data=data, HeapType=HeapType, StringAnnotation=StringAnnotation)
357 value = guarded_eval(good, context)
417 value = guarded_eval(code, context)
358 418 if equality:
359 419 assert value == expected
360 420 else:
361 421 assert isinstance(value, expected)
362 422
363 423
364 424 @pytest.mark.parametrize(
425 "data,code,expected_attributes",
426 [
427 [SpecialTyping(), "data.optional_float()", ["is_integer"]],
428 [
429 SpecialTyping(),
430 "data.union_str_and_int()",
431 ["capitalize", "as_integer_ratio"],
432 ],
433 [SpecialTyping(), "data.protocol()", ["test_method"]],
434 [SpecialTyping(), "data.typed_dict()", ["keys", "values", "items"]],
435 ],
436 )
437 def test_mocks_attributes_of_call_results(data, code, expected_attributes):
438 context = limited(data=data, HeapType=HeapType, StringAnnotation=StringAnnotation)
439 result = guarded_eval(code, context)
440 for attr in expected_attributes:
441 assert hasattr(result, attr)
442 assert attr in dir(result)
443
444
445 @pytest.mark.parametrize(
446 "data,code,expected_items",
447 [
448 [SpecialTyping(), "data.typed_dict()", {"year": int, "name": str}],
449 ],
450 )
451 def test_mocks_items_of_call_results(data, code, expected_items):
452 context = limited(data=data, HeapType=HeapType, StringAnnotation=StringAnnotation)
453 result = guarded_eval(code, context)
454 ipython_keys = result._ipython_key_completions_()
455 for key, value in expected_items.items():
456 assert isinstance(result[key], value)
457 assert key in ipython_keys
458
459
460 @pytest.mark.parametrize(
365 461 "data,bad",
366 462 [
367 463 [[1, 2, 3], "data.append(4)"],
368 464 [{"a": 1}, "data.update()"],
369 465 ],
370 466 )
371 467 def test_rejects_calls_with_side_effects(data, bad):
372 468 context = limited(data=data)
373 469
374 470 with pytest.raises(GuardRejection):
375 471 guarded_eval(bad, context)
376 472
377 473
378 474 @pytest.mark.parametrize(
379 475 "code,expected",
380 476 [
381 477 ["(1\n+\n1)", 2],
382 478 ["list(range(10))[-1:]", [9]],
383 479 ["list(range(20))[3:-2:3]", [3, 6, 9, 12, 15]],
384 480 ],
385 481 )
386 482 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
387 483 def test_evaluates_complex_cases(code, expected, context):
388 484 assert guarded_eval(code, context()) == expected
389 485
390 486
391 487 @pytest.mark.parametrize(
392 488 "code,expected",
393 489 [
394 490 ["1", 1],
395 491 ["1.0", 1.0],
396 492 ["0xdeedbeef", 0xDEEDBEEF],
397 493 ["True", True],
398 494 ["None", None],
399 495 ["{}", {}],
400 496 ["[]", []],
401 497 ],
402 498 )
403 499 @pytest.mark.parametrize("context", MINIMAL_OR_HIGHER)
404 500 def test_evaluates_literals(code, expected, context):
405 501 assert guarded_eval(code, context()) == expected
406 502
407 503
408 504 @pytest.mark.parametrize(
409 505 "code,expected",
410 506 [
411 507 ["-5", -5],
412 508 ["+5", +5],
413 509 ["~5", -6],
414 510 ],
415 511 )
416 512 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
417 513 def test_evaluates_unary_operations(code, expected, context):
418 514 assert guarded_eval(code, context()) == expected
419 515
420 516
421 517 @pytest.mark.parametrize(
422 518 "code,expected",
423 519 [
424 520 ["1 + 1", 2],
425 521 ["3 - 1", 2],
426 522 ["2 * 3", 6],
427 523 ["5 // 2", 2],
428 524 ["5 / 2", 2.5],
429 525 ["5**2", 25],
430 526 ["2 >> 1", 1],
431 527 ["2 << 1", 4],
432 528 ["1 | 2", 3],
433 529 ["1 & 1", 1],
434 530 ["1 & 2", 0],
435 531 ],
436 532 )
437 533 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
438 534 def test_evaluates_binary_operations(code, expected, context):
439 535 assert guarded_eval(code, context()) == expected
440 536
441 537
442 538 @pytest.mark.parametrize(
443 539 "code,expected",
444 540 [
445 541 ["2 > 1", True],
446 542 ["2 < 1", False],
447 543 ["2 <= 1", False],
448 544 ["2 <= 2", True],
449 545 ["1 >= 2", False],
450 546 ["2 >= 2", True],
451 547 ["2 == 2", True],
452 548 ["1 == 2", False],
453 549 ["1 != 2", True],
454 550 ["1 != 1", False],
455 551 ["1 < 4 < 3", False],
456 552 ["(1 < 4) < 3", True],
457 553 ["4 > 3 > 2 > 1", True],
458 554 ["4 > 3 > 2 > 9", False],
459 555 ["1 < 2 < 3 < 4", True],
460 556 ["9 < 2 < 3 < 4", False],
461 557 ["1 < 2 > 1 > 0 > -1 < 1", True],
462 558 ["1 in [1] in [[1]]", True],
463 559 ["1 in [1] in [[2]]", False],
464 560 ["1 in [1]", True],
465 561 ["0 in [1]", False],
466 562 ["1 not in [1]", False],
467 563 ["0 not in [1]", True],
468 564 ["True is True", True],
469 565 ["False is False", True],
470 566 ["True is False", False],
471 567 ["True is not True", False],
472 568 ["False is not True", True],
473 569 ],
474 570 )
475 571 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
476 572 def test_evaluates_comparisons(code, expected, context):
477 573 assert guarded_eval(code, context()) == expected
478 574
479 575
480 576 def test_guards_comparisons():
481 577 class GoodEq(int):
482 578 pass
483 579
484 580 class BadEq(int):
485 581 def __eq__(self, other):
486 582 assert False
487 583
488 584 context = limited(bad=BadEq(1), good=GoodEq(1))
489 585
490 586 with pytest.raises(GuardRejection):
491 587 guarded_eval("bad == 1", context)
492 588
493 589 with pytest.raises(GuardRejection):
494 590 guarded_eval("bad != 1", context)
495 591
496 592 with pytest.raises(GuardRejection):
497 593 guarded_eval("1 == bad", context)
498 594
499 595 with pytest.raises(GuardRejection):
500 596 guarded_eval("1 != bad", context)
501 597
502 598 assert guarded_eval("good == 1", context) is True
503 599 assert guarded_eval("good != 1", context) is False
504 600 assert guarded_eval("1 == good", context) is True
505 601 assert guarded_eval("1 != good", context) is False
506 602
507 603
508 604 def test_guards_unary_operations():
509 605 class GoodOp(int):
510 606 pass
511 607
512 608 class BadOpInv(int):
513 609 def __inv__(self, other):
514 610 assert False
515 611
516 612 class BadOpInverse(int):
517 613 def __inv__(self, other):
518 614 assert False
519 615
520 616 context = limited(good=GoodOp(1), bad1=BadOpInv(1), bad2=BadOpInverse(1))
521 617
522 618 with pytest.raises(GuardRejection):
523 619 guarded_eval("~bad1", context)
524 620
525 621 with pytest.raises(GuardRejection):
526 622 guarded_eval("~bad2", context)
527 623
528 624
529 625 def test_guards_binary_operations():
530 626 class GoodOp(int):
531 627 pass
532 628
533 629 class BadOp(int):
534 630 def __add__(self, other):
535 631 assert False
536 632
537 633 context = limited(good=GoodOp(1), bad=BadOp(1))
538 634
539 635 with pytest.raises(GuardRejection):
540 636 guarded_eval("1 + bad", context)
541 637
542 638 with pytest.raises(GuardRejection):
543 639 guarded_eval("bad + 1", context)
544 640
545 641 assert guarded_eval("good + 1", context) == 2
546 642 assert guarded_eval("1 + good", context) == 2
547 643
548 644
549 645 def test_guards_attributes():
550 646 class GoodAttr(float):
551 647 pass
552 648
553 649 class BadAttr1(float):
554 650 def __getattr__(self, key):
555 651 assert False
556 652
557 653 class BadAttr2(float):
558 654 def __getattribute__(self, key):
559 655 assert False
560 656
561 657 context = limited(good=GoodAttr(0.5), bad1=BadAttr1(0.5), bad2=BadAttr2(0.5))
562 658
563 659 with pytest.raises(GuardRejection):
564 660 guarded_eval("bad1.as_integer_ratio", context)
565 661
566 662 with pytest.raises(GuardRejection):
567 663 guarded_eval("bad2.as_integer_ratio", context)
568 664
569 665 assert guarded_eval("good.as_integer_ratio()", context) == (1, 2)
570 666
571 667
572 668 @pytest.mark.parametrize("context", MINIMAL_OR_HIGHER)
573 669 def test_access_builtins(context):
574 670 assert guarded_eval("round", context()) == round
575 671
576 672
577 673 def test_access_builtins_fails():
578 674 context = limited()
579 675 with pytest.raises(NameError):
580 676 guarded_eval("this_is_not_builtin", context)
581 677
582 678
583 679 def test_rejects_forbidden():
584 680 context = forbidden()
585 681 with pytest.raises(GuardRejection):
586 682 guarded_eval("1", context)
587 683
588 684
589 685 def test_guards_locals_and_globals():
590 686 context = EvaluationContext(
591 687 locals={"local_a": "a"}, globals={"global_b": "b"}, evaluation="minimal"
592 688 )
593 689
594 690 with pytest.raises(GuardRejection):
595 691 guarded_eval("local_a", context)
596 692
597 693 with pytest.raises(GuardRejection):
598 694 guarded_eval("global_b", context)
599 695
600 696
601 697 def test_access_locals_and_globals():
602 698 context = EvaluationContext(
603 699 locals={"local_a": "a"}, globals={"global_b": "b"}, evaluation="limited"
604 700 )
605 701 assert guarded_eval("local_a", context) == "a"
606 702 assert guarded_eval("global_b", context) == "b"
607 703
608 704
609 705 @pytest.mark.parametrize(
610 706 "code",
611 707 ["def func(): pass", "class C: pass", "x = 1", "x += 1", "del x", "import ast"],
612 708 )
613 709 @pytest.mark.parametrize("context", [minimal(), limited(), unsafe()])
614 710 def test_rejects_side_effect_syntax(code, context):
615 711 with pytest.raises(SyntaxError):
616 712 guarded_eval(code, context)
617 713
618 714
619 715 def test_subscript():
620 716 context = EvaluationContext(
621 717 locals={}, globals={}, evaluation="limited", in_subscript=True
622 718 )
623 719 empty_slice = slice(None, None, None)
624 720 assert guarded_eval("", context) == tuple()
625 721 assert guarded_eval(":", context) == empty_slice
626 722 assert guarded_eval("1:2:3", context) == slice(1, 2, 3)
627 723 assert guarded_eval(':, "a"', context) == (empty_slice, "a")
628 724
629 725
630 726 def test_unbind_method():
631 727 class X(list):
632 728 def index(self, k):
633 729 return "CUSTOM"
634 730
635 731 x = X()
636 732 assert _unbind_method(x.index) is X.index
637 733 assert _unbind_method([].index) is list.index
638 734 assert _unbind_method(list.index) is None
639 735
640 736
641 737 def test_assumption_instance_attr_do_not_matter():
642 738 """This is semi-specified in Python documentation.
643 739
644 740 However, since the specification says 'not guaranteed
645 741 to work' rather than 'is forbidden to work', future
646 742 versions could invalidate this assumptions. This test
647 743 is meant to catch such a change if it ever comes true.
648 744 """
649 745
650 746 class T:
651 747 def __getitem__(self, k):
652 748 return "a"
653 749
654 750 def __getattr__(self, k):
655 751 return "a"
656 752
657 753 def f(self):
658 754 return "b"
659 755
660 756 t = T()
661 757 t.__getitem__ = f
662 758 t.__getattr__ = f
663 759 assert t[1] == "a"
664 760 assert t[1] == "a"
665 761
666 762
667 763 def test_assumption_named_tuples_share_getitem():
668 764 """Check assumption on named tuples sharing __getitem__"""
669 765 from typing import NamedTuple
670 766
671 767 class A(NamedTuple):
672 768 pass
673 769
674 770 class B(NamedTuple):
675 771 pass
676 772
677 773 assert A.__getitem__ == B.__getitem__
678 774
679 775
680 776 @dec.skip_without("numpy")
681 777 def test_module_access():
682 778 import numpy
683 779
684 780 context = limited(numpy=numpy)
685 781 assert guarded_eval("numpy.linalg.norm", context) == numpy.linalg.norm
686 782
687 783 context = minimal(numpy=numpy)
688 784 with pytest.raises(GuardRejection):
689 785 guarded_eval("np.linalg.norm", context)
General Comments 0
You need to be logged in to leave comments. Login now