##// END OF EJS Templates
Add version guards
krassowski -
Show More
@@ -1,816 +1,822 b''
1 1 from inspect import isclass, signature, Signature
2 2 from typing import (
3 3 Callable,
4 4 Dict,
5 5 Literal,
6 6 NamedTuple,
7 7 NewType,
8 8 Set,
9 9 Sequence,
10 10 Tuple,
11 11 Type,
12 12 Protocol,
13 13 Union,
14 14 get_args,
15 15 get_origin,
16 16 )
17 from typing_extensions import (
18 Self, # Python >=3.10
19 TypeAliasType, # Python >=3.12
20 )
21 17 import ast
22 18 import builtins
23 19 import collections
24 20 import operator
25 21 import sys
26 22 from functools import cached_property
27 23 from dataclasses import dataclass, field
28 24 from types import MethodDescriptorType, ModuleType
29 25
30
31 26 from IPython.utils.decorators import undoc
32 27
33 28
29 if sys.version_info < (3, 11):
30 from typing_extensions import Self
31 else:
32 from typing import Self
33
34 if sys.version_info < (3, 12):
35 from typing_extensions import TypeAliasType
36 else:
37 from typing import TypeAliasType
38
39
34 40 @undoc
35 41 class HasGetItem(Protocol):
36 42 def __getitem__(self, key) -> None:
37 43 ...
38 44
39 45
40 46 @undoc
41 47 class InstancesHaveGetItem(Protocol):
42 48 def __call__(self, *args, **kwargs) -> HasGetItem:
43 49 ...
44 50
45 51
46 52 @undoc
47 53 class HasGetAttr(Protocol):
48 54 def __getattr__(self, key) -> None:
49 55 ...
50 56
51 57
52 58 @undoc
53 59 class DoesNotHaveGetAttr(Protocol):
54 60 pass
55 61
56 62
57 63 # By default `__getattr__` is not explicitly implemented on most objects
58 64 MayHaveGetattr = Union[HasGetAttr, DoesNotHaveGetAttr]
59 65
60 66
61 67 def _unbind_method(func: Callable) -> Union[Callable, None]:
62 68 """Get unbound method for given bound method.
63 69
64 70 Returns None if cannot get unbound method, or method is already unbound.
65 71 """
66 72 owner = getattr(func, "__self__", None)
67 73 owner_class = type(owner)
68 74 name = getattr(func, "__name__", None)
69 75 instance_dict_overrides = getattr(owner, "__dict__", None)
70 76 if (
71 77 owner is not None
72 78 and name
73 79 and (
74 80 not instance_dict_overrides
75 81 or (instance_dict_overrides and name not in instance_dict_overrides)
76 82 )
77 83 ):
78 84 return getattr(owner_class, name)
79 85 return None
80 86
81 87
82 88 @undoc
83 89 @dataclass
84 90 class EvaluationPolicy:
85 91 """Definition of evaluation policy."""
86 92
87 93 allow_locals_access: bool = False
88 94 allow_globals_access: bool = False
89 95 allow_item_access: bool = False
90 96 allow_attr_access: bool = False
91 97 allow_builtins_access: bool = False
92 98 allow_all_operations: bool = False
93 99 allow_any_calls: bool = False
94 100 allowed_calls: Set[Callable] = field(default_factory=set)
95 101
96 102 def can_get_item(self, value, item):
97 103 return self.allow_item_access
98 104
99 105 def can_get_attr(self, value, attr):
100 106 return self.allow_attr_access
101 107
102 108 def can_operate(self, dunders: Tuple[str, ...], a, b=None):
103 109 if self.allow_all_operations:
104 110 return True
105 111
106 112 def can_call(self, func):
107 113 if self.allow_any_calls:
108 114 return True
109 115
110 116 if func in self.allowed_calls:
111 117 return True
112 118
113 119 owner_method = _unbind_method(func)
114 120
115 121 if owner_method and owner_method in self.allowed_calls:
116 122 return True
117 123
118 124
119 125 def _get_external(module_name: str, access_path: Sequence[str]):
120 126 """Get value from external module given a dotted access path.
121 127
122 128 Raises:
123 129 * `KeyError` if module is removed not found, and
124 130 * `AttributeError` if acess path does not match an exported object
125 131 """
126 132 member_type = sys.modules[module_name]
127 133 for attr in access_path:
128 134 member_type = getattr(member_type, attr)
129 135 return member_type
130 136
131 137
132 138 def _has_original_dunder_external(
133 139 value,
134 140 module_name: str,
135 141 access_path: Sequence[str],
136 142 method_name: str,
137 143 ):
138 144 if module_name not in sys.modules:
139 145 # LBYLB as it is faster
140 146 return False
141 147 try:
142 148 member_type = _get_external(module_name, access_path)
143 149 value_type = type(value)
144 150 if type(value) == member_type:
145 151 return True
146 152 if method_name == "__getattribute__":
147 153 # we have to short-circuit here due to an unresolved issue in
148 154 # `isinstance` implementation: https://bugs.python.org/issue32683
149 155 return False
150 156 if isinstance(value, member_type):
151 157 method = getattr(value_type, method_name, None)
152 158 member_method = getattr(member_type, method_name, None)
153 159 if member_method == method:
154 160 return True
155 161 except (AttributeError, KeyError):
156 162 return False
157 163
158 164
159 165 def _has_original_dunder(
160 166 value, allowed_types, allowed_methods, allowed_external, method_name
161 167 ):
162 168 # note: Python ignores `__getattr__`/`__getitem__` on instances,
163 169 # we only need to check at class level
164 170 value_type = type(value)
165 171
166 172 # strict type check passes β†’ no need to check method
167 173 if value_type in allowed_types:
168 174 return True
169 175
170 176 method = getattr(value_type, method_name, None)
171 177
172 178 if method is None:
173 179 return None
174 180
175 181 if method in allowed_methods:
176 182 return True
177 183
178 184 for module_name, *access_path in allowed_external:
179 185 if _has_original_dunder_external(value, module_name, access_path, method_name):
180 186 return True
181 187
182 188 return False
183 189
184 190
185 191 @undoc
186 192 @dataclass
187 193 class SelectivePolicy(EvaluationPolicy):
188 194 allowed_getitem: Set[InstancesHaveGetItem] = field(default_factory=set)
189 195 allowed_getitem_external: Set[Tuple[str, ...]] = field(default_factory=set)
190 196
191 197 allowed_getattr: Set[MayHaveGetattr] = field(default_factory=set)
192 198 allowed_getattr_external: Set[Tuple[str, ...]] = field(default_factory=set)
193 199
194 200 allowed_operations: Set = field(default_factory=set)
195 201 allowed_operations_external: Set[Tuple[str, ...]] = field(default_factory=set)
196 202
197 203 _operation_methods_cache: Dict[str, Set[Callable]] = field(
198 204 default_factory=dict, init=False
199 205 )
200 206
201 207 def can_get_attr(self, value, attr):
202 208 has_original_attribute = _has_original_dunder(
203 209 value,
204 210 allowed_types=self.allowed_getattr,
205 211 allowed_methods=self._getattribute_methods,
206 212 allowed_external=self.allowed_getattr_external,
207 213 method_name="__getattribute__",
208 214 )
209 215 has_original_attr = _has_original_dunder(
210 216 value,
211 217 allowed_types=self.allowed_getattr,
212 218 allowed_methods=self._getattr_methods,
213 219 allowed_external=self.allowed_getattr_external,
214 220 method_name="__getattr__",
215 221 )
216 222
217 223 accept = False
218 224
219 225 # Many objects do not have `__getattr__`, this is fine.
220 226 if has_original_attr is None and has_original_attribute:
221 227 accept = True
222 228 else:
223 229 # Accept objects without modifications to `__getattr__` and `__getattribute__`
224 230 accept = has_original_attr and has_original_attribute
225 231
226 232 if accept:
227 233 # We still need to check for overriden properties.
228 234
229 235 value_class = type(value)
230 236 if not hasattr(value_class, attr):
231 237 return True
232 238
233 239 class_attr_val = getattr(value_class, attr)
234 240 is_property = isinstance(class_attr_val, property)
235 241
236 242 if not is_property:
237 243 return True
238 244
239 245 # Properties in allowed types are ok (although we do not include any
240 246 # properties in our default allow list currently).
241 247 if type(value) in self.allowed_getattr:
242 248 return True # pragma: no cover
243 249
244 250 # Properties in subclasses of allowed types may be ok if not changed
245 251 for module_name, *access_path in self.allowed_getattr_external:
246 252 try:
247 253 external_class = _get_external(module_name, access_path)
248 254 external_class_attr_val = getattr(external_class, attr)
249 255 except (KeyError, AttributeError):
250 256 return False # pragma: no cover
251 257 return class_attr_val == external_class_attr_val
252 258
253 259 return False
254 260
255 261 def can_get_item(self, value, item):
256 262 """Allow accessing `__getiitem__` of allow-listed instances unless it was not modified."""
257 263 return _has_original_dunder(
258 264 value,
259 265 allowed_types=self.allowed_getitem,
260 266 allowed_methods=self._getitem_methods,
261 267 allowed_external=self.allowed_getitem_external,
262 268 method_name="__getitem__",
263 269 )
264 270
265 271 def can_operate(self, dunders: Tuple[str, ...], a, b=None):
266 272 objects = [a]
267 273 if b is not None:
268 274 objects.append(b)
269 275 return all(
270 276 [
271 277 _has_original_dunder(
272 278 obj,
273 279 allowed_types=self.allowed_operations,
274 280 allowed_methods=self._operator_dunder_methods(dunder),
275 281 allowed_external=self.allowed_operations_external,
276 282 method_name=dunder,
277 283 )
278 284 for dunder in dunders
279 285 for obj in objects
280 286 ]
281 287 )
282 288
283 289 def _operator_dunder_methods(self, dunder: str) -> Set[Callable]:
284 290 if dunder not in self._operation_methods_cache:
285 291 self._operation_methods_cache[dunder] = self._safe_get_methods(
286 292 self.allowed_operations, dunder
287 293 )
288 294 return self._operation_methods_cache[dunder]
289 295
290 296 @cached_property
291 297 def _getitem_methods(self) -> Set[Callable]:
292 298 return self._safe_get_methods(self.allowed_getitem, "__getitem__")
293 299
294 300 @cached_property
295 301 def _getattr_methods(self) -> Set[Callable]:
296 302 return self._safe_get_methods(self.allowed_getattr, "__getattr__")
297 303
298 304 @cached_property
299 305 def _getattribute_methods(self) -> Set[Callable]:
300 306 return self._safe_get_methods(self.allowed_getattr, "__getattribute__")
301 307
302 308 def _safe_get_methods(self, classes, name) -> Set[Callable]:
303 309 return {
304 310 method
305 311 for class_ in classes
306 312 for method in [getattr(class_, name, None)]
307 313 if method
308 314 }
309 315
310 316
311 317 class _DummyNamedTuple(NamedTuple):
312 318 """Used internally to retrieve methods of named tuple instance."""
313 319
314 320
315 321 class EvaluationContext(NamedTuple):
316 322 #: Local namespace
317 323 locals: dict
318 324 #: Global namespace
319 325 globals: dict
320 326 #: Evaluation policy identifier
321 327 evaluation: Literal[
322 328 "forbidden", "minimal", "limited", "unsafe", "dangerous"
323 329 ] = "forbidden"
324 330 #: Whether the evalution of code takes place inside of a subscript.
325 331 #: Useful for evaluating ``:-1, 'col'`` in ``df[:-1, 'col']``.
326 332 in_subscript: bool = False
327 333
328 334
329 335 class _IdentitySubscript:
330 336 """Returns the key itself when item is requested via subscript."""
331 337
332 338 def __getitem__(self, key):
333 339 return key
334 340
335 341
336 342 IDENTITY_SUBSCRIPT = _IdentitySubscript()
337 343 SUBSCRIPT_MARKER = "__SUBSCRIPT_SENTINEL__"
338 344 UNKNOWN_SIGNATURE = Signature()
339 345 NOT_EVALUATED = object()
340 346
341 347
342 348 class GuardRejection(Exception):
343 349 """Exception raised when guard rejects evaluation attempt."""
344 350
345 351 pass
346 352
347 353
348 354 def guarded_eval(code: str, context: EvaluationContext):
349 355 """Evaluate provided code in the evaluation context.
350 356
351 357 If evaluation policy given by context is set to ``forbidden``
352 358 no evaluation will be performed; if it is set to ``dangerous``
353 359 standard :func:`eval` will be used; finally, for any other,
354 360 policy :func:`eval_node` will be called on parsed AST.
355 361 """
356 362 locals_ = context.locals
357 363
358 364 if context.evaluation == "forbidden":
359 365 raise GuardRejection("Forbidden mode")
360 366
361 367 # note: not using `ast.literal_eval` as it does not implement
362 368 # getitem at all, for example it fails on simple `[0][1]`
363 369
364 370 if context.in_subscript:
365 371 # syntatic sugar for ellipsis (:) is only available in susbcripts
366 372 # so we need to trick the ast parser into thinking that we have
367 373 # a subscript, but we need to be able to later recognise that we did
368 374 # it so we can ignore the actual __getitem__ operation
369 375 if not code:
370 376 return tuple()
371 377 locals_ = locals_.copy()
372 378 locals_[SUBSCRIPT_MARKER] = IDENTITY_SUBSCRIPT
373 379 code = SUBSCRIPT_MARKER + "[" + code + "]"
374 380 context = EvaluationContext(**{**context._asdict(), **{"locals": locals_}})
375 381
376 382 if context.evaluation == "dangerous":
377 383 return eval(code, context.globals, context.locals)
378 384
379 385 expression = ast.parse(code, mode="eval")
380 386
381 387 return eval_node(expression, context)
382 388
383 389
384 390 BINARY_OP_DUNDERS: Dict[Type[ast.operator], Tuple[str]] = {
385 391 ast.Add: ("__add__",),
386 392 ast.Sub: ("__sub__",),
387 393 ast.Mult: ("__mul__",),
388 394 ast.Div: ("__truediv__",),
389 395 ast.FloorDiv: ("__floordiv__",),
390 396 ast.Mod: ("__mod__",),
391 397 ast.Pow: ("__pow__",),
392 398 ast.LShift: ("__lshift__",),
393 399 ast.RShift: ("__rshift__",),
394 400 ast.BitOr: ("__or__",),
395 401 ast.BitXor: ("__xor__",),
396 402 ast.BitAnd: ("__and__",),
397 403 ast.MatMult: ("__matmul__",),
398 404 }
399 405
400 406 COMP_OP_DUNDERS: Dict[Type[ast.cmpop], Tuple[str, ...]] = {
401 407 ast.Eq: ("__eq__",),
402 408 ast.NotEq: ("__ne__", "__eq__"),
403 409 ast.Lt: ("__lt__", "__gt__"),
404 410 ast.LtE: ("__le__", "__ge__"),
405 411 ast.Gt: ("__gt__", "__lt__"),
406 412 ast.GtE: ("__ge__", "__le__"),
407 413 ast.In: ("__contains__",),
408 414 # Note: ast.Is, ast.IsNot, ast.NotIn are handled specially
409 415 }
410 416
411 417 UNARY_OP_DUNDERS: Dict[Type[ast.unaryop], Tuple[str, ...]] = {
412 418 ast.USub: ("__neg__",),
413 419 ast.UAdd: ("__pos__",),
414 420 # we have to check both __inv__ and __invert__!
415 421 ast.Invert: ("__invert__", "__inv__"),
416 422 ast.Not: ("__not__",),
417 423 }
418 424
419 425
420 426 class Duck:
421 427 """A dummy class used to create objects of other classes without calling their ``__init__``"""
422 428
423 429
424 430 def _find_dunder(node_op, dunders) -> Union[Tuple[str, ...], None]:
425 431 dunder = None
426 432 for op, candidate_dunder in dunders.items():
427 433 if isinstance(node_op, op):
428 434 dunder = candidate_dunder
429 435 return dunder
430 436
431 437
432 438 def eval_node(node: Union[ast.AST, None], context: EvaluationContext):
433 439 """Evaluate AST node in provided context.
434 440
435 441 Applies evaluation restrictions defined in the context. Currently does not support evaluation of functions with keyword arguments.
436 442
437 443 Does not evaluate actions that always have side effects:
438 444
439 445 - class definitions (``class sth: ...``)
440 446 - function definitions (``def sth: ...``)
441 447 - variable assignments (``x = 1``)
442 448 - augmented assignments (``x += 1``)
443 449 - deletions (``del x``)
444 450
445 451 Does not evaluate operations which do not return values:
446 452
447 453 - assertions (``assert x``)
448 454 - pass (``pass``)
449 455 - imports (``import x``)
450 456 - control flow:
451 457
452 458 - conditionals (``if x:``) except for ternary IfExp (``a if x else b``)
453 459 - loops (``for`` and ``while``)
454 460 - exception handling
455 461
456 462 The purpose of this function is to guard against unwanted side-effects;
457 463 it does not give guarantees on protection from malicious code execution.
458 464 """
459 465 policy = EVALUATION_POLICIES[context.evaluation]
460 466 if node is None:
461 467 return None
462 468 if isinstance(node, ast.Expression):
463 469 return eval_node(node.body, context)
464 470 if isinstance(node, ast.BinOp):
465 471 left = eval_node(node.left, context)
466 472 right = eval_node(node.right, context)
467 473 dunders = _find_dunder(node.op, BINARY_OP_DUNDERS)
468 474 if dunders:
469 475 if policy.can_operate(dunders, left, right):
470 476 return getattr(left, dunders[0])(right)
471 477 else:
472 478 raise GuardRejection(
473 479 f"Operation (`{dunders}`) for",
474 480 type(left),
475 481 f"not allowed in {context.evaluation} mode",
476 482 )
477 483 if isinstance(node, ast.Compare):
478 484 left = eval_node(node.left, context)
479 485 all_true = True
480 486 negate = False
481 487 for op, right in zip(node.ops, node.comparators):
482 488 right = eval_node(right, context)
483 489 dunder = None
484 490 dunders = _find_dunder(op, COMP_OP_DUNDERS)
485 491 if not dunders:
486 492 if isinstance(op, ast.NotIn):
487 493 dunders = COMP_OP_DUNDERS[ast.In]
488 494 negate = True
489 495 if isinstance(op, ast.Is):
490 496 dunder = "is_"
491 497 if isinstance(op, ast.IsNot):
492 498 dunder = "is_"
493 499 negate = True
494 500 if not dunder and dunders:
495 501 dunder = dunders[0]
496 502 if dunder:
497 503 a, b = (right, left) if dunder == "__contains__" else (left, right)
498 504 if dunder == "is_" or dunders and policy.can_operate(dunders, a, b):
499 505 result = getattr(operator, dunder)(a, b)
500 506 if negate:
501 507 result = not result
502 508 if not result:
503 509 all_true = False
504 510 left = right
505 511 else:
506 512 raise GuardRejection(
507 513 f"Comparison (`{dunder}`) for",
508 514 type(left),
509 515 f"not allowed in {context.evaluation} mode",
510 516 )
511 517 else:
512 518 raise ValueError(
513 519 f"Comparison `{dunder}` not supported"
514 520 ) # pragma: no cover
515 521 return all_true
516 522 if isinstance(node, ast.Constant):
517 523 return node.value
518 524 if isinstance(node, ast.Tuple):
519 525 return tuple(eval_node(e, context) for e in node.elts)
520 526 if isinstance(node, ast.List):
521 527 return [eval_node(e, context) for e in node.elts]
522 528 if isinstance(node, ast.Set):
523 529 return {eval_node(e, context) for e in node.elts}
524 530 if isinstance(node, ast.Dict):
525 531 return dict(
526 532 zip(
527 533 [eval_node(k, context) for k in node.keys],
528 534 [eval_node(v, context) for v in node.values],
529 535 )
530 536 )
531 537 if isinstance(node, ast.Slice):
532 538 return slice(
533 539 eval_node(node.lower, context),
534 540 eval_node(node.upper, context),
535 541 eval_node(node.step, context),
536 542 )
537 543 if isinstance(node, ast.UnaryOp):
538 544 value = eval_node(node.operand, context)
539 545 dunders = _find_dunder(node.op, UNARY_OP_DUNDERS)
540 546 if dunders:
541 547 if policy.can_operate(dunders, value):
542 548 return getattr(value, dunders[0])()
543 549 else:
544 550 raise GuardRejection(
545 551 f"Operation (`{dunders}`) for",
546 552 type(value),
547 553 f"not allowed in {context.evaluation} mode",
548 554 )
549 555 if isinstance(node, ast.Subscript):
550 556 value = eval_node(node.value, context)
551 557 slice_ = eval_node(node.slice, context)
552 558 if policy.can_get_item(value, slice_):
553 559 return value[slice_]
554 560 raise GuardRejection(
555 561 "Subscript access (`__getitem__`) for",
556 562 type(value), # not joined to avoid calling `repr`
557 563 f" not allowed in {context.evaluation} mode",
558 564 )
559 565 if isinstance(node, ast.Name):
560 566 return _eval_node_name(node.id, context)
561 567 if isinstance(node, ast.Attribute):
562 568 value = eval_node(node.value, context)
563 569 if policy.can_get_attr(value, node.attr):
564 570 return getattr(value, node.attr)
565 571 raise GuardRejection(
566 572 "Attribute access (`__getattr__`) for",
567 573 type(value), # not joined to avoid calling `repr`
568 574 f"not allowed in {context.evaluation} mode",
569 575 )
570 576 if isinstance(node, ast.IfExp):
571 577 test = eval_node(node.test, context)
572 578 if test:
573 579 return eval_node(node.body, context)
574 580 else:
575 581 return eval_node(node.orelse, context)
576 582 if isinstance(node, ast.Call):
577 583 func = eval_node(node.func, context)
578 584 if policy.can_call(func) and not node.keywords:
579 585 args = [eval_node(arg, context) for arg in node.args]
580 586 return func(*args)
581 587 if isclass(func):
582 588 # this code path gets entered when calling class e.g. `MyClass()`
583 589 # or `my_instance.__class__()` - in both cases `func` is `MyClass`.
584 590 # Should return `MyClass` if `__new__` is not overridden,
585 591 # otherwise whatever `__new__` return type is.
586 592 overridden_return_type = _eval_return_type(func.__new__, node, context)
587 593 if overridden_return_type is not NOT_EVALUATED:
588 594 return overridden_return_type
589 595 return _create_duck_for_heap_type(func)
590 596 else:
591 597 return_type = _eval_return_type(func, node, context)
592 598 if return_type is not NOT_EVALUATED:
593 599 return return_type
594 600 raise GuardRejection(
595 601 "Call for",
596 602 func, # not joined to avoid calling `repr`
597 603 f"not allowed in {context.evaluation} mode",
598 604 )
599 605 raise ValueError("Unhandled node", ast.dump(node))
600 606
601 607
602 608 def _eval_return_type(func: Callable, node: ast.Call, context: EvaluationContext):
603 609 """Evaluate return type of a given callable function.
604 610
605 611 Returns the built-in type, a duck or NOT_EVALUATED sentinel.
606 612 """
607 613 try:
608 614 sig = signature(func)
609 615 except ValueError:
610 616 sig = UNKNOWN_SIGNATURE
611 617 # if annotation was not stringized, or it was stringized
612 618 # but resolved by signature call we know the return type
613 619 not_empty = sig.return_annotation is not Signature.empty
614 620 stringized = isinstance(sig.return_annotation, str)
615 621 if not_empty:
616 622 return_type = (
617 623 _eval_node_name(sig.return_annotation, context)
618 624 if stringized
619 625 else sig.return_annotation
620 626 )
621 627 if return_type is Self and hasattr(func, "__self__"):
622 628 return func.__self__
623 629 elif get_origin(return_type) is Literal:
624 630 type_args = get_args(return_type)
625 631 if len(type_args) == 1:
626 632 return type_args[0]
627 633 elif isinstance(return_type, NewType):
628 634 return _eval_or_create_duck(return_type.__supertype__, node, context)
629 635 elif isinstance(return_type, TypeAliasType):
630 636 return _eval_or_create_duck(return_type.__value__, node, context)
631 637 else:
632 638 return _eval_or_create_duck(return_type, node, context)
633 639 return NOT_EVALUATED
634 640
635 641
636 642 def _eval_node_name(node_id: str, context: EvaluationContext):
637 643 policy = EVALUATION_POLICIES[context.evaluation]
638 644 if policy.allow_locals_access and node_id in context.locals:
639 645 return context.locals[node_id]
640 646 if policy.allow_globals_access and node_id in context.globals:
641 647 return context.globals[node_id]
642 648 if policy.allow_builtins_access and hasattr(builtins, node_id):
643 649 # note: do not use __builtins__, it is implementation detail of cPython
644 650 return getattr(builtins, node_id)
645 651 if not policy.allow_globals_access and not policy.allow_locals_access:
646 652 raise GuardRejection(
647 653 f"Namespace access not allowed in {context.evaluation} mode"
648 654 )
649 655 else:
650 656 raise NameError(f"{node_id} not found in locals, globals, nor builtins")
651 657
652 658
653 659 def _eval_or_create_duck(duck_type, node: ast.Call, context: EvaluationContext):
654 660 policy = EVALUATION_POLICIES[context.evaluation]
655 661 # if allow-listed builtin is on type annotation, instantiate it
656 662 if policy.can_call(duck_type) and not node.keywords:
657 663 args = [eval_node(arg, context) for arg in node.args]
658 664 return duck_type(*args)
659 665 # if custom class is in type annotation, mock it
660 666 return _create_duck_for_heap_type(duck_type)
661 667
662 668
663 669 def _create_duck_for_heap_type(duck_type):
664 670 """Create an imitation of an object of a given type (a duck).
665 671
666 672 Returns the duck or NOT_EVALUATED sentinel if duck could not be created.
667 673 """
668 674 duck = Duck()
669 675 try:
670 676 # this only works for heap types, not builtins
671 677 duck.__class__ = duck_type
672 678 return duck
673 679 except TypeError:
674 680 pass
675 681 return NOT_EVALUATED
676 682
677 683
678 684 SUPPORTED_EXTERNAL_GETITEM = {
679 685 ("pandas", "core", "indexing", "_iLocIndexer"),
680 686 ("pandas", "core", "indexing", "_LocIndexer"),
681 687 ("pandas", "DataFrame"),
682 688 ("pandas", "Series"),
683 689 ("numpy", "ndarray"),
684 690 ("numpy", "void"),
685 691 }
686 692
687 693
688 694 BUILTIN_GETITEM: Set[InstancesHaveGetItem] = {
689 695 dict,
690 696 str, # type: ignore[arg-type]
691 697 bytes, # type: ignore[arg-type]
692 698 list,
693 699 tuple,
694 700 collections.defaultdict,
695 701 collections.deque,
696 702 collections.OrderedDict,
697 703 collections.ChainMap,
698 704 collections.UserDict,
699 705 collections.UserList,
700 706 collections.UserString, # type: ignore[arg-type]
701 707 _DummyNamedTuple,
702 708 _IdentitySubscript,
703 709 }
704 710
705 711
706 712 def _list_methods(cls, source=None):
707 713 """For use on immutable objects or with methods returning a copy"""
708 714 return [getattr(cls, k) for k in (source if source else dir(cls))]
709 715
710 716
711 717 dict_non_mutating_methods = ("copy", "keys", "values", "items")
712 718 list_non_mutating_methods = ("copy", "index", "count")
713 719 set_non_mutating_methods = set(dir(set)) & set(dir(frozenset))
714 720
715 721
716 722 dict_keys: Type[collections.abc.KeysView] = type({}.keys())
717 723
718 724 NUMERICS = {int, float, complex}
719 725
720 726 ALLOWED_CALLS = {
721 727 bytes,
722 728 *_list_methods(bytes),
723 729 dict,
724 730 *_list_methods(dict, dict_non_mutating_methods),
725 731 dict_keys.isdisjoint,
726 732 list,
727 733 *_list_methods(list, list_non_mutating_methods),
728 734 set,
729 735 *_list_methods(set, set_non_mutating_methods),
730 736 frozenset,
731 737 *_list_methods(frozenset),
732 738 range,
733 739 str,
734 740 *_list_methods(str),
735 741 tuple,
736 742 *_list_methods(tuple),
737 743 *NUMERICS,
738 744 *[method for numeric_cls in NUMERICS for method in _list_methods(numeric_cls)],
739 745 collections.deque,
740 746 *_list_methods(collections.deque, list_non_mutating_methods),
741 747 collections.defaultdict,
742 748 *_list_methods(collections.defaultdict, dict_non_mutating_methods),
743 749 collections.OrderedDict,
744 750 *_list_methods(collections.OrderedDict, dict_non_mutating_methods),
745 751 collections.UserDict,
746 752 *_list_methods(collections.UserDict, dict_non_mutating_methods),
747 753 collections.UserList,
748 754 *_list_methods(collections.UserList, list_non_mutating_methods),
749 755 collections.UserString,
750 756 *_list_methods(collections.UserString, dir(str)),
751 757 collections.Counter,
752 758 *_list_methods(collections.Counter, dict_non_mutating_methods),
753 759 collections.Counter.elements,
754 760 collections.Counter.most_common,
755 761 }
756 762
757 763 BUILTIN_GETATTR: Set[MayHaveGetattr] = {
758 764 *BUILTIN_GETITEM,
759 765 set,
760 766 frozenset,
761 767 object,
762 768 type, # `type` handles a lot of generic cases, e.g. numbers as in `int.real`.
763 769 *NUMERICS,
764 770 dict_keys,
765 771 MethodDescriptorType,
766 772 ModuleType,
767 773 }
768 774
769 775
770 776 BUILTIN_OPERATIONS = {*BUILTIN_GETATTR}
771 777
772 778 EVALUATION_POLICIES = {
773 779 "minimal": EvaluationPolicy(
774 780 allow_builtins_access=True,
775 781 allow_locals_access=False,
776 782 allow_globals_access=False,
777 783 allow_item_access=False,
778 784 allow_attr_access=False,
779 785 allowed_calls=set(),
780 786 allow_any_calls=False,
781 787 allow_all_operations=False,
782 788 ),
783 789 "limited": SelectivePolicy(
784 790 allowed_getitem=BUILTIN_GETITEM,
785 791 allowed_getitem_external=SUPPORTED_EXTERNAL_GETITEM,
786 792 allowed_getattr=BUILTIN_GETATTR,
787 793 allowed_getattr_external={
788 794 # pandas Series/Frame implements custom `__getattr__`
789 795 ("pandas", "DataFrame"),
790 796 ("pandas", "Series"),
791 797 },
792 798 allowed_operations=BUILTIN_OPERATIONS,
793 799 allow_builtins_access=True,
794 800 allow_locals_access=True,
795 801 allow_globals_access=True,
796 802 allowed_calls=ALLOWED_CALLS,
797 803 ),
798 804 "unsafe": EvaluationPolicy(
799 805 allow_builtins_access=True,
800 806 allow_locals_access=True,
801 807 allow_globals_access=True,
802 808 allow_attr_access=True,
803 809 allow_item_access=True,
804 810 allow_any_calls=True,
805 811 allow_all_operations=True,
806 812 ),
807 813 }
808 814
809 815
810 816 __all__ = [
811 817 "guarded_eval",
812 818 "eval_node",
813 819 "GuardRejection",
814 820 "EvaluationContext",
815 821 "_unbind_method",
816 822 ]
@@ -1,681 +1,689 b''
1 import sys
1 2 from contextlib import contextmanager
2 3 from typing import NamedTuple, Literal, NewType
3 4 from functools import partial
4 5 from IPython.core.guarded_eval import (
5 6 EvaluationContext,
6 7 GuardRejection,
7 8 guarded_eval,
8 9 _unbind_method,
9 10 )
10 from typing_extensions import (
11 Self, # Python >=3.10
12 TypeAliasType, # Python >=3.12
13 )
14 11 from IPython.testing import decorators as dec
15 12 import pytest
16 13
17 14
15 if sys.version_info < (3, 11):
16 from typing_extensions import Self
17 else:
18 from typing import Self
19
20 if sys.version_info < (3, 12):
21 from typing_extensions import TypeAliasType
22 else:
23 from typing import TypeAliasType
24
25
18 26 def create_context(evaluation: str, **kwargs):
19 27 return EvaluationContext(locals=kwargs, globals={}, evaluation=evaluation)
20 28
21 29
22 30 forbidden = partial(create_context, "forbidden")
23 31 minimal = partial(create_context, "minimal")
24 32 limited = partial(create_context, "limited")
25 33 unsafe = partial(create_context, "unsafe")
26 34 dangerous = partial(create_context, "dangerous")
27 35
28 36 LIMITED_OR_HIGHER = [limited, unsafe, dangerous]
29 37 MINIMAL_OR_HIGHER = [minimal, *LIMITED_OR_HIGHER]
30 38
31 39
32 40 @contextmanager
33 41 def module_not_installed(module: str):
34 42 import sys
35 43
36 44 try:
37 45 to_restore = sys.modules[module]
38 46 del sys.modules[module]
39 47 except KeyError:
40 48 to_restore = None
41 49 try:
42 50 yield
43 51 finally:
44 52 sys.modules[module] = to_restore
45 53
46 54
47 55 def test_external_not_installed():
48 56 """
49 57 Because attribute check requires checking if object is not of allowed
50 58 external type, this tests logic for absence of external module.
51 59 """
52 60
53 61 class Custom:
54 62 def __init__(self):
55 63 self.test = 1
56 64
57 65 def __getattr__(self, key):
58 66 return key
59 67
60 68 with module_not_installed("pandas"):
61 69 context = limited(x=Custom())
62 70 with pytest.raises(GuardRejection):
63 71 guarded_eval("x.test", context)
64 72
65 73
66 74 @dec.skip_without("pandas")
67 75 def test_external_changed_api(monkeypatch):
68 76 """Check that the execution rejects if external API changed paths"""
69 77 import pandas as pd
70 78
71 79 series = pd.Series([1], index=["a"])
72 80
73 81 with monkeypatch.context() as m:
74 82 m.delattr(pd, "Series")
75 83 context = limited(data=series)
76 84 with pytest.raises(GuardRejection):
77 85 guarded_eval("data.iloc[0]", context)
78 86
79 87
80 88 @dec.skip_without("pandas")
81 89 def test_pandas_series_iloc():
82 90 import pandas as pd
83 91
84 92 series = pd.Series([1], index=["a"])
85 93 context = limited(data=series)
86 94 assert guarded_eval("data.iloc[0]", context) == 1
87 95
88 96
89 97 def test_rejects_custom_properties():
90 98 class BadProperty:
91 99 @property
92 100 def iloc(self):
93 101 return [None]
94 102
95 103 series = BadProperty()
96 104 context = limited(data=series)
97 105
98 106 with pytest.raises(GuardRejection):
99 107 guarded_eval("data.iloc[0]", context)
100 108
101 109
102 110 @dec.skip_without("pandas")
103 111 def test_accepts_non_overriden_properties():
104 112 import pandas as pd
105 113
106 114 class GoodProperty(pd.Series):
107 115 pass
108 116
109 117 series = GoodProperty([1], index=["a"])
110 118 context = limited(data=series)
111 119
112 120 assert guarded_eval("data.iloc[0]", context) == 1
113 121
114 122
115 123 @dec.skip_without("pandas")
116 124 def test_pandas_series():
117 125 import pandas as pd
118 126
119 127 context = limited(data=pd.Series([1], index=["a"]))
120 128 assert guarded_eval('data["a"]', context) == 1
121 129 with pytest.raises(KeyError):
122 130 guarded_eval('data["c"]', context)
123 131
124 132
125 133 @dec.skip_without("pandas")
126 134 def test_pandas_bad_series():
127 135 import pandas as pd
128 136
129 137 class BadItemSeries(pd.Series):
130 138 def __getitem__(self, key):
131 139 return "CUSTOM_ITEM"
132 140
133 141 class BadAttrSeries(pd.Series):
134 142 def __getattr__(self, key):
135 143 return "CUSTOM_ATTR"
136 144
137 145 bad_series = BadItemSeries([1], index=["a"])
138 146 context = limited(data=bad_series)
139 147
140 148 with pytest.raises(GuardRejection):
141 149 guarded_eval('data["a"]', context)
142 150 with pytest.raises(GuardRejection):
143 151 guarded_eval('data["c"]', context)
144 152
145 153 # note: here result is a bit unexpected because
146 154 # pandas `__getattr__` calls `__getitem__`;
147 155 # FIXME - special case to handle it?
148 156 assert guarded_eval("data.a", context) == "CUSTOM_ITEM"
149 157
150 158 context = unsafe(data=bad_series)
151 159 assert guarded_eval('data["a"]', context) == "CUSTOM_ITEM"
152 160
153 161 bad_attr_series = BadAttrSeries([1], index=["a"])
154 162 context = limited(data=bad_attr_series)
155 163 assert guarded_eval('data["a"]', context) == 1
156 164 with pytest.raises(GuardRejection):
157 165 guarded_eval("data.a", context)
158 166
159 167
160 168 @dec.skip_without("pandas")
161 169 def test_pandas_dataframe_loc():
162 170 import pandas as pd
163 171 from pandas.testing import assert_series_equal
164 172
165 173 data = pd.DataFrame([{"a": 1}])
166 174 context = limited(data=data)
167 175 assert_series_equal(guarded_eval('data.loc[:, "a"]', context), data["a"])
168 176
169 177
170 178 def test_named_tuple():
171 179 class GoodNamedTuple(NamedTuple):
172 180 a: str
173 181 pass
174 182
175 183 class BadNamedTuple(NamedTuple):
176 184 a: str
177 185
178 186 def __getitem__(self, key):
179 187 return None
180 188
181 189 good = GoodNamedTuple(a="x")
182 190 bad = BadNamedTuple(a="x")
183 191
184 192 context = limited(data=good)
185 193 assert guarded_eval("data[0]", context) == "x"
186 194
187 195 context = limited(data=bad)
188 196 with pytest.raises(GuardRejection):
189 197 guarded_eval("data[0]", context)
190 198
191 199
192 200 def test_dict():
193 201 context = limited(data={"a": 1, "b": {"x": 2}, ("x", "y"): 3})
194 202 assert guarded_eval('data["a"]', context) == 1
195 203 assert guarded_eval('data["b"]', context) == {"x": 2}
196 204 assert guarded_eval('data["b"]["x"]', context) == 2
197 205 assert guarded_eval('data["x", "y"]', context) == 3
198 206
199 207 assert guarded_eval("data.keys", context)
200 208
201 209
202 210 def test_set():
203 211 context = limited(data={"a", "b"})
204 212 assert guarded_eval("data.difference", context)
205 213
206 214
207 215 def test_list():
208 216 context = limited(data=[1, 2, 3])
209 217 assert guarded_eval("data[1]", context) == 2
210 218 assert guarded_eval("data.copy", context)
211 219
212 220
213 221 def test_dict_literal():
214 222 context = limited()
215 223 assert guarded_eval("{}", context) == {}
216 224 assert guarded_eval('{"a": 1}', context) == {"a": 1}
217 225
218 226
219 227 def test_list_literal():
220 228 context = limited()
221 229 assert guarded_eval("[]", context) == []
222 230 assert guarded_eval('[1, "a"]', context) == [1, "a"]
223 231
224 232
225 233 def test_set_literal():
226 234 context = limited()
227 235 assert guarded_eval("set()", context) == set()
228 236 assert guarded_eval('{"a"}', context) == {"a"}
229 237
230 238
231 239 def test_evaluates_if_expression():
232 240 context = limited()
233 241 assert guarded_eval("2 if True else 3", context) == 2
234 242 assert guarded_eval("4 if False else 5", context) == 5
235 243
236 244
237 245 def test_object():
238 246 obj = object()
239 247 context = limited(obj=obj)
240 248 assert guarded_eval("obj.__dir__", context) == obj.__dir__
241 249
242 250
243 251 @pytest.mark.parametrize(
244 252 "code,expected",
245 253 [
246 254 ["int.numerator", int.numerator],
247 255 ["float.is_integer", float.is_integer],
248 256 ["complex.real", complex.real],
249 257 ],
250 258 )
251 259 def test_number_attributes(code, expected):
252 260 assert guarded_eval(code, limited()) == expected
253 261
254 262
255 263 def test_method_descriptor():
256 264 context = limited()
257 265 assert guarded_eval("list.copy.__name__", context) == "copy"
258 266
259 267
260 268 class HeapType:
261 269 pass
262 270
263 271
264 272 class CallCreatesHeapType:
265 273 def __call__(self) -> HeapType:
266 274 return HeapType()
267 275
268 276
269 277 class CallCreatesBuiltin:
270 278 def __call__(self) -> frozenset:
271 279 return frozenset()
272 280
273 281
274 282 class HasStaticMethod:
275 283 @staticmethod
276 284 def static_method() -> HeapType:
277 285 return HeapType()
278 286
279 287
280 288 class InitReturnsFrozenset:
281 289 def __new__(self) -> frozenset: # type:ignore[misc]
282 290 return frozenset()
283 291
284 292
285 293 class StringAnnotation:
286 294 def heap(self) -> "HeapType":
287 295 return HeapType()
288 296
289 297 def copy(self) -> "StringAnnotation":
290 298 return StringAnnotation()
291 299
292 300
293 301 CustomIntType = NewType("CustomIntType", int)
294 302 CustomHeapType = NewType("CustomHeapType", HeapType)
295 303 IntTypeAlias = TypeAliasType("IntTypeAlias", int)
296 304 HeapTypeAlias = TypeAliasType("HeapTypeAlias", HeapType)
297 305
298 306
299 307 class SpecialTyping:
300 308 def custom_int_type(self) -> CustomIntType:
301 309 return CustomIntType(1)
302 310
303 311 def custom_heap_type(self) -> CustomHeapType:
304 312 return CustomHeapType(HeapType())
305 313
306 314 # TODO: remove type:ignore comment once mypy
307 315 # supports explicit calls to `TypeAliasType`, see:
308 316 # https://github.com/python/mypy/issues/16614
309 317 def int_type_alias(self) -> IntTypeAlias: # type:ignore[valid-type]
310 318 return 1
311 319
312 320 def heap_type_alias(self) -> HeapTypeAlias: # type:ignore[valid-type]
313 321 return 1
314 322
315 323 def literal(self) -> Literal[False]:
316 324 return False
317 325
318 326 def self(self) -> Self:
319 327 return self
320 328
321 329
322 330 @pytest.mark.parametrize(
323 331 "data,good,expected,equality",
324 332 [
325 333 [[1, 2, 3], "data.index(2)", 1, True],
326 334 [{"a": 1}, "data.keys().isdisjoint({})", True, True],
327 335 [StringAnnotation(), "data.heap()", HeapType, False],
328 336 [StringAnnotation(), "data.copy()", StringAnnotation, False],
329 337 # test cases for `__call__`
330 338 [CallCreatesHeapType(), "data()", HeapType, False],
331 339 [CallCreatesBuiltin(), "data()", frozenset, False],
332 340 # Test cases for `__init__`
333 341 [HeapType, "data()", HeapType, False],
334 342 [InitReturnsFrozenset, "data()", frozenset, False],
335 343 [HeapType(), "data.__class__()", HeapType, False],
336 344 # supported special cases for typing
337 345 [SpecialTyping(), "data.custom_int_type()", int, False],
338 346 [SpecialTyping(), "data.custom_heap_type()", HeapType, False],
339 347 [SpecialTyping(), "data.int_type_alias()", int, False],
340 348 [SpecialTyping(), "data.heap_type_alias()", HeapType, False],
341 349 [SpecialTyping(), "data.self()", SpecialTyping, False],
342 350 [SpecialTyping(), "data.literal()", False, True],
343 351 # test cases for static methods
344 352 [HasStaticMethod, "data.static_method()", HeapType, False],
345 353 ],
346 354 )
347 355 def test_evaluates_calls(data, good, expected, equality):
348 356 context = limited(data=data, HeapType=HeapType, StringAnnotation=StringAnnotation)
349 357 value = guarded_eval(good, context)
350 358 if equality:
351 359 assert value == expected
352 360 else:
353 361 assert isinstance(value, expected)
354 362
355 363
356 364 @pytest.mark.parametrize(
357 365 "data,bad",
358 366 [
359 367 [[1, 2, 3], "data.append(4)"],
360 368 [{"a": 1}, "data.update()"],
361 369 ],
362 370 )
363 371 def test_rejects_calls_with_side_effects(data, bad):
364 372 context = limited(data=data)
365 373
366 374 with pytest.raises(GuardRejection):
367 375 guarded_eval(bad, context)
368 376
369 377
370 378 @pytest.mark.parametrize(
371 379 "code,expected",
372 380 [
373 381 ["(1\n+\n1)", 2],
374 382 ["list(range(10))[-1:]", [9]],
375 383 ["list(range(20))[3:-2:3]", [3, 6, 9, 12, 15]],
376 384 ],
377 385 )
378 386 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
379 387 def test_evaluates_complex_cases(code, expected, context):
380 388 assert guarded_eval(code, context()) == expected
381 389
382 390
383 391 @pytest.mark.parametrize(
384 392 "code,expected",
385 393 [
386 394 ["1", 1],
387 395 ["1.0", 1.0],
388 396 ["0xdeedbeef", 0xDEEDBEEF],
389 397 ["True", True],
390 398 ["None", None],
391 399 ["{}", {}],
392 400 ["[]", []],
393 401 ],
394 402 )
395 403 @pytest.mark.parametrize("context", MINIMAL_OR_HIGHER)
396 404 def test_evaluates_literals(code, expected, context):
397 405 assert guarded_eval(code, context()) == expected
398 406
399 407
400 408 @pytest.mark.parametrize(
401 409 "code,expected",
402 410 [
403 411 ["-5", -5],
404 412 ["+5", +5],
405 413 ["~5", -6],
406 414 ],
407 415 )
408 416 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
409 417 def test_evaluates_unary_operations(code, expected, context):
410 418 assert guarded_eval(code, context()) == expected
411 419
412 420
413 421 @pytest.mark.parametrize(
414 422 "code,expected",
415 423 [
416 424 ["1 + 1", 2],
417 425 ["3 - 1", 2],
418 426 ["2 * 3", 6],
419 427 ["5 // 2", 2],
420 428 ["5 / 2", 2.5],
421 429 ["5**2", 25],
422 430 ["2 >> 1", 1],
423 431 ["2 << 1", 4],
424 432 ["1 | 2", 3],
425 433 ["1 & 1", 1],
426 434 ["1 & 2", 0],
427 435 ],
428 436 )
429 437 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
430 438 def test_evaluates_binary_operations(code, expected, context):
431 439 assert guarded_eval(code, context()) == expected
432 440
433 441
434 442 @pytest.mark.parametrize(
435 443 "code,expected",
436 444 [
437 445 ["2 > 1", True],
438 446 ["2 < 1", False],
439 447 ["2 <= 1", False],
440 448 ["2 <= 2", True],
441 449 ["1 >= 2", False],
442 450 ["2 >= 2", True],
443 451 ["2 == 2", True],
444 452 ["1 == 2", False],
445 453 ["1 != 2", True],
446 454 ["1 != 1", False],
447 455 ["1 < 4 < 3", False],
448 456 ["(1 < 4) < 3", True],
449 457 ["4 > 3 > 2 > 1", True],
450 458 ["4 > 3 > 2 > 9", False],
451 459 ["1 < 2 < 3 < 4", True],
452 460 ["9 < 2 < 3 < 4", False],
453 461 ["1 < 2 > 1 > 0 > -1 < 1", True],
454 462 ["1 in [1] in [[1]]", True],
455 463 ["1 in [1] in [[2]]", False],
456 464 ["1 in [1]", True],
457 465 ["0 in [1]", False],
458 466 ["1 not in [1]", False],
459 467 ["0 not in [1]", True],
460 468 ["True is True", True],
461 469 ["False is False", True],
462 470 ["True is False", False],
463 471 ["True is not True", False],
464 472 ["False is not True", True],
465 473 ],
466 474 )
467 475 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
468 476 def test_evaluates_comparisons(code, expected, context):
469 477 assert guarded_eval(code, context()) == expected
470 478
471 479
472 480 def test_guards_comparisons():
473 481 class GoodEq(int):
474 482 pass
475 483
476 484 class BadEq(int):
477 485 def __eq__(self, other):
478 486 assert False
479 487
480 488 context = limited(bad=BadEq(1), good=GoodEq(1))
481 489
482 490 with pytest.raises(GuardRejection):
483 491 guarded_eval("bad == 1", context)
484 492
485 493 with pytest.raises(GuardRejection):
486 494 guarded_eval("bad != 1", context)
487 495
488 496 with pytest.raises(GuardRejection):
489 497 guarded_eval("1 == bad", context)
490 498
491 499 with pytest.raises(GuardRejection):
492 500 guarded_eval("1 != bad", context)
493 501
494 502 assert guarded_eval("good == 1", context) is True
495 503 assert guarded_eval("good != 1", context) is False
496 504 assert guarded_eval("1 == good", context) is True
497 505 assert guarded_eval("1 != good", context) is False
498 506
499 507
500 508 def test_guards_unary_operations():
501 509 class GoodOp(int):
502 510 pass
503 511
504 512 class BadOpInv(int):
505 513 def __inv__(self, other):
506 514 assert False
507 515
508 516 class BadOpInverse(int):
509 517 def __inv__(self, other):
510 518 assert False
511 519
512 520 context = limited(good=GoodOp(1), bad1=BadOpInv(1), bad2=BadOpInverse(1))
513 521
514 522 with pytest.raises(GuardRejection):
515 523 guarded_eval("~bad1", context)
516 524
517 525 with pytest.raises(GuardRejection):
518 526 guarded_eval("~bad2", context)
519 527
520 528
521 529 def test_guards_binary_operations():
522 530 class GoodOp(int):
523 531 pass
524 532
525 533 class BadOp(int):
526 534 def __add__(self, other):
527 535 assert False
528 536
529 537 context = limited(good=GoodOp(1), bad=BadOp(1))
530 538
531 539 with pytest.raises(GuardRejection):
532 540 guarded_eval("1 + bad", context)
533 541
534 542 with pytest.raises(GuardRejection):
535 543 guarded_eval("bad + 1", context)
536 544
537 545 assert guarded_eval("good + 1", context) == 2
538 546 assert guarded_eval("1 + good", context) == 2
539 547
540 548
541 549 def test_guards_attributes():
542 550 class GoodAttr(float):
543 551 pass
544 552
545 553 class BadAttr1(float):
546 554 def __getattr__(self, key):
547 555 assert False
548 556
549 557 class BadAttr2(float):
550 558 def __getattribute__(self, key):
551 559 assert False
552 560
553 561 context = limited(good=GoodAttr(0.5), bad1=BadAttr1(0.5), bad2=BadAttr2(0.5))
554 562
555 563 with pytest.raises(GuardRejection):
556 564 guarded_eval("bad1.as_integer_ratio", context)
557 565
558 566 with pytest.raises(GuardRejection):
559 567 guarded_eval("bad2.as_integer_ratio", context)
560 568
561 569 assert guarded_eval("good.as_integer_ratio()", context) == (1, 2)
562 570
563 571
564 572 @pytest.mark.parametrize("context", MINIMAL_OR_HIGHER)
565 573 def test_access_builtins(context):
566 574 assert guarded_eval("round", context()) == round
567 575
568 576
569 577 def test_access_builtins_fails():
570 578 context = limited()
571 579 with pytest.raises(NameError):
572 580 guarded_eval("this_is_not_builtin", context)
573 581
574 582
575 583 def test_rejects_forbidden():
576 584 context = forbidden()
577 585 with pytest.raises(GuardRejection):
578 586 guarded_eval("1", context)
579 587
580 588
581 589 def test_guards_locals_and_globals():
582 590 context = EvaluationContext(
583 591 locals={"local_a": "a"}, globals={"global_b": "b"}, evaluation="minimal"
584 592 )
585 593
586 594 with pytest.raises(GuardRejection):
587 595 guarded_eval("local_a", context)
588 596
589 597 with pytest.raises(GuardRejection):
590 598 guarded_eval("global_b", context)
591 599
592 600
593 601 def test_access_locals_and_globals():
594 602 context = EvaluationContext(
595 603 locals={"local_a": "a"}, globals={"global_b": "b"}, evaluation="limited"
596 604 )
597 605 assert guarded_eval("local_a", context) == "a"
598 606 assert guarded_eval("global_b", context) == "b"
599 607
600 608
601 609 @pytest.mark.parametrize(
602 610 "code",
603 611 ["def func(): pass", "class C: pass", "x = 1", "x += 1", "del x", "import ast"],
604 612 )
605 613 @pytest.mark.parametrize("context", [minimal(), limited(), unsafe()])
606 614 def test_rejects_side_effect_syntax(code, context):
607 615 with pytest.raises(SyntaxError):
608 616 guarded_eval(code, context)
609 617
610 618
611 619 def test_subscript():
612 620 context = EvaluationContext(
613 621 locals={}, globals={}, evaluation="limited", in_subscript=True
614 622 )
615 623 empty_slice = slice(None, None, None)
616 624 assert guarded_eval("", context) == tuple()
617 625 assert guarded_eval(":", context) == empty_slice
618 626 assert guarded_eval("1:2:3", context) == slice(1, 2, 3)
619 627 assert guarded_eval(':, "a"', context) == (empty_slice, "a")
620 628
621 629
622 630 def test_unbind_method():
623 631 class X(list):
624 632 def index(self, k):
625 633 return "CUSTOM"
626 634
627 635 x = X()
628 636 assert _unbind_method(x.index) is X.index
629 637 assert _unbind_method([].index) is list.index
630 638 assert _unbind_method(list.index) is None
631 639
632 640
633 641 def test_assumption_instance_attr_do_not_matter():
634 642 """This is semi-specified in Python documentation.
635 643
636 644 However, since the specification says 'not guaranteed
637 645 to work' rather than 'is forbidden to work', future
638 646 versions could invalidate this assumptions. This test
639 647 is meant to catch such a change if it ever comes true.
640 648 """
641 649
642 650 class T:
643 651 def __getitem__(self, k):
644 652 return "a"
645 653
646 654 def __getattr__(self, k):
647 655 return "a"
648 656
649 657 def f(self):
650 658 return "b"
651 659
652 660 t = T()
653 661 t.__getitem__ = f
654 662 t.__getattr__ = f
655 663 assert t[1] == "a"
656 664 assert t[1] == "a"
657 665
658 666
659 667 def test_assumption_named_tuples_share_getitem():
660 668 """Check assumption on named tuples sharing __getitem__"""
661 669 from typing import NamedTuple
662 670
663 671 class A(NamedTuple):
664 672 pass
665 673
666 674 class B(NamedTuple):
667 675 pass
668 676
669 677 assert A.__getitem__ == B.__getitem__
670 678
671 679
672 680 @dec.skip_without("numpy")
673 681 def test_module_access():
674 682 import numpy
675 683
676 684 context = limited(numpy=numpy)
677 685 assert guarded_eval("numpy.linalg.norm", context) == numpy.linalg.norm
678 686
679 687 context = minimal(numpy=numpy)
680 688 with pytest.raises(GuardRejection):
681 689 guarded_eval("np.linalg.norm", context)
General Comments 0
You need to be logged in to leave comments. Login now