##// END OF EJS Templates
Backport PR #14029: Allow safe access to the `__getattribute__` method of modules
Matthias Bussonnier -
Show More
@@ -1,738 +1,740 b''
1 1 from typing import (
2 2 Any,
3 3 Callable,
4 4 Dict,
5 5 Set,
6 6 Sequence,
7 7 Tuple,
8 8 NamedTuple,
9 9 Type,
10 10 Literal,
11 11 Union,
12 12 TYPE_CHECKING,
13 13 )
14 14 import ast
15 15 import builtins
16 16 import collections
17 17 import operator
18 18 import sys
19 19 from functools import cached_property
20 20 from dataclasses import dataclass, field
21 21
22 22 from IPython.utils.docs import GENERATING_DOCUMENTATION
23 23 from IPython.utils.decorators import undoc
24 24
25 25
26 26 if TYPE_CHECKING or GENERATING_DOCUMENTATION:
27 27 from typing_extensions import Protocol
28 28 else:
29 29 # do not require on runtime
30 30 Protocol = object # requires Python >=3.8
31 31
32 32
33 33 @undoc
34 34 class HasGetItem(Protocol):
35 35 def __getitem__(self, key) -> None:
36 36 ...
37 37
38 38
39 39 @undoc
40 40 class InstancesHaveGetItem(Protocol):
41 41 def __call__(self, *args, **kwargs) -> HasGetItem:
42 42 ...
43 43
44 44
45 45 @undoc
46 46 class HasGetAttr(Protocol):
47 47 def __getattr__(self, key) -> None:
48 48 ...
49 49
50 50
51 51 @undoc
52 52 class DoesNotHaveGetAttr(Protocol):
53 53 pass
54 54
55 55
56 56 # By default `__getattr__` is not explicitly implemented on most objects
57 57 MayHaveGetattr = Union[HasGetAttr, DoesNotHaveGetAttr]
58 58
59 59
60 60 def _unbind_method(func: Callable) -> Union[Callable, None]:
61 61 """Get unbound method for given bound method.
62 62
63 63 Returns None if cannot get unbound method, or method is already unbound.
64 64 """
65 65 owner = getattr(func, "__self__", None)
66 66 owner_class = type(owner)
67 67 name = getattr(func, "__name__", None)
68 68 instance_dict_overrides = getattr(owner, "__dict__", None)
69 69 if (
70 70 owner is not None
71 71 and name
72 72 and (
73 73 not instance_dict_overrides
74 74 or (instance_dict_overrides and name not in instance_dict_overrides)
75 75 )
76 76 ):
77 77 return getattr(owner_class, name)
78 78 return None
79 79
80 80
81 81 @undoc
82 82 @dataclass
83 83 class EvaluationPolicy:
84 84 """Definition of evaluation policy."""
85 85
86 86 allow_locals_access: bool = False
87 87 allow_globals_access: bool = False
88 88 allow_item_access: bool = False
89 89 allow_attr_access: bool = False
90 90 allow_builtins_access: bool = False
91 91 allow_all_operations: bool = False
92 92 allow_any_calls: bool = False
93 93 allowed_calls: Set[Callable] = field(default_factory=set)
94 94
95 95 def can_get_item(self, value, item):
96 96 return self.allow_item_access
97 97
98 98 def can_get_attr(self, value, attr):
99 99 return self.allow_attr_access
100 100
101 101 def can_operate(self, dunders: Tuple[str, ...], a, b=None):
102 102 if self.allow_all_operations:
103 103 return True
104 104
105 105 def can_call(self, func):
106 106 if self.allow_any_calls:
107 107 return True
108 108
109 109 if func in self.allowed_calls:
110 110 return True
111 111
112 112 owner_method = _unbind_method(func)
113 113
114 114 if owner_method and owner_method in self.allowed_calls:
115 115 return True
116 116
117 117
118 118 def _get_external(module_name: str, access_path: Sequence[str]):
119 119 """Get value from external module given a dotted access path.
120 120
121 121 Raises:
122 122 * `KeyError` if module is removed not found, and
123 123 * `AttributeError` if acess path does not match an exported object
124 124 """
125 125 member_type = sys.modules[module_name]
126 126 for attr in access_path:
127 127 member_type = getattr(member_type, attr)
128 128 return member_type
129 129
130 130
131 131 def _has_original_dunder_external(
132 132 value,
133 133 module_name: str,
134 134 access_path: Sequence[str],
135 135 method_name: str,
136 136 ):
137 137 if module_name not in sys.modules:
138 138 # LBYLB as it is faster
139 139 return False
140 140 try:
141 141 member_type = _get_external(module_name, access_path)
142 142 value_type = type(value)
143 143 if type(value) == member_type:
144 144 return True
145 145 if method_name == "__getattribute__":
146 146 # we have to short-circuit here due to an unresolved issue in
147 147 # `isinstance` implementation: https://bugs.python.org/issue32683
148 148 return False
149 149 if isinstance(value, member_type):
150 150 method = getattr(value_type, method_name, None)
151 151 member_method = getattr(member_type, method_name, None)
152 152 if member_method == method:
153 153 return True
154 154 except (AttributeError, KeyError):
155 155 return False
156 156
157 157
158 158 def _has_original_dunder(
159 159 value, allowed_types, allowed_methods, allowed_external, method_name
160 160 ):
161 161 # note: Python ignores `__getattr__`/`__getitem__` on instances,
162 162 # we only need to check at class level
163 163 value_type = type(value)
164 164
165 165 # strict type check passes β†’ no need to check method
166 166 if value_type in allowed_types:
167 167 return True
168 168
169 169 method = getattr(value_type, method_name, None)
170 170
171 171 if method is None:
172 172 return None
173 173
174 174 if method in allowed_methods:
175 175 return True
176 176
177 177 for module_name, *access_path in allowed_external:
178 178 if _has_original_dunder_external(value, module_name, access_path, method_name):
179 179 return True
180 180
181 181 return False
182 182
183 183
184 184 @undoc
185 185 @dataclass
186 186 class SelectivePolicy(EvaluationPolicy):
187 187 allowed_getitem: Set[InstancesHaveGetItem] = field(default_factory=set)
188 188 allowed_getitem_external: Set[Tuple[str, ...]] = field(default_factory=set)
189 189
190 190 allowed_getattr: Set[MayHaveGetattr] = field(default_factory=set)
191 191 allowed_getattr_external: Set[Tuple[str, ...]] = field(default_factory=set)
192 192
193 193 allowed_operations: Set = field(default_factory=set)
194 194 allowed_operations_external: Set[Tuple[str, ...]] = field(default_factory=set)
195 195
196 196 _operation_methods_cache: Dict[str, Set[Callable]] = field(
197 197 default_factory=dict, init=False
198 198 )
199 199
200 200 def can_get_attr(self, value, attr):
201 201 has_original_attribute = _has_original_dunder(
202 202 value,
203 203 allowed_types=self.allowed_getattr,
204 204 allowed_methods=self._getattribute_methods,
205 205 allowed_external=self.allowed_getattr_external,
206 206 method_name="__getattribute__",
207 207 )
208 208 has_original_attr = _has_original_dunder(
209 209 value,
210 210 allowed_types=self.allowed_getattr,
211 211 allowed_methods=self._getattr_methods,
212 212 allowed_external=self.allowed_getattr_external,
213 213 method_name="__getattr__",
214 214 )
215 215
216 216 accept = False
217 217
218 218 # Many objects do not have `__getattr__`, this is fine.
219 219 if has_original_attr is None and has_original_attribute:
220 220 accept = True
221 221 else:
222 222 # Accept objects without modifications to `__getattr__` and `__getattribute__`
223 223 accept = has_original_attr and has_original_attribute
224 224
225 225 if accept:
226 226 # We still need to check for overriden properties.
227 227
228 228 value_class = type(value)
229 229 if not hasattr(value_class, attr):
230 230 return True
231 231
232 232 class_attr_val = getattr(value_class, attr)
233 233 is_property = isinstance(class_attr_val, property)
234 234
235 235 if not is_property:
236 236 return True
237 237
238 238 # Properties in allowed types are ok (although we do not include any
239 239 # properties in our default allow list currently).
240 240 if type(value) in self.allowed_getattr:
241 241 return True # pragma: no cover
242 242
243 243 # Properties in subclasses of allowed types may be ok if not changed
244 244 for module_name, *access_path in self.allowed_getattr_external:
245 245 try:
246 246 external_class = _get_external(module_name, access_path)
247 247 external_class_attr_val = getattr(external_class, attr)
248 248 except (KeyError, AttributeError):
249 249 return False # pragma: no cover
250 250 return class_attr_val == external_class_attr_val
251 251
252 252 return False
253 253
254 254 def can_get_item(self, value, item):
255 255 """Allow accessing `__getiitem__` of allow-listed instances unless it was not modified."""
256 256 return _has_original_dunder(
257 257 value,
258 258 allowed_types=self.allowed_getitem,
259 259 allowed_methods=self._getitem_methods,
260 260 allowed_external=self.allowed_getitem_external,
261 261 method_name="__getitem__",
262 262 )
263 263
264 264 def can_operate(self, dunders: Tuple[str, ...], a, b=None):
265 265 objects = [a]
266 266 if b is not None:
267 267 objects.append(b)
268 268 return all(
269 269 [
270 270 _has_original_dunder(
271 271 obj,
272 272 allowed_types=self.allowed_operations,
273 273 allowed_methods=self._operator_dunder_methods(dunder),
274 274 allowed_external=self.allowed_operations_external,
275 275 method_name=dunder,
276 276 )
277 277 for dunder in dunders
278 278 for obj in objects
279 279 ]
280 280 )
281 281
282 282 def _operator_dunder_methods(self, dunder: str) -> Set[Callable]:
283 283 if dunder not in self._operation_methods_cache:
284 284 self._operation_methods_cache[dunder] = self._safe_get_methods(
285 285 self.allowed_operations, dunder
286 286 )
287 287 return self._operation_methods_cache[dunder]
288 288
289 289 @cached_property
290 290 def _getitem_methods(self) -> Set[Callable]:
291 291 return self._safe_get_methods(self.allowed_getitem, "__getitem__")
292 292
293 293 @cached_property
294 294 def _getattr_methods(self) -> Set[Callable]:
295 295 return self._safe_get_methods(self.allowed_getattr, "__getattr__")
296 296
297 297 @cached_property
298 298 def _getattribute_methods(self) -> Set[Callable]:
299 299 return self._safe_get_methods(self.allowed_getattr, "__getattribute__")
300 300
301 301 def _safe_get_methods(self, classes, name) -> Set[Callable]:
302 302 return {
303 303 method
304 304 for class_ in classes
305 305 for method in [getattr(class_, name, None)]
306 306 if method
307 307 }
308 308
309 309
310 310 class _DummyNamedTuple(NamedTuple):
311 311 """Used internally to retrieve methods of named tuple instance."""
312 312
313 313
314 314 class EvaluationContext(NamedTuple):
315 315 #: Local namespace
316 316 locals: dict
317 317 #: Global namespace
318 318 globals: dict
319 319 #: Evaluation policy identifier
320 320 evaluation: Literal[
321 321 "forbidden", "minimal", "limited", "unsafe", "dangerous"
322 322 ] = "forbidden"
323 323 #: Whether the evalution of code takes place inside of a subscript.
324 324 #: Useful for evaluating ``:-1, 'col'`` in ``df[:-1, 'col']``.
325 325 in_subscript: bool = False
326 326
327 327
328 328 class _IdentitySubscript:
329 329 """Returns the key itself when item is requested via subscript."""
330 330
331 331 def __getitem__(self, key):
332 332 return key
333 333
334 334
335 335 IDENTITY_SUBSCRIPT = _IdentitySubscript()
336 336 SUBSCRIPT_MARKER = "__SUBSCRIPT_SENTINEL__"
337 337
338 338
339 339 class GuardRejection(Exception):
340 340 """Exception raised when guard rejects evaluation attempt."""
341 341
342 342 pass
343 343
344 344
345 345 def guarded_eval(code: str, context: EvaluationContext):
346 346 """Evaluate provided code in the evaluation context.
347 347
348 348 If evaluation policy given by context is set to ``forbidden``
349 349 no evaluation will be performed; if it is set to ``dangerous``
350 350 standard :func:`eval` will be used; finally, for any other,
351 351 policy :func:`eval_node` will be called on parsed AST.
352 352 """
353 353 locals_ = context.locals
354 354
355 355 if context.evaluation == "forbidden":
356 356 raise GuardRejection("Forbidden mode")
357 357
358 358 # note: not using `ast.literal_eval` as it does not implement
359 359 # getitem at all, for example it fails on simple `[0][1]`
360 360
361 361 if context.in_subscript:
362 362 # syntatic sugar for ellipsis (:) is only available in susbcripts
363 363 # so we need to trick the ast parser into thinking that we have
364 364 # a subscript, but we need to be able to later recognise that we did
365 365 # it so we can ignore the actual __getitem__ operation
366 366 if not code:
367 367 return tuple()
368 368 locals_ = locals_.copy()
369 369 locals_[SUBSCRIPT_MARKER] = IDENTITY_SUBSCRIPT
370 370 code = SUBSCRIPT_MARKER + "[" + code + "]"
371 371 context = EvaluationContext(**{**context._asdict(), **{"locals": locals_}})
372 372
373 373 if context.evaluation == "dangerous":
374 374 return eval(code, context.globals, context.locals)
375 375
376 376 expression = ast.parse(code, mode="eval")
377 377
378 378 return eval_node(expression, context)
379 379
380 380
381 381 BINARY_OP_DUNDERS: Dict[Type[ast.operator], Tuple[str]] = {
382 382 ast.Add: ("__add__",),
383 383 ast.Sub: ("__sub__",),
384 384 ast.Mult: ("__mul__",),
385 385 ast.Div: ("__truediv__",),
386 386 ast.FloorDiv: ("__floordiv__",),
387 387 ast.Mod: ("__mod__",),
388 388 ast.Pow: ("__pow__",),
389 389 ast.LShift: ("__lshift__",),
390 390 ast.RShift: ("__rshift__",),
391 391 ast.BitOr: ("__or__",),
392 392 ast.BitXor: ("__xor__",),
393 393 ast.BitAnd: ("__and__",),
394 394 ast.MatMult: ("__matmul__",),
395 395 }
396 396
397 397 COMP_OP_DUNDERS: Dict[Type[ast.cmpop], Tuple[str, ...]] = {
398 398 ast.Eq: ("__eq__",),
399 399 ast.NotEq: ("__ne__", "__eq__"),
400 400 ast.Lt: ("__lt__", "__gt__"),
401 401 ast.LtE: ("__le__", "__ge__"),
402 402 ast.Gt: ("__gt__", "__lt__"),
403 403 ast.GtE: ("__ge__", "__le__"),
404 404 ast.In: ("__contains__",),
405 405 # Note: ast.Is, ast.IsNot, ast.NotIn are handled specially
406 406 }
407 407
408 408 UNARY_OP_DUNDERS: Dict[Type[ast.unaryop], Tuple[str, ...]] = {
409 409 ast.USub: ("__neg__",),
410 410 ast.UAdd: ("__pos__",),
411 411 # we have to check both __inv__ and __invert__!
412 412 ast.Invert: ("__invert__", "__inv__"),
413 413 ast.Not: ("__not__",),
414 414 }
415 415
416 416
417 417 def _find_dunder(node_op, dunders) -> Union[Tuple[str, ...], None]:
418 418 dunder = None
419 419 for op, candidate_dunder in dunders.items():
420 420 if isinstance(node_op, op):
421 421 dunder = candidate_dunder
422 422 return dunder
423 423
424 424
425 425 def eval_node(node: Union[ast.AST, None], context: EvaluationContext):
426 426 """Evaluate AST node in provided context.
427 427
428 428 Applies evaluation restrictions defined in the context. Currently does not support evaluation of functions with keyword arguments.
429 429
430 430 Does not evaluate actions that always have side effects:
431 431
432 432 - class definitions (``class sth: ...``)
433 433 - function definitions (``def sth: ...``)
434 434 - variable assignments (``x = 1``)
435 435 - augmented assignments (``x += 1``)
436 436 - deletions (``del x``)
437 437
438 438 Does not evaluate operations which do not return values:
439 439
440 440 - assertions (``assert x``)
441 441 - pass (``pass``)
442 442 - imports (``import x``)
443 443 - control flow:
444 444
445 445 - conditionals (``if x:``) except for ternary IfExp (``a if x else b``)
446 446 - loops (``for`` and `while``)
447 447 - exception handling
448 448
449 449 The purpose of this function is to guard against unwanted side-effects;
450 450 it does not give guarantees on protection from malicious code execution.
451 451 """
452 452 policy = EVALUATION_POLICIES[context.evaluation]
453 453 if node is None:
454 454 return None
455 455 if isinstance(node, ast.Expression):
456 456 return eval_node(node.body, context)
457 457 if isinstance(node, ast.BinOp):
458 458 left = eval_node(node.left, context)
459 459 right = eval_node(node.right, context)
460 460 dunders = _find_dunder(node.op, BINARY_OP_DUNDERS)
461 461 if dunders:
462 462 if policy.can_operate(dunders, left, right):
463 463 return getattr(left, dunders[0])(right)
464 464 else:
465 465 raise GuardRejection(
466 466 f"Operation (`{dunders}`) for",
467 467 type(left),
468 468 f"not allowed in {context.evaluation} mode",
469 469 )
470 470 if isinstance(node, ast.Compare):
471 471 left = eval_node(node.left, context)
472 472 all_true = True
473 473 negate = False
474 474 for op, right in zip(node.ops, node.comparators):
475 475 right = eval_node(right, context)
476 476 dunder = None
477 477 dunders = _find_dunder(op, COMP_OP_DUNDERS)
478 478 if not dunders:
479 479 if isinstance(op, ast.NotIn):
480 480 dunders = COMP_OP_DUNDERS[ast.In]
481 481 negate = True
482 482 if isinstance(op, ast.Is):
483 483 dunder = "is_"
484 484 if isinstance(op, ast.IsNot):
485 485 dunder = "is_"
486 486 negate = True
487 487 if not dunder and dunders:
488 488 dunder = dunders[0]
489 489 if dunder:
490 490 a, b = (right, left) if dunder == "__contains__" else (left, right)
491 491 if dunder == "is_" or dunders and policy.can_operate(dunders, a, b):
492 492 result = getattr(operator, dunder)(a, b)
493 493 if negate:
494 494 result = not result
495 495 if not result:
496 496 all_true = False
497 497 left = right
498 498 else:
499 499 raise GuardRejection(
500 500 f"Comparison (`{dunder}`) for",
501 501 type(left),
502 502 f"not allowed in {context.evaluation} mode",
503 503 )
504 504 else:
505 505 raise ValueError(
506 506 f"Comparison `{dunder}` not supported"
507 507 ) # pragma: no cover
508 508 return all_true
509 509 if isinstance(node, ast.Constant):
510 510 return node.value
511 511 if isinstance(node, ast.Index):
512 512 # deprecated since Python 3.9
513 513 return eval_node(node.value, context) # pragma: no cover
514 514 if isinstance(node, ast.Tuple):
515 515 return tuple(eval_node(e, context) for e in node.elts)
516 516 if isinstance(node, ast.List):
517 517 return [eval_node(e, context) for e in node.elts]
518 518 if isinstance(node, ast.Set):
519 519 return {eval_node(e, context) for e in node.elts}
520 520 if isinstance(node, ast.Dict):
521 521 return dict(
522 522 zip(
523 523 [eval_node(k, context) for k in node.keys],
524 524 [eval_node(v, context) for v in node.values],
525 525 )
526 526 )
527 527 if isinstance(node, ast.Slice):
528 528 return slice(
529 529 eval_node(node.lower, context),
530 530 eval_node(node.upper, context),
531 531 eval_node(node.step, context),
532 532 )
533 533 if isinstance(node, ast.ExtSlice):
534 534 # deprecated since Python 3.9
535 535 return tuple([eval_node(dim, context) for dim in node.dims]) # pragma: no cover
536 536 if isinstance(node, ast.UnaryOp):
537 537 value = eval_node(node.operand, context)
538 538 dunders = _find_dunder(node.op, UNARY_OP_DUNDERS)
539 539 if dunders:
540 540 if policy.can_operate(dunders, value):
541 541 return getattr(value, dunders[0])()
542 542 else:
543 543 raise GuardRejection(
544 544 f"Operation (`{dunders}`) for",
545 545 type(value),
546 546 f"not allowed in {context.evaluation} mode",
547 547 )
548 548 if isinstance(node, ast.Subscript):
549 549 value = eval_node(node.value, context)
550 550 slice_ = eval_node(node.slice, context)
551 551 if policy.can_get_item(value, slice_):
552 552 return value[slice_]
553 553 raise GuardRejection(
554 554 "Subscript access (`__getitem__`) for",
555 555 type(value), # not joined to avoid calling `repr`
556 556 f" not allowed in {context.evaluation} mode",
557 557 )
558 558 if isinstance(node, ast.Name):
559 559 if policy.allow_locals_access and node.id in context.locals:
560 560 return context.locals[node.id]
561 561 if policy.allow_globals_access and node.id in context.globals:
562 562 return context.globals[node.id]
563 563 if policy.allow_builtins_access and hasattr(builtins, node.id):
564 564 # note: do not use __builtins__, it is implementation detail of cPython
565 565 return getattr(builtins, node.id)
566 566 if not policy.allow_globals_access and not policy.allow_locals_access:
567 567 raise GuardRejection(
568 568 f"Namespace access not allowed in {context.evaluation} mode"
569 569 )
570 570 else:
571 571 raise NameError(f"{node.id} not found in locals, globals, nor builtins")
572 572 if isinstance(node, ast.Attribute):
573 573 value = eval_node(node.value, context)
574 574 if policy.can_get_attr(value, node.attr):
575 575 return getattr(value, node.attr)
576 576 raise GuardRejection(
577 577 "Attribute access (`__getattr__`) for",
578 578 type(value), # not joined to avoid calling `repr`
579 579 f"not allowed in {context.evaluation} mode",
580 580 )
581 581 if isinstance(node, ast.IfExp):
582 582 test = eval_node(node.test, context)
583 583 if test:
584 584 return eval_node(node.body, context)
585 585 else:
586 586 return eval_node(node.orelse, context)
587 587 if isinstance(node, ast.Call):
588 588 func = eval_node(node.func, context)
589 589 if policy.can_call(func) and not node.keywords:
590 590 args = [eval_node(arg, context) for arg in node.args]
591 591 return func(*args)
592 592 raise GuardRejection(
593 593 "Call for",
594 594 func, # not joined to avoid calling `repr`
595 595 f"not allowed in {context.evaluation} mode",
596 596 )
597 597 raise ValueError("Unhandled node", ast.dump(node))
598 598
599 599
600 600 SUPPORTED_EXTERNAL_GETITEM = {
601 601 ("pandas", "core", "indexing", "_iLocIndexer"),
602 602 ("pandas", "core", "indexing", "_LocIndexer"),
603 603 ("pandas", "DataFrame"),
604 604 ("pandas", "Series"),
605 605 ("numpy", "ndarray"),
606 606 ("numpy", "void"),
607 607 }
608 608
609 609
610 610 BUILTIN_GETITEM: Set[InstancesHaveGetItem] = {
611 611 dict,
612 612 str, # type: ignore[arg-type]
613 613 bytes, # type: ignore[arg-type]
614 614 list,
615 615 tuple,
616 616 collections.defaultdict,
617 617 collections.deque,
618 618 collections.OrderedDict,
619 619 collections.ChainMap,
620 620 collections.UserDict,
621 621 collections.UserList,
622 622 collections.UserString, # type: ignore[arg-type]
623 623 _DummyNamedTuple,
624 624 _IdentitySubscript,
625 625 }
626 626
627 627
628 628 def _list_methods(cls, source=None):
629 629 """For use on immutable objects or with methods returning a copy"""
630 630 return [getattr(cls, k) for k in (source if source else dir(cls))]
631 631
632 632
633 633 dict_non_mutating_methods = ("copy", "keys", "values", "items")
634 634 list_non_mutating_methods = ("copy", "index", "count")
635 635 set_non_mutating_methods = set(dir(set)) & set(dir(frozenset))
636 636
637 637
638 638 dict_keys: Type[collections.abc.KeysView] = type({}.keys())
639 639 method_descriptor: Any = type(list.copy)
640 module = type(builtins)
640 641
641 642 NUMERICS = {int, float, complex}
642 643
643 644 ALLOWED_CALLS = {
644 645 bytes,
645 646 *_list_methods(bytes),
646 647 dict,
647 648 *_list_methods(dict, dict_non_mutating_methods),
648 649 dict_keys.isdisjoint,
649 650 list,
650 651 *_list_methods(list, list_non_mutating_methods),
651 652 set,
652 653 *_list_methods(set, set_non_mutating_methods),
653 654 frozenset,
654 655 *_list_methods(frozenset),
655 656 range,
656 657 str,
657 658 *_list_methods(str),
658 659 tuple,
659 660 *_list_methods(tuple),
660 661 *NUMERICS,
661 662 *[method for numeric_cls in NUMERICS for method in _list_methods(numeric_cls)],
662 663 collections.deque,
663 664 *_list_methods(collections.deque, list_non_mutating_methods),
664 665 collections.defaultdict,
665 666 *_list_methods(collections.defaultdict, dict_non_mutating_methods),
666 667 collections.OrderedDict,
667 668 *_list_methods(collections.OrderedDict, dict_non_mutating_methods),
668 669 collections.UserDict,
669 670 *_list_methods(collections.UserDict, dict_non_mutating_methods),
670 671 collections.UserList,
671 672 *_list_methods(collections.UserList, list_non_mutating_methods),
672 673 collections.UserString,
673 674 *_list_methods(collections.UserString, dir(str)),
674 675 collections.Counter,
675 676 *_list_methods(collections.Counter, dict_non_mutating_methods),
676 677 collections.Counter.elements,
677 678 collections.Counter.most_common,
678 679 }
679 680
680 681 BUILTIN_GETATTR: Set[MayHaveGetattr] = {
681 682 *BUILTIN_GETITEM,
682 683 set,
683 684 frozenset,
684 685 object,
685 686 type, # `type` handles a lot of generic cases, e.g. numbers as in `int.real`.
686 687 *NUMERICS,
687 688 dict_keys,
688 689 method_descriptor,
690 module,
689 691 }
690 692
691 693
692 694 BUILTIN_OPERATIONS = {*BUILTIN_GETATTR}
693 695
694 696 EVALUATION_POLICIES = {
695 697 "minimal": EvaluationPolicy(
696 698 allow_builtins_access=True,
697 699 allow_locals_access=False,
698 700 allow_globals_access=False,
699 701 allow_item_access=False,
700 702 allow_attr_access=False,
701 703 allowed_calls=set(),
702 704 allow_any_calls=False,
703 705 allow_all_operations=False,
704 706 ),
705 707 "limited": SelectivePolicy(
706 708 allowed_getitem=BUILTIN_GETITEM,
707 709 allowed_getitem_external=SUPPORTED_EXTERNAL_GETITEM,
708 710 allowed_getattr=BUILTIN_GETATTR,
709 711 allowed_getattr_external={
710 712 # pandas Series/Frame implements custom `__getattr__`
711 713 ("pandas", "DataFrame"),
712 714 ("pandas", "Series"),
713 715 },
714 716 allowed_operations=BUILTIN_OPERATIONS,
715 717 allow_builtins_access=True,
716 718 allow_locals_access=True,
717 719 allow_globals_access=True,
718 720 allowed_calls=ALLOWED_CALLS,
719 721 ),
720 722 "unsafe": EvaluationPolicy(
721 723 allow_builtins_access=True,
722 724 allow_locals_access=True,
723 725 allow_globals_access=True,
724 726 allow_attr_access=True,
725 727 allow_item_access=True,
726 728 allow_any_calls=True,
727 729 allow_all_operations=True,
728 730 ),
729 731 }
730 732
731 733
732 734 __all__ = [
733 735 "guarded_eval",
734 736 "eval_node",
735 737 "GuardRejection",
736 738 "EvaluationContext",
737 739 "_unbind_method",
738 740 ]
@@ -1,570 +1,582 b''
1 1 from contextlib import contextmanager
2 2 from typing import NamedTuple
3 3 from functools import partial
4 4 from IPython.core.guarded_eval import (
5 5 EvaluationContext,
6 6 GuardRejection,
7 7 guarded_eval,
8 8 _unbind_method,
9 9 )
10 10 from IPython.testing import decorators as dec
11 11 import pytest
12 12
13 13
14 14 def create_context(evaluation: str, **kwargs):
15 15 return EvaluationContext(locals=kwargs, globals={}, evaluation=evaluation)
16 16
17 17
18 18 forbidden = partial(create_context, "forbidden")
19 19 minimal = partial(create_context, "minimal")
20 20 limited = partial(create_context, "limited")
21 21 unsafe = partial(create_context, "unsafe")
22 22 dangerous = partial(create_context, "dangerous")
23 23
24 24 LIMITED_OR_HIGHER = [limited, unsafe, dangerous]
25 25 MINIMAL_OR_HIGHER = [minimal, *LIMITED_OR_HIGHER]
26 26
27 27
28 28 @contextmanager
29 29 def module_not_installed(module: str):
30 30 import sys
31 31
32 32 try:
33 33 to_restore = sys.modules[module]
34 34 del sys.modules[module]
35 35 except KeyError:
36 36 to_restore = None
37 37 try:
38 38 yield
39 39 finally:
40 40 sys.modules[module] = to_restore
41 41
42 42
43 43 def test_external_not_installed():
44 44 """
45 45 Because attribute check requires checking if object is not of allowed
46 46 external type, this tests logic for absence of external module.
47 47 """
48 48
49 49 class Custom:
50 50 def __init__(self):
51 51 self.test = 1
52 52
53 53 def __getattr__(self, key):
54 54 return key
55 55
56 56 with module_not_installed("pandas"):
57 57 context = limited(x=Custom())
58 58 with pytest.raises(GuardRejection):
59 59 guarded_eval("x.test", context)
60 60
61 61
62 62 @dec.skip_without("pandas")
63 63 def test_external_changed_api(monkeypatch):
64 64 """Check that the execution rejects if external API changed paths"""
65 65 import pandas as pd
66 66
67 67 series = pd.Series([1], index=["a"])
68 68
69 69 with monkeypatch.context() as m:
70 70 m.delattr(pd, "Series")
71 71 context = limited(data=series)
72 72 with pytest.raises(GuardRejection):
73 73 guarded_eval("data.iloc[0]", context)
74 74
75 75
76 76 @dec.skip_without("pandas")
77 77 def test_pandas_series_iloc():
78 78 import pandas as pd
79 79
80 80 series = pd.Series([1], index=["a"])
81 81 context = limited(data=series)
82 82 assert guarded_eval("data.iloc[0]", context) == 1
83 83
84 84
85 85 def test_rejects_custom_properties():
86 86 class BadProperty:
87 87 @property
88 88 def iloc(self):
89 89 return [None]
90 90
91 91 series = BadProperty()
92 92 context = limited(data=series)
93 93
94 94 with pytest.raises(GuardRejection):
95 95 guarded_eval("data.iloc[0]", context)
96 96
97 97
98 98 @dec.skip_without("pandas")
99 99 def test_accepts_non_overriden_properties():
100 100 import pandas as pd
101 101
102 102 class GoodProperty(pd.Series):
103 103 pass
104 104
105 105 series = GoodProperty([1], index=["a"])
106 106 context = limited(data=series)
107 107
108 108 assert guarded_eval("data.iloc[0]", context) == 1
109 109
110 110
111 111 @dec.skip_without("pandas")
112 112 def test_pandas_series():
113 113 import pandas as pd
114 114
115 115 context = limited(data=pd.Series([1], index=["a"]))
116 116 assert guarded_eval('data["a"]', context) == 1
117 117 with pytest.raises(KeyError):
118 118 guarded_eval('data["c"]', context)
119 119
120 120
121 121 @dec.skip_without("pandas")
122 122 def test_pandas_bad_series():
123 123 import pandas as pd
124 124
125 125 class BadItemSeries(pd.Series):
126 126 def __getitem__(self, key):
127 127 return "CUSTOM_ITEM"
128 128
129 129 class BadAttrSeries(pd.Series):
130 130 def __getattr__(self, key):
131 131 return "CUSTOM_ATTR"
132 132
133 133 bad_series = BadItemSeries([1], index=["a"])
134 134 context = limited(data=bad_series)
135 135
136 136 with pytest.raises(GuardRejection):
137 137 guarded_eval('data["a"]', context)
138 138 with pytest.raises(GuardRejection):
139 139 guarded_eval('data["c"]', context)
140 140
141 141 # note: here result is a bit unexpected because
142 142 # pandas `__getattr__` calls `__getitem__`;
143 143 # FIXME - special case to handle it?
144 144 assert guarded_eval("data.a", context) == "CUSTOM_ITEM"
145 145
146 146 context = unsafe(data=bad_series)
147 147 assert guarded_eval('data["a"]', context) == "CUSTOM_ITEM"
148 148
149 149 bad_attr_series = BadAttrSeries([1], index=["a"])
150 150 context = limited(data=bad_attr_series)
151 151 assert guarded_eval('data["a"]', context) == 1
152 152 with pytest.raises(GuardRejection):
153 153 guarded_eval("data.a", context)
154 154
155 155
156 156 @dec.skip_without("pandas")
157 157 def test_pandas_dataframe_loc():
158 158 import pandas as pd
159 159 from pandas.testing import assert_series_equal
160 160
161 161 data = pd.DataFrame([{"a": 1}])
162 162 context = limited(data=data)
163 163 assert_series_equal(guarded_eval('data.loc[:, "a"]', context), data["a"])
164 164
165 165
166 166 def test_named_tuple():
167 167 class GoodNamedTuple(NamedTuple):
168 168 a: str
169 169 pass
170 170
171 171 class BadNamedTuple(NamedTuple):
172 172 a: str
173 173
174 174 def __getitem__(self, key):
175 175 return None
176 176
177 177 good = GoodNamedTuple(a="x")
178 178 bad = BadNamedTuple(a="x")
179 179
180 180 context = limited(data=good)
181 181 assert guarded_eval("data[0]", context) == "x"
182 182
183 183 context = limited(data=bad)
184 184 with pytest.raises(GuardRejection):
185 185 guarded_eval("data[0]", context)
186 186
187 187
188 188 def test_dict():
189 189 context = limited(data={"a": 1, "b": {"x": 2}, ("x", "y"): 3})
190 190 assert guarded_eval('data["a"]', context) == 1
191 191 assert guarded_eval('data["b"]', context) == {"x": 2}
192 192 assert guarded_eval('data["b"]["x"]', context) == 2
193 193 assert guarded_eval('data["x", "y"]', context) == 3
194 194
195 195 assert guarded_eval("data.keys", context)
196 196
197 197
198 198 def test_set():
199 199 context = limited(data={"a", "b"})
200 200 assert guarded_eval("data.difference", context)
201 201
202 202
203 203 def test_list():
204 204 context = limited(data=[1, 2, 3])
205 205 assert guarded_eval("data[1]", context) == 2
206 206 assert guarded_eval("data.copy", context)
207 207
208 208
209 209 def test_dict_literal():
210 210 context = limited()
211 211 assert guarded_eval("{}", context) == {}
212 212 assert guarded_eval('{"a": 1}', context) == {"a": 1}
213 213
214 214
215 215 def test_list_literal():
216 216 context = limited()
217 217 assert guarded_eval("[]", context) == []
218 218 assert guarded_eval('[1, "a"]', context) == [1, "a"]
219 219
220 220
221 221 def test_set_literal():
222 222 context = limited()
223 223 assert guarded_eval("set()", context) == set()
224 224 assert guarded_eval('{"a"}', context) == {"a"}
225 225
226 226
227 227 def test_evaluates_if_expression():
228 228 context = limited()
229 229 assert guarded_eval("2 if True else 3", context) == 2
230 230 assert guarded_eval("4 if False else 5", context) == 5
231 231
232 232
233 233 def test_object():
234 234 obj = object()
235 235 context = limited(obj=obj)
236 236 assert guarded_eval("obj.__dir__", context) == obj.__dir__
237 237
238 238
239 239 @pytest.mark.parametrize(
240 240 "code,expected",
241 241 [
242 242 ["int.numerator", int.numerator],
243 243 ["float.is_integer", float.is_integer],
244 244 ["complex.real", complex.real],
245 245 ],
246 246 )
247 247 def test_number_attributes(code, expected):
248 248 assert guarded_eval(code, limited()) == expected
249 249
250 250
251 251 def test_method_descriptor():
252 252 context = limited()
253 253 assert guarded_eval("list.copy.__name__", context) == "copy"
254 254
255 255
256 256 @pytest.mark.parametrize(
257 257 "data,good,bad,expected",
258 258 [
259 259 [[1, 2, 3], "data.index(2)", "data.append(4)", 1],
260 260 [{"a": 1}, "data.keys().isdisjoint({})", "data.update()", True],
261 261 ],
262 262 )
263 263 def test_evaluates_calls(data, good, bad, expected):
264 264 context = limited(data=data)
265 265 assert guarded_eval(good, context) == expected
266 266
267 267 with pytest.raises(GuardRejection):
268 268 guarded_eval(bad, context)
269 269
270 270
271 271 @pytest.mark.parametrize(
272 272 "code,expected",
273 273 [
274 274 ["(1\n+\n1)", 2],
275 275 ["list(range(10))[-1:]", [9]],
276 276 ["list(range(20))[3:-2:3]", [3, 6, 9, 12, 15]],
277 277 ],
278 278 )
279 279 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
280 280 def test_evaluates_complex_cases(code, expected, context):
281 281 assert guarded_eval(code, context()) == expected
282 282
283 283
284 284 @pytest.mark.parametrize(
285 285 "code,expected",
286 286 [
287 287 ["1", 1],
288 288 ["1.0", 1.0],
289 289 ["0xdeedbeef", 0xDEEDBEEF],
290 290 ["True", True],
291 291 ["None", None],
292 292 ["{}", {}],
293 293 ["[]", []],
294 294 ],
295 295 )
296 296 @pytest.mark.parametrize("context", MINIMAL_OR_HIGHER)
297 297 def test_evaluates_literals(code, expected, context):
298 298 assert guarded_eval(code, context()) == expected
299 299
300 300
301 301 @pytest.mark.parametrize(
302 302 "code,expected",
303 303 [
304 304 ["-5", -5],
305 305 ["+5", +5],
306 306 ["~5", -6],
307 307 ],
308 308 )
309 309 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
310 310 def test_evaluates_unary_operations(code, expected, context):
311 311 assert guarded_eval(code, context()) == expected
312 312
313 313
314 314 @pytest.mark.parametrize(
315 315 "code,expected",
316 316 [
317 317 ["1 + 1", 2],
318 318 ["3 - 1", 2],
319 319 ["2 * 3", 6],
320 320 ["5 // 2", 2],
321 321 ["5 / 2", 2.5],
322 322 ["5**2", 25],
323 323 ["2 >> 1", 1],
324 324 ["2 << 1", 4],
325 325 ["1 | 2", 3],
326 326 ["1 & 1", 1],
327 327 ["1 & 2", 0],
328 328 ],
329 329 )
330 330 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
331 331 def test_evaluates_binary_operations(code, expected, context):
332 332 assert guarded_eval(code, context()) == expected
333 333
334 334
335 335 @pytest.mark.parametrize(
336 336 "code,expected",
337 337 [
338 338 ["2 > 1", True],
339 339 ["2 < 1", False],
340 340 ["2 <= 1", False],
341 341 ["2 <= 2", True],
342 342 ["1 >= 2", False],
343 343 ["2 >= 2", True],
344 344 ["2 == 2", True],
345 345 ["1 == 2", False],
346 346 ["1 != 2", True],
347 347 ["1 != 1", False],
348 348 ["1 < 4 < 3", False],
349 349 ["(1 < 4) < 3", True],
350 350 ["4 > 3 > 2 > 1", True],
351 351 ["4 > 3 > 2 > 9", False],
352 352 ["1 < 2 < 3 < 4", True],
353 353 ["9 < 2 < 3 < 4", False],
354 354 ["1 < 2 > 1 > 0 > -1 < 1", True],
355 355 ["1 in [1] in [[1]]", True],
356 356 ["1 in [1] in [[2]]", False],
357 357 ["1 in [1]", True],
358 358 ["0 in [1]", False],
359 359 ["1 not in [1]", False],
360 360 ["0 not in [1]", True],
361 361 ["True is True", True],
362 362 ["False is False", True],
363 363 ["True is False", False],
364 364 ["True is not True", False],
365 365 ["False is not True", True],
366 366 ],
367 367 )
368 368 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
369 369 def test_evaluates_comparisons(code, expected, context):
370 370 assert guarded_eval(code, context()) == expected
371 371
372 372
373 373 def test_guards_comparisons():
374 374 class GoodEq(int):
375 375 pass
376 376
377 377 class BadEq(int):
378 378 def __eq__(self, other):
379 379 assert False
380 380
381 381 context = limited(bad=BadEq(1), good=GoodEq(1))
382 382
383 383 with pytest.raises(GuardRejection):
384 384 guarded_eval("bad == 1", context)
385 385
386 386 with pytest.raises(GuardRejection):
387 387 guarded_eval("bad != 1", context)
388 388
389 389 with pytest.raises(GuardRejection):
390 390 guarded_eval("1 == bad", context)
391 391
392 392 with pytest.raises(GuardRejection):
393 393 guarded_eval("1 != bad", context)
394 394
395 395 assert guarded_eval("good == 1", context) is True
396 396 assert guarded_eval("good != 1", context) is False
397 397 assert guarded_eval("1 == good", context) is True
398 398 assert guarded_eval("1 != good", context) is False
399 399
400 400
401 401 def test_guards_unary_operations():
402 402 class GoodOp(int):
403 403 pass
404 404
405 405 class BadOpInv(int):
406 406 def __inv__(self, other):
407 407 assert False
408 408
409 409 class BadOpInverse(int):
410 410 def __inv__(self, other):
411 411 assert False
412 412
413 413 context = limited(good=GoodOp(1), bad1=BadOpInv(1), bad2=BadOpInverse(1))
414 414
415 415 with pytest.raises(GuardRejection):
416 416 guarded_eval("~bad1", context)
417 417
418 418 with pytest.raises(GuardRejection):
419 419 guarded_eval("~bad2", context)
420 420
421 421
422 422 def test_guards_binary_operations():
423 423 class GoodOp(int):
424 424 pass
425 425
426 426 class BadOp(int):
427 427 def __add__(self, other):
428 428 assert False
429 429
430 430 context = limited(good=GoodOp(1), bad=BadOp(1))
431 431
432 432 with pytest.raises(GuardRejection):
433 433 guarded_eval("1 + bad", context)
434 434
435 435 with pytest.raises(GuardRejection):
436 436 guarded_eval("bad + 1", context)
437 437
438 438 assert guarded_eval("good + 1", context) == 2
439 439 assert guarded_eval("1 + good", context) == 2
440 440
441 441
442 442 def test_guards_attributes():
443 443 class GoodAttr(float):
444 444 pass
445 445
446 446 class BadAttr1(float):
447 447 def __getattr__(self, key):
448 448 assert False
449 449
450 450 class BadAttr2(float):
451 451 def __getattribute__(self, key):
452 452 assert False
453 453
454 454 context = limited(good=GoodAttr(0.5), bad1=BadAttr1(0.5), bad2=BadAttr2(0.5))
455 455
456 456 with pytest.raises(GuardRejection):
457 457 guarded_eval("bad1.as_integer_ratio", context)
458 458
459 459 with pytest.raises(GuardRejection):
460 460 guarded_eval("bad2.as_integer_ratio", context)
461 461
462 462 assert guarded_eval("good.as_integer_ratio()", context) == (1, 2)
463 463
464 464
465 465 @pytest.mark.parametrize("context", MINIMAL_OR_HIGHER)
466 466 def test_access_builtins(context):
467 467 assert guarded_eval("round", context()) == round
468 468
469 469
470 470 def test_access_builtins_fails():
471 471 context = limited()
472 472 with pytest.raises(NameError):
473 473 guarded_eval("this_is_not_builtin", context)
474 474
475 475
476 476 def test_rejects_forbidden():
477 477 context = forbidden()
478 478 with pytest.raises(GuardRejection):
479 479 guarded_eval("1", context)
480 480
481 481
482 482 def test_guards_locals_and_globals():
483 483 context = EvaluationContext(
484 484 locals={"local_a": "a"}, globals={"global_b": "b"}, evaluation="minimal"
485 485 )
486 486
487 487 with pytest.raises(GuardRejection):
488 488 guarded_eval("local_a", context)
489 489
490 490 with pytest.raises(GuardRejection):
491 491 guarded_eval("global_b", context)
492 492
493 493
494 494 def test_access_locals_and_globals():
495 495 context = EvaluationContext(
496 496 locals={"local_a": "a"}, globals={"global_b": "b"}, evaluation="limited"
497 497 )
498 498 assert guarded_eval("local_a", context) == "a"
499 499 assert guarded_eval("global_b", context) == "b"
500 500
501 501
502 502 @pytest.mark.parametrize(
503 503 "code",
504 504 ["def func(): pass", "class C: pass", "x = 1", "x += 1", "del x", "import ast"],
505 505 )
506 506 @pytest.mark.parametrize("context", [minimal(), limited(), unsafe()])
507 507 def test_rejects_side_effect_syntax(code, context):
508 508 with pytest.raises(SyntaxError):
509 509 guarded_eval(code, context)
510 510
511 511
512 512 def test_subscript():
513 513 context = EvaluationContext(
514 514 locals={}, globals={}, evaluation="limited", in_subscript=True
515 515 )
516 516 empty_slice = slice(None, None, None)
517 517 assert guarded_eval("", context) == tuple()
518 518 assert guarded_eval(":", context) == empty_slice
519 519 assert guarded_eval("1:2:3", context) == slice(1, 2, 3)
520 520 assert guarded_eval(':, "a"', context) == (empty_slice, "a")
521 521
522 522
523 523 def test_unbind_method():
524 524 class X(list):
525 525 def index(self, k):
526 526 return "CUSTOM"
527 527
528 528 x = X()
529 529 assert _unbind_method(x.index) is X.index
530 530 assert _unbind_method([].index) is list.index
531 531 assert _unbind_method(list.index) is None
532 532
533 533
534 534 def test_assumption_instance_attr_do_not_matter():
535 535 """This is semi-specified in Python documentation.
536 536
537 537 However, since the specification says 'not guaranted
538 538 to work' rather than 'is forbidden to work', future
539 539 versions could invalidate this assumptions. This test
540 540 is meant to catch such a change if it ever comes true.
541 541 """
542 542
543 543 class T:
544 544 def __getitem__(self, k):
545 545 return "a"
546 546
547 547 def __getattr__(self, k):
548 548 return "a"
549 549
550 550 def f(self):
551 551 return "b"
552 552
553 553 t = T()
554 554 t.__getitem__ = f
555 555 t.__getattr__ = f
556 556 assert t[1] == "a"
557 557 assert t[1] == "a"
558 558
559 559
560 560 def test_assumption_named_tuples_share_getitem():
561 561 """Check assumption on named tuples sharing __getitem__"""
562 562 from typing import NamedTuple
563 563
564 564 class A(NamedTuple):
565 565 pass
566 566
567 567 class B(NamedTuple):
568 568 pass
569 569
570 570 assert A.__getitem__ == B.__getitem__
571
572
573 @dec.skip_without("numpy")
574 def test_module_access():
575 import numpy
576
577 context = limited(numpy=numpy)
578 assert guarded_eval("numpy.linalg.norm", context) == numpy.linalg.norm
579
580 context = minimal(numpy=numpy)
581 with pytest.raises(GuardRejection):
582 guarded_eval("np.linalg.norm", context)
General Comments 0
You need to be logged in to leave comments. Login now