guarded_eval.py
539 lines
| 16.9 KiB
| text/x-python
|
PythonLexer
krassowski
|
r27915 | from typing import ( | ||
Any, | ||||
Callable, | ||||
Set, | ||||
Tuple, | ||||
NamedTuple, | ||||
Type, | ||||
Literal, | ||||
Union, | ||||
TYPE_CHECKING, | ||||
) | ||||
import builtins | ||||
krassowski
|
r27906 | import collections | ||
import sys | ||||
import ast | ||||
from functools import cached_property | ||||
from dataclasses import dataclass, field | ||||
krassowski
|
r27912 | from IPython.utils.docs import GENERATING_DOCUMENTATION | ||
if TYPE_CHECKING or GENERATING_DOCUMENTATION: | ||||
from typing_extensions import Protocol | ||||
else: | ||||
# do not require on runtime | ||||
Protocol = object # requires Python >=3.8 | ||||
krassowski
|
r27906 | |||
class HasGetItem(Protocol): | ||||
krassowski
|
r27913 | def __getitem__(self, key) -> None: | ||
... | ||||
krassowski
|
r27906 | |||
class InstancesHaveGetItem(Protocol): | ||||
krassowski
|
r27915 | def __call__(self, *args, **kwargs) -> HasGetItem: | ||
krassowski
|
r27913 | ... | ||
krassowski
|
r27906 | |||
class HasGetAttr(Protocol): | ||||
krassowski
|
r27913 | def __getattr__(self, key) -> None: | ||
... | ||||
krassowski
|
r27906 | |||
class DoesNotHaveGetAttr(Protocol): | ||||
pass | ||||
krassowski
|
r27913 | |||
krassowski
|
r27906 | # By default `__getattr__` is not explicitly implemented on most objects | ||
MayHaveGetattr = Union[HasGetAttr, DoesNotHaveGetAttr] | ||||
def unbind_method(func: Callable) -> Union[Callable, None]: | ||||
"""Get unbound method for given bound method. | ||||
Returns None if cannot get unbound method.""" | ||||
krassowski
|
r27913 | owner = getattr(func, "__self__", None) | ||
krassowski
|
r27906 | owner_class = type(owner) | ||
krassowski
|
r27913 | name = getattr(func, "__name__", None) | ||
instance_dict_overrides = getattr(owner, "__dict__", None) | ||||
krassowski
|
r27906 | if ( | ||
owner is not None | ||||
krassowski
|
r27913 | and name | ||
and ( | ||||
krassowski
|
r27906 | not instance_dict_overrides | ||
krassowski
|
r27913 | or (instance_dict_overrides and name not in instance_dict_overrides) | ||
krassowski
|
r27906 | ) | ||
): | ||||
return getattr(owner_class, name) | ||||
krassowski
|
r27915 | return None | ||
krassowski
|
r27906 | |||
@dataclass | ||||
class EvaluationPolicy: | ||||
allow_locals_access: bool = False | ||||
allow_globals_access: bool = False | ||||
allow_item_access: bool = False | ||||
allow_attr_access: bool = False | ||||
allow_builtins_access: bool = False | ||||
allow_any_calls: bool = False | ||||
allowed_calls: Set[Callable] = field(default_factory=set) | ||||
def can_get_item(self, value, item): | ||||
return self.allow_item_access | ||||
def can_get_attr(self, value, attr): | ||||
return self.allow_attr_access | ||||
def can_call(self, func): | ||||
if self.allow_any_calls: | ||||
return True | ||||
if func in self.allowed_calls: | ||||
return True | ||||
owner_method = unbind_method(func) | ||||
if owner_method and owner_method in self.allowed_calls: | ||||
return True | ||||
krassowski
|
r27913 | |||
def has_original_dunder_external( | ||||
value, | ||||
module_name, | ||||
access_path, | ||||
method_name, | ||||
): | ||||
krassowski
|
r27906 | try: | ||
if module_name not in sys.modules: | ||||
return False | ||||
member_type = sys.modules[module_name] | ||||
for attr in access_path: | ||||
member_type = getattr(member_type, attr) | ||||
value_type = type(value) | ||||
if type(value) == member_type: | ||||
return True | ||||
if isinstance(value, member_type): | ||||
method = getattr(value_type, method_name, None) | ||||
member_method = getattr(member_type, method_name, None) | ||||
if member_method == method: | ||||
return True | ||||
except (AttributeError, KeyError): | ||||
return False | ||||
def has_original_dunder( | ||||
krassowski
|
r27913 | value, allowed_types, allowed_methods, allowed_external, method_name | ||
krassowski
|
r27906 | ): | ||
# note: Python ignores `__getattr__`/`__getitem__` on instances, | ||||
# we only need to check at class level | ||||
value_type = type(value) | ||||
# strict type check passes → no need to check method | ||||
if value_type in allowed_types: | ||||
return True | ||||
method = getattr(value_type, method_name, None) | ||||
if not method: | ||||
return None | ||||
if method in allowed_methods: | ||||
return True | ||||
for module_name, *access_path in allowed_external: | ||||
if has_original_dunder_external(value, module_name, access_path, method_name): | ||||
return True | ||||
return False | ||||
@dataclass | ||||
class SelectivePolicy(EvaluationPolicy): | ||||
krassowski
|
r27915 | allowed_getitem: Set[InstancesHaveGetItem] = field(default_factory=set) | ||
krassowski
|
r27906 | allowed_getitem_external: Set[Tuple[str, ...]] = field(default_factory=set) | ||
allowed_getattr: Set[MayHaveGetattr] = field(default_factory=set) | ||||
allowed_getattr_external: Set[Tuple[str, ...]] = field(default_factory=set) | ||||
def can_get_attr(self, value, attr): | ||||
has_original_attribute = has_original_dunder( | ||||
value, | ||||
allowed_types=self.allowed_getattr, | ||||
allowed_methods=self._getattribute_methods, | ||||
allowed_external=self.allowed_getattr_external, | ||||
krassowski
|
r27913 | method_name="__getattribute__", | ||
krassowski
|
r27906 | ) | ||
has_original_attr = has_original_dunder( | ||||
value, | ||||
allowed_types=self.allowed_getattr, | ||||
allowed_methods=self._getattr_methods, | ||||
allowed_external=self.allowed_getattr_external, | ||||
krassowski
|
r27913 | method_name="__getattr__", | ||
krassowski
|
r27906 | ) | ||
# Many objects do not have `__getattr__`, this is fine | ||||
if has_original_attr is None and has_original_attribute: | ||||
return True | ||||
# Accept objects without modifications to `__getattr__` and `__getattribute__` | ||||
return has_original_attr and has_original_attribute | ||||
def get_attr(self, value, attr): | ||||
if self.can_get_attr(value, attr): | ||||
return getattr(value, attr) | ||||
def can_get_item(self, value, item): | ||||
"""Allow accessing `__getiitem__` of allow-listed instances unless it was not modified.""" | ||||
return has_original_dunder( | ||||
value, | ||||
allowed_types=self.allowed_getitem, | ||||
allowed_methods=self._getitem_methods, | ||||
allowed_external=self.allowed_getitem_external, | ||||
krassowski
|
r27913 | method_name="__getitem__", | ||
krassowski
|
r27906 | ) | ||
@cached_property | ||||
def _getitem_methods(self) -> Set[Callable]: | ||||
krassowski
|
r27913 | return self._safe_get_methods(self.allowed_getitem, "__getitem__") | ||
krassowski
|
r27906 | |||
@cached_property | ||||
def _getattr_methods(self) -> Set[Callable]: | ||||
krassowski
|
r27913 | return self._safe_get_methods(self.allowed_getattr, "__getattr__") | ||
krassowski
|
r27906 | |||
@cached_property | ||||
def _getattribute_methods(self) -> Set[Callable]: | ||||
krassowski
|
r27913 | return self._safe_get_methods(self.allowed_getattr, "__getattribute__") | ||
krassowski
|
r27906 | |||
def _safe_get_methods(self, classes, name) -> Set[Callable]: | ||||
return { | ||||
method | ||||
for class_ in classes | ||||
for method in [getattr(class_, name, None)] | ||||
if method | ||||
} | ||||
class DummyNamedTuple(NamedTuple): | ||||
pass | ||||
class EvaluationContext(NamedTuple): | ||||
locals_: dict | ||||
globals_: dict | ||||
krassowski
|
r27913 | evaluation: Literal[ | ||
krassowski
|
r27914 | "forbidden", "minimal", "limited", "unsafe", "dangerous" | ||
krassowski
|
r27913 | ] = "forbidden" | ||
krassowski
|
r27906 | in_subscript: bool = False | ||
class IdentitySubscript: | ||||
def __getitem__(self, key): | ||||
return key | ||||
krassowski
|
r27913 | |||
krassowski
|
r27906 | IDENTITY_SUBSCRIPT = IdentitySubscript() | ||
krassowski
|
r27913 | SUBSCRIPT_MARKER = "__SUBSCRIPT_SENTINEL__" | ||
krassowski
|
r27906 | |||
class GuardRejection(ValueError): | ||||
pass | ||||
krassowski
|
r27913 | def guarded_eval(code: str, context: EvaluationContext): | ||
krassowski
|
r27906 | locals_ = context.locals_ | ||
krassowski
|
r27913 | if context.evaluation == "forbidden": | ||
raise GuardRejection("Forbidden mode") | ||||
krassowski
|
r27906 | |||
# note: not using `ast.literal_eval` as it does not implement | ||||
# getitem at all, for example it fails on simple `[0][1]` | ||||
if context.in_subscript: | ||||
# syntatic sugar for ellipsis (:) is only available in susbcripts | ||||
# so we need to trick the ast parser into thinking that we have | ||||
# a subscript, but we need to be able to later recognise that we did | ||||
# it so we can ignore the actual __getitem__ operation | ||||
if not code: | ||||
return tuple() | ||||
locals_ = locals_.copy() | ||||
locals_[SUBSCRIPT_MARKER] = IDENTITY_SUBSCRIPT | ||||
krassowski
|
r27913 | code = SUBSCRIPT_MARKER + "[" + code + "]" | ||
context = EvaluationContext(**{**context._asdict(), **{"locals_": locals_}}) | ||||
krassowski
|
r27906 | |||
krassowski
|
r27913 | if context.evaluation == "dangerous": | ||
krassowski
|
r27906 | return eval(code, context.globals_, context.locals_) | ||
krassowski
|
r27913 | expression = ast.parse(code, mode="eval") | ||
krassowski
|
r27906 | |||
return eval_node(expression, context) | ||||
krassowski
|
r27913 | |||
krassowski
|
r27906 | def eval_node(node: Union[ast.AST, None], context: EvaluationContext): | ||
""" | ||||
Evaluate AST node in provided context. | ||||
Applies evaluation restrictions defined in the context. | ||||
krassowski
|
r27914 | Currently does not support evaluation of functions with keyword arguments. | ||
krassowski
|
r27906 | |||
Does not evaluate actions which always have side effects: | ||||
krassowski
|
r27912 | - class definitions (``class sth: ...``) | ||
- function definitions (``def sth: ...``) | ||||
- variable assignments (``x = 1``) | ||||
krassowski
|
r27914 | - augmented assignments (``x += 1``) | ||
krassowski
|
r27912 | - deletions (``del x``) | ||
krassowski
|
r27906 | |||
Does not evaluate operations which do not return values: | ||||
krassowski
|
r27912 | - assertions (``assert x``) | ||
- pass (``pass``) | ||||
- imports (``import x``) | ||||
krassowski
|
r27906 | - control flow | ||
krassowski
|
r27914 | - conditionals (``if x:``) except for ternary IfExp (``a if x else b``) | ||
krassowski
|
r27912 | - loops (``for`` and `while``) | ||
krassowski
|
r27906 | - exception handling | ||
krassowski
|
r27912 | |||
The purpose of this function is to guard against unwanted side-effects; | ||||
it does not give guarantees on protection from malicious code execution. | ||||
krassowski
|
r27906 | """ | ||
policy = EVALUATION_POLICIES[context.evaluation] | ||||
if node is None: | ||||
return None | ||||
if isinstance(node, ast.Expression): | ||||
return eval_node(node.body, context) | ||||
if isinstance(node, ast.BinOp): | ||||
# TODO: add guards | ||||
left = eval_node(node.left, context) | ||||
right = eval_node(node.right, context) | ||||
if isinstance(node.op, ast.Add): | ||||
return left + right | ||||
if isinstance(node.op, ast.Sub): | ||||
return left - right | ||||
if isinstance(node.op, ast.Mult): | ||||
return left * right | ||||
if isinstance(node.op, ast.Div): | ||||
return left / right | ||||
if isinstance(node.op, ast.FloorDiv): | ||||
return left // right | ||||
if isinstance(node.op, ast.Mod): | ||||
return left % right | ||||
if isinstance(node.op, ast.Pow): | ||||
krassowski
|
r27913 | return left**right | ||
krassowski
|
r27906 | if isinstance(node.op, ast.LShift): | ||
return left << right | ||||
if isinstance(node.op, ast.RShift): | ||||
return left >> right | ||||
if isinstance(node.op, ast.BitOr): | ||||
return left | right | ||||
if isinstance(node.op, ast.BitXor): | ||||
return left ^ right | ||||
if isinstance(node.op, ast.BitAnd): | ||||
return left & right | ||||
if isinstance(node.op, ast.MatMult): | ||||
return left @ right | ||||
if isinstance(node, ast.Constant): | ||||
return node.value | ||||
if isinstance(node, ast.Index): | ||||
return eval_node(node.value, context) | ||||
if isinstance(node, ast.Tuple): | ||||
krassowski
|
r27913 | return tuple(eval_node(e, context) for e in node.elts) | ||
krassowski
|
r27906 | if isinstance(node, ast.List): | ||
krassowski
|
r27913 | return [eval_node(e, context) for e in node.elts] | ||
krassowski
|
r27906 | if isinstance(node, ast.Set): | ||
krassowski
|
r27913 | return {eval_node(e, context) for e in node.elts} | ||
krassowski
|
r27906 | if isinstance(node, ast.Dict): | ||
krassowski
|
r27913 | return dict( | ||
zip( | ||||
[eval_node(k, context) for k in node.keys], | ||||
[eval_node(v, context) for v in node.values], | ||||
) | ||||
) | ||||
krassowski
|
r27906 | if isinstance(node, ast.Slice): | ||
return slice( | ||||
eval_node(node.lower, context), | ||||
eval_node(node.upper, context), | ||||
krassowski
|
r27913 | eval_node(node.step, context), | ||
krassowski
|
r27906 | ) | ||
if isinstance(node, ast.ExtSlice): | ||||
krassowski
|
r27913 | return tuple([eval_node(dim, context) for dim in node.dims]) | ||
krassowski
|
r27906 | if isinstance(node, ast.UnaryOp): | ||
# TODO: add guards | ||||
value = eval_node(node.operand, context) | ||||
if isinstance(node.op, ast.USub): | ||||
return -value | ||||
if isinstance(node.op, ast.UAdd): | ||||
return +value | ||||
if isinstance(node.op, ast.Invert): | ||||
return ~value | ||||
if isinstance(node.op, ast.Not): | ||||
return not value | ||||
krassowski
|
r27913 | raise ValueError("Unhandled unary operation:", node.op) | ||
krassowski
|
r27906 | if isinstance(node, ast.Subscript): | ||
value = eval_node(node.value, context) | ||||
slice_ = eval_node(node.slice, context) | ||||
if policy.can_get_item(value, slice_): | ||||
return value[slice_] | ||||
raise GuardRejection( | ||||
krassowski
|
r27913 | "Subscript access (`__getitem__`) for", | ||
type(value), # not joined to avoid calling `repr` | ||||
f" not allowed in {context.evaluation} mode", | ||||
krassowski
|
r27906 | ) | ||
if isinstance(node, ast.Name): | ||||
if policy.allow_locals_access and node.id in context.locals_: | ||||
return context.locals_[node.id] | ||||
if policy.allow_globals_access and node.id in context.globals_: | ||||
return context.globals_[node.id] | ||||
krassowski
|
r27915 | if policy.allow_builtins_access and hasattr(builtins, node.id): | ||
# note: do not use __builtins__, it is implementation detail of Python | ||||
return getattr(builtins, node.id) | ||||
krassowski
|
r27906 | if not policy.allow_globals_access and not policy.allow_locals_access: | ||
raise GuardRejection( | ||||
krassowski
|
r27913 | f"Namespace access not allowed in {context.evaluation} mode" | ||
krassowski
|
r27906 | ) | ||
else: | ||||
krassowski
|
r27913 | raise NameError(f"{node.id} not found in locals nor globals") | ||
krassowski
|
r27906 | if isinstance(node, ast.Attribute): | ||
value = eval_node(node.value, context) | ||||
if policy.can_get_attr(value, node.attr): | ||||
return getattr(value, node.attr) | ||||
raise GuardRejection( | ||||
krassowski
|
r27913 | "Attribute access (`__getattr__`) for", | ||
type(value), # not joined to avoid calling `repr` | ||||
f"not allowed in {context.evaluation} mode", | ||||
krassowski
|
r27906 | ) | ||
if isinstance(node, ast.IfExp): | ||||
test = eval_node(node.test, context) | ||||
if test: | ||||
krassowski
|
r27913 | return eval_node(node.body, context) | ||
krassowski
|
r27906 | else: | ||
return eval_node(node.orelse, context) | ||||
if isinstance(node, ast.Call): | ||||
func = eval_node(node.func, context) | ||||
if policy.can_call(func) and not node.keywords: | ||||
krassowski
|
r27913 | args = [eval_node(arg, context) for arg in node.args] | ||
krassowski
|
r27906 | return func(*args) | ||
raise GuardRejection( | ||||
krassowski
|
r27913 | "Call for", | ||
func, # not joined to avoid calling `repr` | ||||
f"not allowed in {context.evaluation} mode", | ||||
krassowski
|
r27906 | ) | ||
krassowski
|
r27913 | raise ValueError("Unhandled node", node) | ||
krassowski
|
r27906 | |||
SUPPORTED_EXTERNAL_GETITEM = { | ||||
krassowski
|
r27913 | ("pandas", "core", "indexing", "_iLocIndexer"), | ||
("pandas", "core", "indexing", "_LocIndexer"), | ||||
("pandas", "DataFrame"), | ||||
("pandas", "Series"), | ||||
("numpy", "ndarray"), | ||||
("numpy", "void"), | ||||
krassowski
|
r27906 | } | ||
krassowski
|
r27915 | BUILTIN_GETITEM: Set[InstancesHaveGetItem] = { | ||
krassowski
|
r27906 | dict, | ||
str, | ||||
bytes, | ||||
list, | ||||
tuple, | ||||
collections.defaultdict, | ||||
collections.deque, | ||||
collections.OrderedDict, | ||||
collections.ChainMap, | ||||
collections.UserDict, | ||||
collections.UserList, | ||||
collections.UserString, | ||||
DummyNamedTuple, | ||||
krassowski
|
r27913 | IdentitySubscript, | ||
krassowski
|
r27906 | } | ||
def _list_methods(cls, source=None): | ||||
"""For use on immutable objects or with methods returning a copy""" | ||||
krassowski
|
r27913 | return [getattr(cls, k) for k in (source if source else dir(cls))] | ||
krassowski
|
r27906 | |||
krassowski
|
r27913 | dict_non_mutating_methods = ("copy", "keys", "values", "items") | ||
list_non_mutating_methods = ("copy", "index", "count") | ||||
krassowski
|
r27906 | set_non_mutating_methods = set(dir(set)) & set(dir(frozenset)) | ||
krassowski
|
r27915 | dict_keys: Type[collections.abc.KeysView] = type({}.keys()) | ||
method_descriptor: Any = type(list.copy) | ||||
krassowski
|
r27906 | |||
ALLOWED_CALLS = { | ||||
bytes, | ||||
*_list_methods(bytes), | ||||
dict, | ||||
*_list_methods(dict, dict_non_mutating_methods), | ||||
dict_keys.isdisjoint, | ||||
list, | ||||
*_list_methods(list, list_non_mutating_methods), | ||||
set, | ||||
*_list_methods(set, set_non_mutating_methods), | ||||
frozenset, | ||||
*_list_methods(frozenset), | ||||
range, | ||||
str, | ||||
*_list_methods(str), | ||||
tuple, | ||||
*_list_methods(tuple), | ||||
collections.deque, | ||||
*_list_methods(collections.deque, list_non_mutating_methods), | ||||
collections.defaultdict, | ||||
*_list_methods(collections.defaultdict, dict_non_mutating_methods), | ||||
collections.OrderedDict, | ||||
*_list_methods(collections.OrderedDict, dict_non_mutating_methods), | ||||
collections.UserDict, | ||||
*_list_methods(collections.UserDict, dict_non_mutating_methods), | ||||
collections.UserList, | ||||
*_list_methods(collections.UserList, list_non_mutating_methods), | ||||
collections.UserString, | ||||
*_list_methods(collections.UserString, dir(str)), | ||||
collections.Counter, | ||||
*_list_methods(collections.Counter, dict_non_mutating_methods), | ||||
collections.Counter.elements, | ||||
krassowski
|
r27913 | collections.Counter.most_common, | ||
krassowski
|
r27906 | } | ||
krassowski
|
r27915 | BUILTIN_GETATTR: Set[MayHaveGetattr] = { | ||
*BUILTIN_GETITEM, | ||||
set, | ||||
frozenset, | ||||
object, | ||||
type, # `type` handles a lot of generic cases, e.g. numbers as in `int.real`. | ||||
dict_keys, | ||||
method_descriptor, | ||||
} | ||||
krassowski
|
r27906 | EVALUATION_POLICIES = { | ||
krassowski
|
r27913 | "minimal": EvaluationPolicy( | ||
krassowski
|
r27906 | allow_builtins_access=True, | ||
allow_locals_access=False, | ||||
allow_globals_access=False, | ||||
allow_item_access=False, | ||||
allow_attr_access=False, | ||||
allowed_calls=set(), | ||||
krassowski
|
r27913 | allow_any_calls=False, | ||
krassowski
|
r27906 | ), | ||
krassowski
|
r27914 | "limited": SelectivePolicy( | ||
krassowski
|
r27906 | # TODO: | ||
# - should reject binary and unary operations if custom methods would be dispatched | ||||
allowed_getitem=BUILTIN_GETITEM, | ||||
allowed_getitem_external=SUPPORTED_EXTERNAL_GETITEM, | ||||
krassowski
|
r27915 | allowed_getattr=BUILTIN_GETATTR, | ||
krassowski
|
r27906 | allowed_getattr_external={ | ||
# pandas Series/Frame implements custom `__getattr__` | ||||
krassowski
|
r27913 | ("pandas", "DataFrame"), | ||
("pandas", "Series"), | ||||
krassowski
|
r27906 | }, | ||
allow_builtins_access=True, | ||||
allow_locals_access=True, | ||||
allow_globals_access=True, | ||||
krassowski
|
r27913 | allowed_calls=ALLOWED_CALLS, | ||
krassowski
|
r27906 | ), | ||
krassowski
|
r27913 | "unsafe": EvaluationPolicy( | ||
krassowski
|
r27906 | allow_builtins_access=True, | ||
allow_locals_access=True, | ||||
allow_globals_access=True, | ||||
allow_attr_access=True, | ||||
allow_item_access=True, | ||||
krassowski
|
r27913 | allow_any_calls=True, | ||
), | ||||
} | ||||