##// END OF EJS Templates
Guard against custom properties
krassowski -
Show More
@@ -1,695 +1,738 b''
1 1 from typing import (
2 2 Any,
3 3 Callable,
4 4 Dict,
5 5 Set,
6 Sequence,
6 7 Tuple,
7 8 NamedTuple,
8 9 Type,
9 10 Literal,
10 11 Union,
11 12 TYPE_CHECKING,
12 13 )
13 14 import ast
14 15 import builtins
15 16 import collections
16 17 import operator
17 18 import sys
18 19 from functools import cached_property
19 20 from dataclasses import dataclass, field
20 21
21 22 from IPython.utils.docs import GENERATING_DOCUMENTATION
22 23 from IPython.utils.decorators import undoc
23 24
24 25
25 26 if TYPE_CHECKING or GENERATING_DOCUMENTATION:
26 27 from typing_extensions import Protocol
27 28 else:
28 29 # do not require on runtime
29 30 Protocol = object # requires Python >=3.8
30 31
31 32
32 33 @undoc
33 34 class HasGetItem(Protocol):
34 35 def __getitem__(self, key) -> None:
35 36 ...
36 37
37 38
38 39 @undoc
39 40 class InstancesHaveGetItem(Protocol):
40 41 def __call__(self, *args, **kwargs) -> HasGetItem:
41 42 ...
42 43
43 44
44 45 @undoc
45 46 class HasGetAttr(Protocol):
46 47 def __getattr__(self, key) -> None:
47 48 ...
48 49
49 50
50 51 @undoc
51 52 class DoesNotHaveGetAttr(Protocol):
52 53 pass
53 54
54 55
55 56 # By default `__getattr__` is not explicitly implemented on most objects
56 57 MayHaveGetattr = Union[HasGetAttr, DoesNotHaveGetAttr]
57 58
58 59
59 60 def _unbind_method(func: Callable) -> Union[Callable, None]:
60 61 """Get unbound method for given bound method.
61 62
62 63 Returns None if cannot get unbound method."""
63 64 owner = getattr(func, "__self__", None)
64 65 owner_class = type(owner)
65 66 name = getattr(func, "__name__", None)
66 67 instance_dict_overrides = getattr(owner, "__dict__", None)
67 68 if (
68 69 owner is not None
69 70 and name
70 71 and (
71 72 not instance_dict_overrides
72 73 or (instance_dict_overrides and name not in instance_dict_overrides)
73 74 )
74 75 ):
75 76 return getattr(owner_class, name)
76 77 return None
77 78
78 79
79 80 @undoc
80 81 @dataclass
81 82 class EvaluationPolicy:
82 83 """Definition of evaluation policy."""
83 84
84 85 allow_locals_access: bool = False
85 86 allow_globals_access: bool = False
86 87 allow_item_access: bool = False
87 88 allow_attr_access: bool = False
88 89 allow_builtins_access: bool = False
89 90 allow_all_operations: bool = False
90 91 allow_any_calls: bool = False
91 92 allowed_calls: Set[Callable] = field(default_factory=set)
92 93
93 94 def can_get_item(self, value, item):
94 95 return self.allow_item_access
95 96
96 97 def can_get_attr(self, value, attr):
97 98 return self.allow_attr_access
98 99
99 100 def can_operate(self, dunders: Tuple[str, ...], a, b=None):
100 101 if self.allow_all_operations:
101 102 return True
102 103
103 104 def can_call(self, func):
104 105 if self.allow_any_calls:
105 106 return True
106 107
107 108 if func in self.allowed_calls:
108 109 return True
109 110
110 111 owner_method = _unbind_method(func)
111 112
112 113 if owner_method and owner_method in self.allowed_calls:
113 114 return True
114 115
115 116
117 def _get_external(module_name: str, access_path: Sequence[str]):
118 """Get value from external module given a dotted access path.
119
120 Raises:
121 * `KeyError` if module is removed not found, and
122 * `AttributeError` if acess path does not match an exported object
123 """
124 member_type = sys.modules[module_name]
125 for attr in access_path:
126 member_type = getattr(member_type, attr)
127 return member_type
128
129
116 130 def _has_original_dunder_external(
117 131 value,
118 module_name,
119 access_path,
120 method_name,
132 module_name: str,
133 access_path: Sequence[str],
134 method_name: str,
121 135 ):
122 try:
123 136 if module_name not in sys.modules:
137 # LBYLB as it is faster
124 138 return False
125 member_type = sys.modules[module_name]
126 for attr in access_path:
127 member_type = getattr(member_type, attr)
139 try:
140 member_type = _get_external(module_name, access_path)
128 141 value_type = type(value)
129 142 if type(value) == member_type:
130 143 return True
131 144 if method_name == "__getattribute__":
132 145 # we have to short-circuit here due to an unresolved issue in
133 146 # `isinstance` implementation: https://bugs.python.org/issue32683
134 147 return False
135 148 if isinstance(value, member_type):
136 149 method = getattr(value_type, method_name, None)
137 150 member_method = getattr(member_type, method_name, None)
138 151 if member_method == method:
139 152 return True
140 153 except (AttributeError, KeyError):
141 154 return False
142 155
143 156
144 157 def _has_original_dunder(
145 158 value, allowed_types, allowed_methods, allowed_external, method_name
146 159 ):
147 160 # note: Python ignores `__getattr__`/`__getitem__` on instances,
148 161 # we only need to check at class level
149 162 value_type = type(value)
150 163
151 164 # strict type check passes β†’ no need to check method
152 165 if value_type in allowed_types:
153 166 return True
154 167
155 168 method = getattr(value_type, method_name, None)
156 169
157 170 if method is None:
158 171 return None
159 172
160 173 if method in allowed_methods:
161 174 return True
162 175
163 176 for module_name, *access_path in allowed_external:
164 177 if _has_original_dunder_external(value, module_name, access_path, method_name):
165 178 return True
166 179
167 180 return False
168 181
169 182
170 183 @undoc
171 184 @dataclass
172 185 class SelectivePolicy(EvaluationPolicy):
173 186 allowed_getitem: Set[InstancesHaveGetItem] = field(default_factory=set)
174 187 allowed_getitem_external: Set[Tuple[str, ...]] = field(default_factory=set)
175 188
176 189 allowed_getattr: Set[MayHaveGetattr] = field(default_factory=set)
177 190 allowed_getattr_external: Set[Tuple[str, ...]] = field(default_factory=set)
178 191
179 192 allowed_operations: Set = field(default_factory=set)
180 193 allowed_operations_external: Set[Tuple[str, ...]] = field(default_factory=set)
181 194
182 195 _operation_methods_cache: Dict[str, Set[Callable]] = field(
183 196 default_factory=dict, init=False
184 197 )
185 198
186 199 def can_get_attr(self, value, attr):
187 200 has_original_attribute = _has_original_dunder(
188 201 value,
189 202 allowed_types=self.allowed_getattr,
190 203 allowed_methods=self._getattribute_methods,
191 204 allowed_external=self.allowed_getattr_external,
192 205 method_name="__getattribute__",
193 206 )
194 207 has_original_attr = _has_original_dunder(
195 208 value,
196 209 allowed_types=self.allowed_getattr,
197 210 allowed_methods=self._getattr_methods,
198 211 allowed_external=self.allowed_getattr_external,
199 212 method_name="__getattr__",
200 213 )
201 214
215 accept = False
216
202 217 # Many objects do not have `__getattr__`, this is fine
203 218 if has_original_attr is None and has_original_attribute:
219 accept = True
220 else:
221 # Accept objects without modifications to `__getattr__` and `__getattribute__`
222 accept = has_original_attr and has_original_attribute
223
224 if accept:
225 # We still need to check for overriden properties.
226
227 value_class = type(value)
228 if not hasattr(value_class, attr):
204 229 return True
205 230
206 # Accept objects without modifications to `__getattr__` and `__getattribute__`
207 return has_original_attr and has_original_attribute
231 class_attr_val = getattr(value_class, attr)
232 is_property = isinstance(class_attr_val, property)
233
234 if not is_property:
235 return True
236
237 # Properties in allowed types are ok
238 if type(value) in self.allowed_getattr:
239 return True
240
241 # Properties in subclasses of allowed types may be ok if not changed
242 for module_name, *access_path in self.allowed_getattr_external:
243 try:
244 external_class = _get_external(module_name, access_path)
245 external_class_attr_val = getattr(external_class, attr)
246 except (KeyError, AttributeError):
247 return False # pragma: no cover
248 return class_attr_val == external_class_attr_val
249
250 return False
208 251
209 252 def can_get_item(self, value, item):
210 253 """Allow accessing `__getiitem__` of allow-listed instances unless it was not modified."""
211 254 return _has_original_dunder(
212 255 value,
213 256 allowed_types=self.allowed_getitem,
214 257 allowed_methods=self._getitem_methods,
215 258 allowed_external=self.allowed_getitem_external,
216 259 method_name="__getitem__",
217 260 )
218 261
219 262 def can_operate(self, dunders: Tuple[str, ...], a, b=None):
220 263 objects = [a]
221 264 if b is not None:
222 265 objects.append(b)
223 266 return all(
224 267 [
225 268 _has_original_dunder(
226 269 obj,
227 270 allowed_types=self.allowed_operations,
228 271 allowed_methods=self._operator_dunder_methods(dunder),
229 272 allowed_external=self.allowed_operations_external,
230 273 method_name=dunder,
231 274 )
232 275 for dunder in dunders
233 276 for obj in objects
234 277 ]
235 278 )
236 279
237 280 def _operator_dunder_methods(self, dunder: str) -> Set[Callable]:
238 281 if dunder not in self._operation_methods_cache:
239 282 self._operation_methods_cache[dunder] = self._safe_get_methods(
240 283 self.allowed_operations, dunder
241 284 )
242 285 return self._operation_methods_cache[dunder]
243 286
244 287 @cached_property
245 288 def _getitem_methods(self) -> Set[Callable]:
246 289 return self._safe_get_methods(self.allowed_getitem, "__getitem__")
247 290
248 291 @cached_property
249 292 def _getattr_methods(self) -> Set[Callable]:
250 293 return self._safe_get_methods(self.allowed_getattr, "__getattr__")
251 294
252 295 @cached_property
253 296 def _getattribute_methods(self) -> Set[Callable]:
254 297 return self._safe_get_methods(self.allowed_getattr, "__getattribute__")
255 298
256 299 def _safe_get_methods(self, classes, name) -> Set[Callable]:
257 300 return {
258 301 method
259 302 for class_ in classes
260 303 for method in [getattr(class_, name, None)]
261 304 if method
262 305 }
263 306
264 307
265 308 class _DummyNamedTuple(NamedTuple):
266 309 """Used internally to retrieve methods of named tuple instance."""
267 310
268 311
269 312 class EvaluationContext(NamedTuple):
270 313 #: Local namespace
271 314 locals: dict
272 315 #: Global namespace
273 316 globals: dict
274 317 #: Evaluation policy identifier
275 318 evaluation: Literal[
276 319 "forbidden", "minimal", "limited", "unsafe", "dangerous"
277 320 ] = "forbidden"
278 321 #: Whether the evalution of code takes place inside of a subscript.
279 322 #: Useful for evaluating ``:-1, 'col'`` in ``df[:-1, 'col']``.
280 323 in_subscript: bool = False
281 324
282 325
283 326 class _IdentitySubscript:
284 327 """Returns the key itself when item is requested via subscript."""
285 328
286 329 def __getitem__(self, key):
287 330 return key
288 331
289 332
290 333 IDENTITY_SUBSCRIPT = _IdentitySubscript()
291 334 SUBSCRIPT_MARKER = "__SUBSCRIPT_SENTINEL__"
292 335
293 336
294 337 class GuardRejection(Exception):
295 338 """Exception raised when guard rejects evaluation attempt."""
296 339
297 340 pass
298 341
299 342
300 343 def guarded_eval(code: str, context: EvaluationContext):
301 344 """Evaluate provided code in the evaluation context.
302 345
303 346 If evaluation policy given by context is set to ``forbidden``
304 347 no evaluation will be performed; if it is set to ``dangerous``
305 348 standard :func:`eval` will be used; finally, for any other,
306 349 policy :func:`eval_node` will be called on parsed AST.
307 350 """
308 351 locals_ = context.locals
309 352
310 353 if context.evaluation == "forbidden":
311 354 raise GuardRejection("Forbidden mode")
312 355
313 356 # note: not using `ast.literal_eval` as it does not implement
314 357 # getitem at all, for example it fails on simple `[0][1]`
315 358
316 359 if context.in_subscript:
317 360 # syntatic sugar for ellipsis (:) is only available in susbcripts
318 361 # so we need to trick the ast parser into thinking that we have
319 362 # a subscript, but we need to be able to later recognise that we did
320 363 # it so we can ignore the actual __getitem__ operation
321 364 if not code:
322 365 return tuple()
323 366 locals_ = locals_.copy()
324 367 locals_[SUBSCRIPT_MARKER] = IDENTITY_SUBSCRIPT
325 368 code = SUBSCRIPT_MARKER + "[" + code + "]"
326 369 context = EvaluationContext(**{**context._asdict(), **{"locals": locals_}})
327 370
328 371 if context.evaluation == "dangerous":
329 372 return eval(code, context.globals, context.locals)
330 373
331 374 expression = ast.parse(code, mode="eval")
332 375
333 376 return eval_node(expression, context)
334 377
335 378
336 379 BINARY_OP_DUNDERS: Dict[Type[ast.operator], Tuple[str]] = {
337 380 ast.Add: ("__add__",),
338 381 ast.Sub: ("__sub__",),
339 382 ast.Mult: ("__mul__",),
340 383 ast.Div: ("__truediv__",),
341 384 ast.FloorDiv: ("__floordiv__",),
342 385 ast.Mod: ("__mod__",),
343 386 ast.Pow: ("__pow__",),
344 387 ast.LShift: ("__lshift__",),
345 388 ast.RShift: ("__rshift__",),
346 389 ast.BitOr: ("__or__",),
347 390 ast.BitXor: ("__xor__",),
348 391 ast.BitAnd: ("__and__",),
349 392 ast.MatMult: ("__matmul__",),
350 393 }
351 394
352 395 COMP_OP_DUNDERS: Dict[Type[ast.cmpop], Tuple[str, ...]] = {
353 396 ast.Eq: ("__eq__",),
354 397 ast.NotEq: ("__ne__", "__eq__"),
355 398 ast.Lt: ("__lt__", "__gt__"),
356 399 ast.LtE: ("__le__", "__ge__"),
357 400 ast.Gt: ("__gt__", "__lt__"),
358 401 ast.GtE: ("__ge__", "__le__"),
359 402 ast.In: ("__contains__",),
360 403 # Note: ast.Is, ast.IsNot, ast.NotIn are handled specially
361 404 }
362 405
363 406 UNARY_OP_DUNDERS: Dict[Type[ast.unaryop], Tuple[str, ...]] = {
364 407 ast.USub: ("__neg__",),
365 408 ast.UAdd: ("__pos__",),
366 409 # we have to check both __inv__ and __invert__!
367 410 ast.Invert: ("__invert__", "__inv__"),
368 411 ast.Not: ("__not__",),
369 412 }
370 413
371 414
372 415 def _find_dunder(node_op, dunders) -> Union[Tuple[str, ...], None]:
373 416 dunder = None
374 417 for op, candidate_dunder in dunders.items():
375 418 if isinstance(node_op, op):
376 419 dunder = candidate_dunder
377 420 return dunder
378 421
379 422
380 423 def eval_node(node: Union[ast.AST, None], context: EvaluationContext):
381 424 """Evaluate AST node in provided context.
382 425
383 426 Applies evaluation restrictions defined in the context. Currently does not support evaluation of functions with keyword arguments.
384 427
385 428 Does not evaluate actions that always have side effects:
386 429
387 430 - class definitions (``class sth: ...``)
388 431 - function definitions (``def sth: ...``)
389 432 - variable assignments (``x = 1``)
390 433 - augmented assignments (``x += 1``)
391 434 - deletions (``del x``)
392 435
393 436 Does not evaluate operations which do not return values:
394 437
395 438 - assertions (``assert x``)
396 439 - pass (``pass``)
397 440 - imports (``import x``)
398 441 - control flow:
399 442
400 443 - conditionals (``if x:``) except for ternary IfExp (``a if x else b``)
401 444 - loops (``for`` and `while``)
402 445 - exception handling
403 446
404 447 The purpose of this function is to guard against unwanted side-effects;
405 448 it does not give guarantees on protection from malicious code execution.
406 449 """
407 450 policy = EVALUATION_POLICIES[context.evaluation]
408 451 if node is None:
409 452 return None
410 453 if isinstance(node, ast.Expression):
411 454 return eval_node(node.body, context)
412 455 if isinstance(node, ast.BinOp):
413 456 left = eval_node(node.left, context)
414 457 right = eval_node(node.right, context)
415 458 dunders = _find_dunder(node.op, BINARY_OP_DUNDERS)
416 459 if dunders:
417 460 if policy.can_operate(dunders, left, right):
418 461 return getattr(left, dunders[0])(right)
419 462 else:
420 463 raise GuardRejection(
421 464 f"Operation (`{dunders}`) for",
422 465 type(left),
423 466 f"not allowed in {context.evaluation} mode",
424 467 )
425 468 if isinstance(node, ast.Compare):
426 469 left = eval_node(node.left, context)
427 470 all_true = True
428 471 negate = False
429 472 for op, right in zip(node.ops, node.comparators):
430 473 right = eval_node(right, context)
431 474 dunder = None
432 475 dunders = _find_dunder(op, COMP_OP_DUNDERS)
433 476 if not dunders:
434 477 if isinstance(op, ast.NotIn):
435 478 dunders = COMP_OP_DUNDERS[ast.In]
436 479 negate = True
437 480 if isinstance(op, ast.Is):
438 481 dunder = "is_"
439 482 if isinstance(op, ast.IsNot):
440 483 dunder = "is_"
441 484 negate = True
442 485 if not dunder and dunders:
443 486 dunder = dunders[0]
444 487 if dunder:
445 488 a, b = (right, left) if dunder == "__contains__" else (left, right)
446 489 if dunder == "is_" or dunders and policy.can_operate(dunders, a, b):
447 490 result = getattr(operator, dunder)(a, b)
448 491 if negate:
449 492 result = not result
450 493 if not result:
451 494 all_true = False
452 495 left = right
453 496 else:
454 497 raise GuardRejection(
455 498 f"Comparison (`{dunder}`) for",
456 499 type(left),
457 500 f"not allowed in {context.evaluation} mode",
458 501 )
459 502 else:
460 503 raise ValueError(
461 504 f"Comparison `{dunder}` not supported"
462 505 ) # pragma: no cover
463 506 return all_true
464 507 if isinstance(node, ast.Constant):
465 508 return node.value
466 509 if isinstance(node, ast.Index):
467 510 # deprecated since Python 3.9
468 511 return eval_node(node.value, context) # pragma: no cover
469 512 if isinstance(node, ast.Tuple):
470 513 return tuple(eval_node(e, context) for e in node.elts)
471 514 if isinstance(node, ast.List):
472 515 return [eval_node(e, context) for e in node.elts]
473 516 if isinstance(node, ast.Set):
474 517 return {eval_node(e, context) for e in node.elts}
475 518 if isinstance(node, ast.Dict):
476 519 return dict(
477 520 zip(
478 521 [eval_node(k, context) for k in node.keys],
479 522 [eval_node(v, context) for v in node.values],
480 523 )
481 524 )
482 525 if isinstance(node, ast.Slice):
483 526 return slice(
484 527 eval_node(node.lower, context),
485 528 eval_node(node.upper, context),
486 529 eval_node(node.step, context),
487 530 )
488 531 if isinstance(node, ast.ExtSlice):
489 532 # deprecated since Python 3.9
490 533 return tuple([eval_node(dim, context) for dim in node.dims]) # pragma: no cover
491 534 if isinstance(node, ast.UnaryOp):
492 535 value = eval_node(node.operand, context)
493 536 dunders = _find_dunder(node.op, UNARY_OP_DUNDERS)
494 537 if dunders:
495 538 if policy.can_operate(dunders, value):
496 539 return getattr(value, dunders[0])()
497 540 else:
498 541 raise GuardRejection(
499 542 f"Operation (`{dunders}`) for",
500 543 type(value),
501 544 f"not allowed in {context.evaluation} mode",
502 545 )
503 546 if isinstance(node, ast.Subscript):
504 547 value = eval_node(node.value, context)
505 548 slice_ = eval_node(node.slice, context)
506 549 if policy.can_get_item(value, slice_):
507 550 return value[slice_]
508 551 raise GuardRejection(
509 552 "Subscript access (`__getitem__`) for",
510 553 type(value), # not joined to avoid calling `repr`
511 554 f" not allowed in {context.evaluation} mode",
512 555 )
513 556 if isinstance(node, ast.Name):
514 557 if policy.allow_locals_access and node.id in context.locals:
515 558 return context.locals[node.id]
516 559 if policy.allow_globals_access and node.id in context.globals:
517 560 return context.globals[node.id]
518 561 if policy.allow_builtins_access and hasattr(builtins, node.id):
519 562 # note: do not use __builtins__, it is implementation detail of cPython
520 563 return getattr(builtins, node.id)
521 564 if not policy.allow_globals_access and not policy.allow_locals_access:
522 565 raise GuardRejection(
523 566 f"Namespace access not allowed in {context.evaluation} mode"
524 567 )
525 568 else:
526 569 raise NameError(f"{node.id} not found in locals, globals, nor builtins")
527 570 if isinstance(node, ast.Attribute):
528 571 value = eval_node(node.value, context)
529 572 if policy.can_get_attr(value, node.attr):
530 573 return getattr(value, node.attr)
531 574 raise GuardRejection(
532 575 "Attribute access (`__getattr__`) for",
533 576 type(value), # not joined to avoid calling `repr`
534 577 f"not allowed in {context.evaluation} mode",
535 578 )
536 579 if isinstance(node, ast.IfExp):
537 580 test = eval_node(node.test, context)
538 581 if test:
539 582 return eval_node(node.body, context)
540 583 else:
541 584 return eval_node(node.orelse, context)
542 585 if isinstance(node, ast.Call):
543 586 func = eval_node(node.func, context)
544 587 if policy.can_call(func) and not node.keywords:
545 588 args = [eval_node(arg, context) for arg in node.args]
546 589 return func(*args)
547 590 raise GuardRejection(
548 591 "Call for",
549 592 func, # not joined to avoid calling `repr`
550 593 f"not allowed in {context.evaluation} mode",
551 594 )
552 595 raise ValueError("Unhandled node", ast.dump(node))
553 596
554 597
555 598 SUPPORTED_EXTERNAL_GETITEM = {
556 599 ("pandas", "core", "indexing", "_iLocIndexer"),
557 600 ("pandas", "core", "indexing", "_LocIndexer"),
558 601 ("pandas", "DataFrame"),
559 602 ("pandas", "Series"),
560 603 ("numpy", "ndarray"),
561 604 ("numpy", "void"),
562 605 }
563 606
564 607
565 608 BUILTIN_GETITEM: Set[InstancesHaveGetItem] = {
566 609 dict,
567 610 str,
568 611 bytes,
569 612 list,
570 613 tuple,
571 614 collections.defaultdict,
572 615 collections.deque,
573 616 collections.OrderedDict,
574 617 collections.ChainMap,
575 618 collections.UserDict,
576 619 collections.UserList,
577 620 collections.UserString,
578 621 _DummyNamedTuple,
579 622 _IdentitySubscript,
580 623 }
581 624
582 625
583 626 def _list_methods(cls, source=None):
584 627 """For use on immutable objects or with methods returning a copy"""
585 628 return [getattr(cls, k) for k in (source if source else dir(cls))]
586 629
587 630
588 631 dict_non_mutating_methods = ("copy", "keys", "values", "items")
589 632 list_non_mutating_methods = ("copy", "index", "count")
590 633 set_non_mutating_methods = set(dir(set)) & set(dir(frozenset))
591 634
592 635
593 636 dict_keys: Type[collections.abc.KeysView] = type({}.keys())
594 637 method_descriptor: Any = type(list.copy)
595 638
596 639 NUMERICS = {int, float, complex}
597 640
598 641 ALLOWED_CALLS = {
599 642 bytes,
600 643 *_list_methods(bytes),
601 644 dict,
602 645 *_list_methods(dict, dict_non_mutating_methods),
603 646 dict_keys.isdisjoint,
604 647 list,
605 648 *_list_methods(list, list_non_mutating_methods),
606 649 set,
607 650 *_list_methods(set, set_non_mutating_methods),
608 651 frozenset,
609 652 *_list_methods(frozenset),
610 653 range,
611 654 str,
612 655 *_list_methods(str),
613 656 tuple,
614 657 *_list_methods(tuple),
615 658 *NUMERICS,
616 659 *[method for numeric_cls in NUMERICS for method in _list_methods(numeric_cls)],
617 660 collections.deque,
618 661 *_list_methods(collections.deque, list_non_mutating_methods),
619 662 collections.defaultdict,
620 663 *_list_methods(collections.defaultdict, dict_non_mutating_methods),
621 664 collections.OrderedDict,
622 665 *_list_methods(collections.OrderedDict, dict_non_mutating_methods),
623 666 collections.UserDict,
624 667 *_list_methods(collections.UserDict, dict_non_mutating_methods),
625 668 collections.UserList,
626 669 *_list_methods(collections.UserList, list_non_mutating_methods),
627 670 collections.UserString,
628 671 *_list_methods(collections.UserString, dir(str)),
629 672 collections.Counter,
630 673 *_list_methods(collections.Counter, dict_non_mutating_methods),
631 674 collections.Counter.elements,
632 675 collections.Counter.most_common,
633 676 }
634 677
635 678 BUILTIN_GETATTR: Set[MayHaveGetattr] = {
636 679 *BUILTIN_GETITEM,
637 680 set,
638 681 frozenset,
639 682 object,
640 683 type, # `type` handles a lot of generic cases, e.g. numbers as in `int.real`.
641 684 *NUMERICS,
642 685 dict_keys,
643 686 method_descriptor,
644 687 }
645 688
646 689
647 690 BUILTIN_OPERATIONS = {*BUILTIN_GETATTR}
648 691
649 692 EVALUATION_POLICIES = {
650 693 "minimal": EvaluationPolicy(
651 694 allow_builtins_access=True,
652 695 allow_locals_access=False,
653 696 allow_globals_access=False,
654 697 allow_item_access=False,
655 698 allow_attr_access=False,
656 699 allowed_calls=set(),
657 700 allow_any_calls=False,
658 701 allow_all_operations=False,
659 702 ),
660 703 "limited": SelectivePolicy(
661 704 # TODO:
662 705 # - should reject binary and unary operations if custom methods would be dispatched
663 706 allowed_getitem=BUILTIN_GETITEM,
664 707 allowed_getitem_external=SUPPORTED_EXTERNAL_GETITEM,
665 708 allowed_getattr=BUILTIN_GETATTR,
666 709 allowed_getattr_external={
667 710 # pandas Series/Frame implements custom `__getattr__`
668 711 ("pandas", "DataFrame"),
669 712 ("pandas", "Series"),
670 713 },
671 714 allowed_operations=BUILTIN_OPERATIONS,
672 715 allow_builtins_access=True,
673 716 allow_locals_access=True,
674 717 allow_globals_access=True,
675 718 allowed_calls=ALLOWED_CALLS,
676 719 ),
677 720 "unsafe": EvaluationPolicy(
678 721 allow_builtins_access=True,
679 722 allow_locals_access=True,
680 723 allow_globals_access=True,
681 724 allow_attr_access=True,
682 725 allow_item_access=True,
683 726 allow_any_calls=True,
684 727 allow_all_operations=True,
685 728 ),
686 729 }
687 730
688 731
689 732 __all__ = [
690 733 "guarded_eval",
691 734 "eval_node",
692 735 "GuardRejection",
693 736 "EvaluationContext",
694 737 "_unbind_method",
695 738 ]
@@ -1,492 +1,537 b''
1 from contextlib import contextmanager
1 2 from typing import NamedTuple
2 3 from functools import partial
3 4 from IPython.core.guarded_eval import (
4 5 EvaluationContext,
5 6 GuardRejection,
6 7 guarded_eval,
7 8 _unbind_method,
8 9 )
9 10 from IPython.testing import decorators as dec
10 11 import pytest
11 12
12 13
13 14 def create_context(evaluation: str, **kwargs):
14 15 return EvaluationContext(locals=kwargs, globals={}, evaluation=evaluation)
15 16
16 17
17 18 forbidden = partial(create_context, "forbidden")
18 19 minimal = partial(create_context, "minimal")
19 20 limited = partial(create_context, "limited")
20 21 unsafe = partial(create_context, "unsafe")
21 22 dangerous = partial(create_context, "dangerous")
22 23
23 24 LIMITED_OR_HIGHER = [limited, unsafe, dangerous]
24 25
25 26 MINIMAL_OR_HIGHER = [minimal, *LIMITED_OR_HIGHER]
26 27
27 28
29 @contextmanager
30 def module_not_installed(module: str):
31 import sys
32
33 try:
34 to_restore = sys.modules[module]
35 del sys.modules[module]
36 except KeyError:
37 to_restore = None
38 try:
39 yield
40 finally:
41 sys.modules[module] = to_restore
42
43
28 44 @dec.skip_without("pandas")
29 45 def test_pandas_series_iloc():
30 46 import pandas as pd
31 47
32 48 series = pd.Series([1], index=["a"])
33 49 context = limited(data=series)
34 50 assert guarded_eval("data.iloc[0]", context) == 1
35 51
36 52
53 def test_rejects_custom_properties():
54 class BadProperty:
55 @property
56 def iloc(self):
57 return [None]
58
59 series = BadProperty()
60 context = limited(data=series)
61
62 with pytest.raises(GuardRejection):
63 guarded_eval("data.iloc[0]", context)
64
65
66 @dec.skip_without("pandas")
67 def test_accepts_non_overriden_properties():
68 import pandas as pd
69
70 class GoodProperty(pd.Series):
71 pass
72
73 series = GoodProperty([1], index=["a"])
74 context = limited(data=series)
75
76 assert guarded_eval("data.iloc[0]", context) == 1
77
78
37 79 @dec.skip_without("pandas")
38 80 def test_pandas_series():
39 81 import pandas as pd
40 82
41 83 context = limited(data=pd.Series([1], index=["a"]))
42 84 assert guarded_eval('data["a"]', context) == 1
43 85 with pytest.raises(KeyError):
44 86 guarded_eval('data["c"]', context)
45 87
46 88
47 89 @dec.skip_without("pandas")
48 90 def test_pandas_bad_series():
49 91 import pandas as pd
50 92
51 93 class BadItemSeries(pd.Series):
52 94 def __getitem__(self, key):
53 95 return "CUSTOM_ITEM"
54 96
55 97 class BadAttrSeries(pd.Series):
56 98 def __getattr__(self, key):
57 99 return "CUSTOM_ATTR"
58 100
59 101 bad_series = BadItemSeries([1], index=["a"])
60 102 context = limited(data=bad_series)
61 103
62 104 with pytest.raises(GuardRejection):
63 105 guarded_eval('data["a"]', context)
64 106 with pytest.raises(GuardRejection):
65 107 guarded_eval('data["c"]', context)
66 108
67 109 # note: here result is a bit unexpected because
68 110 # pandas `__getattr__` calls `__getitem__`;
69 111 # FIXME - special case to handle it?
70 112 assert guarded_eval("data.a", context) == "CUSTOM_ITEM"
71 113
72 114 context = unsafe(data=bad_series)
73 115 assert guarded_eval('data["a"]', context) == "CUSTOM_ITEM"
74 116
75 117 bad_attr_series = BadAttrSeries([1], index=["a"])
76 118 context = limited(data=bad_attr_series)
77 119 assert guarded_eval('data["a"]', context) == 1
78 120 with pytest.raises(GuardRejection):
79 121 guarded_eval("data.a", context)
80 122
81 123
82 124 @dec.skip_without("pandas")
83 125 def test_pandas_dataframe_loc():
84 126 import pandas as pd
85 127 from pandas.testing import assert_series_equal
86 128
87 129 data = pd.DataFrame([{"a": 1}])
88 130 context = limited(data=data)
89 131 assert_series_equal(guarded_eval('data.loc[:, "a"]', context), data["a"])
90 132
91 133
92 134 def test_named_tuple():
93 135 class GoodNamedTuple(NamedTuple):
94 136 a: str
95 137 pass
96 138
97 139 class BadNamedTuple(NamedTuple):
98 140 a: str
99 141
100 142 def __getitem__(self, key):
101 143 return None
102 144
103 145 good = GoodNamedTuple(a="x")
104 146 bad = BadNamedTuple(a="x")
105 147
106 148 context = limited(data=good)
107 149 assert guarded_eval("data[0]", context) == "x"
108 150
109 151 context = limited(data=bad)
110 152 with pytest.raises(GuardRejection):
111 153 guarded_eval("data[0]", context)
112 154
113 155
114 156 def test_dict():
115 157 context = limited(data={"a": 1, "b": {"x": 2}, ("x", "y"): 3})
116 158 assert guarded_eval('data["a"]', context) == 1
117 159 assert guarded_eval('data["b"]', context) == {"x": 2}
118 160 assert guarded_eval('data["b"]["x"]', context) == 2
119 161 assert guarded_eval('data["x", "y"]', context) == 3
120 162
121 163 assert guarded_eval("data.keys", context)
122 164
123 165
124 166 def test_set():
125 167 context = limited(data={"a", "b"})
126 168 assert guarded_eval("data.difference", context)
127 169
128 170
129 171 def test_list():
130 172 context = limited(data=[1, 2, 3])
131 173 assert guarded_eval("data[1]", context) == 2
132 174 assert guarded_eval("data.copy", context)
133 175
134 176
135 177 def test_dict_literal():
136 178 context = limited()
137 179 assert guarded_eval("{}", context) == {}
138 180 assert guarded_eval('{"a": 1}', context) == {"a": 1}
139 181
140 182
141 183 def test_list_literal():
142 184 context = limited()
143 185 assert guarded_eval("[]", context) == []
144 186 assert guarded_eval('[1, "a"]', context) == [1, "a"]
145 187
146 188
147 189 def test_set_literal():
148 190 context = limited()
149 191 assert guarded_eval("set()", context) == set()
150 192 assert guarded_eval('{"a"}', context) == {"a"}
151 193
152 194
153 195 def test_evaluates_if_expression():
154 196 context = limited()
155 197 assert guarded_eval("2 if True else 3", context) == 2
156 198 assert guarded_eval("4 if False else 5", context) == 5
157 199
158 200
159 201 def test_object():
160 202 obj = object()
161 203 context = limited(obj=obj)
162 204 assert guarded_eval("obj.__dir__", context) == obj.__dir__
163 205
164 206
165 207 @pytest.mark.parametrize(
166 208 "code,expected",
167 209 [
168 210 ["int.numerator", int.numerator],
169 211 ["float.is_integer", float.is_integer],
170 212 ["complex.real", complex.real],
171 213 ],
172 214 )
173 215 def test_number_attributes(code, expected):
174 216 assert guarded_eval(code, limited()) == expected
175 217
176 218
177 219 def test_method_descriptor():
178 220 context = limited()
179 221 assert guarded_eval("list.copy.__name__", context) == "copy"
180 222
181 223
182 224 @pytest.mark.parametrize(
183 225 "data,good,bad,expected",
184 226 [
185 227 [[1, 2, 3], "data.index(2)", "data.append(4)", 1],
186 228 [{"a": 1}, "data.keys().isdisjoint({})", "data.update()", True],
187 229 ],
188 230 )
189 231 def test_evaluates_calls(data, good, bad, expected):
190 232 context = limited(data=data)
191 233 assert guarded_eval(good, context) == expected
192 234
193 235 with pytest.raises(GuardRejection):
194 236 guarded_eval(bad, context)
195 237
196 238
197 239 @pytest.mark.parametrize(
198 240 "code,expected",
199 241 [
200 242 ["(1\n+\n1)", 2],
201 243 ["list(range(10))[-1:]", [9]],
202 244 ["list(range(20))[3:-2:3]", [3, 6, 9, 12, 15]],
203 245 ],
204 246 )
205 247 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
206 248 def test_evaluates_complex_cases(code, expected, context):
207 249 assert guarded_eval(code, context()) == expected
208 250
209 251
210 252 @pytest.mark.parametrize(
211 253 "code,expected",
212 254 [
213 255 ["1", 1],
214 256 ["1.0", 1.0],
215 257 ["0xdeedbeef", 0xDEEDBEEF],
216 258 ["True", True],
217 259 ["None", None],
218 260 ["{}", {}],
219 261 ["[]", []],
220 262 ],
221 263 )
222 264 @pytest.mark.parametrize("context", MINIMAL_OR_HIGHER)
223 265 def test_evaluates_literals(code, expected, context):
224 266 assert guarded_eval(code, context()) == expected
225 267
226 268
227 269 @pytest.mark.parametrize(
228 270 "code,expected",
229 271 [
230 272 ["-5", -5],
231 273 ["+5", +5],
232 274 ["~5", -6],
233 275 ],
234 276 )
235 277 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
236 278 def test_evaluates_unary_operations(code, expected, context):
237 279 assert guarded_eval(code, context()) == expected
238 280
239 281
240 282 @pytest.mark.parametrize(
241 283 "code,expected",
242 284 [
243 285 ["1 + 1", 2],
244 286 ["3 - 1", 2],
245 287 ["2 * 3", 6],
246 288 ["5 // 2", 2],
247 289 ["5 / 2", 2.5],
248 290 ["5**2", 25],
249 291 ["2 >> 1", 1],
250 292 ["2 << 1", 4],
251 293 ["1 | 2", 3],
252 294 ["1 & 1", 1],
253 295 ["1 & 2", 0],
254 296 ],
255 297 )
256 298 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
257 299 def test_evaluates_binary_operations(code, expected, context):
258 300 assert guarded_eval(code, context()) == expected
259 301
260 302
261 303 @pytest.mark.parametrize(
262 304 "code,expected",
263 305 [
264 306 ["2 > 1", True],
265 307 ["2 < 1", False],
266 308 ["2 <= 1", False],
267 309 ["2 <= 2", True],
268 310 ["1 >= 2", False],
269 311 ["2 >= 2", True],
270 312 ["2 == 2", True],
271 313 ["1 == 2", False],
272 314 ["1 != 2", True],
273 315 ["1 != 1", False],
274 316 ["1 < 4 < 3", False],
275 317 ["(1 < 4) < 3", True],
276 318 ["4 > 3 > 2 > 1", True],
277 319 ["4 > 3 > 2 > 9", False],
278 320 ["1 < 2 < 3 < 4", True],
279 321 ["9 < 2 < 3 < 4", False],
280 322 ["1 < 2 > 1 > 0 > -1 < 1", True],
281 323 ["1 in [1] in [[1]]", True],
282 324 ["1 in [1] in [[2]]", False],
283 325 ["1 in [1]", True],
284 326 ["0 in [1]", False],
285 327 ["1 not in [1]", False],
286 328 ["0 not in [1]", True],
287 329 ["True is True", True],
288 330 ["False is False", True],
289 331 ["True is False", False],
290 332 ["True is not True", False],
291 333 ["False is not True", True],
292 334 ],
293 335 )
294 336 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
295 337 def test_evaluates_comparisons(code, expected, context):
296 338 assert guarded_eval(code, context()) == expected
297 339
298 340
299 341 def test_guards_comparisons():
300 342 class GoodEq(int):
301 343 pass
302 344
303 345 class BadEq(int):
304 346 def __eq__(self, other):
305 347 assert False
306 348
307 349 context = limited(bad=BadEq(1), good=GoodEq(1))
308 350
309 351 with pytest.raises(GuardRejection):
310 352 guarded_eval("bad == 1", context)
311 353
312 354 with pytest.raises(GuardRejection):
313 355 guarded_eval("bad != 1", context)
314 356
315 357 with pytest.raises(GuardRejection):
316 358 guarded_eval("1 == bad", context)
317 359
318 360 with pytest.raises(GuardRejection):
319 361 guarded_eval("1 != bad", context)
320 362
321 363 assert guarded_eval("good == 1", context) is True
322 364 assert guarded_eval("good != 1", context) is False
323 365 assert guarded_eval("1 == good", context) is True
324 366 assert guarded_eval("1 != good", context) is False
325 367
326 368
327 369 def test_guards_unary_operations():
328 370 class GoodOp(int):
329 371 pass
330 372
331 373 class BadOpInv(int):
332 374 def __inv__(self, other):
333 375 assert False
334 376
335 377 class BadOpInverse(int):
336 378 def __inv__(self, other):
337 379 assert False
338 380
339 381 context = limited(good=GoodOp(1), bad1=BadOpInv(1), bad2=BadOpInverse(1))
340 382
341 383 with pytest.raises(GuardRejection):
342 384 guarded_eval("~bad1", context)
343 385
344 386 with pytest.raises(GuardRejection):
345 387 guarded_eval("~bad2", context)
346 388
347 389
348 390 def test_guards_binary_operations():
349 391 class GoodOp(int):
350 392 pass
351 393
352 394 class BadOp(int):
353 395 def __add__(self, other):
354 396 assert False
355 397
356 398 context = limited(good=GoodOp(1), bad=BadOp(1))
357 399
358 400 with pytest.raises(GuardRejection):
359 401 guarded_eval("1 + bad", context)
360 402
361 403 with pytest.raises(GuardRejection):
362 404 guarded_eval("bad + 1", context)
363 405
364 406 assert guarded_eval("good + 1", context) == 2
365 407 assert guarded_eval("1 + good", context) == 2
366 408
367 409
368 410 def test_guards_attributes():
369 411 class GoodAttr(float):
370 412 pass
371 413
372 414 class BadAttr1(float):
373 415 def __getattr__(self, key):
374 416 assert False
375 417
376 418 class BadAttr2(float):
377 419 def __getattribute__(self, key):
378 420 assert False
379 421
380 422 context = limited(good=GoodAttr(0.5), bad1=BadAttr1(0.5), bad2=BadAttr2(0.5))
381 423
382 424 with pytest.raises(GuardRejection):
383 425 guarded_eval("bad1.as_integer_ratio", context)
384 426
385 427 with pytest.raises(GuardRejection):
386 428 guarded_eval("bad2.as_integer_ratio", context)
387 429
388 430 assert guarded_eval("good.as_integer_ratio()", context) == (1, 2)
389 431
390 432
391 433 @pytest.mark.parametrize("context", MINIMAL_OR_HIGHER)
392 434 def test_access_builtins(context):
393 435 assert guarded_eval("round", context()) == round
394 436
395 437
396 438 def test_access_builtins_fails():
397 439 context = limited()
398 440 with pytest.raises(NameError):
399 441 guarded_eval("this_is_not_builtin", context)
400 442
401 443
402 444 def test_rejects_forbidden():
403 445 context = forbidden()
404 446 with pytest.raises(GuardRejection):
405 447 guarded_eval("1", context)
406 448
407 449
408 450 def test_guards_locals_and_globals():
409 451 context = EvaluationContext(
410 452 locals={"local_a": "a"}, globals={"global_b": "b"}, evaluation="minimal"
411 453 )
412 454
413 455 with pytest.raises(GuardRejection):
414 456 guarded_eval("local_a", context)
415 457
416 458 with pytest.raises(GuardRejection):
417 459 guarded_eval("global_b", context)
418 460
419 461
420 462 def test_access_locals_and_globals():
421 463 context = EvaluationContext(
422 464 locals={"local_a": "a"}, globals={"global_b": "b"}, evaluation="limited"
423 465 )
424 466 assert guarded_eval("local_a", context) == "a"
425 467 assert guarded_eval("global_b", context) == "b"
426 468
427 469
428 470 @pytest.mark.parametrize(
429 471 "code",
430 472 ["def func(): pass", "class C: pass", "x = 1", "x += 1", "del x", "import ast"],
431 473 )
432 474 @pytest.mark.parametrize("context", [minimal(), limited(), unsafe()])
433 475 def test_rejects_side_effect_syntax(code, context):
434 476 with pytest.raises(SyntaxError):
435 477 guarded_eval(code, context)
436 478
437 479
438 480 def test_subscript():
439 481 context = EvaluationContext(
440 482 locals={}, globals={}, evaluation="limited", in_subscript=True
441 483 )
442 484 empty_slice = slice(None, None, None)
443 485 assert guarded_eval("", context) == tuple()
444 486 assert guarded_eval(":", context) == empty_slice
445 487 assert guarded_eval("1:2:3", context) == slice(1, 2, 3)
446 488 assert guarded_eval(':, "a"', context) == (empty_slice, "a")
447 489
448 490
449 491 def test_unbind_method():
450 492 class X(list):
451 493 def index(self, k):
452 494 return "CUSTOM"
453 495
454 496 x = X()
455 497 assert _unbind_method(x.index) is X.index
456 498 assert _unbind_method([].index) is list.index
457 499
458 500
459 501 def test_assumption_instance_attr_do_not_matter():
460 502 """This is semi-specified in Python documentation.
461 503
462 504 However, since the specification says 'not guaranted
463 505 to work' rather than 'is forbidden to work', future
464 506 versions could invalidate this assumptions. This test
465 507 is meant to catch such a change if it ever comes true.
466 508 """
467 509
468 510 class T:
469 511 def __getitem__(self, k):
470 512 return "a"
471 513
472 514 def __getattr__(self, k):
473 515 return "a"
474 516
517 def f(self):
518 return "b"
519
475 520 t = T()
476 t.__getitem__ = lambda f: "b"
477 t.__getattr__ = lambda f: "b"
521 t.__getitem__ = f
522 t.__getattr__ = f
478 523 assert t[1] == "a"
479 524 assert t[1] == "a"
480 525
481 526
482 527 def test_assumption_named_tuples_share_getitem():
483 528 """Check assumption on named tuples sharing __getitem__"""
484 529 from typing import NamedTuple
485 530
486 531 class A(NamedTuple):
487 532 pass
488 533
489 534 class B(NamedTuple):
490 535 pass
491 536
492 537 assert A.__getitem__ == B.__getitem__
General Comments 0
You need to be logged in to leave comments. Login now