##// END OF EJS Templates
Infer type for `__init__` calls (including `__new__` mods)
krassowski -
Show More
@@ -1,760 +1,793 b''
1 from inspect import signature, Signature
1 from inspect import isclass, signature, Signature
2 2 from typing import (
3 3 Any,
4 4 Callable,
5 5 Dict,
6 6 Set,
7 7 Sequence,
8 8 Tuple,
9 9 NamedTuple,
10 10 Type,
11 11 Literal,
12 12 Union,
13 13 TYPE_CHECKING,
14 14 )
15 15 import ast
16 16 import builtins
17 17 import collections
18 18 import operator
19 19 import sys
20 20 from functools import cached_property
21 21 from dataclasses import dataclass, field
22 22 from types import MethodDescriptorType, ModuleType
23 23
24 24 from IPython.utils.docs import GENERATING_DOCUMENTATION
25 25 from IPython.utils.decorators import undoc
26 26
27 27
28 28 if TYPE_CHECKING or GENERATING_DOCUMENTATION:
29 29 from typing_extensions import Protocol
30 30 else:
31 31 # do not require on runtime
32 32 Protocol = object # requires Python >=3.8
33 33
34 34
35 35 @undoc
36 36 class HasGetItem(Protocol):
37 37 def __getitem__(self, key) -> None:
38 38 ...
39 39
40 40
41 41 @undoc
42 42 class InstancesHaveGetItem(Protocol):
43 43 def __call__(self, *args, **kwargs) -> HasGetItem:
44 44 ...
45 45
46 46
47 47 @undoc
48 48 class HasGetAttr(Protocol):
49 49 def __getattr__(self, key) -> None:
50 50 ...
51 51
52 52
53 53 @undoc
54 54 class DoesNotHaveGetAttr(Protocol):
55 55 pass
56 56
57 57
58 58 # By default `__getattr__` is not explicitly implemented on most objects
59 59 MayHaveGetattr = Union[HasGetAttr, DoesNotHaveGetAttr]
60 60
61 61
62 62 def _unbind_method(func: Callable) -> Union[Callable, None]:
63 63 """Get unbound method for given bound method.
64 64
65 65 Returns None if cannot get unbound method, or method is already unbound.
66 66 """
67 67 owner = getattr(func, "__self__", None)
68 68 owner_class = type(owner)
69 69 name = getattr(func, "__name__", None)
70 70 instance_dict_overrides = getattr(owner, "__dict__", None)
71 71 if (
72 72 owner is not None
73 73 and name
74 74 and (
75 75 not instance_dict_overrides
76 76 or (instance_dict_overrides and name not in instance_dict_overrides)
77 77 )
78 78 ):
79 79 return getattr(owner_class, name)
80 80 return None
81 81
82 82
83 83 @undoc
84 84 @dataclass
85 85 class EvaluationPolicy:
86 86 """Definition of evaluation policy."""
87 87
88 88 allow_locals_access: bool = False
89 89 allow_globals_access: bool = False
90 90 allow_item_access: bool = False
91 91 allow_attr_access: bool = False
92 92 allow_builtins_access: bool = False
93 93 allow_all_operations: bool = False
94 94 allow_any_calls: bool = False
95 95 allowed_calls: Set[Callable] = field(default_factory=set)
96 96
97 97 def can_get_item(self, value, item):
98 98 return self.allow_item_access
99 99
100 100 def can_get_attr(self, value, attr):
101 101 return self.allow_attr_access
102 102
103 103 def can_operate(self, dunders: Tuple[str, ...], a, b=None):
104 104 if self.allow_all_operations:
105 105 return True
106 106
107 107 def can_call(self, func):
108 108 if self.allow_any_calls:
109 109 return True
110 110
111 111 if func in self.allowed_calls:
112 112 return True
113 113
114 114 owner_method = _unbind_method(func)
115 115
116 116 if owner_method and owner_method in self.allowed_calls:
117 117 return True
118 118
119 119
120 120 def _get_external(module_name: str, access_path: Sequence[str]):
121 121 """Get value from external module given a dotted access path.
122 122
123 123 Raises:
124 124 * `KeyError` if module is removed not found, and
125 125 * `AttributeError` if acess path does not match an exported object
126 126 """
127 127 member_type = sys.modules[module_name]
128 128 for attr in access_path:
129 129 member_type = getattr(member_type, attr)
130 130 return member_type
131 131
132 132
133 133 def _has_original_dunder_external(
134 134 value,
135 135 module_name: str,
136 136 access_path: Sequence[str],
137 137 method_name: str,
138 138 ):
139 139 if module_name not in sys.modules:
140 140 # LBYLB as it is faster
141 141 return False
142 142 try:
143 143 member_type = _get_external(module_name, access_path)
144 144 value_type = type(value)
145 145 if type(value) == member_type:
146 146 return True
147 147 if method_name == "__getattribute__":
148 148 # we have to short-circuit here due to an unresolved issue in
149 149 # `isinstance` implementation: https://bugs.python.org/issue32683
150 150 return False
151 151 if isinstance(value, member_type):
152 152 method = getattr(value_type, method_name, None)
153 153 member_method = getattr(member_type, method_name, None)
154 154 if member_method == method:
155 155 return True
156 156 except (AttributeError, KeyError):
157 157 return False
158 158
159 159
160 160 def _has_original_dunder(
161 161 value, allowed_types, allowed_methods, allowed_external, method_name
162 162 ):
163 163 # note: Python ignores `__getattr__`/`__getitem__` on instances,
164 164 # we only need to check at class level
165 165 value_type = type(value)
166 166
167 167 # strict type check passes β†’ no need to check method
168 168 if value_type in allowed_types:
169 169 return True
170 170
171 171 method = getattr(value_type, method_name, None)
172 172
173 173 if method is None:
174 174 return None
175 175
176 176 if method in allowed_methods:
177 177 return True
178 178
179 179 for module_name, *access_path in allowed_external:
180 180 if _has_original_dunder_external(value, module_name, access_path, method_name):
181 181 return True
182 182
183 183 return False
184 184
185 185
186 186 @undoc
187 187 @dataclass
188 188 class SelectivePolicy(EvaluationPolicy):
189 189 allowed_getitem: Set[InstancesHaveGetItem] = field(default_factory=set)
190 190 allowed_getitem_external: Set[Tuple[str, ...]] = field(default_factory=set)
191 191
192 192 allowed_getattr: Set[MayHaveGetattr] = field(default_factory=set)
193 193 allowed_getattr_external: Set[Tuple[str, ...]] = field(default_factory=set)
194 194
195 195 allowed_operations: Set = field(default_factory=set)
196 196 allowed_operations_external: Set[Tuple[str, ...]] = field(default_factory=set)
197 197
198 198 _operation_methods_cache: Dict[str, Set[Callable]] = field(
199 199 default_factory=dict, init=False
200 200 )
201 201
202 202 def can_get_attr(self, value, attr):
203 203 has_original_attribute = _has_original_dunder(
204 204 value,
205 205 allowed_types=self.allowed_getattr,
206 206 allowed_methods=self._getattribute_methods,
207 207 allowed_external=self.allowed_getattr_external,
208 208 method_name="__getattribute__",
209 209 )
210 210 has_original_attr = _has_original_dunder(
211 211 value,
212 212 allowed_types=self.allowed_getattr,
213 213 allowed_methods=self._getattr_methods,
214 214 allowed_external=self.allowed_getattr_external,
215 215 method_name="__getattr__",
216 216 )
217 217
218 218 accept = False
219 219
220 220 # Many objects do not have `__getattr__`, this is fine.
221 221 if has_original_attr is None and has_original_attribute:
222 222 accept = True
223 223 else:
224 224 # Accept objects without modifications to `__getattr__` and `__getattribute__`
225 225 accept = has_original_attr and has_original_attribute
226 226
227 227 if accept:
228 228 # We still need to check for overriden properties.
229 229
230 230 value_class = type(value)
231 231 if not hasattr(value_class, attr):
232 232 return True
233 233
234 234 class_attr_val = getattr(value_class, attr)
235 235 is_property = isinstance(class_attr_val, property)
236 236
237 237 if not is_property:
238 238 return True
239 239
240 240 # Properties in allowed types are ok (although we do not include any
241 241 # properties in our default allow list currently).
242 242 if type(value) in self.allowed_getattr:
243 243 return True # pragma: no cover
244 244
245 245 # Properties in subclasses of allowed types may be ok if not changed
246 246 for module_name, *access_path in self.allowed_getattr_external:
247 247 try:
248 248 external_class = _get_external(module_name, access_path)
249 249 external_class_attr_val = getattr(external_class, attr)
250 250 except (KeyError, AttributeError):
251 251 return False # pragma: no cover
252 252 return class_attr_val == external_class_attr_val
253 253
254 254 return False
255 255
256 256 def can_get_item(self, value, item):
257 257 """Allow accessing `__getiitem__` of allow-listed instances unless it was not modified."""
258 258 return _has_original_dunder(
259 259 value,
260 260 allowed_types=self.allowed_getitem,
261 261 allowed_methods=self._getitem_methods,
262 262 allowed_external=self.allowed_getitem_external,
263 263 method_name="__getitem__",
264 264 )
265 265
266 266 def can_operate(self, dunders: Tuple[str, ...], a, b=None):
267 267 objects = [a]
268 268 if b is not None:
269 269 objects.append(b)
270 270 return all(
271 271 [
272 272 _has_original_dunder(
273 273 obj,
274 274 allowed_types=self.allowed_operations,
275 275 allowed_methods=self._operator_dunder_methods(dunder),
276 276 allowed_external=self.allowed_operations_external,
277 277 method_name=dunder,
278 278 )
279 279 for dunder in dunders
280 280 for obj in objects
281 281 ]
282 282 )
283 283
284 284 def _operator_dunder_methods(self, dunder: str) -> Set[Callable]:
285 285 if dunder not in self._operation_methods_cache:
286 286 self._operation_methods_cache[dunder] = self._safe_get_methods(
287 287 self.allowed_operations, dunder
288 288 )
289 289 return self._operation_methods_cache[dunder]
290 290
291 291 @cached_property
292 292 def _getitem_methods(self) -> Set[Callable]:
293 293 return self._safe_get_methods(self.allowed_getitem, "__getitem__")
294 294
295 295 @cached_property
296 296 def _getattr_methods(self) -> Set[Callable]:
297 297 return self._safe_get_methods(self.allowed_getattr, "__getattr__")
298 298
299 299 @cached_property
300 300 def _getattribute_methods(self) -> Set[Callable]:
301 301 return self._safe_get_methods(self.allowed_getattr, "__getattribute__")
302 302
303 303 def _safe_get_methods(self, classes, name) -> Set[Callable]:
304 304 return {
305 305 method
306 306 for class_ in classes
307 307 for method in [getattr(class_, name, None)]
308 308 if method
309 309 }
310 310
311 311
312 312 class _DummyNamedTuple(NamedTuple):
313 313 """Used internally to retrieve methods of named tuple instance."""
314 314
315 315
316 316 class EvaluationContext(NamedTuple):
317 317 #: Local namespace
318 318 locals: dict
319 319 #: Global namespace
320 320 globals: dict
321 321 #: Evaluation policy identifier
322 322 evaluation: Literal[
323 323 "forbidden", "minimal", "limited", "unsafe", "dangerous"
324 324 ] = "forbidden"
325 325 #: Whether the evalution of code takes place inside of a subscript.
326 326 #: Useful for evaluating ``:-1, 'col'`` in ``df[:-1, 'col']``.
327 327 in_subscript: bool = False
328 328
329 329
330 330 class _IdentitySubscript:
331 331 """Returns the key itself when item is requested via subscript."""
332 332
333 333 def __getitem__(self, key):
334 334 return key
335 335
336 336
337 337 IDENTITY_SUBSCRIPT = _IdentitySubscript()
338 338 SUBSCRIPT_MARKER = "__SUBSCRIPT_SENTINEL__"
339 339 UNKNOWN_SIGNATURE = Signature()
340 NOT_EVALUATED = object()
340 341
341 342
342 343 class GuardRejection(Exception):
343 344 """Exception raised when guard rejects evaluation attempt."""
344 345
345 346 pass
346 347
347 348
348 349 def guarded_eval(code: str, context: EvaluationContext):
349 350 """Evaluate provided code in the evaluation context.
350 351
351 352 If evaluation policy given by context is set to ``forbidden``
352 353 no evaluation will be performed; if it is set to ``dangerous``
353 354 standard :func:`eval` will be used; finally, for any other,
354 355 policy :func:`eval_node` will be called on parsed AST.
355 356 """
356 357 locals_ = context.locals
357 358
358 359 if context.evaluation == "forbidden":
359 360 raise GuardRejection("Forbidden mode")
360 361
361 362 # note: not using `ast.literal_eval` as it does not implement
362 363 # getitem at all, for example it fails on simple `[0][1]`
363 364
364 365 if context.in_subscript:
365 366 # syntatic sugar for ellipsis (:) is only available in susbcripts
366 367 # so we need to trick the ast parser into thinking that we have
367 368 # a subscript, but we need to be able to later recognise that we did
368 369 # it so we can ignore the actual __getitem__ operation
369 370 if not code:
370 371 return tuple()
371 372 locals_ = locals_.copy()
372 373 locals_[SUBSCRIPT_MARKER] = IDENTITY_SUBSCRIPT
373 374 code = SUBSCRIPT_MARKER + "[" + code + "]"
374 375 context = EvaluationContext(**{**context._asdict(), **{"locals": locals_}})
375 376
376 377 if context.evaluation == "dangerous":
377 378 return eval(code, context.globals, context.locals)
378 379
379 380 expression = ast.parse(code, mode="eval")
380 381
381 382 return eval_node(expression, context)
382 383
383 384
384 385 BINARY_OP_DUNDERS: Dict[Type[ast.operator], Tuple[str]] = {
385 386 ast.Add: ("__add__",),
386 387 ast.Sub: ("__sub__",),
387 388 ast.Mult: ("__mul__",),
388 389 ast.Div: ("__truediv__",),
389 390 ast.FloorDiv: ("__floordiv__",),
390 391 ast.Mod: ("__mod__",),
391 392 ast.Pow: ("__pow__",),
392 393 ast.LShift: ("__lshift__",),
393 394 ast.RShift: ("__rshift__",),
394 395 ast.BitOr: ("__or__",),
395 396 ast.BitXor: ("__xor__",),
396 397 ast.BitAnd: ("__and__",),
397 398 ast.MatMult: ("__matmul__",),
398 399 }
399 400
400 401 COMP_OP_DUNDERS: Dict[Type[ast.cmpop], Tuple[str, ...]] = {
401 402 ast.Eq: ("__eq__",),
402 403 ast.NotEq: ("__ne__", "__eq__"),
403 404 ast.Lt: ("__lt__", "__gt__"),
404 405 ast.LtE: ("__le__", "__ge__"),
405 406 ast.Gt: ("__gt__", "__lt__"),
406 407 ast.GtE: ("__ge__", "__le__"),
407 408 ast.In: ("__contains__",),
408 409 # Note: ast.Is, ast.IsNot, ast.NotIn are handled specially
409 410 }
410 411
411 412 UNARY_OP_DUNDERS: Dict[Type[ast.unaryop], Tuple[str, ...]] = {
412 413 ast.USub: ("__neg__",),
413 414 ast.UAdd: ("__pos__",),
414 415 # we have to check both __inv__ and __invert__!
415 416 ast.Invert: ("__invert__", "__inv__"),
416 417 ast.Not: ("__not__",),
417 418 }
418 419
419 420
420 421 class Duck:
421 422 """A dummy class used to create objects of other classes without calling their ``__init__``"""
422 423
423 424
424 425 def _find_dunder(node_op, dunders) -> Union[Tuple[str, ...], None]:
425 426 dunder = None
426 427 for op, candidate_dunder in dunders.items():
427 428 if isinstance(node_op, op):
428 429 dunder = candidate_dunder
429 430 return dunder
430 431
431 432
432 433 def eval_node(node: Union[ast.AST, None], context: EvaluationContext):
433 434 """Evaluate AST node in provided context.
434 435
435 436 Applies evaluation restrictions defined in the context. Currently does not support evaluation of functions with keyword arguments.
436 437
437 438 Does not evaluate actions that always have side effects:
438 439
439 440 - class definitions (``class sth: ...``)
440 441 - function definitions (``def sth: ...``)
441 442 - variable assignments (``x = 1``)
442 443 - augmented assignments (``x += 1``)
443 444 - deletions (``del x``)
444 445
445 446 Does not evaluate operations which do not return values:
446 447
447 448 - assertions (``assert x``)
448 449 - pass (``pass``)
449 450 - imports (``import x``)
450 451 - control flow:
451 452
452 453 - conditionals (``if x:``) except for ternary IfExp (``a if x else b``)
453 454 - loops (``for`` and ``while``)
454 455 - exception handling
455 456
456 457 The purpose of this function is to guard against unwanted side-effects;
457 458 it does not give guarantees on protection from malicious code execution.
458 459 """
459 460 policy = EVALUATION_POLICIES[context.evaluation]
460 461 if node is None:
461 462 return None
462 463 if isinstance(node, ast.Expression):
463 464 return eval_node(node.body, context)
464 465 if isinstance(node, ast.BinOp):
465 466 left = eval_node(node.left, context)
466 467 right = eval_node(node.right, context)
467 468 dunders = _find_dunder(node.op, BINARY_OP_DUNDERS)
468 469 if dunders:
469 470 if policy.can_operate(dunders, left, right):
470 471 return getattr(left, dunders[0])(right)
471 472 else:
472 473 raise GuardRejection(
473 474 f"Operation (`{dunders}`) for",
474 475 type(left),
475 476 f"not allowed in {context.evaluation} mode",
476 477 )
477 478 if isinstance(node, ast.Compare):
478 479 left = eval_node(node.left, context)
479 480 all_true = True
480 481 negate = False
481 482 for op, right in zip(node.ops, node.comparators):
482 483 right = eval_node(right, context)
483 484 dunder = None
484 485 dunders = _find_dunder(op, COMP_OP_DUNDERS)
485 486 if not dunders:
486 487 if isinstance(op, ast.NotIn):
487 488 dunders = COMP_OP_DUNDERS[ast.In]
488 489 negate = True
489 490 if isinstance(op, ast.Is):
490 491 dunder = "is_"
491 492 if isinstance(op, ast.IsNot):
492 493 dunder = "is_"
493 494 negate = True
494 495 if not dunder and dunders:
495 496 dunder = dunders[0]
496 497 if dunder:
497 498 a, b = (right, left) if dunder == "__contains__" else (left, right)
498 499 if dunder == "is_" or dunders and policy.can_operate(dunders, a, b):
499 500 result = getattr(operator, dunder)(a, b)
500 501 if negate:
501 502 result = not result
502 503 if not result:
503 504 all_true = False
504 505 left = right
505 506 else:
506 507 raise GuardRejection(
507 508 f"Comparison (`{dunder}`) for",
508 509 type(left),
509 510 f"not allowed in {context.evaluation} mode",
510 511 )
511 512 else:
512 513 raise ValueError(
513 514 f"Comparison `{dunder}` not supported"
514 515 ) # pragma: no cover
515 516 return all_true
516 517 if isinstance(node, ast.Constant):
517 518 return node.value
518 519 if isinstance(node, ast.Tuple):
519 520 return tuple(eval_node(e, context) for e in node.elts)
520 521 if isinstance(node, ast.List):
521 522 return [eval_node(e, context) for e in node.elts]
522 523 if isinstance(node, ast.Set):
523 524 return {eval_node(e, context) for e in node.elts}
524 525 if isinstance(node, ast.Dict):
525 526 return dict(
526 527 zip(
527 528 [eval_node(k, context) for k in node.keys],
528 529 [eval_node(v, context) for v in node.values],
529 530 )
530 531 )
531 532 if isinstance(node, ast.Slice):
532 533 return slice(
533 534 eval_node(node.lower, context),
534 535 eval_node(node.upper, context),
535 536 eval_node(node.step, context),
536 537 )
537 538 if isinstance(node, ast.UnaryOp):
538 539 value = eval_node(node.operand, context)
539 540 dunders = _find_dunder(node.op, UNARY_OP_DUNDERS)
540 541 if dunders:
541 542 if policy.can_operate(dunders, value):
542 543 return getattr(value, dunders[0])()
543 544 else:
544 545 raise GuardRejection(
545 546 f"Operation (`{dunders}`) for",
546 547 type(value),
547 548 f"not allowed in {context.evaluation} mode",
548 549 )
549 550 if isinstance(node, ast.Subscript):
550 551 value = eval_node(node.value, context)
551 552 slice_ = eval_node(node.slice, context)
552 553 if policy.can_get_item(value, slice_):
553 554 return value[slice_]
554 555 raise GuardRejection(
555 556 "Subscript access (`__getitem__`) for",
556 557 type(value), # not joined to avoid calling `repr`
557 558 f" not allowed in {context.evaluation} mode",
558 559 )
559 560 if isinstance(node, ast.Name):
560 561 if policy.allow_locals_access and node.id in context.locals:
561 562 return context.locals[node.id]
562 563 if policy.allow_globals_access and node.id in context.globals:
563 564 return context.globals[node.id]
564 565 if policy.allow_builtins_access and hasattr(builtins, node.id):
565 566 # note: do not use __builtins__, it is implementation detail of cPython
566 567 return getattr(builtins, node.id)
567 568 if not policy.allow_globals_access and not policy.allow_locals_access:
568 569 raise GuardRejection(
569 570 f"Namespace access not allowed in {context.evaluation} mode"
570 571 )
571 572 else:
572 573 raise NameError(f"{node.id} not found in locals, globals, nor builtins")
573 574 if isinstance(node, ast.Attribute):
574 575 value = eval_node(node.value, context)
575 576 if policy.can_get_attr(value, node.attr):
576 577 return getattr(value, node.attr)
577 578 raise GuardRejection(
578 579 "Attribute access (`__getattr__`) for",
579 580 type(value), # not joined to avoid calling `repr`
580 581 f"not allowed in {context.evaluation} mode",
581 582 )
582 583 if isinstance(node, ast.IfExp):
583 584 test = eval_node(node.test, context)
584 585 if test:
585 586 return eval_node(node.body, context)
586 587 else:
587 588 return eval_node(node.orelse, context)
588 589 if isinstance(node, ast.Call):
589 590 func = eval_node(node.func, context)
590 591 if policy.can_call(func) and not node.keywords:
591 592 args = [eval_node(arg, context) for arg in node.args]
592 593 return func(*args)
593 try:
594 sig = signature(func)
595 except ValueError:
596 sig = UNKNOWN_SIGNATURE
597 # if annotation was not stringized, or it was stringized
598 # but resolved by signature call we know the return type
599 not_empty = sig.return_annotation is not Signature.empty
600 not_stringized = not isinstance(sig.return_annotation, str)
601 if not_empty and not_stringized:
602 duck = Duck()
603 # if allow-listed builtin is on type annotation, instantiate it
604 if policy.can_call(sig.return_annotation) and not node.keywords:
605 args = [eval_node(arg, context) for arg in node.args]
606 return sig.return_annotation(*args)
607 try:
608 # if custom class is in type annotation, mock it;
609 # this only works for heap types, not builtins
610 duck.__class__ = sig.return_annotation
611 return duck
612 except TypeError:
613 pass
594 if isclass(func):
595 # this code path gets entered when calling class e.g. `MyClass()`
596 # or `my_instance.__class__()` - in both cases `func` is `MyClass`.
597 # Should return `MyClass` if `__new__` is not overridden,
598 # otherwise whatever `__new__` return type is.
599 overridden_return_type = _eval_return_type(
600 func.__new__, policy, node, context
601 )
602 if overridden_return_type is not NOT_EVALUATED:
603 return overridden_return_type
604 return _create_duck_for_type(func)
605 else:
606 return_type = _eval_return_type(func, policy, node, context)
607 if return_type is not NOT_EVALUATED:
608 return return_type
614 609 raise GuardRejection(
615 610 "Call for",
616 611 func, # not joined to avoid calling `repr`
617 612 f"not allowed in {context.evaluation} mode",
618 613 )
619 614 raise ValueError("Unhandled node", ast.dump(node))
620 615
621 616
617 def _eval_return_type(func, policy, node, context):
618 """Evaluate return type of a given callable function.
619
620 Returns the built-in type, a duck or NOT_EVALUATED sentinel.
621 """
622 try:
623 sig = signature(func)
624 except ValueError:
625 sig = UNKNOWN_SIGNATURE
626 # if annotation was not stringized, or it was stringized
627 # but resolved by signature call we know the return type
628 not_empty = sig.return_annotation is not Signature.empty
629 not_stringized = not isinstance(sig.return_annotation, str)
630 if not_empty and not_stringized:
631 # if allow-listed builtin is on type annotation, instantiate it
632 if policy.can_call(sig.return_annotation) and not node.keywords:
633 args = [eval_node(arg, context) for arg in node.args]
634 # if custom class is in type annotation, mock it;
635 return sig.return_annotation(*args)
636 return _create_duck_for_type(sig.return_annotation)
637 return NOT_EVALUATED
638
639
640 def _create_duck_for_type(duck_type):
641 """Create an imitation of an object of a given type (a duck).
642
643 Returns the duck or NOT_EVALUATED sentinel if duck could not be created.
644 """
645 duck = Duck()
646 try:
647 # this only works for heap types, not builtins
648 duck.__class__ = duck_type
649 return duck
650 except TypeError:
651 pass
652 return NOT_EVALUATED
653
654
622 655 SUPPORTED_EXTERNAL_GETITEM = {
623 656 ("pandas", "core", "indexing", "_iLocIndexer"),
624 657 ("pandas", "core", "indexing", "_LocIndexer"),
625 658 ("pandas", "DataFrame"),
626 659 ("pandas", "Series"),
627 660 ("numpy", "ndarray"),
628 661 ("numpy", "void"),
629 662 }
630 663
631 664
632 665 BUILTIN_GETITEM: Set[InstancesHaveGetItem] = {
633 666 dict,
634 667 str, # type: ignore[arg-type]
635 668 bytes, # type: ignore[arg-type]
636 669 list,
637 670 tuple,
638 671 collections.defaultdict,
639 672 collections.deque,
640 673 collections.OrderedDict,
641 674 collections.ChainMap,
642 675 collections.UserDict,
643 676 collections.UserList,
644 677 collections.UserString, # type: ignore[arg-type]
645 678 _DummyNamedTuple,
646 679 _IdentitySubscript,
647 680 }
648 681
649 682
650 683 def _list_methods(cls, source=None):
651 684 """For use on immutable objects or with methods returning a copy"""
652 685 return [getattr(cls, k) for k in (source if source else dir(cls))]
653 686
654 687
655 688 dict_non_mutating_methods = ("copy", "keys", "values", "items")
656 689 list_non_mutating_methods = ("copy", "index", "count")
657 690 set_non_mutating_methods = set(dir(set)) & set(dir(frozenset))
658 691
659 692
660 693 dict_keys: Type[collections.abc.KeysView] = type({}.keys())
661 694
662 695 NUMERICS = {int, float, complex}
663 696
664 697 ALLOWED_CALLS = {
665 698 bytes,
666 699 *_list_methods(bytes),
667 700 dict,
668 701 *_list_methods(dict, dict_non_mutating_methods),
669 702 dict_keys.isdisjoint,
670 703 list,
671 704 *_list_methods(list, list_non_mutating_methods),
672 705 set,
673 706 *_list_methods(set, set_non_mutating_methods),
674 707 frozenset,
675 708 *_list_methods(frozenset),
676 709 range,
677 710 str,
678 711 *_list_methods(str),
679 712 tuple,
680 713 *_list_methods(tuple),
681 714 *NUMERICS,
682 715 *[method for numeric_cls in NUMERICS for method in _list_methods(numeric_cls)],
683 716 collections.deque,
684 717 *_list_methods(collections.deque, list_non_mutating_methods),
685 718 collections.defaultdict,
686 719 *_list_methods(collections.defaultdict, dict_non_mutating_methods),
687 720 collections.OrderedDict,
688 721 *_list_methods(collections.OrderedDict, dict_non_mutating_methods),
689 722 collections.UserDict,
690 723 *_list_methods(collections.UserDict, dict_non_mutating_methods),
691 724 collections.UserList,
692 725 *_list_methods(collections.UserList, list_non_mutating_methods),
693 726 collections.UserString,
694 727 *_list_methods(collections.UserString, dir(str)),
695 728 collections.Counter,
696 729 *_list_methods(collections.Counter, dict_non_mutating_methods),
697 730 collections.Counter.elements,
698 731 collections.Counter.most_common,
699 732 }
700 733
701 734 BUILTIN_GETATTR: Set[MayHaveGetattr] = {
702 735 *BUILTIN_GETITEM,
703 736 set,
704 737 frozenset,
705 738 object,
706 739 type, # `type` handles a lot of generic cases, e.g. numbers as in `int.real`.
707 740 *NUMERICS,
708 741 dict_keys,
709 742 MethodDescriptorType,
710 743 ModuleType,
711 744 }
712 745
713 746
714 747 BUILTIN_OPERATIONS = {*BUILTIN_GETATTR}
715 748
716 749 EVALUATION_POLICIES = {
717 750 "minimal": EvaluationPolicy(
718 751 allow_builtins_access=True,
719 752 allow_locals_access=False,
720 753 allow_globals_access=False,
721 754 allow_item_access=False,
722 755 allow_attr_access=False,
723 756 allowed_calls=set(),
724 757 allow_any_calls=False,
725 758 allow_all_operations=False,
726 759 ),
727 760 "limited": SelectivePolicy(
728 761 allowed_getitem=BUILTIN_GETITEM,
729 762 allowed_getitem_external=SUPPORTED_EXTERNAL_GETITEM,
730 763 allowed_getattr=BUILTIN_GETATTR,
731 764 allowed_getattr_external={
732 765 # pandas Series/Frame implements custom `__getattr__`
733 766 ("pandas", "DataFrame"),
734 767 ("pandas", "Series"),
735 768 },
736 769 allowed_operations=BUILTIN_OPERATIONS,
737 770 allow_builtins_access=True,
738 771 allow_locals_access=True,
739 772 allow_globals_access=True,
740 773 allowed_calls=ALLOWED_CALLS,
741 774 ),
742 775 "unsafe": EvaluationPolicy(
743 776 allow_builtins_access=True,
744 777 allow_locals_access=True,
745 778 allow_globals_access=True,
746 779 allow_attr_access=True,
747 780 allow_item_access=True,
748 781 allow_any_calls=True,
749 782 allow_all_operations=True,
750 783 ),
751 784 }
752 785
753 786
754 787 __all__ = [
755 788 "guarded_eval",
756 789 "eval_node",
757 790 "GuardRejection",
758 791 "EvaluationContext",
759 792 "_unbind_method",
760 793 ]
@@ -1,602 +1,631 b''
1 1 from contextlib import contextmanager
2 2 from typing import NamedTuple
3 3 from functools import partial
4 4 from IPython.core.guarded_eval import (
5 5 EvaluationContext,
6 6 GuardRejection,
7 7 guarded_eval,
8 8 _unbind_method,
9 9 )
10 10 from IPython.testing import decorators as dec
11 11 import pytest
12 12
13 13
14 14 def create_context(evaluation: str, **kwargs):
15 15 return EvaluationContext(locals=kwargs, globals={}, evaluation=evaluation)
16 16
17 17
18 18 forbidden = partial(create_context, "forbidden")
19 19 minimal = partial(create_context, "minimal")
20 20 limited = partial(create_context, "limited")
21 21 unsafe = partial(create_context, "unsafe")
22 22 dangerous = partial(create_context, "dangerous")
23 23
24 24 LIMITED_OR_HIGHER = [limited, unsafe, dangerous]
25 25 MINIMAL_OR_HIGHER = [minimal, *LIMITED_OR_HIGHER]
26 26
27 27
28 28 @contextmanager
29 29 def module_not_installed(module: str):
30 30 import sys
31 31
32 32 try:
33 33 to_restore = sys.modules[module]
34 34 del sys.modules[module]
35 35 except KeyError:
36 36 to_restore = None
37 37 try:
38 38 yield
39 39 finally:
40 40 sys.modules[module] = to_restore
41 41
42 42
43 43 def test_external_not_installed():
44 44 """
45 45 Because attribute check requires checking if object is not of allowed
46 46 external type, this tests logic for absence of external module.
47 47 """
48 48
49 49 class Custom:
50 50 def __init__(self):
51 51 self.test = 1
52 52
53 53 def __getattr__(self, key):
54 54 return key
55 55
56 56 with module_not_installed("pandas"):
57 57 context = limited(x=Custom())
58 58 with pytest.raises(GuardRejection):
59 59 guarded_eval("x.test", context)
60 60
61 61
62 62 @dec.skip_without("pandas")
63 63 def test_external_changed_api(monkeypatch):
64 64 """Check that the execution rejects if external API changed paths"""
65 65 import pandas as pd
66 66
67 67 series = pd.Series([1], index=["a"])
68 68
69 69 with monkeypatch.context() as m:
70 70 m.delattr(pd, "Series")
71 71 context = limited(data=series)
72 72 with pytest.raises(GuardRejection):
73 73 guarded_eval("data.iloc[0]", context)
74 74
75 75
76 76 @dec.skip_without("pandas")
77 77 def test_pandas_series_iloc():
78 78 import pandas as pd
79 79
80 80 series = pd.Series([1], index=["a"])
81 81 context = limited(data=series)
82 82 assert guarded_eval("data.iloc[0]", context) == 1
83 83
84 84
85 85 def test_rejects_custom_properties():
86 86 class BadProperty:
87 87 @property
88 88 def iloc(self):
89 89 return [None]
90 90
91 91 series = BadProperty()
92 92 context = limited(data=series)
93 93
94 94 with pytest.raises(GuardRejection):
95 95 guarded_eval("data.iloc[0]", context)
96 96
97 97
98 98 @dec.skip_without("pandas")
99 99 def test_accepts_non_overriden_properties():
100 100 import pandas as pd
101 101
102 102 class GoodProperty(pd.Series):
103 103 pass
104 104
105 105 series = GoodProperty([1], index=["a"])
106 106 context = limited(data=series)
107 107
108 108 assert guarded_eval("data.iloc[0]", context) == 1
109 109
110 110
111 111 @dec.skip_without("pandas")
112 112 def test_pandas_series():
113 113 import pandas as pd
114 114
115 115 context = limited(data=pd.Series([1], index=["a"]))
116 116 assert guarded_eval('data["a"]', context) == 1
117 117 with pytest.raises(KeyError):
118 118 guarded_eval('data["c"]', context)
119 119
120 120
121 121 @dec.skip_without("pandas")
122 122 def test_pandas_bad_series():
123 123 import pandas as pd
124 124
125 125 class BadItemSeries(pd.Series):
126 126 def __getitem__(self, key):
127 127 return "CUSTOM_ITEM"
128 128
129 129 class BadAttrSeries(pd.Series):
130 130 def __getattr__(self, key):
131 131 return "CUSTOM_ATTR"
132 132
133 133 bad_series = BadItemSeries([1], index=["a"])
134 134 context = limited(data=bad_series)
135 135
136 136 with pytest.raises(GuardRejection):
137 137 guarded_eval('data["a"]', context)
138 138 with pytest.raises(GuardRejection):
139 139 guarded_eval('data["c"]', context)
140 140
141 141 # note: here result is a bit unexpected because
142 142 # pandas `__getattr__` calls `__getitem__`;
143 143 # FIXME - special case to handle it?
144 144 assert guarded_eval("data.a", context) == "CUSTOM_ITEM"
145 145
146 146 context = unsafe(data=bad_series)
147 147 assert guarded_eval('data["a"]', context) == "CUSTOM_ITEM"
148 148
149 149 bad_attr_series = BadAttrSeries([1], index=["a"])
150 150 context = limited(data=bad_attr_series)
151 151 assert guarded_eval('data["a"]', context) == 1
152 152 with pytest.raises(GuardRejection):
153 153 guarded_eval("data.a", context)
154 154
155 155
156 156 @dec.skip_without("pandas")
157 157 def test_pandas_dataframe_loc():
158 158 import pandas as pd
159 159 from pandas.testing import assert_series_equal
160 160
161 161 data = pd.DataFrame([{"a": 1}])
162 162 context = limited(data=data)
163 163 assert_series_equal(guarded_eval('data.loc[:, "a"]', context), data["a"])
164 164
165 165
166 166 def test_named_tuple():
167 167 class GoodNamedTuple(NamedTuple):
168 168 a: str
169 169 pass
170 170
171 171 class BadNamedTuple(NamedTuple):
172 172 a: str
173 173
174 174 def __getitem__(self, key):
175 175 return None
176 176
177 177 good = GoodNamedTuple(a="x")
178 178 bad = BadNamedTuple(a="x")
179 179
180 180 context = limited(data=good)
181 181 assert guarded_eval("data[0]", context) == "x"
182 182
183 183 context = limited(data=bad)
184 184 with pytest.raises(GuardRejection):
185 185 guarded_eval("data[0]", context)
186 186
187 187
188 188 def test_dict():
189 189 context = limited(data={"a": 1, "b": {"x": 2}, ("x", "y"): 3})
190 190 assert guarded_eval('data["a"]', context) == 1
191 191 assert guarded_eval('data["b"]', context) == {"x": 2}
192 192 assert guarded_eval('data["b"]["x"]', context) == 2
193 193 assert guarded_eval('data["x", "y"]', context) == 3
194 194
195 195 assert guarded_eval("data.keys", context)
196 196
197 197
198 198 def test_set():
199 199 context = limited(data={"a", "b"})
200 200 assert guarded_eval("data.difference", context)
201 201
202 202
203 203 def test_list():
204 204 context = limited(data=[1, 2, 3])
205 205 assert guarded_eval("data[1]", context) == 2
206 206 assert guarded_eval("data.copy", context)
207 207
208 208
209 209 def test_dict_literal():
210 210 context = limited()
211 211 assert guarded_eval("{}", context) == {}
212 212 assert guarded_eval('{"a": 1}', context) == {"a": 1}
213 213
214 214
215 215 def test_list_literal():
216 216 context = limited()
217 217 assert guarded_eval("[]", context) == []
218 218 assert guarded_eval('[1, "a"]', context) == [1, "a"]
219 219
220 220
221 221 def test_set_literal():
222 222 context = limited()
223 223 assert guarded_eval("set()", context) == set()
224 224 assert guarded_eval('{"a"}', context) == {"a"}
225 225
226 226
227 227 def test_evaluates_if_expression():
228 228 context = limited()
229 229 assert guarded_eval("2 if True else 3", context) == 2
230 230 assert guarded_eval("4 if False else 5", context) == 5
231 231
232 232
233 233 def test_object():
234 234 obj = object()
235 235 context = limited(obj=obj)
236 236 assert guarded_eval("obj.__dir__", context) == obj.__dir__
237 237
238 238
239 239 @pytest.mark.parametrize(
240 240 "code,expected",
241 241 [
242 242 ["int.numerator", int.numerator],
243 243 ["float.is_integer", float.is_integer],
244 244 ["complex.real", complex.real],
245 245 ],
246 246 )
247 247 def test_number_attributes(code, expected):
248 248 assert guarded_eval(code, limited()) == expected
249 249
250 250
251 251 def test_method_descriptor():
252 252 context = limited()
253 253 assert guarded_eval("list.copy.__name__", context) == "copy"
254 254
255 255
256 256 class HeapType:
257 257 pass
258 258
259 259
260 260 class CallCreatesHeapType:
261 261 def __call__(self) -> HeapType:
262 262 return HeapType()
263 263
264 264
265 265 class CallCreatesBuiltin:
266 266 def __call__(self) -> frozenset:
267 267 return frozenset()
268 268
269 269
270 class HasStaticMethod:
271 @staticmethod
272 def static_method() -> HeapType:
273 return HeapType()
274
275
276 class InitReturnsFrozenset:
277 def __new__(self) -> frozenset: # type:ignore[misc]
278 return frozenset()
279
280
270 281 @pytest.mark.parametrize(
271 "data,good,bad,expected, equality",
282 "data,good,expected,equality",
272 283 [
273 [[1, 2, 3], "data.index(2)", "data.append(4)", 1, True],
274 [{"a": 1}, "data.keys().isdisjoint({})", "data.update()", True, True],
275 [CallCreatesHeapType(), "data()", "data.__class__()", HeapType, False],
276 [CallCreatesBuiltin(), "data()", "data.__class__()", frozenset, False],
284 [[1, 2, 3], "data.index(2)", 1, True],
285 [{"a": 1}, "data.keys().isdisjoint({})", True, True],
286 # test cases for `__call__`
287 [CallCreatesHeapType(), "data()", HeapType, False],
288 [CallCreatesBuiltin(), "data()", frozenset, False],
289 # Test cases for `__init__`
290 [HeapType, "data()", HeapType, False],
291 [InitReturnsFrozenset, "data()", frozenset, False],
292 [HeapType(), "data.__class__()", HeapType, False],
293 # test cases for static and class methods
294 [HasStaticMethod, "data.static_method()", HeapType, False],
277 295 ],
278 296 )
279 def test_evaluates_calls(data, good, bad, expected, equality):
297 def test_evaluates_calls(data, good, expected, equality):
280 298 context = limited(data=data)
281 299 value = guarded_eval(good, context)
282 300 if equality:
283 301 assert value == expected
284 302 else:
285 303 assert isinstance(value, expected)
286 304
305
306 @pytest.mark.parametrize(
307 "data,bad",
308 [
309 [[1, 2, 3], "data.append(4)"],
310 [{"a": 1}, "data.update()"],
311 ],
312 )
313 def test_rejects_calls_with_side_effects(data, bad):
314 context = limited(data=data)
315
287 316 with pytest.raises(GuardRejection):
288 317 guarded_eval(bad, context)
289 318
290 319
291 320 @pytest.mark.parametrize(
292 321 "code,expected",
293 322 [
294 323 ["(1\n+\n1)", 2],
295 324 ["list(range(10))[-1:]", [9]],
296 325 ["list(range(20))[3:-2:3]", [3, 6, 9, 12, 15]],
297 326 ],
298 327 )
299 328 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
300 329 def test_evaluates_complex_cases(code, expected, context):
301 330 assert guarded_eval(code, context()) == expected
302 331
303 332
304 333 @pytest.mark.parametrize(
305 334 "code,expected",
306 335 [
307 336 ["1", 1],
308 337 ["1.0", 1.0],
309 338 ["0xdeedbeef", 0xDEEDBEEF],
310 339 ["True", True],
311 340 ["None", None],
312 341 ["{}", {}],
313 342 ["[]", []],
314 343 ],
315 344 )
316 345 @pytest.mark.parametrize("context", MINIMAL_OR_HIGHER)
317 346 def test_evaluates_literals(code, expected, context):
318 347 assert guarded_eval(code, context()) == expected
319 348
320 349
321 350 @pytest.mark.parametrize(
322 351 "code,expected",
323 352 [
324 353 ["-5", -5],
325 354 ["+5", +5],
326 355 ["~5", -6],
327 356 ],
328 357 )
329 358 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
330 359 def test_evaluates_unary_operations(code, expected, context):
331 360 assert guarded_eval(code, context()) == expected
332 361
333 362
334 363 @pytest.mark.parametrize(
335 364 "code,expected",
336 365 [
337 366 ["1 + 1", 2],
338 367 ["3 - 1", 2],
339 368 ["2 * 3", 6],
340 369 ["5 // 2", 2],
341 370 ["5 / 2", 2.5],
342 371 ["5**2", 25],
343 372 ["2 >> 1", 1],
344 373 ["2 << 1", 4],
345 374 ["1 | 2", 3],
346 375 ["1 & 1", 1],
347 376 ["1 & 2", 0],
348 377 ],
349 378 )
350 379 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
351 380 def test_evaluates_binary_operations(code, expected, context):
352 381 assert guarded_eval(code, context()) == expected
353 382
354 383
355 384 @pytest.mark.parametrize(
356 385 "code,expected",
357 386 [
358 387 ["2 > 1", True],
359 388 ["2 < 1", False],
360 389 ["2 <= 1", False],
361 390 ["2 <= 2", True],
362 391 ["1 >= 2", False],
363 392 ["2 >= 2", True],
364 393 ["2 == 2", True],
365 394 ["1 == 2", False],
366 395 ["1 != 2", True],
367 396 ["1 != 1", False],
368 397 ["1 < 4 < 3", False],
369 398 ["(1 < 4) < 3", True],
370 399 ["4 > 3 > 2 > 1", True],
371 400 ["4 > 3 > 2 > 9", False],
372 401 ["1 < 2 < 3 < 4", True],
373 402 ["9 < 2 < 3 < 4", False],
374 403 ["1 < 2 > 1 > 0 > -1 < 1", True],
375 404 ["1 in [1] in [[1]]", True],
376 405 ["1 in [1] in [[2]]", False],
377 406 ["1 in [1]", True],
378 407 ["0 in [1]", False],
379 408 ["1 not in [1]", False],
380 409 ["0 not in [1]", True],
381 410 ["True is True", True],
382 411 ["False is False", True],
383 412 ["True is False", False],
384 413 ["True is not True", False],
385 414 ["False is not True", True],
386 415 ],
387 416 )
388 417 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
389 418 def test_evaluates_comparisons(code, expected, context):
390 419 assert guarded_eval(code, context()) == expected
391 420
392 421
393 422 def test_guards_comparisons():
394 423 class GoodEq(int):
395 424 pass
396 425
397 426 class BadEq(int):
398 427 def __eq__(self, other):
399 428 assert False
400 429
401 430 context = limited(bad=BadEq(1), good=GoodEq(1))
402 431
403 432 with pytest.raises(GuardRejection):
404 433 guarded_eval("bad == 1", context)
405 434
406 435 with pytest.raises(GuardRejection):
407 436 guarded_eval("bad != 1", context)
408 437
409 438 with pytest.raises(GuardRejection):
410 439 guarded_eval("1 == bad", context)
411 440
412 441 with pytest.raises(GuardRejection):
413 442 guarded_eval("1 != bad", context)
414 443
415 444 assert guarded_eval("good == 1", context) is True
416 445 assert guarded_eval("good != 1", context) is False
417 446 assert guarded_eval("1 == good", context) is True
418 447 assert guarded_eval("1 != good", context) is False
419 448
420 449
421 450 def test_guards_unary_operations():
422 451 class GoodOp(int):
423 452 pass
424 453
425 454 class BadOpInv(int):
426 455 def __inv__(self, other):
427 456 assert False
428 457
429 458 class BadOpInverse(int):
430 459 def __inv__(self, other):
431 460 assert False
432 461
433 462 context = limited(good=GoodOp(1), bad1=BadOpInv(1), bad2=BadOpInverse(1))
434 463
435 464 with pytest.raises(GuardRejection):
436 465 guarded_eval("~bad1", context)
437 466
438 467 with pytest.raises(GuardRejection):
439 468 guarded_eval("~bad2", context)
440 469
441 470
442 471 def test_guards_binary_operations():
443 472 class GoodOp(int):
444 473 pass
445 474
446 475 class BadOp(int):
447 476 def __add__(self, other):
448 477 assert False
449 478
450 479 context = limited(good=GoodOp(1), bad=BadOp(1))
451 480
452 481 with pytest.raises(GuardRejection):
453 482 guarded_eval("1 + bad", context)
454 483
455 484 with pytest.raises(GuardRejection):
456 485 guarded_eval("bad + 1", context)
457 486
458 487 assert guarded_eval("good + 1", context) == 2
459 488 assert guarded_eval("1 + good", context) == 2
460 489
461 490
462 491 def test_guards_attributes():
463 492 class GoodAttr(float):
464 493 pass
465 494
466 495 class BadAttr1(float):
467 496 def __getattr__(self, key):
468 497 assert False
469 498
470 499 class BadAttr2(float):
471 500 def __getattribute__(self, key):
472 501 assert False
473 502
474 503 context = limited(good=GoodAttr(0.5), bad1=BadAttr1(0.5), bad2=BadAttr2(0.5))
475 504
476 505 with pytest.raises(GuardRejection):
477 506 guarded_eval("bad1.as_integer_ratio", context)
478 507
479 508 with pytest.raises(GuardRejection):
480 509 guarded_eval("bad2.as_integer_ratio", context)
481 510
482 511 assert guarded_eval("good.as_integer_ratio()", context) == (1, 2)
483 512
484 513
485 514 @pytest.mark.parametrize("context", MINIMAL_OR_HIGHER)
486 515 def test_access_builtins(context):
487 516 assert guarded_eval("round", context()) == round
488 517
489 518
490 519 def test_access_builtins_fails():
491 520 context = limited()
492 521 with pytest.raises(NameError):
493 522 guarded_eval("this_is_not_builtin", context)
494 523
495 524
496 525 def test_rejects_forbidden():
497 526 context = forbidden()
498 527 with pytest.raises(GuardRejection):
499 528 guarded_eval("1", context)
500 529
501 530
502 531 def test_guards_locals_and_globals():
503 532 context = EvaluationContext(
504 533 locals={"local_a": "a"}, globals={"global_b": "b"}, evaluation="minimal"
505 534 )
506 535
507 536 with pytest.raises(GuardRejection):
508 537 guarded_eval("local_a", context)
509 538
510 539 with pytest.raises(GuardRejection):
511 540 guarded_eval("global_b", context)
512 541
513 542
514 543 def test_access_locals_and_globals():
515 544 context = EvaluationContext(
516 545 locals={"local_a": "a"}, globals={"global_b": "b"}, evaluation="limited"
517 546 )
518 547 assert guarded_eval("local_a", context) == "a"
519 548 assert guarded_eval("global_b", context) == "b"
520 549
521 550
522 551 @pytest.mark.parametrize(
523 552 "code",
524 553 ["def func(): pass", "class C: pass", "x = 1", "x += 1", "del x", "import ast"],
525 554 )
526 555 @pytest.mark.parametrize("context", [minimal(), limited(), unsafe()])
527 556 def test_rejects_side_effect_syntax(code, context):
528 557 with pytest.raises(SyntaxError):
529 558 guarded_eval(code, context)
530 559
531 560
532 561 def test_subscript():
533 562 context = EvaluationContext(
534 563 locals={}, globals={}, evaluation="limited", in_subscript=True
535 564 )
536 565 empty_slice = slice(None, None, None)
537 566 assert guarded_eval("", context) == tuple()
538 567 assert guarded_eval(":", context) == empty_slice
539 568 assert guarded_eval("1:2:3", context) == slice(1, 2, 3)
540 569 assert guarded_eval(':, "a"', context) == (empty_slice, "a")
541 570
542 571
543 572 def test_unbind_method():
544 573 class X(list):
545 574 def index(self, k):
546 575 return "CUSTOM"
547 576
548 577 x = X()
549 578 assert _unbind_method(x.index) is X.index
550 579 assert _unbind_method([].index) is list.index
551 580 assert _unbind_method(list.index) is None
552 581
553 582
554 583 def test_assumption_instance_attr_do_not_matter():
555 584 """This is semi-specified in Python documentation.
556 585
557 586 However, since the specification says 'not guaranteed
558 587 to work' rather than 'is forbidden to work', future
559 588 versions could invalidate this assumptions. This test
560 589 is meant to catch such a change if it ever comes true.
561 590 """
562 591
563 592 class T:
564 593 def __getitem__(self, k):
565 594 return "a"
566 595
567 596 def __getattr__(self, k):
568 597 return "a"
569 598
570 599 def f(self):
571 600 return "b"
572 601
573 602 t = T()
574 603 t.__getitem__ = f
575 604 t.__getattr__ = f
576 605 assert t[1] == "a"
577 606 assert t[1] == "a"
578 607
579 608
580 609 def test_assumption_named_tuples_share_getitem():
581 610 """Check assumption on named tuples sharing __getitem__"""
582 611 from typing import NamedTuple
583 612
584 613 class A(NamedTuple):
585 614 pass
586 615
587 616 class B(NamedTuple):
588 617 pass
589 618
590 619 assert A.__getitem__ == B.__getitem__
591 620
592 621
593 622 @dec.skip_without("numpy")
594 623 def test_module_access():
595 624 import numpy
596 625
597 626 context = limited(numpy=numpy)
598 627 assert guarded_eval("numpy.linalg.norm", context) == numpy.linalg.norm
599 628
600 629 context = minimal(numpy=numpy)
601 630 with pytest.raises(GuardRejection):
602 631 guarded_eval("np.linalg.norm", context)
General Comments 0
You need to be logged in to leave comments. Login now