Show More
@@ -1,6 +1,7 b'' | |||||
1 | from typing import ( |
|
1 | from typing import ( | |
2 | Any, |
|
2 | Any, | |
3 | Callable, |
|
3 | Callable, | |
|
4 | Dict, | |||
4 | Set, |
|
5 | Set, | |
5 | Tuple, |
|
6 | Tuple, | |
6 | NamedTuple, |
|
7 | NamedTuple, | |
@@ -9,10 +10,11 b' from typing import (' | |||||
9 | Union, |
|
10 | Union, | |
10 | TYPE_CHECKING, |
|
11 | TYPE_CHECKING, | |
11 | ) |
|
12 | ) | |
|
13 | import ast | |||
12 | import builtins |
|
14 | import builtins | |
13 | import collections |
|
15 | import collections | |
|
16 | import operator | |||
14 | import sys |
|
17 | import sys | |
15 | import ast |
|
|||
16 | from functools import cached_property |
|
18 | from functools import cached_property | |
17 | from dataclasses import dataclass, field |
|
19 | from dataclasses import dataclass, field | |
18 |
|
20 | |||
@@ -84,6 +86,7 b' class EvaluationPolicy:' | |||||
84 | allow_item_access: bool = False |
|
86 | allow_item_access: bool = False | |
85 | allow_attr_access: bool = False |
|
87 | allow_attr_access: bool = False | |
86 | allow_builtins_access: bool = False |
|
88 | allow_builtins_access: bool = False | |
|
89 | allow_all_operations: bool = False | |||
87 | allow_any_calls: bool = False |
|
90 | allow_any_calls: bool = False | |
88 | allowed_calls: Set[Callable] = field(default_factory=set) |
|
91 | allowed_calls: Set[Callable] = field(default_factory=set) | |
89 |
|
92 | |||
@@ -93,6 +96,10 b' class EvaluationPolicy:' | |||||
93 | def can_get_attr(self, value, attr): |
|
96 | def can_get_attr(self, value, attr): | |
94 | return self.allow_attr_access |
|
97 | return self.allow_attr_access | |
95 |
|
98 | |||
|
99 | def can_operate(self, dunders: Tuple[str, ...], a, b=None): | |||
|
100 | if self.allow_all_operations: | |||
|
101 | return True | |||
|
102 | ||||
96 | def can_call(self, func): |
|
103 | def can_call(self, func): | |
97 | if self.allow_any_calls: |
|
104 | if self.allow_any_calls: | |
98 | return True |
|
105 | return True | |
@@ -160,9 +167,17 b' def _has_original_dunder(' | |||||
160 | class SelectivePolicy(EvaluationPolicy): |
|
167 | class SelectivePolicy(EvaluationPolicy): | |
161 | allowed_getitem: Set[InstancesHaveGetItem] = field(default_factory=set) |
|
168 | allowed_getitem: Set[InstancesHaveGetItem] = field(default_factory=set) | |
162 | allowed_getitem_external: Set[Tuple[str, ...]] = field(default_factory=set) |
|
169 | allowed_getitem_external: Set[Tuple[str, ...]] = field(default_factory=set) | |
|
170 | ||||
163 | allowed_getattr: Set[MayHaveGetattr] = field(default_factory=set) |
|
171 | allowed_getattr: Set[MayHaveGetattr] = field(default_factory=set) | |
164 | allowed_getattr_external: Set[Tuple[str, ...]] = field(default_factory=set) |
|
172 | allowed_getattr_external: Set[Tuple[str, ...]] = field(default_factory=set) | |
165 |
|
173 | |||
|
174 | allowed_operations: Set = field(default_factory=set) | |||
|
175 | allowed_operations_external: Set[Tuple[str, ...]] = field(default_factory=set) | |||
|
176 | ||||
|
177 | _operation_methods_cache: Dict[str, Set[Callable]] = field( | |||
|
178 | default_factory=dict, init=False | |||
|
179 | ) | |||
|
180 | ||||
166 | def can_get_attr(self, value, attr): |
|
181 | def can_get_attr(self, value, attr): | |
167 | has_original_attribute = _has_original_dunder( |
|
182 | has_original_attribute = _has_original_dunder( | |
168 | value, |
|
183 | value, | |
@@ -199,6 +214,27 b' class SelectivePolicy(EvaluationPolicy):' | |||||
199 | method_name="__getitem__", |
|
214 | method_name="__getitem__", | |
200 | ) |
|
215 | ) | |
201 |
|
216 | |||
|
217 | def can_operate(self, dunders: Tuple[str, ...], a, b=None): | |||
|
218 | return all( | |||
|
219 | [ | |||
|
220 | _has_original_dunder( | |||
|
221 | a, | |||
|
222 | allowed_types=self.allowed_operations, | |||
|
223 | allowed_methods=self._dunder_methods(dunder), | |||
|
224 | allowed_external=self.allowed_operations_external, | |||
|
225 | method_name=dunder, | |||
|
226 | ) | |||
|
227 | for dunder in dunders | |||
|
228 | ] | |||
|
229 | ) | |||
|
230 | ||||
|
231 | def _dunder_methods(self, dunder: str) -> Set[Callable]: | |||
|
232 | if dunder not in self._operation_methods_cache: | |||
|
233 | self._operation_methods_cache[dunder] = self._safe_get_methods( | |||
|
234 | self.allowed_operations, dunder | |||
|
235 | ) | |||
|
236 | return self._operation_methods_cache[dunder] | |||
|
237 | ||||
202 | @cached_property |
|
238 | @cached_property | |
203 | def _getitem_methods(self) -> Set[Callable]: |
|
239 | def _getitem_methods(self) -> Set[Callable]: | |
204 | return self._safe_get_methods(self.allowed_getitem, "__getitem__") |
|
240 | return self._safe_get_methods(self.allowed_getitem, "__getitem__") | |
@@ -291,6 +327,50 b' def guarded_eval(code: str, context: EvaluationContext):' | |||||
291 | return eval_node(expression, context) |
|
327 | return eval_node(expression, context) | |
292 |
|
328 | |||
293 |
|
329 | |||
|
330 | BINARY_OP_DUNDERS: Dict[Type[ast.operator], Tuple[str]] = { | |||
|
331 | ast.Add: ("__add__",), | |||
|
332 | ast.Sub: ("__sub__",), | |||
|
333 | ast.Mult: ("__mul__",), | |||
|
334 | ast.Div: ("__truediv__",), | |||
|
335 | ast.FloorDiv: ("__floordiv__",), | |||
|
336 | ast.Mod: ("__mod__",), | |||
|
337 | ast.Pow: ("__pow__",), | |||
|
338 | ast.LShift: ("__lshift__",), | |||
|
339 | ast.RShift: ("__rshift__",), | |||
|
340 | ast.BitOr: ("__or__",), | |||
|
341 | ast.BitXor: ("__xor__",), | |||
|
342 | ast.BitAnd: ("__and__",), | |||
|
343 | ast.MatMult: ("__matmul__",), | |||
|
344 | } | |||
|
345 | ||||
|
346 | COMP_OP_DUNDERS: Dict[Type[ast.cmpop], Tuple[str, ...]] = { | |||
|
347 | ast.Eq: ("__eq__",), | |||
|
348 | ast.NotEq: ("__ne__", "__eq__"), | |||
|
349 | ast.Lt: ("__lt__", "__gt__"), | |||
|
350 | ast.LtE: ("__le__", "__ge__"), | |||
|
351 | ast.Gt: ("__gt__", "__lt__"), | |||
|
352 | ast.GtE: ("__ge__", "__le__"), | |||
|
353 | ast.In: ("__contains__",), | |||
|
354 | # Note: ast.Is, ast.IsNot, ast.NotIn are handled specially | |||
|
355 | } | |||
|
356 | ||||
|
357 | UNARY_OP_DUNDERS: Dict[Type[ast.unaryop], Tuple[str, ...]] = { | |||
|
358 | ast.USub: ("__neg__",), | |||
|
359 | ast.UAdd: ("__pos__",), | |||
|
360 | # we have to check both __inv__ and __invert__! | |||
|
361 | ast.Invert: ("__invert__", "__inv__"), | |||
|
362 | ast.Not: ("__not__",), | |||
|
363 | } | |||
|
364 | ||||
|
365 | ||||
|
366 | def _find_dunder(node_op, dunders) -> Union[Tuple[str, ...], None]: | |||
|
367 | dunder = None | |||
|
368 | for op, candidate_dunder in dunders.items(): | |||
|
369 | if isinstance(node_op, op): | |||
|
370 | dunder = candidate_dunder | |||
|
371 | return dunder | |||
|
372 | ||||
|
373 | ||||
294 | def eval_node(node: Union[ast.AST, None], context: EvaluationContext): |
|
374 | def eval_node(node: Union[ast.AST, None], context: EvaluationContext): | |
295 | """Evaluate AST node in provided context. |
|
375 | """Evaluate AST node in provided context. | |
296 |
|
376 | |||
@@ -324,35 +404,55 b' def eval_node(node: Union[ast.AST, None], context: EvaluationContext):' | |||||
324 | if isinstance(node, ast.Expression): |
|
404 | if isinstance(node, ast.Expression): | |
325 | return eval_node(node.body, context) |
|
405 | return eval_node(node.body, context) | |
326 | if isinstance(node, ast.BinOp): |
|
406 | if isinstance(node, ast.BinOp): | |
327 | # TODO: add guards |
|
|||
328 | left = eval_node(node.left, context) |
|
407 | left = eval_node(node.left, context) | |
329 | right = eval_node(node.right, context) |
|
408 | right = eval_node(node.right, context) | |
330 | if isinstance(node.op, ast.Add): |
|
409 | dunders = _find_dunder(node.op, BINARY_OP_DUNDERS) | |
331 | return left + right |
|
410 | if dunders: | |
332 | if isinstance(node.op, ast.Sub): |
|
411 | if policy.can_operate(dunders, left, right): | |
333 |
return left |
|
412 | return getattr(left, dunders[0])(right) | |
334 | if isinstance(node.op, ast.Mult): |
|
413 | else: | |
335 | return left * right |
|
414 | raise GuardRejection( | |
336 | if isinstance(node.op, ast.Div): |
|
415 | f"Operation (`{dunders}`) for", | |
337 | return left / right |
|
416 | type(left), | |
338 | if isinstance(node.op, ast.FloorDiv): |
|
417 | f"not allowed in {context.evaluation} mode", | |
339 |
|
|
418 | ) | |
340 |
|
|
419 | if isinstance(node, ast.Compare): | |
341 | return left % right |
|
420 | left = eval_node(node.left, context) | |
342 | if isinstance(node.op, ast.Pow): |
|
421 | all_true = True | |
343 | return left**right |
|
422 | negate = False | |
344 | if isinstance(node.op, ast.LShift): |
|
423 | for op, right in zip(node.ops, node.comparators): | |
345 | return left << right |
|
424 | right = eval_node(right, context) | |
346 | if isinstance(node.op, ast.RShift): |
|
425 | dunder = None | |
347 | return left >> right |
|
426 | dunders = _find_dunder(op, COMP_OP_DUNDERS) | |
348 | if isinstance(node.op, ast.BitOr): |
|
427 | if not dunders: | |
349 | return left | right |
|
428 | if isinstance(op, ast.NotIn): | |
350 | if isinstance(node.op, ast.BitXor): |
|
429 | dunders = COMP_OP_DUNDERS[ast.In] | |
351 | return left ^ right |
|
430 | negate = True | |
352 |
if isinstance( |
|
431 | if isinstance(op, ast.Is): | |
353 | return left & right |
|
432 | dunder = "is_" | |
354 |
if isinstance( |
|
433 | if isinstance(op, ast.IsNot): | |
355 | return left @ right |
|
434 | dunder = "is_" | |
|
435 | negate = True | |||
|
436 | if not dunder and dunders: | |||
|
437 | dunder = dunders[0] | |||
|
438 | if dunder: | |||
|
439 | a, b = (right, left) if dunder == "__contains__" else (left, right) | |||
|
440 | if dunder == "is_" or dunders and policy.can_operate(dunders, a, b): | |||
|
441 | result = getattr(operator, dunder)(a, b) | |||
|
442 | if negate: | |||
|
443 | result = not result | |||
|
444 | if not result: | |||
|
445 | all_true = False | |||
|
446 | left = right | |||
|
447 | else: | |||
|
448 | raise GuardRejection( | |||
|
449 | f"Comparison (`{dunder}`) for", | |||
|
450 | type(left), | |||
|
451 | f"not allowed in {context.evaluation} mode", | |||
|
452 | ) | |||
|
453 | else: | |||
|
454 | raise ValueError(f"Comparison `{dunder}` not supported") | |||
|
455 | return all_true | |||
356 | if isinstance(node, ast.Constant): |
|
456 | if isinstance(node, ast.Constant): | |
357 | return node.value |
|
457 | return node.value | |
358 | if isinstance(node, ast.Index): |
|
458 | if isinstance(node, ast.Index): | |
@@ -379,16 +479,17 b' def eval_node(node: Union[ast.AST, None], context: EvaluationContext):' | |||||
379 | if isinstance(node, ast.ExtSlice): |
|
479 | if isinstance(node, ast.ExtSlice): | |
380 | return tuple([eval_node(dim, context) for dim in node.dims]) |
|
480 | return tuple([eval_node(dim, context) for dim in node.dims]) | |
381 | if isinstance(node, ast.UnaryOp): |
|
481 | if isinstance(node, ast.UnaryOp): | |
382 | # TODO: add guards |
|
|||
383 | value = eval_node(node.operand, context) |
|
482 | value = eval_node(node.operand, context) | |
384 | if isinstance(node.op, ast.USub): |
|
483 | dunders = _find_dunder(node.op, UNARY_OP_DUNDERS) | |
385 | return -value |
|
484 | if dunders: | |
386 | if isinstance(node.op, ast.UAdd): |
|
485 | if policy.can_operate(dunders, value): | |
387 |
return |
|
486 | return getattr(value, dunders[0])() | |
388 | if isinstance(node.op, ast.Invert): |
|
487 | else: | |
389 | return ~value |
|
488 | raise GuardRejection( | |
390 | if isinstance(node.op, ast.Not): |
|
489 | f"Operation (`{dunders}`) for", | |
391 |
|
|
490 | type(value), | |
|
491 | f"not allowed in {context.evaluation} mode", | |||
|
492 | ) | |||
392 | raise ValueError("Unhandled unary operation:", node.op) |
|
493 | raise ValueError("Unhandled unary operation:", node.op) | |
393 | if isinstance(node, ast.Subscript): |
|
494 | if isinstance(node, ast.Subscript): | |
394 | value = eval_node(node.value, context) |
|
495 | value = eval_node(node.value, context) | |
@@ -527,6 +628,9 b' BUILTIN_GETATTR: Set[MayHaveGetattr] = {' | |||||
527 | method_descriptor, |
|
628 | method_descriptor, | |
528 | } |
|
629 | } | |
529 |
|
630 | |||
|
631 | ||||
|
632 | BUILTIN_OPERATIONS = {int, float, complex, *BUILTIN_GETATTR} | |||
|
633 | ||||
530 | EVALUATION_POLICIES = { |
|
634 | EVALUATION_POLICIES = { | |
531 | "minimal": EvaluationPolicy( |
|
635 | "minimal": EvaluationPolicy( | |
532 | allow_builtins_access=True, |
|
636 | allow_builtins_access=True, | |
@@ -536,6 +640,7 b' EVALUATION_POLICIES = {' | |||||
536 | allow_attr_access=False, |
|
640 | allow_attr_access=False, | |
537 | allowed_calls=set(), |
|
641 | allowed_calls=set(), | |
538 | allow_any_calls=False, |
|
642 | allow_any_calls=False, | |
|
643 | allow_all_operations=False, | |||
539 | ), |
|
644 | ), | |
540 | "limited": SelectivePolicy( |
|
645 | "limited": SelectivePolicy( | |
541 | # TODO: |
|
646 | # TODO: | |
@@ -548,6 +653,7 b' EVALUATION_POLICIES = {' | |||||
548 | ("pandas", "DataFrame"), |
|
653 | ("pandas", "DataFrame"), | |
549 | ("pandas", "Series"), |
|
654 | ("pandas", "Series"), | |
550 | }, |
|
655 | }, | |
|
656 | allowed_operations=BUILTIN_OPERATIONS, | |||
551 | allow_builtins_access=True, |
|
657 | allow_builtins_access=True, | |
552 | allow_locals_access=True, |
|
658 | allow_locals_access=True, | |
553 | allow_globals_access=True, |
|
659 | allow_globals_access=True, | |
@@ -560,6 +666,7 b' EVALUATION_POLICIES = {' | |||||
560 | allow_attr_access=True, |
|
666 | allow_attr_access=True, | |
561 | allow_item_access=True, |
|
667 | allow_item_access=True, | |
562 | allow_any_calls=True, |
|
668 | allow_any_calls=True, | |
|
669 | allow_all_operations=True, | |||
563 | ), |
|
670 | ), | |
564 | } |
|
671 | } | |
565 |
|
672 |
@@ -199,6 +199,76 b' def test_literals(code, expected):' | |||||
199 | assert guarded_eval(code, context) == expected |
|
199 | assert guarded_eval(code, context) == expected | |
200 |
|
200 | |||
201 |
|
201 | |||
|
202 | @pytest.mark.parametrize( | |||
|
203 | "code,expected", | |||
|
204 | [ | |||
|
205 | ["-5", -5], | |||
|
206 | ["+5", +5], | |||
|
207 | ["~5", -6], | |||
|
208 | ], | |||
|
209 | ) | |||
|
210 | def test_unary_operations(code, expected): | |||
|
211 | context = limited() | |||
|
212 | assert guarded_eval(code, context) == expected | |||
|
213 | ||||
|
214 | ||||
|
215 | @pytest.mark.parametrize( | |||
|
216 | "code,expected", | |||
|
217 | [ | |||
|
218 | ["1 + 1", 2], | |||
|
219 | ["3 - 1", 2], | |||
|
220 | ["2 * 3", 6], | |||
|
221 | ["5 // 2", 2], | |||
|
222 | ["5 / 2", 2.5], | |||
|
223 | ["5**2", 25], | |||
|
224 | ["2 >> 1", 1], | |||
|
225 | ["2 << 1", 4], | |||
|
226 | ["1 | 2", 3], | |||
|
227 | ["1 & 1", 1], | |||
|
228 | ["1 & 2", 0], | |||
|
229 | ], | |||
|
230 | ) | |||
|
231 | def test_binary_operations(code, expected): | |||
|
232 | context = limited() | |||
|
233 | assert guarded_eval(code, context) == expected | |||
|
234 | ||||
|
235 | ||||
|
236 | @pytest.mark.parametrize( | |||
|
237 | "code,expected", | |||
|
238 | [ | |||
|
239 | ["2 > 1", True], | |||
|
240 | ["2 < 1", False], | |||
|
241 | ["2 <= 1", False], | |||
|
242 | ["2 <= 2", True], | |||
|
243 | ["1 >= 2", False], | |||
|
244 | ["2 >= 2", True], | |||
|
245 | ["2 == 2", True], | |||
|
246 | ["1 == 2", False], | |||
|
247 | ["1 != 2", True], | |||
|
248 | ["1 != 1", False], | |||
|
249 | ["1 < 4 < 3", False], | |||
|
250 | ["(1 < 4) < 3", True], | |||
|
251 | ["4 > 3 > 2 > 1", True], | |||
|
252 | ["4 > 3 > 2 > 9", False], | |||
|
253 | ["1 < 2 < 3 < 4", True], | |||
|
254 | ["9 < 2 < 3 < 4", False], | |||
|
255 | ["1 < 2 > 1 > 0 > -1 < 1", True], | |||
|
256 | ["1 in [1] in [[1]]", True], | |||
|
257 | ["1 in [1] in [[2]]", False], | |||
|
258 | ["1 in [1]", True], | |||
|
259 | ["0 in [1]", False], | |||
|
260 | ["1 not in [1]", False], | |||
|
261 | ["0 not in [1]", True], | |||
|
262 | ["True is True", True], | |||
|
263 | ["False is False", True], | |||
|
264 | ["True is False", False], | |||
|
265 | ], | |||
|
266 | ) | |||
|
267 | def test_comparisons(code, expected): | |||
|
268 | context = limited() | |||
|
269 | assert guarded_eval(code, context) == expected | |||
|
270 | ||||
|
271 | ||||
202 | def test_access_builtins(): |
|
272 | def test_access_builtins(): | |
203 | context = limited() |
|
273 | context = limited() | |
204 | assert guarded_eval("round", context) == round |
|
274 | assert guarded_eval("round", context) == round |
General Comments 0
You need to be logged in to leave comments.
Login now