##// END OF EJS Templates
Support stringized return type annotations
krassowski -
Show More
@@ -1,793 +1,803 b''
1 1 from inspect import isclass, signature, Signature
2 2 from typing import (
3 Any,
4 3 Callable,
5 4 Dict,
6 5 Set,
7 6 Sequence,
8 7 Tuple,
9 8 NamedTuple,
10 9 Type,
11 10 Literal,
12 11 Union,
13 12 TYPE_CHECKING,
14 13 )
15 14 import ast
16 15 import builtins
17 16 import collections
18 17 import operator
19 18 import sys
20 19 from functools import cached_property
21 20 from dataclasses import dataclass, field
22 21 from types import MethodDescriptorType, ModuleType
23 22
24 23 from IPython.utils.docs import GENERATING_DOCUMENTATION
25 24 from IPython.utils.decorators import undoc
26 25
27 26
28 27 if TYPE_CHECKING or GENERATING_DOCUMENTATION:
29 28 from typing_extensions import Protocol
30 29 else:
31 30 # do not require on runtime
32 31 Protocol = object # requires Python >=3.8
33 32
34 33
35 34 @undoc
36 35 class HasGetItem(Protocol):
37 36 def __getitem__(self, key) -> None:
38 37 ...
39 38
40 39
41 40 @undoc
42 41 class InstancesHaveGetItem(Protocol):
43 42 def __call__(self, *args, **kwargs) -> HasGetItem:
44 43 ...
45 44
46 45
47 46 @undoc
48 47 class HasGetAttr(Protocol):
49 48 def __getattr__(self, key) -> None:
50 49 ...
51 50
52 51
53 52 @undoc
54 53 class DoesNotHaveGetAttr(Protocol):
55 54 pass
56 55
57 56
58 57 # By default `__getattr__` is not explicitly implemented on most objects
59 58 MayHaveGetattr = Union[HasGetAttr, DoesNotHaveGetAttr]
60 59
61 60
62 61 def _unbind_method(func: Callable) -> Union[Callable, None]:
63 62 """Get unbound method for given bound method.
64 63
65 64 Returns None if cannot get unbound method, or method is already unbound.
66 65 """
67 66 owner = getattr(func, "__self__", None)
68 67 owner_class = type(owner)
69 68 name = getattr(func, "__name__", None)
70 69 instance_dict_overrides = getattr(owner, "__dict__", None)
71 70 if (
72 71 owner is not None
73 72 and name
74 73 and (
75 74 not instance_dict_overrides
76 75 or (instance_dict_overrides and name not in instance_dict_overrides)
77 76 )
78 77 ):
79 78 return getattr(owner_class, name)
80 79 return None
81 80
82 81
83 82 @undoc
84 83 @dataclass
85 84 class EvaluationPolicy:
86 85 """Definition of evaluation policy."""
87 86
88 87 allow_locals_access: bool = False
89 88 allow_globals_access: bool = False
90 89 allow_item_access: bool = False
91 90 allow_attr_access: bool = False
92 91 allow_builtins_access: bool = False
93 92 allow_all_operations: bool = False
94 93 allow_any_calls: bool = False
95 94 allowed_calls: Set[Callable] = field(default_factory=set)
96 95
97 96 def can_get_item(self, value, item):
98 97 return self.allow_item_access
99 98
100 99 def can_get_attr(self, value, attr):
101 100 return self.allow_attr_access
102 101
103 102 def can_operate(self, dunders: Tuple[str, ...], a, b=None):
104 103 if self.allow_all_operations:
105 104 return True
106 105
107 106 def can_call(self, func):
108 107 if self.allow_any_calls:
109 108 return True
110 109
111 110 if func in self.allowed_calls:
112 111 return True
113 112
114 113 owner_method = _unbind_method(func)
115 114
116 115 if owner_method and owner_method in self.allowed_calls:
117 116 return True
118 117
119 118
120 119 def _get_external(module_name: str, access_path: Sequence[str]):
121 120 """Get value from external module given a dotted access path.
122 121
123 122 Raises:
124 123 * `KeyError` if module is removed not found, and
125 124 * `AttributeError` if acess path does not match an exported object
126 125 """
127 126 member_type = sys.modules[module_name]
128 127 for attr in access_path:
129 128 member_type = getattr(member_type, attr)
130 129 return member_type
131 130
132 131
133 132 def _has_original_dunder_external(
134 133 value,
135 134 module_name: str,
136 135 access_path: Sequence[str],
137 136 method_name: str,
138 137 ):
139 138 if module_name not in sys.modules:
140 139 # LBYLB as it is faster
141 140 return False
142 141 try:
143 142 member_type = _get_external(module_name, access_path)
144 143 value_type = type(value)
145 144 if type(value) == member_type:
146 145 return True
147 146 if method_name == "__getattribute__":
148 147 # we have to short-circuit here due to an unresolved issue in
149 148 # `isinstance` implementation: https://bugs.python.org/issue32683
150 149 return False
151 150 if isinstance(value, member_type):
152 151 method = getattr(value_type, method_name, None)
153 152 member_method = getattr(member_type, method_name, None)
154 153 if member_method == method:
155 154 return True
156 155 except (AttributeError, KeyError):
157 156 return False
158 157
159 158
160 159 def _has_original_dunder(
161 160 value, allowed_types, allowed_methods, allowed_external, method_name
162 161 ):
163 162 # note: Python ignores `__getattr__`/`__getitem__` on instances,
164 163 # we only need to check at class level
165 164 value_type = type(value)
166 165
167 166 # strict type check passes β†’ no need to check method
168 167 if value_type in allowed_types:
169 168 return True
170 169
171 170 method = getattr(value_type, method_name, None)
172 171
173 172 if method is None:
174 173 return None
175 174
176 175 if method in allowed_methods:
177 176 return True
178 177
179 178 for module_name, *access_path in allowed_external:
180 179 if _has_original_dunder_external(value, module_name, access_path, method_name):
181 180 return True
182 181
183 182 return False
184 183
185 184
186 185 @undoc
187 186 @dataclass
188 187 class SelectivePolicy(EvaluationPolicy):
189 188 allowed_getitem: Set[InstancesHaveGetItem] = field(default_factory=set)
190 189 allowed_getitem_external: Set[Tuple[str, ...]] = field(default_factory=set)
191 190
192 191 allowed_getattr: Set[MayHaveGetattr] = field(default_factory=set)
193 192 allowed_getattr_external: Set[Tuple[str, ...]] = field(default_factory=set)
194 193
195 194 allowed_operations: Set = field(default_factory=set)
196 195 allowed_operations_external: Set[Tuple[str, ...]] = field(default_factory=set)
197 196
198 197 _operation_methods_cache: Dict[str, Set[Callable]] = field(
199 198 default_factory=dict, init=False
200 199 )
201 200
202 201 def can_get_attr(self, value, attr):
203 202 has_original_attribute = _has_original_dunder(
204 203 value,
205 204 allowed_types=self.allowed_getattr,
206 205 allowed_methods=self._getattribute_methods,
207 206 allowed_external=self.allowed_getattr_external,
208 207 method_name="__getattribute__",
209 208 )
210 209 has_original_attr = _has_original_dunder(
211 210 value,
212 211 allowed_types=self.allowed_getattr,
213 212 allowed_methods=self._getattr_methods,
214 213 allowed_external=self.allowed_getattr_external,
215 214 method_name="__getattr__",
216 215 )
217 216
218 217 accept = False
219 218
220 219 # Many objects do not have `__getattr__`, this is fine.
221 220 if has_original_attr is None and has_original_attribute:
222 221 accept = True
223 222 else:
224 223 # Accept objects without modifications to `__getattr__` and `__getattribute__`
225 224 accept = has_original_attr and has_original_attribute
226 225
227 226 if accept:
228 227 # We still need to check for overriden properties.
229 228
230 229 value_class = type(value)
231 230 if not hasattr(value_class, attr):
232 231 return True
233 232
234 233 class_attr_val = getattr(value_class, attr)
235 234 is_property = isinstance(class_attr_val, property)
236 235
237 236 if not is_property:
238 237 return True
239 238
240 239 # Properties in allowed types are ok (although we do not include any
241 240 # properties in our default allow list currently).
242 241 if type(value) in self.allowed_getattr:
243 242 return True # pragma: no cover
244 243
245 244 # Properties in subclasses of allowed types may be ok if not changed
246 245 for module_name, *access_path in self.allowed_getattr_external:
247 246 try:
248 247 external_class = _get_external(module_name, access_path)
249 248 external_class_attr_val = getattr(external_class, attr)
250 249 except (KeyError, AttributeError):
251 250 return False # pragma: no cover
252 251 return class_attr_val == external_class_attr_val
253 252
254 253 return False
255 254
256 255 def can_get_item(self, value, item):
257 256 """Allow accessing `__getiitem__` of allow-listed instances unless it was not modified."""
258 257 return _has_original_dunder(
259 258 value,
260 259 allowed_types=self.allowed_getitem,
261 260 allowed_methods=self._getitem_methods,
262 261 allowed_external=self.allowed_getitem_external,
263 262 method_name="__getitem__",
264 263 )
265 264
266 265 def can_operate(self, dunders: Tuple[str, ...], a, b=None):
267 266 objects = [a]
268 267 if b is not None:
269 268 objects.append(b)
270 269 return all(
271 270 [
272 271 _has_original_dunder(
273 272 obj,
274 273 allowed_types=self.allowed_operations,
275 274 allowed_methods=self._operator_dunder_methods(dunder),
276 275 allowed_external=self.allowed_operations_external,
277 276 method_name=dunder,
278 277 )
279 278 for dunder in dunders
280 279 for obj in objects
281 280 ]
282 281 )
283 282
284 283 def _operator_dunder_methods(self, dunder: str) -> Set[Callable]:
285 284 if dunder not in self._operation_methods_cache:
286 285 self._operation_methods_cache[dunder] = self._safe_get_methods(
287 286 self.allowed_operations, dunder
288 287 )
289 288 return self._operation_methods_cache[dunder]
290 289
291 290 @cached_property
292 291 def _getitem_methods(self) -> Set[Callable]:
293 292 return self._safe_get_methods(self.allowed_getitem, "__getitem__")
294 293
295 294 @cached_property
296 295 def _getattr_methods(self) -> Set[Callable]:
297 296 return self._safe_get_methods(self.allowed_getattr, "__getattr__")
298 297
299 298 @cached_property
300 299 def _getattribute_methods(self) -> Set[Callable]:
301 300 return self._safe_get_methods(self.allowed_getattr, "__getattribute__")
302 301
303 302 def _safe_get_methods(self, classes, name) -> Set[Callable]:
304 303 return {
305 304 method
306 305 for class_ in classes
307 306 for method in [getattr(class_, name, None)]
308 307 if method
309 308 }
310 309
311 310
312 311 class _DummyNamedTuple(NamedTuple):
313 312 """Used internally to retrieve methods of named tuple instance."""
314 313
315 314
316 315 class EvaluationContext(NamedTuple):
317 316 #: Local namespace
318 317 locals: dict
319 318 #: Global namespace
320 319 globals: dict
321 320 #: Evaluation policy identifier
322 321 evaluation: Literal[
323 322 "forbidden", "minimal", "limited", "unsafe", "dangerous"
324 323 ] = "forbidden"
325 324 #: Whether the evalution of code takes place inside of a subscript.
326 325 #: Useful for evaluating ``:-1, 'col'`` in ``df[:-1, 'col']``.
327 326 in_subscript: bool = False
328 327
329 328
330 329 class _IdentitySubscript:
331 330 """Returns the key itself when item is requested via subscript."""
332 331
333 332 def __getitem__(self, key):
334 333 return key
335 334
336 335
337 336 IDENTITY_SUBSCRIPT = _IdentitySubscript()
338 337 SUBSCRIPT_MARKER = "__SUBSCRIPT_SENTINEL__"
339 338 UNKNOWN_SIGNATURE = Signature()
340 339 NOT_EVALUATED = object()
341 340
342 341
343 342 class GuardRejection(Exception):
344 343 """Exception raised when guard rejects evaluation attempt."""
345 344
346 345 pass
347 346
348 347
349 348 def guarded_eval(code: str, context: EvaluationContext):
350 349 """Evaluate provided code in the evaluation context.
351 350
352 351 If evaluation policy given by context is set to ``forbidden``
353 352 no evaluation will be performed; if it is set to ``dangerous``
354 353 standard :func:`eval` will be used; finally, for any other,
355 354 policy :func:`eval_node` will be called on parsed AST.
356 355 """
357 356 locals_ = context.locals
358 357
359 358 if context.evaluation == "forbidden":
360 359 raise GuardRejection("Forbidden mode")
361 360
362 361 # note: not using `ast.literal_eval` as it does not implement
363 362 # getitem at all, for example it fails on simple `[0][1]`
364 363
365 364 if context.in_subscript:
366 365 # syntatic sugar for ellipsis (:) is only available in susbcripts
367 366 # so we need to trick the ast parser into thinking that we have
368 367 # a subscript, but we need to be able to later recognise that we did
369 368 # it so we can ignore the actual __getitem__ operation
370 369 if not code:
371 370 return tuple()
372 371 locals_ = locals_.copy()
373 372 locals_[SUBSCRIPT_MARKER] = IDENTITY_SUBSCRIPT
374 373 code = SUBSCRIPT_MARKER + "[" + code + "]"
375 374 context = EvaluationContext(**{**context._asdict(), **{"locals": locals_}})
376 375
377 376 if context.evaluation == "dangerous":
378 377 return eval(code, context.globals, context.locals)
379 378
380 379 expression = ast.parse(code, mode="eval")
381 380
382 381 return eval_node(expression, context)
383 382
384 383
385 384 BINARY_OP_DUNDERS: Dict[Type[ast.operator], Tuple[str]] = {
386 385 ast.Add: ("__add__",),
387 386 ast.Sub: ("__sub__",),
388 387 ast.Mult: ("__mul__",),
389 388 ast.Div: ("__truediv__",),
390 389 ast.FloorDiv: ("__floordiv__",),
391 390 ast.Mod: ("__mod__",),
392 391 ast.Pow: ("__pow__",),
393 392 ast.LShift: ("__lshift__",),
394 393 ast.RShift: ("__rshift__",),
395 394 ast.BitOr: ("__or__",),
396 395 ast.BitXor: ("__xor__",),
397 396 ast.BitAnd: ("__and__",),
398 397 ast.MatMult: ("__matmul__",),
399 398 }
400 399
401 400 COMP_OP_DUNDERS: Dict[Type[ast.cmpop], Tuple[str, ...]] = {
402 401 ast.Eq: ("__eq__",),
403 402 ast.NotEq: ("__ne__", "__eq__"),
404 403 ast.Lt: ("__lt__", "__gt__"),
405 404 ast.LtE: ("__le__", "__ge__"),
406 405 ast.Gt: ("__gt__", "__lt__"),
407 406 ast.GtE: ("__ge__", "__le__"),
408 407 ast.In: ("__contains__",),
409 408 # Note: ast.Is, ast.IsNot, ast.NotIn are handled specially
410 409 }
411 410
412 411 UNARY_OP_DUNDERS: Dict[Type[ast.unaryop], Tuple[str, ...]] = {
413 412 ast.USub: ("__neg__",),
414 413 ast.UAdd: ("__pos__",),
415 414 # we have to check both __inv__ and __invert__!
416 415 ast.Invert: ("__invert__", "__inv__"),
417 416 ast.Not: ("__not__",),
418 417 }
419 418
420 419
421 420 class Duck:
422 421 """A dummy class used to create objects of other classes without calling their ``__init__``"""
423 422
424 423
425 424 def _find_dunder(node_op, dunders) -> Union[Tuple[str, ...], None]:
426 425 dunder = None
427 426 for op, candidate_dunder in dunders.items():
428 427 if isinstance(node_op, op):
429 428 dunder = candidate_dunder
430 429 return dunder
431 430
432 431
433 432 def eval_node(node: Union[ast.AST, None], context: EvaluationContext):
434 433 """Evaluate AST node in provided context.
435 434
436 435 Applies evaluation restrictions defined in the context. Currently does not support evaluation of functions with keyword arguments.
437 436
438 437 Does not evaluate actions that always have side effects:
439 438
440 439 - class definitions (``class sth: ...``)
441 440 - function definitions (``def sth: ...``)
442 441 - variable assignments (``x = 1``)
443 442 - augmented assignments (``x += 1``)
444 443 - deletions (``del x``)
445 444
446 445 Does not evaluate operations which do not return values:
447 446
448 447 - assertions (``assert x``)
449 448 - pass (``pass``)
450 449 - imports (``import x``)
451 450 - control flow:
452 451
453 452 - conditionals (``if x:``) except for ternary IfExp (``a if x else b``)
454 453 - loops (``for`` and ``while``)
455 454 - exception handling
456 455
457 456 The purpose of this function is to guard against unwanted side-effects;
458 457 it does not give guarantees on protection from malicious code execution.
459 458 """
460 459 policy = EVALUATION_POLICIES[context.evaluation]
461 460 if node is None:
462 461 return None
463 462 if isinstance(node, ast.Expression):
464 463 return eval_node(node.body, context)
465 464 if isinstance(node, ast.BinOp):
466 465 left = eval_node(node.left, context)
467 466 right = eval_node(node.right, context)
468 467 dunders = _find_dunder(node.op, BINARY_OP_DUNDERS)
469 468 if dunders:
470 469 if policy.can_operate(dunders, left, right):
471 470 return getattr(left, dunders[0])(right)
472 471 else:
473 472 raise GuardRejection(
474 473 f"Operation (`{dunders}`) for",
475 474 type(left),
476 475 f"not allowed in {context.evaluation} mode",
477 476 )
478 477 if isinstance(node, ast.Compare):
479 478 left = eval_node(node.left, context)
480 479 all_true = True
481 480 negate = False
482 481 for op, right in zip(node.ops, node.comparators):
483 482 right = eval_node(right, context)
484 483 dunder = None
485 484 dunders = _find_dunder(op, COMP_OP_DUNDERS)
486 485 if not dunders:
487 486 if isinstance(op, ast.NotIn):
488 487 dunders = COMP_OP_DUNDERS[ast.In]
489 488 negate = True
490 489 if isinstance(op, ast.Is):
491 490 dunder = "is_"
492 491 if isinstance(op, ast.IsNot):
493 492 dunder = "is_"
494 493 negate = True
495 494 if not dunder and dunders:
496 495 dunder = dunders[0]
497 496 if dunder:
498 497 a, b = (right, left) if dunder == "__contains__" else (left, right)
499 498 if dunder == "is_" or dunders and policy.can_operate(dunders, a, b):
500 499 result = getattr(operator, dunder)(a, b)
501 500 if negate:
502 501 result = not result
503 502 if not result:
504 503 all_true = False
505 504 left = right
506 505 else:
507 506 raise GuardRejection(
508 507 f"Comparison (`{dunder}`) for",
509 508 type(left),
510 509 f"not allowed in {context.evaluation} mode",
511 510 )
512 511 else:
513 512 raise ValueError(
514 513 f"Comparison `{dunder}` not supported"
515 514 ) # pragma: no cover
516 515 return all_true
517 516 if isinstance(node, ast.Constant):
518 517 return node.value
519 518 if isinstance(node, ast.Tuple):
520 519 return tuple(eval_node(e, context) for e in node.elts)
521 520 if isinstance(node, ast.List):
522 521 return [eval_node(e, context) for e in node.elts]
523 522 if isinstance(node, ast.Set):
524 523 return {eval_node(e, context) for e in node.elts}
525 524 if isinstance(node, ast.Dict):
526 525 return dict(
527 526 zip(
528 527 [eval_node(k, context) for k in node.keys],
529 528 [eval_node(v, context) for v in node.values],
530 529 )
531 530 )
532 531 if isinstance(node, ast.Slice):
533 532 return slice(
534 533 eval_node(node.lower, context),
535 534 eval_node(node.upper, context),
536 535 eval_node(node.step, context),
537 536 )
538 537 if isinstance(node, ast.UnaryOp):
539 538 value = eval_node(node.operand, context)
540 539 dunders = _find_dunder(node.op, UNARY_OP_DUNDERS)
541 540 if dunders:
542 541 if policy.can_operate(dunders, value):
543 542 return getattr(value, dunders[0])()
544 543 else:
545 544 raise GuardRejection(
546 545 f"Operation (`{dunders}`) for",
547 546 type(value),
548 547 f"not allowed in {context.evaluation} mode",
549 548 )
550 549 if isinstance(node, ast.Subscript):
551 550 value = eval_node(node.value, context)
552 551 slice_ = eval_node(node.slice, context)
553 552 if policy.can_get_item(value, slice_):
554 553 return value[slice_]
555 554 raise GuardRejection(
556 555 "Subscript access (`__getitem__`) for",
557 556 type(value), # not joined to avoid calling `repr`
558 557 f" not allowed in {context.evaluation} mode",
559 558 )
560 559 if isinstance(node, ast.Name):
561 if policy.allow_locals_access and node.id in context.locals:
562 return context.locals[node.id]
563 if policy.allow_globals_access and node.id in context.globals:
564 return context.globals[node.id]
565 if policy.allow_builtins_access and hasattr(builtins, node.id):
566 # note: do not use __builtins__, it is implementation detail of cPython
567 return getattr(builtins, node.id)
568 if not policy.allow_globals_access and not policy.allow_locals_access:
569 raise GuardRejection(
570 f"Namespace access not allowed in {context.evaluation} mode"
571 )
572 else:
573 raise NameError(f"{node.id} not found in locals, globals, nor builtins")
560 return _eval_node_name(node.id, policy, context)
574 561 if isinstance(node, ast.Attribute):
575 562 value = eval_node(node.value, context)
576 563 if policy.can_get_attr(value, node.attr):
577 564 return getattr(value, node.attr)
578 565 raise GuardRejection(
579 566 "Attribute access (`__getattr__`) for",
580 567 type(value), # not joined to avoid calling `repr`
581 568 f"not allowed in {context.evaluation} mode",
582 569 )
583 570 if isinstance(node, ast.IfExp):
584 571 test = eval_node(node.test, context)
585 572 if test:
586 573 return eval_node(node.body, context)
587 574 else:
588 575 return eval_node(node.orelse, context)
589 576 if isinstance(node, ast.Call):
590 577 func = eval_node(node.func, context)
591 578 if policy.can_call(func) and not node.keywords:
592 579 args = [eval_node(arg, context) for arg in node.args]
593 580 return func(*args)
594 581 if isclass(func):
595 582 # this code path gets entered when calling class e.g. `MyClass()`
596 583 # or `my_instance.__class__()` - in both cases `func` is `MyClass`.
597 584 # Should return `MyClass` if `__new__` is not overridden,
598 585 # otherwise whatever `__new__` return type is.
599 586 overridden_return_type = _eval_return_type(
600 587 func.__new__, policy, node, context
601 588 )
602 589 if overridden_return_type is not NOT_EVALUATED:
603 590 return overridden_return_type
604 591 return _create_duck_for_type(func)
605 592 else:
606 593 return_type = _eval_return_type(func, policy, node, context)
607 594 if return_type is not NOT_EVALUATED:
608 595 return return_type
609 596 raise GuardRejection(
610 597 "Call for",
611 598 func, # not joined to avoid calling `repr`
612 599 f"not allowed in {context.evaluation} mode",
613 600 )
614 601 raise ValueError("Unhandled node", ast.dump(node))
615 602
616 603
617 def _eval_return_type(func, policy, node, context):
604 def _eval_return_type(
605 func: Callable, policy: EvaluationPolicy, node: ast.Call, context: EvaluationContext
606 ):
618 607 """Evaluate return type of a given callable function.
619 608
620 609 Returns the built-in type, a duck or NOT_EVALUATED sentinel.
621 610 """
622 611 try:
623 612 sig = signature(func)
624 613 except ValueError:
625 614 sig = UNKNOWN_SIGNATURE
626 615 # if annotation was not stringized, or it was stringized
627 616 # but resolved by signature call we know the return type
628 617 not_empty = sig.return_annotation is not Signature.empty
629 not_stringized = not isinstance(sig.return_annotation, str)
630 if not_empty and not_stringized:
618 stringized = isinstance(sig.return_annotation, str)
619 if not_empty:
620 return_type = (
621 _eval_node_name(sig.return_annotation, policy, context)
622 if stringized
623 else sig.return_annotation
624 )
631 625 # if allow-listed builtin is on type annotation, instantiate it
632 if policy.can_call(sig.return_annotation) and not node.keywords:
626 if policy.can_call(return_type) and not node.keywords:
633 627 args = [eval_node(arg, context) for arg in node.args]
634 628 # if custom class is in type annotation, mock it;
635 return sig.return_annotation(*args)
636 return _create_duck_for_type(sig.return_annotation)
629 return return_type(*args)
630 return _create_duck_for_type(return_type)
637 631 return NOT_EVALUATED
638 632
639 633
634 def _eval_node_name(node_id: str, policy: EvaluationPolicy, context: EvaluationContext):
635 if policy.allow_locals_access and node_id in context.locals:
636 return context.locals[node_id]
637 if policy.allow_globals_access and node_id in context.globals:
638 return context.globals[node_id]
639 if policy.allow_builtins_access and hasattr(builtins, node_id):
640 # note: do not use __builtins__, it is implementation detail of cPython
641 return getattr(builtins, node_id)
642 if not policy.allow_globals_access and not policy.allow_locals_access:
643 raise GuardRejection(
644 f"Namespace access not allowed in {context.evaluation} mode"
645 )
646 else:
647 raise NameError(f"{node_id} not found in locals, globals, nor builtins")
648
649
640 650 def _create_duck_for_type(duck_type):
641 651 """Create an imitation of an object of a given type (a duck).
642 652
643 653 Returns the duck or NOT_EVALUATED sentinel if duck could not be created.
644 654 """
645 655 duck = Duck()
646 656 try:
647 657 # this only works for heap types, not builtins
648 658 duck.__class__ = duck_type
649 659 return duck
650 660 except TypeError:
651 661 pass
652 662 return NOT_EVALUATED
653 663
654 664
655 665 SUPPORTED_EXTERNAL_GETITEM = {
656 666 ("pandas", "core", "indexing", "_iLocIndexer"),
657 667 ("pandas", "core", "indexing", "_LocIndexer"),
658 668 ("pandas", "DataFrame"),
659 669 ("pandas", "Series"),
660 670 ("numpy", "ndarray"),
661 671 ("numpy", "void"),
662 672 }
663 673
664 674
665 675 BUILTIN_GETITEM: Set[InstancesHaveGetItem] = {
666 676 dict,
667 677 str, # type: ignore[arg-type]
668 678 bytes, # type: ignore[arg-type]
669 679 list,
670 680 tuple,
671 681 collections.defaultdict,
672 682 collections.deque,
673 683 collections.OrderedDict,
674 684 collections.ChainMap,
675 685 collections.UserDict,
676 686 collections.UserList,
677 687 collections.UserString, # type: ignore[arg-type]
678 688 _DummyNamedTuple,
679 689 _IdentitySubscript,
680 690 }
681 691
682 692
683 693 def _list_methods(cls, source=None):
684 694 """For use on immutable objects or with methods returning a copy"""
685 695 return [getattr(cls, k) for k in (source if source else dir(cls))]
686 696
687 697
688 698 dict_non_mutating_methods = ("copy", "keys", "values", "items")
689 699 list_non_mutating_methods = ("copy", "index", "count")
690 700 set_non_mutating_methods = set(dir(set)) & set(dir(frozenset))
691 701
692 702
693 703 dict_keys: Type[collections.abc.KeysView] = type({}.keys())
694 704
695 705 NUMERICS = {int, float, complex}
696 706
697 707 ALLOWED_CALLS = {
698 708 bytes,
699 709 *_list_methods(bytes),
700 710 dict,
701 711 *_list_methods(dict, dict_non_mutating_methods),
702 712 dict_keys.isdisjoint,
703 713 list,
704 714 *_list_methods(list, list_non_mutating_methods),
705 715 set,
706 716 *_list_methods(set, set_non_mutating_methods),
707 717 frozenset,
708 718 *_list_methods(frozenset),
709 719 range,
710 720 str,
711 721 *_list_methods(str),
712 722 tuple,
713 723 *_list_methods(tuple),
714 724 *NUMERICS,
715 725 *[method for numeric_cls in NUMERICS for method in _list_methods(numeric_cls)],
716 726 collections.deque,
717 727 *_list_methods(collections.deque, list_non_mutating_methods),
718 728 collections.defaultdict,
719 729 *_list_methods(collections.defaultdict, dict_non_mutating_methods),
720 730 collections.OrderedDict,
721 731 *_list_methods(collections.OrderedDict, dict_non_mutating_methods),
722 732 collections.UserDict,
723 733 *_list_methods(collections.UserDict, dict_non_mutating_methods),
724 734 collections.UserList,
725 735 *_list_methods(collections.UserList, list_non_mutating_methods),
726 736 collections.UserString,
727 737 *_list_methods(collections.UserString, dir(str)),
728 738 collections.Counter,
729 739 *_list_methods(collections.Counter, dict_non_mutating_methods),
730 740 collections.Counter.elements,
731 741 collections.Counter.most_common,
732 742 }
733 743
734 744 BUILTIN_GETATTR: Set[MayHaveGetattr] = {
735 745 *BUILTIN_GETITEM,
736 746 set,
737 747 frozenset,
738 748 object,
739 749 type, # `type` handles a lot of generic cases, e.g. numbers as in `int.real`.
740 750 *NUMERICS,
741 751 dict_keys,
742 752 MethodDescriptorType,
743 753 ModuleType,
744 754 }
745 755
746 756
747 757 BUILTIN_OPERATIONS = {*BUILTIN_GETATTR}
748 758
749 759 EVALUATION_POLICIES = {
750 760 "minimal": EvaluationPolicy(
751 761 allow_builtins_access=True,
752 762 allow_locals_access=False,
753 763 allow_globals_access=False,
754 764 allow_item_access=False,
755 765 allow_attr_access=False,
756 766 allowed_calls=set(),
757 767 allow_any_calls=False,
758 768 allow_all_operations=False,
759 769 ),
760 770 "limited": SelectivePolicy(
761 771 allowed_getitem=BUILTIN_GETITEM,
762 772 allowed_getitem_external=SUPPORTED_EXTERNAL_GETITEM,
763 773 allowed_getattr=BUILTIN_GETATTR,
764 774 allowed_getattr_external={
765 775 # pandas Series/Frame implements custom `__getattr__`
766 776 ("pandas", "DataFrame"),
767 777 ("pandas", "Series"),
768 778 },
769 779 allowed_operations=BUILTIN_OPERATIONS,
770 780 allow_builtins_access=True,
771 781 allow_locals_access=True,
772 782 allow_globals_access=True,
773 783 allowed_calls=ALLOWED_CALLS,
774 784 ),
775 785 "unsafe": EvaluationPolicy(
776 786 allow_builtins_access=True,
777 787 allow_locals_access=True,
778 788 allow_globals_access=True,
779 789 allow_attr_access=True,
780 790 allow_item_access=True,
781 791 allow_any_calls=True,
782 792 allow_all_operations=True,
783 793 ),
784 794 }
785 795
786 796
787 797 __all__ = [
788 798 "guarded_eval",
789 799 "eval_node",
790 800 "GuardRejection",
791 801 "EvaluationContext",
792 802 "_unbind_method",
793 803 ]
@@ -1,631 +1,641 b''
1 1 from contextlib import contextmanager
2 2 from typing import NamedTuple
3 3 from functools import partial
4 4 from IPython.core.guarded_eval import (
5 5 EvaluationContext,
6 6 GuardRejection,
7 7 guarded_eval,
8 8 _unbind_method,
9 9 )
10 10 from IPython.testing import decorators as dec
11 11 import pytest
12 12
13 13
14 14 def create_context(evaluation: str, **kwargs):
15 15 return EvaluationContext(locals=kwargs, globals={}, evaluation=evaluation)
16 16
17 17
18 18 forbidden = partial(create_context, "forbidden")
19 19 minimal = partial(create_context, "minimal")
20 20 limited = partial(create_context, "limited")
21 21 unsafe = partial(create_context, "unsafe")
22 22 dangerous = partial(create_context, "dangerous")
23 23
24 24 LIMITED_OR_HIGHER = [limited, unsafe, dangerous]
25 25 MINIMAL_OR_HIGHER = [minimal, *LIMITED_OR_HIGHER]
26 26
27 27
28 28 @contextmanager
29 29 def module_not_installed(module: str):
30 30 import sys
31 31
32 32 try:
33 33 to_restore = sys.modules[module]
34 34 del sys.modules[module]
35 35 except KeyError:
36 36 to_restore = None
37 37 try:
38 38 yield
39 39 finally:
40 40 sys.modules[module] = to_restore
41 41
42 42
43 43 def test_external_not_installed():
44 44 """
45 45 Because attribute check requires checking if object is not of allowed
46 46 external type, this tests logic for absence of external module.
47 47 """
48 48
49 49 class Custom:
50 50 def __init__(self):
51 51 self.test = 1
52 52
53 53 def __getattr__(self, key):
54 54 return key
55 55
56 56 with module_not_installed("pandas"):
57 57 context = limited(x=Custom())
58 58 with pytest.raises(GuardRejection):
59 59 guarded_eval("x.test", context)
60 60
61 61
62 62 @dec.skip_without("pandas")
63 63 def test_external_changed_api(monkeypatch):
64 64 """Check that the execution rejects if external API changed paths"""
65 65 import pandas as pd
66 66
67 67 series = pd.Series([1], index=["a"])
68 68
69 69 with monkeypatch.context() as m:
70 70 m.delattr(pd, "Series")
71 71 context = limited(data=series)
72 72 with pytest.raises(GuardRejection):
73 73 guarded_eval("data.iloc[0]", context)
74 74
75 75
76 76 @dec.skip_without("pandas")
77 77 def test_pandas_series_iloc():
78 78 import pandas as pd
79 79
80 80 series = pd.Series([1], index=["a"])
81 81 context = limited(data=series)
82 82 assert guarded_eval("data.iloc[0]", context) == 1
83 83
84 84
85 85 def test_rejects_custom_properties():
86 86 class BadProperty:
87 87 @property
88 88 def iloc(self):
89 89 return [None]
90 90
91 91 series = BadProperty()
92 92 context = limited(data=series)
93 93
94 94 with pytest.raises(GuardRejection):
95 95 guarded_eval("data.iloc[0]", context)
96 96
97 97
98 98 @dec.skip_without("pandas")
99 99 def test_accepts_non_overriden_properties():
100 100 import pandas as pd
101 101
102 102 class GoodProperty(pd.Series):
103 103 pass
104 104
105 105 series = GoodProperty([1], index=["a"])
106 106 context = limited(data=series)
107 107
108 108 assert guarded_eval("data.iloc[0]", context) == 1
109 109
110 110
111 111 @dec.skip_without("pandas")
112 112 def test_pandas_series():
113 113 import pandas as pd
114 114
115 115 context = limited(data=pd.Series([1], index=["a"]))
116 116 assert guarded_eval('data["a"]', context) == 1
117 117 with pytest.raises(KeyError):
118 118 guarded_eval('data["c"]', context)
119 119
120 120
121 121 @dec.skip_without("pandas")
122 122 def test_pandas_bad_series():
123 123 import pandas as pd
124 124
125 125 class BadItemSeries(pd.Series):
126 126 def __getitem__(self, key):
127 127 return "CUSTOM_ITEM"
128 128
129 129 class BadAttrSeries(pd.Series):
130 130 def __getattr__(self, key):
131 131 return "CUSTOM_ATTR"
132 132
133 133 bad_series = BadItemSeries([1], index=["a"])
134 134 context = limited(data=bad_series)
135 135
136 136 with pytest.raises(GuardRejection):
137 137 guarded_eval('data["a"]', context)
138 138 with pytest.raises(GuardRejection):
139 139 guarded_eval('data["c"]', context)
140 140
141 141 # note: here result is a bit unexpected because
142 142 # pandas `__getattr__` calls `__getitem__`;
143 143 # FIXME - special case to handle it?
144 144 assert guarded_eval("data.a", context) == "CUSTOM_ITEM"
145 145
146 146 context = unsafe(data=bad_series)
147 147 assert guarded_eval('data["a"]', context) == "CUSTOM_ITEM"
148 148
149 149 bad_attr_series = BadAttrSeries([1], index=["a"])
150 150 context = limited(data=bad_attr_series)
151 151 assert guarded_eval('data["a"]', context) == 1
152 152 with pytest.raises(GuardRejection):
153 153 guarded_eval("data.a", context)
154 154
155 155
156 156 @dec.skip_without("pandas")
157 157 def test_pandas_dataframe_loc():
158 158 import pandas as pd
159 159 from pandas.testing import assert_series_equal
160 160
161 161 data = pd.DataFrame([{"a": 1}])
162 162 context = limited(data=data)
163 163 assert_series_equal(guarded_eval('data.loc[:, "a"]', context), data["a"])
164 164
165 165
166 166 def test_named_tuple():
167 167 class GoodNamedTuple(NamedTuple):
168 168 a: str
169 169 pass
170 170
171 171 class BadNamedTuple(NamedTuple):
172 172 a: str
173 173
174 174 def __getitem__(self, key):
175 175 return None
176 176
177 177 good = GoodNamedTuple(a="x")
178 178 bad = BadNamedTuple(a="x")
179 179
180 180 context = limited(data=good)
181 181 assert guarded_eval("data[0]", context) == "x"
182 182
183 183 context = limited(data=bad)
184 184 with pytest.raises(GuardRejection):
185 185 guarded_eval("data[0]", context)
186 186
187 187
188 188 def test_dict():
189 189 context = limited(data={"a": 1, "b": {"x": 2}, ("x", "y"): 3})
190 190 assert guarded_eval('data["a"]', context) == 1
191 191 assert guarded_eval('data["b"]', context) == {"x": 2}
192 192 assert guarded_eval('data["b"]["x"]', context) == 2
193 193 assert guarded_eval('data["x", "y"]', context) == 3
194 194
195 195 assert guarded_eval("data.keys", context)
196 196
197 197
198 198 def test_set():
199 199 context = limited(data={"a", "b"})
200 200 assert guarded_eval("data.difference", context)
201 201
202 202
203 203 def test_list():
204 204 context = limited(data=[1, 2, 3])
205 205 assert guarded_eval("data[1]", context) == 2
206 206 assert guarded_eval("data.copy", context)
207 207
208 208
209 209 def test_dict_literal():
210 210 context = limited()
211 211 assert guarded_eval("{}", context) == {}
212 212 assert guarded_eval('{"a": 1}', context) == {"a": 1}
213 213
214 214
215 215 def test_list_literal():
216 216 context = limited()
217 217 assert guarded_eval("[]", context) == []
218 218 assert guarded_eval('[1, "a"]', context) == [1, "a"]
219 219
220 220
221 221 def test_set_literal():
222 222 context = limited()
223 223 assert guarded_eval("set()", context) == set()
224 224 assert guarded_eval('{"a"}', context) == {"a"}
225 225
226 226
227 227 def test_evaluates_if_expression():
228 228 context = limited()
229 229 assert guarded_eval("2 if True else 3", context) == 2
230 230 assert guarded_eval("4 if False else 5", context) == 5
231 231
232 232
233 233 def test_object():
234 234 obj = object()
235 235 context = limited(obj=obj)
236 236 assert guarded_eval("obj.__dir__", context) == obj.__dir__
237 237
238 238
239 239 @pytest.mark.parametrize(
240 240 "code,expected",
241 241 [
242 242 ["int.numerator", int.numerator],
243 243 ["float.is_integer", float.is_integer],
244 244 ["complex.real", complex.real],
245 245 ],
246 246 )
247 247 def test_number_attributes(code, expected):
248 248 assert guarded_eval(code, limited()) == expected
249 249
250 250
251 251 def test_method_descriptor():
252 252 context = limited()
253 253 assert guarded_eval("list.copy.__name__", context) == "copy"
254 254
255 255
256 256 class HeapType:
257 257 pass
258 258
259 259
260 260 class CallCreatesHeapType:
261 261 def __call__(self) -> HeapType:
262 262 return HeapType()
263 263
264 264
265 265 class CallCreatesBuiltin:
266 266 def __call__(self) -> frozenset:
267 267 return frozenset()
268 268
269 269
270 270 class HasStaticMethod:
271 271 @staticmethod
272 272 def static_method() -> HeapType:
273 273 return HeapType()
274 274
275 275
276 276 class InitReturnsFrozenset:
277 277 def __new__(self) -> frozenset: # type:ignore[misc]
278 278 return frozenset()
279 279
280 280
281 class StringAnnotation:
282 def heap(self) -> "HeapType":
283 return HeapType()
284
285 def copy(self) -> "StringAnnotation":
286 return StringAnnotation()
287
288
281 289 @pytest.mark.parametrize(
282 290 "data,good,expected,equality",
283 291 [
284 292 [[1, 2, 3], "data.index(2)", 1, True],
285 293 [{"a": 1}, "data.keys().isdisjoint({})", True, True],
294 [StringAnnotation(), "data.heap()", HeapType, False],
295 [StringAnnotation(), "data.copy()", StringAnnotation, False],
286 296 # test cases for `__call__`
287 297 [CallCreatesHeapType(), "data()", HeapType, False],
288 298 [CallCreatesBuiltin(), "data()", frozenset, False],
289 299 # Test cases for `__init__`
290 300 [HeapType, "data()", HeapType, False],
291 301 [InitReturnsFrozenset, "data()", frozenset, False],
292 302 [HeapType(), "data.__class__()", HeapType, False],
293 # test cases for static and class methods
303 # test cases for static methods
294 304 [HasStaticMethod, "data.static_method()", HeapType, False],
295 305 ],
296 306 )
297 307 def test_evaluates_calls(data, good, expected, equality):
298 context = limited(data=data)
308 context = limited(data=data, HeapType=HeapType, StringAnnotation=StringAnnotation)
299 309 value = guarded_eval(good, context)
300 310 if equality:
301 311 assert value == expected
302 312 else:
303 313 assert isinstance(value, expected)
304 314
305 315
306 316 @pytest.mark.parametrize(
307 317 "data,bad",
308 318 [
309 319 [[1, 2, 3], "data.append(4)"],
310 320 [{"a": 1}, "data.update()"],
311 321 ],
312 322 )
313 323 def test_rejects_calls_with_side_effects(data, bad):
314 324 context = limited(data=data)
315 325
316 326 with pytest.raises(GuardRejection):
317 327 guarded_eval(bad, context)
318 328
319 329
320 330 @pytest.mark.parametrize(
321 331 "code,expected",
322 332 [
323 333 ["(1\n+\n1)", 2],
324 334 ["list(range(10))[-1:]", [9]],
325 335 ["list(range(20))[3:-2:3]", [3, 6, 9, 12, 15]],
326 336 ],
327 337 )
328 338 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
329 339 def test_evaluates_complex_cases(code, expected, context):
330 340 assert guarded_eval(code, context()) == expected
331 341
332 342
333 343 @pytest.mark.parametrize(
334 344 "code,expected",
335 345 [
336 346 ["1", 1],
337 347 ["1.0", 1.0],
338 348 ["0xdeedbeef", 0xDEEDBEEF],
339 349 ["True", True],
340 350 ["None", None],
341 351 ["{}", {}],
342 352 ["[]", []],
343 353 ],
344 354 )
345 355 @pytest.mark.parametrize("context", MINIMAL_OR_HIGHER)
346 356 def test_evaluates_literals(code, expected, context):
347 357 assert guarded_eval(code, context()) == expected
348 358
349 359
350 360 @pytest.mark.parametrize(
351 361 "code,expected",
352 362 [
353 363 ["-5", -5],
354 364 ["+5", +5],
355 365 ["~5", -6],
356 366 ],
357 367 )
358 368 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
359 369 def test_evaluates_unary_operations(code, expected, context):
360 370 assert guarded_eval(code, context()) == expected
361 371
362 372
363 373 @pytest.mark.parametrize(
364 374 "code,expected",
365 375 [
366 376 ["1 + 1", 2],
367 377 ["3 - 1", 2],
368 378 ["2 * 3", 6],
369 379 ["5 // 2", 2],
370 380 ["5 / 2", 2.5],
371 381 ["5**2", 25],
372 382 ["2 >> 1", 1],
373 383 ["2 << 1", 4],
374 384 ["1 | 2", 3],
375 385 ["1 & 1", 1],
376 386 ["1 & 2", 0],
377 387 ],
378 388 )
379 389 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
380 390 def test_evaluates_binary_operations(code, expected, context):
381 391 assert guarded_eval(code, context()) == expected
382 392
383 393
384 394 @pytest.mark.parametrize(
385 395 "code,expected",
386 396 [
387 397 ["2 > 1", True],
388 398 ["2 < 1", False],
389 399 ["2 <= 1", False],
390 400 ["2 <= 2", True],
391 401 ["1 >= 2", False],
392 402 ["2 >= 2", True],
393 403 ["2 == 2", True],
394 404 ["1 == 2", False],
395 405 ["1 != 2", True],
396 406 ["1 != 1", False],
397 407 ["1 < 4 < 3", False],
398 408 ["(1 < 4) < 3", True],
399 409 ["4 > 3 > 2 > 1", True],
400 410 ["4 > 3 > 2 > 9", False],
401 411 ["1 < 2 < 3 < 4", True],
402 412 ["9 < 2 < 3 < 4", False],
403 413 ["1 < 2 > 1 > 0 > -1 < 1", True],
404 414 ["1 in [1] in [[1]]", True],
405 415 ["1 in [1] in [[2]]", False],
406 416 ["1 in [1]", True],
407 417 ["0 in [1]", False],
408 418 ["1 not in [1]", False],
409 419 ["0 not in [1]", True],
410 420 ["True is True", True],
411 421 ["False is False", True],
412 422 ["True is False", False],
413 423 ["True is not True", False],
414 424 ["False is not True", True],
415 425 ],
416 426 )
417 427 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
418 428 def test_evaluates_comparisons(code, expected, context):
419 429 assert guarded_eval(code, context()) == expected
420 430
421 431
422 432 def test_guards_comparisons():
423 433 class GoodEq(int):
424 434 pass
425 435
426 436 class BadEq(int):
427 437 def __eq__(self, other):
428 438 assert False
429 439
430 440 context = limited(bad=BadEq(1), good=GoodEq(1))
431 441
432 442 with pytest.raises(GuardRejection):
433 443 guarded_eval("bad == 1", context)
434 444
435 445 with pytest.raises(GuardRejection):
436 446 guarded_eval("bad != 1", context)
437 447
438 448 with pytest.raises(GuardRejection):
439 449 guarded_eval("1 == bad", context)
440 450
441 451 with pytest.raises(GuardRejection):
442 452 guarded_eval("1 != bad", context)
443 453
444 454 assert guarded_eval("good == 1", context) is True
445 455 assert guarded_eval("good != 1", context) is False
446 456 assert guarded_eval("1 == good", context) is True
447 457 assert guarded_eval("1 != good", context) is False
448 458
449 459
450 460 def test_guards_unary_operations():
451 461 class GoodOp(int):
452 462 pass
453 463
454 464 class BadOpInv(int):
455 465 def __inv__(self, other):
456 466 assert False
457 467
458 468 class BadOpInverse(int):
459 469 def __inv__(self, other):
460 470 assert False
461 471
462 472 context = limited(good=GoodOp(1), bad1=BadOpInv(1), bad2=BadOpInverse(1))
463 473
464 474 with pytest.raises(GuardRejection):
465 475 guarded_eval("~bad1", context)
466 476
467 477 with pytest.raises(GuardRejection):
468 478 guarded_eval("~bad2", context)
469 479
470 480
471 481 def test_guards_binary_operations():
472 482 class GoodOp(int):
473 483 pass
474 484
475 485 class BadOp(int):
476 486 def __add__(self, other):
477 487 assert False
478 488
479 489 context = limited(good=GoodOp(1), bad=BadOp(1))
480 490
481 491 with pytest.raises(GuardRejection):
482 492 guarded_eval("1 + bad", context)
483 493
484 494 with pytest.raises(GuardRejection):
485 495 guarded_eval("bad + 1", context)
486 496
487 497 assert guarded_eval("good + 1", context) == 2
488 498 assert guarded_eval("1 + good", context) == 2
489 499
490 500
491 501 def test_guards_attributes():
492 502 class GoodAttr(float):
493 503 pass
494 504
495 505 class BadAttr1(float):
496 506 def __getattr__(self, key):
497 507 assert False
498 508
499 509 class BadAttr2(float):
500 510 def __getattribute__(self, key):
501 511 assert False
502 512
503 513 context = limited(good=GoodAttr(0.5), bad1=BadAttr1(0.5), bad2=BadAttr2(0.5))
504 514
505 515 with pytest.raises(GuardRejection):
506 516 guarded_eval("bad1.as_integer_ratio", context)
507 517
508 518 with pytest.raises(GuardRejection):
509 519 guarded_eval("bad2.as_integer_ratio", context)
510 520
511 521 assert guarded_eval("good.as_integer_ratio()", context) == (1, 2)
512 522
513 523
514 524 @pytest.mark.parametrize("context", MINIMAL_OR_HIGHER)
515 525 def test_access_builtins(context):
516 526 assert guarded_eval("round", context()) == round
517 527
518 528
519 529 def test_access_builtins_fails():
520 530 context = limited()
521 531 with pytest.raises(NameError):
522 532 guarded_eval("this_is_not_builtin", context)
523 533
524 534
525 535 def test_rejects_forbidden():
526 536 context = forbidden()
527 537 with pytest.raises(GuardRejection):
528 538 guarded_eval("1", context)
529 539
530 540
531 541 def test_guards_locals_and_globals():
532 542 context = EvaluationContext(
533 543 locals={"local_a": "a"}, globals={"global_b": "b"}, evaluation="minimal"
534 544 )
535 545
536 546 with pytest.raises(GuardRejection):
537 547 guarded_eval("local_a", context)
538 548
539 549 with pytest.raises(GuardRejection):
540 550 guarded_eval("global_b", context)
541 551
542 552
543 553 def test_access_locals_and_globals():
544 554 context = EvaluationContext(
545 555 locals={"local_a": "a"}, globals={"global_b": "b"}, evaluation="limited"
546 556 )
547 557 assert guarded_eval("local_a", context) == "a"
548 558 assert guarded_eval("global_b", context) == "b"
549 559
550 560
551 561 @pytest.mark.parametrize(
552 562 "code",
553 563 ["def func(): pass", "class C: pass", "x = 1", "x += 1", "del x", "import ast"],
554 564 )
555 565 @pytest.mark.parametrize("context", [minimal(), limited(), unsafe()])
556 566 def test_rejects_side_effect_syntax(code, context):
557 567 with pytest.raises(SyntaxError):
558 568 guarded_eval(code, context)
559 569
560 570
561 571 def test_subscript():
562 572 context = EvaluationContext(
563 573 locals={}, globals={}, evaluation="limited", in_subscript=True
564 574 )
565 575 empty_slice = slice(None, None, None)
566 576 assert guarded_eval("", context) == tuple()
567 577 assert guarded_eval(":", context) == empty_slice
568 578 assert guarded_eval("1:2:3", context) == slice(1, 2, 3)
569 579 assert guarded_eval(':, "a"', context) == (empty_slice, "a")
570 580
571 581
572 582 def test_unbind_method():
573 583 class X(list):
574 584 def index(self, k):
575 585 return "CUSTOM"
576 586
577 587 x = X()
578 588 assert _unbind_method(x.index) is X.index
579 589 assert _unbind_method([].index) is list.index
580 590 assert _unbind_method(list.index) is None
581 591
582 592
583 593 def test_assumption_instance_attr_do_not_matter():
584 594 """This is semi-specified in Python documentation.
585 595
586 596 However, since the specification says 'not guaranteed
587 597 to work' rather than 'is forbidden to work', future
588 598 versions could invalidate this assumptions. This test
589 599 is meant to catch such a change if it ever comes true.
590 600 """
591 601
592 602 class T:
593 603 def __getitem__(self, k):
594 604 return "a"
595 605
596 606 def __getattr__(self, k):
597 607 return "a"
598 608
599 609 def f(self):
600 610 return "b"
601 611
602 612 t = T()
603 613 t.__getitem__ = f
604 614 t.__getattr__ = f
605 615 assert t[1] == "a"
606 616 assert t[1] == "a"
607 617
608 618
609 619 def test_assumption_named_tuples_share_getitem():
610 620 """Check assumption on named tuples sharing __getitem__"""
611 621 from typing import NamedTuple
612 622
613 623 class A(NamedTuple):
614 624 pass
615 625
616 626 class B(NamedTuple):
617 627 pass
618 628
619 629 assert A.__getitem__ == B.__getitem__
620 630
621 631
622 632 @dec.skip_without("numpy")
623 633 def test_module_access():
624 634 import numpy
625 635
626 636 context = limited(numpy=numpy)
627 637 assert guarded_eval("numpy.linalg.norm", context) == numpy.linalg.norm
628 638
629 639 context = minimal(numpy=numpy)
630 640 with pytest.raises(GuardRejection):
631 641 guarded_eval("np.linalg.norm", context)
General Comments 0
You need to be logged in to leave comments. Login now