##// END OF EJS Templates
Add guards for binary, unary operators and comparators
krassowski -
Show More
@@ -1,6 +1,7 b''
1 from typing import (
1 from typing import (
2 Any,
2 Any,
3 Callable,
3 Callable,
4 Dict,
4 Set,
5 Set,
5 Tuple,
6 Tuple,
6 NamedTuple,
7 NamedTuple,
@@ -9,10 +10,11 b' from typing import ('
9 Union,
10 Union,
10 TYPE_CHECKING,
11 TYPE_CHECKING,
11 )
12 )
13 import ast
12 import builtins
14 import builtins
13 import collections
15 import collections
16 import operator
14 import sys
17 import sys
15 import ast
16 from functools import cached_property
18 from functools import cached_property
17 from dataclasses import dataclass, field
19 from dataclasses import dataclass, field
18
20
@@ -84,6 +86,7 b' class EvaluationPolicy:'
84 allow_item_access: bool = False
86 allow_item_access: bool = False
85 allow_attr_access: bool = False
87 allow_attr_access: bool = False
86 allow_builtins_access: bool = False
88 allow_builtins_access: bool = False
89 allow_all_operations: bool = False
87 allow_any_calls: bool = False
90 allow_any_calls: bool = False
88 allowed_calls: Set[Callable] = field(default_factory=set)
91 allowed_calls: Set[Callable] = field(default_factory=set)
89
92
@@ -93,6 +96,10 b' class EvaluationPolicy:'
93 def can_get_attr(self, value, attr):
96 def can_get_attr(self, value, attr):
94 return self.allow_attr_access
97 return self.allow_attr_access
95
98
99 def can_operate(self, dunders: Tuple[str, ...], a, b=None):
100 if self.allow_all_operations:
101 return True
102
96 def can_call(self, func):
103 def can_call(self, func):
97 if self.allow_any_calls:
104 if self.allow_any_calls:
98 return True
105 return True
@@ -160,9 +167,17 b' def _has_original_dunder('
160 class SelectivePolicy(EvaluationPolicy):
167 class SelectivePolicy(EvaluationPolicy):
161 allowed_getitem: Set[InstancesHaveGetItem] = field(default_factory=set)
168 allowed_getitem: Set[InstancesHaveGetItem] = field(default_factory=set)
162 allowed_getitem_external: Set[Tuple[str, ...]] = field(default_factory=set)
169 allowed_getitem_external: Set[Tuple[str, ...]] = field(default_factory=set)
170
163 allowed_getattr: Set[MayHaveGetattr] = field(default_factory=set)
171 allowed_getattr: Set[MayHaveGetattr] = field(default_factory=set)
164 allowed_getattr_external: Set[Tuple[str, ...]] = field(default_factory=set)
172 allowed_getattr_external: Set[Tuple[str, ...]] = field(default_factory=set)
165
173
174 allowed_operations: Set = field(default_factory=set)
175 allowed_operations_external: Set[Tuple[str, ...]] = field(default_factory=set)
176
177 _operation_methods_cache: Dict[str, Set[Callable]] = field(
178 default_factory=dict, init=False
179 )
180
166 def can_get_attr(self, value, attr):
181 def can_get_attr(self, value, attr):
167 has_original_attribute = _has_original_dunder(
182 has_original_attribute = _has_original_dunder(
168 value,
183 value,
@@ -199,6 +214,27 b' class SelectivePolicy(EvaluationPolicy):'
199 method_name="__getitem__",
214 method_name="__getitem__",
200 )
215 )
201
216
217 def can_operate(self, dunders: Tuple[str, ...], a, b=None):
218 return all(
219 [
220 _has_original_dunder(
221 a,
222 allowed_types=self.allowed_operations,
223 allowed_methods=self._dunder_methods(dunder),
224 allowed_external=self.allowed_operations_external,
225 method_name=dunder,
226 )
227 for dunder in dunders
228 ]
229 )
230
231 def _dunder_methods(self, dunder: str) -> Set[Callable]:
232 if dunder not in self._operation_methods_cache:
233 self._operation_methods_cache[dunder] = self._safe_get_methods(
234 self.allowed_operations, dunder
235 )
236 return self._operation_methods_cache[dunder]
237
202 @cached_property
238 @cached_property
203 def _getitem_methods(self) -> Set[Callable]:
239 def _getitem_methods(self) -> Set[Callable]:
204 return self._safe_get_methods(self.allowed_getitem, "__getitem__")
240 return self._safe_get_methods(self.allowed_getitem, "__getitem__")
@@ -291,6 +327,50 b' def guarded_eval(code: str, context: EvaluationContext):'
291 return eval_node(expression, context)
327 return eval_node(expression, context)
292
328
293
329
330 BINARY_OP_DUNDERS: Dict[Type[ast.operator], Tuple[str]] = {
331 ast.Add: ("__add__",),
332 ast.Sub: ("__sub__",),
333 ast.Mult: ("__mul__",),
334 ast.Div: ("__truediv__",),
335 ast.FloorDiv: ("__floordiv__",),
336 ast.Mod: ("__mod__",),
337 ast.Pow: ("__pow__",),
338 ast.LShift: ("__lshift__",),
339 ast.RShift: ("__rshift__",),
340 ast.BitOr: ("__or__",),
341 ast.BitXor: ("__xor__",),
342 ast.BitAnd: ("__and__",),
343 ast.MatMult: ("__matmul__",),
344 }
345
346 COMP_OP_DUNDERS: Dict[Type[ast.cmpop], Tuple[str, ...]] = {
347 ast.Eq: ("__eq__",),
348 ast.NotEq: ("__ne__", "__eq__"),
349 ast.Lt: ("__lt__", "__gt__"),
350 ast.LtE: ("__le__", "__ge__"),
351 ast.Gt: ("__gt__", "__lt__"),
352 ast.GtE: ("__ge__", "__le__"),
353 ast.In: ("__contains__",),
354 # Note: ast.Is, ast.IsNot, ast.NotIn are handled specially
355 }
356
357 UNARY_OP_DUNDERS: Dict[Type[ast.unaryop], Tuple[str, ...]] = {
358 ast.USub: ("__neg__",),
359 ast.UAdd: ("__pos__",),
360 # we have to check both __inv__ and __invert__!
361 ast.Invert: ("__invert__", "__inv__"),
362 ast.Not: ("__not__",),
363 }
364
365
366 def _find_dunder(node_op, dunders) -> Union[Tuple[str, ...], None]:
367 dunder = None
368 for op, candidate_dunder in dunders.items():
369 if isinstance(node_op, op):
370 dunder = candidate_dunder
371 return dunder
372
373
294 def eval_node(node: Union[ast.AST, None], context: EvaluationContext):
374 def eval_node(node: Union[ast.AST, None], context: EvaluationContext):
295 """Evaluate AST node in provided context.
375 """Evaluate AST node in provided context.
296
376
@@ -324,35 +404,55 b' def eval_node(node: Union[ast.AST, None], context: EvaluationContext):'
324 if isinstance(node, ast.Expression):
404 if isinstance(node, ast.Expression):
325 return eval_node(node.body, context)
405 return eval_node(node.body, context)
326 if isinstance(node, ast.BinOp):
406 if isinstance(node, ast.BinOp):
327 # TODO: add guards
328 left = eval_node(node.left, context)
407 left = eval_node(node.left, context)
329 right = eval_node(node.right, context)
408 right = eval_node(node.right, context)
330 if isinstance(node.op, ast.Add):
409 dunders = _find_dunder(node.op, BINARY_OP_DUNDERS)
331 return left + right
410 if dunders:
332 if isinstance(node.op, ast.Sub):
411 if policy.can_operate(dunders, left, right):
333 return left - right
412 return getattr(left, dunders[0])(right)
334 if isinstance(node.op, ast.Mult):
413 else:
335 return left * right
414 raise GuardRejection(
336 if isinstance(node.op, ast.Div):
415 f"Operation (`{dunders}`) for",
337 return left / right
416 type(left),
338 if isinstance(node.op, ast.FloorDiv):
417 f"not allowed in {context.evaluation} mode",
339 return left // right
418 )
340 if isinstance(node.op, ast.Mod):
419 if isinstance(node, ast.Compare):
341 return left % right
420 left = eval_node(node.left, context)
342 if isinstance(node.op, ast.Pow):
421 all_true = True
343 return left**right
422 negate = False
344 if isinstance(node.op, ast.LShift):
423 for op, right in zip(node.ops, node.comparators):
345 return left << right
424 right = eval_node(right, context)
346 if isinstance(node.op, ast.RShift):
425 dunder = None
347 return left >> right
426 dunders = _find_dunder(op, COMP_OP_DUNDERS)
348 if isinstance(node.op, ast.BitOr):
427 if not dunders:
349 return left | right
428 if isinstance(op, ast.NotIn):
350 if isinstance(node.op, ast.BitXor):
429 dunders = COMP_OP_DUNDERS[ast.In]
351 return left ^ right
430 negate = True
352 if isinstance(node.op, ast.BitAnd):
431 if isinstance(op, ast.Is):
353 return left & right
432 dunder = "is_"
354 if isinstance(node.op, ast.MatMult):
433 if isinstance(op, ast.IsNot):
355 return left @ right
434 dunder = "is_"
435 negate = True
436 if not dunder and dunders:
437 dunder = dunders[0]
438 if dunder:
439 a, b = (right, left) if dunder == "__contains__" else (left, right)
440 if dunder == "is_" or dunders and policy.can_operate(dunders, a, b):
441 result = getattr(operator, dunder)(a, b)
442 if negate:
443 result = not result
444 if not result:
445 all_true = False
446 left = right
447 else:
448 raise GuardRejection(
449 f"Comparison (`{dunder}`) for",
450 type(left),
451 f"not allowed in {context.evaluation} mode",
452 )
453 else:
454 raise ValueError(f"Comparison `{dunder}` not supported")
455 return all_true
356 if isinstance(node, ast.Constant):
456 if isinstance(node, ast.Constant):
357 return node.value
457 return node.value
358 if isinstance(node, ast.Index):
458 if isinstance(node, ast.Index):
@@ -379,16 +479,17 b' def eval_node(node: Union[ast.AST, None], context: EvaluationContext):'
379 if isinstance(node, ast.ExtSlice):
479 if isinstance(node, ast.ExtSlice):
380 return tuple([eval_node(dim, context) for dim in node.dims])
480 return tuple([eval_node(dim, context) for dim in node.dims])
381 if isinstance(node, ast.UnaryOp):
481 if isinstance(node, ast.UnaryOp):
382 # TODO: add guards
383 value = eval_node(node.operand, context)
482 value = eval_node(node.operand, context)
384 if isinstance(node.op, ast.USub):
483 dunders = _find_dunder(node.op, UNARY_OP_DUNDERS)
385 return -value
484 if dunders:
386 if isinstance(node.op, ast.UAdd):
485 if policy.can_operate(dunders, value):
387 return +value
486 return getattr(value, dunders[0])()
388 if isinstance(node.op, ast.Invert):
487 else:
389 return ~value
488 raise GuardRejection(
390 if isinstance(node.op, ast.Not):
489 f"Operation (`{dunders}`) for",
391 return not value
490 type(value),
491 f"not allowed in {context.evaluation} mode",
492 )
392 raise ValueError("Unhandled unary operation:", node.op)
493 raise ValueError("Unhandled unary operation:", node.op)
393 if isinstance(node, ast.Subscript):
494 if isinstance(node, ast.Subscript):
394 value = eval_node(node.value, context)
495 value = eval_node(node.value, context)
@@ -527,6 +628,9 b' BUILTIN_GETATTR: Set[MayHaveGetattr] = {'
527 method_descriptor,
628 method_descriptor,
528 }
629 }
529
630
631
632 BUILTIN_OPERATIONS = {int, float, complex, *BUILTIN_GETATTR}
633
530 EVALUATION_POLICIES = {
634 EVALUATION_POLICIES = {
531 "minimal": EvaluationPolicy(
635 "minimal": EvaluationPolicy(
532 allow_builtins_access=True,
636 allow_builtins_access=True,
@@ -536,6 +640,7 b' EVALUATION_POLICIES = {'
536 allow_attr_access=False,
640 allow_attr_access=False,
537 allowed_calls=set(),
641 allowed_calls=set(),
538 allow_any_calls=False,
642 allow_any_calls=False,
643 allow_all_operations=False,
539 ),
644 ),
540 "limited": SelectivePolicy(
645 "limited": SelectivePolicy(
541 # TODO:
646 # TODO:
@@ -548,6 +653,7 b' EVALUATION_POLICIES = {'
548 ("pandas", "DataFrame"),
653 ("pandas", "DataFrame"),
549 ("pandas", "Series"),
654 ("pandas", "Series"),
550 },
655 },
656 allowed_operations=BUILTIN_OPERATIONS,
551 allow_builtins_access=True,
657 allow_builtins_access=True,
552 allow_locals_access=True,
658 allow_locals_access=True,
553 allow_globals_access=True,
659 allow_globals_access=True,
@@ -560,6 +666,7 b' EVALUATION_POLICIES = {'
560 allow_attr_access=True,
666 allow_attr_access=True,
561 allow_item_access=True,
667 allow_item_access=True,
562 allow_any_calls=True,
668 allow_any_calls=True,
669 allow_all_operations=True,
563 ),
670 ),
564 }
671 }
565
672
@@ -199,6 +199,76 b' def test_literals(code, expected):'
199 assert guarded_eval(code, context) == expected
199 assert guarded_eval(code, context) == expected
200
200
201
201
202 @pytest.mark.parametrize(
203 "code,expected",
204 [
205 ["-5", -5],
206 ["+5", +5],
207 ["~5", -6],
208 ],
209 )
210 def test_unary_operations(code, expected):
211 context = limited()
212 assert guarded_eval(code, context) == expected
213
214
215 @pytest.mark.parametrize(
216 "code,expected",
217 [
218 ["1 + 1", 2],
219 ["3 - 1", 2],
220 ["2 * 3", 6],
221 ["5 // 2", 2],
222 ["5 / 2", 2.5],
223 ["5**2", 25],
224 ["2 >> 1", 1],
225 ["2 << 1", 4],
226 ["1 | 2", 3],
227 ["1 & 1", 1],
228 ["1 & 2", 0],
229 ],
230 )
231 def test_binary_operations(code, expected):
232 context = limited()
233 assert guarded_eval(code, context) == expected
234
235
236 @pytest.mark.parametrize(
237 "code,expected",
238 [
239 ["2 > 1", True],
240 ["2 < 1", False],
241 ["2 <= 1", False],
242 ["2 <= 2", True],
243 ["1 >= 2", False],
244 ["2 >= 2", True],
245 ["2 == 2", True],
246 ["1 == 2", False],
247 ["1 != 2", True],
248 ["1 != 1", False],
249 ["1 < 4 < 3", False],
250 ["(1 < 4) < 3", True],
251 ["4 > 3 > 2 > 1", True],
252 ["4 > 3 > 2 > 9", False],
253 ["1 < 2 < 3 < 4", True],
254 ["9 < 2 < 3 < 4", False],
255 ["1 < 2 > 1 > 0 > -1 < 1", True],
256 ["1 in [1] in [[1]]", True],
257 ["1 in [1] in [[2]]", False],
258 ["1 in [1]", True],
259 ["0 in [1]", False],
260 ["1 not in [1]", False],
261 ["0 not in [1]", True],
262 ["True is True", True],
263 ["False is False", True],
264 ["True is False", False],
265 ],
266 )
267 def test_comparisons(code, expected):
268 context = limited()
269 assert guarded_eval(code, context) == expected
270
271
202 def test_access_builtins():
272 def test_access_builtins():
203 context = limited()
273 context = limited()
204 assert guarded_eval("round", context) == round
274 assert guarded_eval("round", context) == round
General Comments 0
You need to be logged in to leave comments. Login now