##// END OF EJS Templates
Use module and method descriptor types declared in the types module
Carlos Cordoba -
Show More
@@ -1,734 +1,733 b''
1 from typing import (
1 from typing import (
2 Any,
2 Any,
3 Callable,
3 Callable,
4 Dict,
4 Dict,
5 Set,
5 Set,
6 Sequence,
6 Sequence,
7 Tuple,
7 Tuple,
8 NamedTuple,
8 NamedTuple,
9 Type,
9 Type,
10 Literal,
10 Literal,
11 Union,
11 Union,
12 TYPE_CHECKING,
12 TYPE_CHECKING,
13 )
13 )
14 import ast
14 import ast
15 import builtins
15 import builtins
16 import collections
16 import collections
17 import operator
17 import operator
18 import sys
18 import sys
19 from functools import cached_property
19 from functools import cached_property
20 from dataclasses import dataclass, field
20 from dataclasses import dataclass, field
21 from types import MethodDescriptorType, ModuleType
21
22
22 from IPython.utils.docs import GENERATING_DOCUMENTATION
23 from IPython.utils.docs import GENERATING_DOCUMENTATION
23 from IPython.utils.decorators import undoc
24 from IPython.utils.decorators import undoc
24
25
25
26
26 if TYPE_CHECKING or GENERATING_DOCUMENTATION:
27 if TYPE_CHECKING or GENERATING_DOCUMENTATION:
27 from typing_extensions import Protocol
28 from typing_extensions import Protocol
28 else:
29 else:
29 # do not require on runtime
30 # do not require on runtime
30 Protocol = object # requires Python >=3.8
31 Protocol = object # requires Python >=3.8
31
32
32
33
33 @undoc
34 @undoc
34 class HasGetItem(Protocol):
35 class HasGetItem(Protocol):
35 def __getitem__(self, key) -> None:
36 def __getitem__(self, key) -> None:
36 ...
37 ...
37
38
38
39
39 @undoc
40 @undoc
40 class InstancesHaveGetItem(Protocol):
41 class InstancesHaveGetItem(Protocol):
41 def __call__(self, *args, **kwargs) -> HasGetItem:
42 def __call__(self, *args, **kwargs) -> HasGetItem:
42 ...
43 ...
43
44
44
45
45 @undoc
46 @undoc
46 class HasGetAttr(Protocol):
47 class HasGetAttr(Protocol):
47 def __getattr__(self, key) -> None:
48 def __getattr__(self, key) -> None:
48 ...
49 ...
49
50
50
51
51 @undoc
52 @undoc
52 class DoesNotHaveGetAttr(Protocol):
53 class DoesNotHaveGetAttr(Protocol):
53 pass
54 pass
54
55
55
56
56 # By default `__getattr__` is not explicitly implemented on most objects
57 # By default `__getattr__` is not explicitly implemented on most objects
57 MayHaveGetattr = Union[HasGetAttr, DoesNotHaveGetAttr]
58 MayHaveGetattr = Union[HasGetAttr, DoesNotHaveGetAttr]
58
59
59
60
60 def _unbind_method(func: Callable) -> Union[Callable, None]:
61 def _unbind_method(func: Callable) -> Union[Callable, None]:
61 """Get unbound method for given bound method.
62 """Get unbound method for given bound method.
62
63
63 Returns None if cannot get unbound method, or method is already unbound.
64 Returns None if cannot get unbound method, or method is already unbound.
64 """
65 """
65 owner = getattr(func, "__self__", None)
66 owner = getattr(func, "__self__", None)
66 owner_class = type(owner)
67 owner_class = type(owner)
67 name = getattr(func, "__name__", None)
68 name = getattr(func, "__name__", None)
68 instance_dict_overrides = getattr(owner, "__dict__", None)
69 instance_dict_overrides = getattr(owner, "__dict__", None)
69 if (
70 if (
70 owner is not None
71 owner is not None
71 and name
72 and name
72 and (
73 and (
73 not instance_dict_overrides
74 not instance_dict_overrides
74 or (instance_dict_overrides and name not in instance_dict_overrides)
75 or (instance_dict_overrides and name not in instance_dict_overrides)
75 )
76 )
76 ):
77 ):
77 return getattr(owner_class, name)
78 return getattr(owner_class, name)
78 return None
79 return None
79
80
80
81
81 @undoc
82 @undoc
82 @dataclass
83 @dataclass
83 class EvaluationPolicy:
84 class EvaluationPolicy:
84 """Definition of evaluation policy."""
85 """Definition of evaluation policy."""
85
86
86 allow_locals_access: bool = False
87 allow_locals_access: bool = False
87 allow_globals_access: bool = False
88 allow_globals_access: bool = False
88 allow_item_access: bool = False
89 allow_item_access: bool = False
89 allow_attr_access: bool = False
90 allow_attr_access: bool = False
90 allow_builtins_access: bool = False
91 allow_builtins_access: bool = False
91 allow_all_operations: bool = False
92 allow_all_operations: bool = False
92 allow_any_calls: bool = False
93 allow_any_calls: bool = False
93 allowed_calls: Set[Callable] = field(default_factory=set)
94 allowed_calls: Set[Callable] = field(default_factory=set)
94
95
95 def can_get_item(self, value, item):
96 def can_get_item(self, value, item):
96 return self.allow_item_access
97 return self.allow_item_access
97
98
98 def can_get_attr(self, value, attr):
99 def can_get_attr(self, value, attr):
99 return self.allow_attr_access
100 return self.allow_attr_access
100
101
101 def can_operate(self, dunders: Tuple[str, ...], a, b=None):
102 def can_operate(self, dunders: Tuple[str, ...], a, b=None):
102 if self.allow_all_operations:
103 if self.allow_all_operations:
103 return True
104 return True
104
105
105 def can_call(self, func):
106 def can_call(self, func):
106 if self.allow_any_calls:
107 if self.allow_any_calls:
107 return True
108 return True
108
109
109 if func in self.allowed_calls:
110 if func in self.allowed_calls:
110 return True
111 return True
111
112
112 owner_method = _unbind_method(func)
113 owner_method = _unbind_method(func)
113
114
114 if owner_method and owner_method in self.allowed_calls:
115 if owner_method and owner_method in self.allowed_calls:
115 return True
116 return True
116
117
117
118
118 def _get_external(module_name: str, access_path: Sequence[str]):
119 def _get_external(module_name: str, access_path: Sequence[str]):
119 """Get value from external module given a dotted access path.
120 """Get value from external module given a dotted access path.
120
121
121 Raises:
122 Raises:
122 * `KeyError` if module is removed not found, and
123 * `KeyError` if module is removed not found, and
123 * `AttributeError` if acess path does not match an exported object
124 * `AttributeError` if acess path does not match an exported object
124 """
125 """
125 member_type = sys.modules[module_name]
126 member_type = sys.modules[module_name]
126 for attr in access_path:
127 for attr in access_path:
127 member_type = getattr(member_type, attr)
128 member_type = getattr(member_type, attr)
128 return member_type
129 return member_type
129
130
130
131
131 def _has_original_dunder_external(
132 def _has_original_dunder_external(
132 value,
133 value,
133 module_name: str,
134 module_name: str,
134 access_path: Sequence[str],
135 access_path: Sequence[str],
135 method_name: str,
136 method_name: str,
136 ):
137 ):
137 if module_name not in sys.modules:
138 if module_name not in sys.modules:
138 # LBYLB as it is faster
139 # LBYLB as it is faster
139 return False
140 return False
140 try:
141 try:
141 member_type = _get_external(module_name, access_path)
142 member_type = _get_external(module_name, access_path)
142 value_type = type(value)
143 value_type = type(value)
143 if type(value) == member_type:
144 if type(value) == member_type:
144 return True
145 return True
145 if method_name == "__getattribute__":
146 if method_name == "__getattribute__":
146 # we have to short-circuit here due to an unresolved issue in
147 # we have to short-circuit here due to an unresolved issue in
147 # `isinstance` implementation: https://bugs.python.org/issue32683
148 # `isinstance` implementation: https://bugs.python.org/issue32683
148 return False
149 return False
149 if isinstance(value, member_type):
150 if isinstance(value, member_type):
150 method = getattr(value_type, method_name, None)
151 method = getattr(value_type, method_name, None)
151 member_method = getattr(member_type, method_name, None)
152 member_method = getattr(member_type, method_name, None)
152 if member_method == method:
153 if member_method == method:
153 return True
154 return True
154 except (AttributeError, KeyError):
155 except (AttributeError, KeyError):
155 return False
156 return False
156
157
157
158
158 def _has_original_dunder(
159 def _has_original_dunder(
159 value, allowed_types, allowed_methods, allowed_external, method_name
160 value, allowed_types, allowed_methods, allowed_external, method_name
160 ):
161 ):
161 # note: Python ignores `__getattr__`/`__getitem__` on instances,
162 # note: Python ignores `__getattr__`/`__getitem__` on instances,
162 # we only need to check at class level
163 # we only need to check at class level
163 value_type = type(value)
164 value_type = type(value)
164
165
165 # strict type check passes β†’ no need to check method
166 # strict type check passes β†’ no need to check method
166 if value_type in allowed_types:
167 if value_type in allowed_types:
167 return True
168 return True
168
169
169 method = getattr(value_type, method_name, None)
170 method = getattr(value_type, method_name, None)
170
171
171 if method is None:
172 if method is None:
172 return None
173 return None
173
174
174 if method in allowed_methods:
175 if method in allowed_methods:
175 return True
176 return True
176
177
177 for module_name, *access_path in allowed_external:
178 for module_name, *access_path in allowed_external:
178 if _has_original_dunder_external(value, module_name, access_path, method_name):
179 if _has_original_dunder_external(value, module_name, access_path, method_name):
179 return True
180 return True
180
181
181 return False
182 return False
182
183
183
184
184 @undoc
185 @undoc
185 @dataclass
186 @dataclass
186 class SelectivePolicy(EvaluationPolicy):
187 class SelectivePolicy(EvaluationPolicy):
187 allowed_getitem: Set[InstancesHaveGetItem] = field(default_factory=set)
188 allowed_getitem: Set[InstancesHaveGetItem] = field(default_factory=set)
188 allowed_getitem_external: Set[Tuple[str, ...]] = field(default_factory=set)
189 allowed_getitem_external: Set[Tuple[str, ...]] = field(default_factory=set)
189
190
190 allowed_getattr: Set[MayHaveGetattr] = field(default_factory=set)
191 allowed_getattr: Set[MayHaveGetattr] = field(default_factory=set)
191 allowed_getattr_external: Set[Tuple[str, ...]] = field(default_factory=set)
192 allowed_getattr_external: Set[Tuple[str, ...]] = field(default_factory=set)
192
193
193 allowed_operations: Set = field(default_factory=set)
194 allowed_operations: Set = field(default_factory=set)
194 allowed_operations_external: Set[Tuple[str, ...]] = field(default_factory=set)
195 allowed_operations_external: Set[Tuple[str, ...]] = field(default_factory=set)
195
196
196 _operation_methods_cache: Dict[str, Set[Callable]] = field(
197 _operation_methods_cache: Dict[str, Set[Callable]] = field(
197 default_factory=dict, init=False
198 default_factory=dict, init=False
198 )
199 )
199
200
200 def can_get_attr(self, value, attr):
201 def can_get_attr(self, value, attr):
201 has_original_attribute = _has_original_dunder(
202 has_original_attribute = _has_original_dunder(
202 value,
203 value,
203 allowed_types=self.allowed_getattr,
204 allowed_types=self.allowed_getattr,
204 allowed_methods=self._getattribute_methods,
205 allowed_methods=self._getattribute_methods,
205 allowed_external=self.allowed_getattr_external,
206 allowed_external=self.allowed_getattr_external,
206 method_name="__getattribute__",
207 method_name="__getattribute__",
207 )
208 )
208 has_original_attr = _has_original_dunder(
209 has_original_attr = _has_original_dunder(
209 value,
210 value,
210 allowed_types=self.allowed_getattr,
211 allowed_types=self.allowed_getattr,
211 allowed_methods=self._getattr_methods,
212 allowed_methods=self._getattr_methods,
212 allowed_external=self.allowed_getattr_external,
213 allowed_external=self.allowed_getattr_external,
213 method_name="__getattr__",
214 method_name="__getattr__",
214 )
215 )
215
216
216 accept = False
217 accept = False
217
218
218 # Many objects do not have `__getattr__`, this is fine.
219 # Many objects do not have `__getattr__`, this is fine.
219 if has_original_attr is None and has_original_attribute:
220 if has_original_attr is None and has_original_attribute:
220 accept = True
221 accept = True
221 else:
222 else:
222 # Accept objects without modifications to `__getattr__` and `__getattribute__`
223 # Accept objects without modifications to `__getattr__` and `__getattribute__`
223 accept = has_original_attr and has_original_attribute
224 accept = has_original_attr and has_original_attribute
224
225
225 if accept:
226 if accept:
226 # We still need to check for overriden properties.
227 # We still need to check for overriden properties.
227
228
228 value_class = type(value)
229 value_class = type(value)
229 if not hasattr(value_class, attr):
230 if not hasattr(value_class, attr):
230 return True
231 return True
231
232
232 class_attr_val = getattr(value_class, attr)
233 class_attr_val = getattr(value_class, attr)
233 is_property = isinstance(class_attr_val, property)
234 is_property = isinstance(class_attr_val, property)
234
235
235 if not is_property:
236 if not is_property:
236 return True
237 return True
237
238
238 # Properties in allowed types are ok (although we do not include any
239 # Properties in allowed types are ok (although we do not include any
239 # properties in our default allow list currently).
240 # properties in our default allow list currently).
240 if type(value) in self.allowed_getattr:
241 if type(value) in self.allowed_getattr:
241 return True # pragma: no cover
242 return True # pragma: no cover
242
243
243 # Properties in subclasses of allowed types may be ok if not changed
244 # Properties in subclasses of allowed types may be ok if not changed
244 for module_name, *access_path in self.allowed_getattr_external:
245 for module_name, *access_path in self.allowed_getattr_external:
245 try:
246 try:
246 external_class = _get_external(module_name, access_path)
247 external_class = _get_external(module_name, access_path)
247 external_class_attr_val = getattr(external_class, attr)
248 external_class_attr_val = getattr(external_class, attr)
248 except (KeyError, AttributeError):
249 except (KeyError, AttributeError):
249 return False # pragma: no cover
250 return False # pragma: no cover
250 return class_attr_val == external_class_attr_val
251 return class_attr_val == external_class_attr_val
251
252
252 return False
253 return False
253
254
254 def can_get_item(self, value, item):
255 def can_get_item(self, value, item):
255 """Allow accessing `__getiitem__` of allow-listed instances unless it was not modified."""
256 """Allow accessing `__getiitem__` of allow-listed instances unless it was not modified."""
256 return _has_original_dunder(
257 return _has_original_dunder(
257 value,
258 value,
258 allowed_types=self.allowed_getitem,
259 allowed_types=self.allowed_getitem,
259 allowed_methods=self._getitem_methods,
260 allowed_methods=self._getitem_methods,
260 allowed_external=self.allowed_getitem_external,
261 allowed_external=self.allowed_getitem_external,
261 method_name="__getitem__",
262 method_name="__getitem__",
262 )
263 )
263
264
264 def can_operate(self, dunders: Tuple[str, ...], a, b=None):
265 def can_operate(self, dunders: Tuple[str, ...], a, b=None):
265 objects = [a]
266 objects = [a]
266 if b is not None:
267 if b is not None:
267 objects.append(b)
268 objects.append(b)
268 return all(
269 return all(
269 [
270 [
270 _has_original_dunder(
271 _has_original_dunder(
271 obj,
272 obj,
272 allowed_types=self.allowed_operations,
273 allowed_types=self.allowed_operations,
273 allowed_methods=self._operator_dunder_methods(dunder),
274 allowed_methods=self._operator_dunder_methods(dunder),
274 allowed_external=self.allowed_operations_external,
275 allowed_external=self.allowed_operations_external,
275 method_name=dunder,
276 method_name=dunder,
276 )
277 )
277 for dunder in dunders
278 for dunder in dunders
278 for obj in objects
279 for obj in objects
279 ]
280 ]
280 )
281 )
281
282
282 def _operator_dunder_methods(self, dunder: str) -> Set[Callable]:
283 def _operator_dunder_methods(self, dunder: str) -> Set[Callable]:
283 if dunder not in self._operation_methods_cache:
284 if dunder not in self._operation_methods_cache:
284 self._operation_methods_cache[dunder] = self._safe_get_methods(
285 self._operation_methods_cache[dunder] = self._safe_get_methods(
285 self.allowed_operations, dunder
286 self.allowed_operations, dunder
286 )
287 )
287 return self._operation_methods_cache[dunder]
288 return self._operation_methods_cache[dunder]
288
289
289 @cached_property
290 @cached_property
290 def _getitem_methods(self) -> Set[Callable]:
291 def _getitem_methods(self) -> Set[Callable]:
291 return self._safe_get_methods(self.allowed_getitem, "__getitem__")
292 return self._safe_get_methods(self.allowed_getitem, "__getitem__")
292
293
293 @cached_property
294 @cached_property
294 def _getattr_methods(self) -> Set[Callable]:
295 def _getattr_methods(self) -> Set[Callable]:
295 return self._safe_get_methods(self.allowed_getattr, "__getattr__")
296 return self._safe_get_methods(self.allowed_getattr, "__getattr__")
296
297
297 @cached_property
298 @cached_property
298 def _getattribute_methods(self) -> Set[Callable]:
299 def _getattribute_methods(self) -> Set[Callable]:
299 return self._safe_get_methods(self.allowed_getattr, "__getattribute__")
300 return self._safe_get_methods(self.allowed_getattr, "__getattribute__")
300
301
301 def _safe_get_methods(self, classes, name) -> Set[Callable]:
302 def _safe_get_methods(self, classes, name) -> Set[Callable]:
302 return {
303 return {
303 method
304 method
304 for class_ in classes
305 for class_ in classes
305 for method in [getattr(class_, name, None)]
306 for method in [getattr(class_, name, None)]
306 if method
307 if method
307 }
308 }
308
309
309
310
310 class _DummyNamedTuple(NamedTuple):
311 class _DummyNamedTuple(NamedTuple):
311 """Used internally to retrieve methods of named tuple instance."""
312 """Used internally to retrieve methods of named tuple instance."""
312
313
313
314
314 class EvaluationContext(NamedTuple):
315 class EvaluationContext(NamedTuple):
315 #: Local namespace
316 #: Local namespace
316 locals: dict
317 locals: dict
317 #: Global namespace
318 #: Global namespace
318 globals: dict
319 globals: dict
319 #: Evaluation policy identifier
320 #: Evaluation policy identifier
320 evaluation: Literal[
321 evaluation: Literal[
321 "forbidden", "minimal", "limited", "unsafe", "dangerous"
322 "forbidden", "minimal", "limited", "unsafe", "dangerous"
322 ] = "forbidden"
323 ] = "forbidden"
323 #: Whether the evalution of code takes place inside of a subscript.
324 #: Whether the evalution of code takes place inside of a subscript.
324 #: Useful for evaluating ``:-1, 'col'`` in ``df[:-1, 'col']``.
325 #: Useful for evaluating ``:-1, 'col'`` in ``df[:-1, 'col']``.
325 in_subscript: bool = False
326 in_subscript: bool = False
326
327
327
328
328 class _IdentitySubscript:
329 class _IdentitySubscript:
329 """Returns the key itself when item is requested via subscript."""
330 """Returns the key itself when item is requested via subscript."""
330
331
331 def __getitem__(self, key):
332 def __getitem__(self, key):
332 return key
333 return key
333
334
334
335
335 IDENTITY_SUBSCRIPT = _IdentitySubscript()
336 IDENTITY_SUBSCRIPT = _IdentitySubscript()
336 SUBSCRIPT_MARKER = "__SUBSCRIPT_SENTINEL__"
337 SUBSCRIPT_MARKER = "__SUBSCRIPT_SENTINEL__"
337
338
338
339
339 class GuardRejection(Exception):
340 class GuardRejection(Exception):
340 """Exception raised when guard rejects evaluation attempt."""
341 """Exception raised when guard rejects evaluation attempt."""
341
342
342 pass
343 pass
343
344
344
345
345 def guarded_eval(code: str, context: EvaluationContext):
346 def guarded_eval(code: str, context: EvaluationContext):
346 """Evaluate provided code in the evaluation context.
347 """Evaluate provided code in the evaluation context.
347
348
348 If evaluation policy given by context is set to ``forbidden``
349 If evaluation policy given by context is set to ``forbidden``
349 no evaluation will be performed; if it is set to ``dangerous``
350 no evaluation will be performed; if it is set to ``dangerous``
350 standard :func:`eval` will be used; finally, for any other,
351 standard :func:`eval` will be used; finally, for any other,
351 policy :func:`eval_node` will be called on parsed AST.
352 policy :func:`eval_node` will be called on parsed AST.
352 """
353 """
353 locals_ = context.locals
354 locals_ = context.locals
354
355
355 if context.evaluation == "forbidden":
356 if context.evaluation == "forbidden":
356 raise GuardRejection("Forbidden mode")
357 raise GuardRejection("Forbidden mode")
357
358
358 # note: not using `ast.literal_eval` as it does not implement
359 # note: not using `ast.literal_eval` as it does not implement
359 # getitem at all, for example it fails on simple `[0][1]`
360 # getitem at all, for example it fails on simple `[0][1]`
360
361
361 if context.in_subscript:
362 if context.in_subscript:
362 # syntatic sugar for ellipsis (:) is only available in susbcripts
363 # syntatic sugar for ellipsis (:) is only available in susbcripts
363 # so we need to trick the ast parser into thinking that we have
364 # so we need to trick the ast parser into thinking that we have
364 # a subscript, but we need to be able to later recognise that we did
365 # a subscript, but we need to be able to later recognise that we did
365 # it so we can ignore the actual __getitem__ operation
366 # it so we can ignore the actual __getitem__ operation
366 if not code:
367 if not code:
367 return tuple()
368 return tuple()
368 locals_ = locals_.copy()
369 locals_ = locals_.copy()
369 locals_[SUBSCRIPT_MARKER] = IDENTITY_SUBSCRIPT
370 locals_[SUBSCRIPT_MARKER] = IDENTITY_SUBSCRIPT
370 code = SUBSCRIPT_MARKER + "[" + code + "]"
371 code = SUBSCRIPT_MARKER + "[" + code + "]"
371 context = EvaluationContext(**{**context._asdict(), **{"locals": locals_}})
372 context = EvaluationContext(**{**context._asdict(), **{"locals": locals_}})
372
373
373 if context.evaluation == "dangerous":
374 if context.evaluation == "dangerous":
374 return eval(code, context.globals, context.locals)
375 return eval(code, context.globals, context.locals)
375
376
376 expression = ast.parse(code, mode="eval")
377 expression = ast.parse(code, mode="eval")
377
378
378 return eval_node(expression, context)
379 return eval_node(expression, context)
379
380
380
381
381 BINARY_OP_DUNDERS: Dict[Type[ast.operator], Tuple[str]] = {
382 BINARY_OP_DUNDERS: Dict[Type[ast.operator], Tuple[str]] = {
382 ast.Add: ("__add__",),
383 ast.Add: ("__add__",),
383 ast.Sub: ("__sub__",),
384 ast.Sub: ("__sub__",),
384 ast.Mult: ("__mul__",),
385 ast.Mult: ("__mul__",),
385 ast.Div: ("__truediv__",),
386 ast.Div: ("__truediv__",),
386 ast.FloorDiv: ("__floordiv__",),
387 ast.FloorDiv: ("__floordiv__",),
387 ast.Mod: ("__mod__",),
388 ast.Mod: ("__mod__",),
388 ast.Pow: ("__pow__",),
389 ast.Pow: ("__pow__",),
389 ast.LShift: ("__lshift__",),
390 ast.LShift: ("__lshift__",),
390 ast.RShift: ("__rshift__",),
391 ast.RShift: ("__rshift__",),
391 ast.BitOr: ("__or__",),
392 ast.BitOr: ("__or__",),
392 ast.BitXor: ("__xor__",),
393 ast.BitXor: ("__xor__",),
393 ast.BitAnd: ("__and__",),
394 ast.BitAnd: ("__and__",),
394 ast.MatMult: ("__matmul__",),
395 ast.MatMult: ("__matmul__",),
395 }
396 }
396
397
397 COMP_OP_DUNDERS: Dict[Type[ast.cmpop], Tuple[str, ...]] = {
398 COMP_OP_DUNDERS: Dict[Type[ast.cmpop], Tuple[str, ...]] = {
398 ast.Eq: ("__eq__",),
399 ast.Eq: ("__eq__",),
399 ast.NotEq: ("__ne__", "__eq__"),
400 ast.NotEq: ("__ne__", "__eq__"),
400 ast.Lt: ("__lt__", "__gt__"),
401 ast.Lt: ("__lt__", "__gt__"),
401 ast.LtE: ("__le__", "__ge__"),
402 ast.LtE: ("__le__", "__ge__"),
402 ast.Gt: ("__gt__", "__lt__"),
403 ast.Gt: ("__gt__", "__lt__"),
403 ast.GtE: ("__ge__", "__le__"),
404 ast.GtE: ("__ge__", "__le__"),
404 ast.In: ("__contains__",),
405 ast.In: ("__contains__",),
405 # Note: ast.Is, ast.IsNot, ast.NotIn are handled specially
406 # Note: ast.Is, ast.IsNot, ast.NotIn are handled specially
406 }
407 }
407
408
408 UNARY_OP_DUNDERS: Dict[Type[ast.unaryop], Tuple[str, ...]] = {
409 UNARY_OP_DUNDERS: Dict[Type[ast.unaryop], Tuple[str, ...]] = {
409 ast.USub: ("__neg__",),
410 ast.USub: ("__neg__",),
410 ast.UAdd: ("__pos__",),
411 ast.UAdd: ("__pos__",),
411 # we have to check both __inv__ and __invert__!
412 # we have to check both __inv__ and __invert__!
412 ast.Invert: ("__invert__", "__inv__"),
413 ast.Invert: ("__invert__", "__inv__"),
413 ast.Not: ("__not__",),
414 ast.Not: ("__not__",),
414 }
415 }
415
416
416
417
417 def _find_dunder(node_op, dunders) -> Union[Tuple[str, ...], None]:
418 def _find_dunder(node_op, dunders) -> Union[Tuple[str, ...], None]:
418 dunder = None
419 dunder = None
419 for op, candidate_dunder in dunders.items():
420 for op, candidate_dunder in dunders.items():
420 if isinstance(node_op, op):
421 if isinstance(node_op, op):
421 dunder = candidate_dunder
422 dunder = candidate_dunder
422 return dunder
423 return dunder
423
424
424
425
425 def eval_node(node: Union[ast.AST, None], context: EvaluationContext):
426 def eval_node(node: Union[ast.AST, None], context: EvaluationContext):
426 """Evaluate AST node in provided context.
427 """Evaluate AST node in provided context.
427
428
428 Applies evaluation restrictions defined in the context. Currently does not support evaluation of functions with keyword arguments.
429 Applies evaluation restrictions defined in the context. Currently does not support evaluation of functions with keyword arguments.
429
430
430 Does not evaluate actions that always have side effects:
431 Does not evaluate actions that always have side effects:
431
432
432 - class definitions (``class sth: ...``)
433 - class definitions (``class sth: ...``)
433 - function definitions (``def sth: ...``)
434 - function definitions (``def sth: ...``)
434 - variable assignments (``x = 1``)
435 - variable assignments (``x = 1``)
435 - augmented assignments (``x += 1``)
436 - augmented assignments (``x += 1``)
436 - deletions (``del x``)
437 - deletions (``del x``)
437
438
438 Does not evaluate operations which do not return values:
439 Does not evaluate operations which do not return values:
439
440
440 - assertions (``assert x``)
441 - assertions (``assert x``)
441 - pass (``pass``)
442 - pass (``pass``)
442 - imports (``import x``)
443 - imports (``import x``)
443 - control flow:
444 - control flow:
444
445
445 - conditionals (``if x:``) except for ternary IfExp (``a if x else b``)
446 - conditionals (``if x:``) except for ternary IfExp (``a if x else b``)
446 - loops (``for`` and `while``)
447 - loops (``for`` and `while``)
447 - exception handling
448 - exception handling
448
449
449 The purpose of this function is to guard against unwanted side-effects;
450 The purpose of this function is to guard against unwanted side-effects;
450 it does not give guarantees on protection from malicious code execution.
451 it does not give guarantees on protection from malicious code execution.
451 """
452 """
452 policy = EVALUATION_POLICIES[context.evaluation]
453 policy = EVALUATION_POLICIES[context.evaluation]
453 if node is None:
454 if node is None:
454 return None
455 return None
455 if isinstance(node, ast.Expression):
456 if isinstance(node, ast.Expression):
456 return eval_node(node.body, context)
457 return eval_node(node.body, context)
457 if isinstance(node, ast.BinOp):
458 if isinstance(node, ast.BinOp):
458 left = eval_node(node.left, context)
459 left = eval_node(node.left, context)
459 right = eval_node(node.right, context)
460 right = eval_node(node.right, context)
460 dunders = _find_dunder(node.op, BINARY_OP_DUNDERS)
461 dunders = _find_dunder(node.op, BINARY_OP_DUNDERS)
461 if dunders:
462 if dunders:
462 if policy.can_operate(dunders, left, right):
463 if policy.can_operate(dunders, left, right):
463 return getattr(left, dunders[0])(right)
464 return getattr(left, dunders[0])(right)
464 else:
465 else:
465 raise GuardRejection(
466 raise GuardRejection(
466 f"Operation (`{dunders}`) for",
467 f"Operation (`{dunders}`) for",
467 type(left),
468 type(left),
468 f"not allowed in {context.evaluation} mode",
469 f"not allowed in {context.evaluation} mode",
469 )
470 )
470 if isinstance(node, ast.Compare):
471 if isinstance(node, ast.Compare):
471 left = eval_node(node.left, context)
472 left = eval_node(node.left, context)
472 all_true = True
473 all_true = True
473 negate = False
474 negate = False
474 for op, right in zip(node.ops, node.comparators):
475 for op, right in zip(node.ops, node.comparators):
475 right = eval_node(right, context)
476 right = eval_node(right, context)
476 dunder = None
477 dunder = None
477 dunders = _find_dunder(op, COMP_OP_DUNDERS)
478 dunders = _find_dunder(op, COMP_OP_DUNDERS)
478 if not dunders:
479 if not dunders:
479 if isinstance(op, ast.NotIn):
480 if isinstance(op, ast.NotIn):
480 dunders = COMP_OP_DUNDERS[ast.In]
481 dunders = COMP_OP_DUNDERS[ast.In]
481 negate = True
482 negate = True
482 if isinstance(op, ast.Is):
483 if isinstance(op, ast.Is):
483 dunder = "is_"
484 dunder = "is_"
484 if isinstance(op, ast.IsNot):
485 if isinstance(op, ast.IsNot):
485 dunder = "is_"
486 dunder = "is_"
486 negate = True
487 negate = True
487 if not dunder and dunders:
488 if not dunder and dunders:
488 dunder = dunders[0]
489 dunder = dunders[0]
489 if dunder:
490 if dunder:
490 a, b = (right, left) if dunder == "__contains__" else (left, right)
491 a, b = (right, left) if dunder == "__contains__" else (left, right)
491 if dunder == "is_" or dunders and policy.can_operate(dunders, a, b):
492 if dunder == "is_" or dunders and policy.can_operate(dunders, a, b):
492 result = getattr(operator, dunder)(a, b)
493 result = getattr(operator, dunder)(a, b)
493 if negate:
494 if negate:
494 result = not result
495 result = not result
495 if not result:
496 if not result:
496 all_true = False
497 all_true = False
497 left = right
498 left = right
498 else:
499 else:
499 raise GuardRejection(
500 raise GuardRejection(
500 f"Comparison (`{dunder}`) for",
501 f"Comparison (`{dunder}`) for",
501 type(left),
502 type(left),
502 f"not allowed in {context.evaluation} mode",
503 f"not allowed in {context.evaluation} mode",
503 )
504 )
504 else:
505 else:
505 raise ValueError(
506 raise ValueError(
506 f"Comparison `{dunder}` not supported"
507 f"Comparison `{dunder}` not supported"
507 ) # pragma: no cover
508 ) # pragma: no cover
508 return all_true
509 return all_true
509 if isinstance(node, ast.Constant):
510 if isinstance(node, ast.Constant):
510 return node.value
511 return node.value
511 if isinstance(node, ast.Tuple):
512 if isinstance(node, ast.Tuple):
512 return tuple(eval_node(e, context) for e in node.elts)
513 return tuple(eval_node(e, context) for e in node.elts)
513 if isinstance(node, ast.List):
514 if isinstance(node, ast.List):
514 return [eval_node(e, context) for e in node.elts]
515 return [eval_node(e, context) for e in node.elts]
515 if isinstance(node, ast.Set):
516 if isinstance(node, ast.Set):
516 return {eval_node(e, context) for e in node.elts}
517 return {eval_node(e, context) for e in node.elts}
517 if isinstance(node, ast.Dict):
518 if isinstance(node, ast.Dict):
518 return dict(
519 return dict(
519 zip(
520 zip(
520 [eval_node(k, context) for k in node.keys],
521 [eval_node(k, context) for k in node.keys],
521 [eval_node(v, context) for v in node.values],
522 [eval_node(v, context) for v in node.values],
522 )
523 )
523 )
524 )
524 if isinstance(node, ast.Slice):
525 if isinstance(node, ast.Slice):
525 return slice(
526 return slice(
526 eval_node(node.lower, context),
527 eval_node(node.lower, context),
527 eval_node(node.upper, context),
528 eval_node(node.upper, context),
528 eval_node(node.step, context),
529 eval_node(node.step, context),
529 )
530 )
530 if isinstance(node, ast.UnaryOp):
531 if isinstance(node, ast.UnaryOp):
531 value = eval_node(node.operand, context)
532 value = eval_node(node.operand, context)
532 dunders = _find_dunder(node.op, UNARY_OP_DUNDERS)
533 dunders = _find_dunder(node.op, UNARY_OP_DUNDERS)
533 if dunders:
534 if dunders:
534 if policy.can_operate(dunders, value):
535 if policy.can_operate(dunders, value):
535 return getattr(value, dunders[0])()
536 return getattr(value, dunders[0])()
536 else:
537 else:
537 raise GuardRejection(
538 raise GuardRejection(
538 f"Operation (`{dunders}`) for",
539 f"Operation (`{dunders}`) for",
539 type(value),
540 type(value),
540 f"not allowed in {context.evaluation} mode",
541 f"not allowed in {context.evaluation} mode",
541 )
542 )
542 if isinstance(node, ast.Subscript):
543 if isinstance(node, ast.Subscript):
543 value = eval_node(node.value, context)
544 value = eval_node(node.value, context)
544 slice_ = eval_node(node.slice, context)
545 slice_ = eval_node(node.slice, context)
545 if policy.can_get_item(value, slice_):
546 if policy.can_get_item(value, slice_):
546 return value[slice_]
547 return value[slice_]
547 raise GuardRejection(
548 raise GuardRejection(
548 "Subscript access (`__getitem__`) for",
549 "Subscript access (`__getitem__`) for",
549 type(value), # not joined to avoid calling `repr`
550 type(value), # not joined to avoid calling `repr`
550 f" not allowed in {context.evaluation} mode",
551 f" not allowed in {context.evaluation} mode",
551 )
552 )
552 if isinstance(node, ast.Name):
553 if isinstance(node, ast.Name):
553 if policy.allow_locals_access and node.id in context.locals:
554 if policy.allow_locals_access and node.id in context.locals:
554 return context.locals[node.id]
555 return context.locals[node.id]
555 if policy.allow_globals_access and node.id in context.globals:
556 if policy.allow_globals_access and node.id in context.globals:
556 return context.globals[node.id]
557 return context.globals[node.id]
557 if policy.allow_builtins_access and hasattr(builtins, node.id):
558 if policy.allow_builtins_access and hasattr(builtins, node.id):
558 # note: do not use __builtins__, it is implementation detail of cPython
559 # note: do not use __builtins__, it is implementation detail of cPython
559 return getattr(builtins, node.id)
560 return getattr(builtins, node.id)
560 if not policy.allow_globals_access and not policy.allow_locals_access:
561 if not policy.allow_globals_access and not policy.allow_locals_access:
561 raise GuardRejection(
562 raise GuardRejection(
562 f"Namespace access not allowed in {context.evaluation} mode"
563 f"Namespace access not allowed in {context.evaluation} mode"
563 )
564 )
564 else:
565 else:
565 raise NameError(f"{node.id} not found in locals, globals, nor builtins")
566 raise NameError(f"{node.id} not found in locals, globals, nor builtins")
566 if isinstance(node, ast.Attribute):
567 if isinstance(node, ast.Attribute):
567 value = eval_node(node.value, context)
568 value = eval_node(node.value, context)
568 if policy.can_get_attr(value, node.attr):
569 if policy.can_get_attr(value, node.attr):
569 return getattr(value, node.attr)
570 return getattr(value, node.attr)
570 raise GuardRejection(
571 raise GuardRejection(
571 "Attribute access (`__getattr__`) for",
572 "Attribute access (`__getattr__`) for",
572 type(value), # not joined to avoid calling `repr`
573 type(value), # not joined to avoid calling `repr`
573 f"not allowed in {context.evaluation} mode",
574 f"not allowed in {context.evaluation} mode",
574 )
575 )
575 if isinstance(node, ast.IfExp):
576 if isinstance(node, ast.IfExp):
576 test = eval_node(node.test, context)
577 test = eval_node(node.test, context)
577 if test:
578 if test:
578 return eval_node(node.body, context)
579 return eval_node(node.body, context)
579 else:
580 else:
580 return eval_node(node.orelse, context)
581 return eval_node(node.orelse, context)
581 if isinstance(node, ast.Call):
582 if isinstance(node, ast.Call):
582 func = eval_node(node.func, context)
583 func = eval_node(node.func, context)
583 if policy.can_call(func) and not node.keywords:
584 if policy.can_call(func) and not node.keywords:
584 args = [eval_node(arg, context) for arg in node.args]
585 args = [eval_node(arg, context) for arg in node.args]
585 return func(*args)
586 return func(*args)
586 raise GuardRejection(
587 raise GuardRejection(
587 "Call for",
588 "Call for",
588 func, # not joined to avoid calling `repr`
589 func, # not joined to avoid calling `repr`
589 f"not allowed in {context.evaluation} mode",
590 f"not allowed in {context.evaluation} mode",
590 )
591 )
591 raise ValueError("Unhandled node", ast.dump(node))
592 raise ValueError("Unhandled node", ast.dump(node))
592
593
593
594
594 SUPPORTED_EXTERNAL_GETITEM = {
595 SUPPORTED_EXTERNAL_GETITEM = {
595 ("pandas", "core", "indexing", "_iLocIndexer"),
596 ("pandas", "core", "indexing", "_iLocIndexer"),
596 ("pandas", "core", "indexing", "_LocIndexer"),
597 ("pandas", "core", "indexing", "_LocIndexer"),
597 ("pandas", "DataFrame"),
598 ("pandas", "DataFrame"),
598 ("pandas", "Series"),
599 ("pandas", "Series"),
599 ("numpy", "ndarray"),
600 ("numpy", "ndarray"),
600 ("numpy", "void"),
601 ("numpy", "void"),
601 }
602 }
602
603
603
604
604 BUILTIN_GETITEM: Set[InstancesHaveGetItem] = {
605 BUILTIN_GETITEM: Set[InstancesHaveGetItem] = {
605 dict,
606 dict,
606 str, # type: ignore[arg-type]
607 str, # type: ignore[arg-type]
607 bytes, # type: ignore[arg-type]
608 bytes, # type: ignore[arg-type]
608 list,
609 list,
609 tuple,
610 tuple,
610 collections.defaultdict,
611 collections.defaultdict,
611 collections.deque,
612 collections.deque,
612 collections.OrderedDict,
613 collections.OrderedDict,
613 collections.ChainMap,
614 collections.ChainMap,
614 collections.UserDict,
615 collections.UserDict,
615 collections.UserList,
616 collections.UserList,
616 collections.UserString, # type: ignore[arg-type]
617 collections.UserString, # type: ignore[arg-type]
617 _DummyNamedTuple,
618 _DummyNamedTuple,
618 _IdentitySubscript,
619 _IdentitySubscript,
619 }
620 }
620
621
621
622
622 def _list_methods(cls, source=None):
623 def _list_methods(cls, source=None):
623 """For use on immutable objects or with methods returning a copy"""
624 """For use on immutable objects or with methods returning a copy"""
624 return [getattr(cls, k) for k in (source if source else dir(cls))]
625 return [getattr(cls, k) for k in (source if source else dir(cls))]
625
626
626
627
627 dict_non_mutating_methods = ("copy", "keys", "values", "items")
628 dict_non_mutating_methods = ("copy", "keys", "values", "items")
628 list_non_mutating_methods = ("copy", "index", "count")
629 list_non_mutating_methods = ("copy", "index", "count")
629 set_non_mutating_methods = set(dir(set)) & set(dir(frozenset))
630 set_non_mutating_methods = set(dir(set)) & set(dir(frozenset))
630
631
631
632
632 dict_keys: Type[collections.abc.KeysView] = type({}.keys())
633 dict_keys: Type[collections.abc.KeysView] = type({}.keys())
633 method_descriptor: Any = type(list.copy)
634 module = type(builtins)
635
634
636 NUMERICS = {int, float, complex}
635 NUMERICS = {int, float, complex}
637
636
638 ALLOWED_CALLS = {
637 ALLOWED_CALLS = {
639 bytes,
638 bytes,
640 *_list_methods(bytes),
639 *_list_methods(bytes),
641 dict,
640 dict,
642 *_list_methods(dict, dict_non_mutating_methods),
641 *_list_methods(dict, dict_non_mutating_methods),
643 dict_keys.isdisjoint,
642 dict_keys.isdisjoint,
644 list,
643 list,
645 *_list_methods(list, list_non_mutating_methods),
644 *_list_methods(list, list_non_mutating_methods),
646 set,
645 set,
647 *_list_methods(set, set_non_mutating_methods),
646 *_list_methods(set, set_non_mutating_methods),
648 frozenset,
647 frozenset,
649 *_list_methods(frozenset),
648 *_list_methods(frozenset),
650 range,
649 range,
651 str,
650 str,
652 *_list_methods(str),
651 *_list_methods(str),
653 tuple,
652 tuple,
654 *_list_methods(tuple),
653 *_list_methods(tuple),
655 *NUMERICS,
654 *NUMERICS,
656 *[method for numeric_cls in NUMERICS for method in _list_methods(numeric_cls)],
655 *[method for numeric_cls in NUMERICS for method in _list_methods(numeric_cls)],
657 collections.deque,
656 collections.deque,
658 *_list_methods(collections.deque, list_non_mutating_methods),
657 *_list_methods(collections.deque, list_non_mutating_methods),
659 collections.defaultdict,
658 collections.defaultdict,
660 *_list_methods(collections.defaultdict, dict_non_mutating_methods),
659 *_list_methods(collections.defaultdict, dict_non_mutating_methods),
661 collections.OrderedDict,
660 collections.OrderedDict,
662 *_list_methods(collections.OrderedDict, dict_non_mutating_methods),
661 *_list_methods(collections.OrderedDict, dict_non_mutating_methods),
663 collections.UserDict,
662 collections.UserDict,
664 *_list_methods(collections.UserDict, dict_non_mutating_methods),
663 *_list_methods(collections.UserDict, dict_non_mutating_methods),
665 collections.UserList,
664 collections.UserList,
666 *_list_methods(collections.UserList, list_non_mutating_methods),
665 *_list_methods(collections.UserList, list_non_mutating_methods),
667 collections.UserString,
666 collections.UserString,
668 *_list_methods(collections.UserString, dir(str)),
667 *_list_methods(collections.UserString, dir(str)),
669 collections.Counter,
668 collections.Counter,
670 *_list_methods(collections.Counter, dict_non_mutating_methods),
669 *_list_methods(collections.Counter, dict_non_mutating_methods),
671 collections.Counter.elements,
670 collections.Counter.elements,
672 collections.Counter.most_common,
671 collections.Counter.most_common,
673 }
672 }
674
673
675 BUILTIN_GETATTR: Set[MayHaveGetattr] = {
674 BUILTIN_GETATTR: Set[MayHaveGetattr] = {
676 *BUILTIN_GETITEM,
675 *BUILTIN_GETITEM,
677 set,
676 set,
678 frozenset,
677 frozenset,
679 object,
678 object,
680 type, # `type` handles a lot of generic cases, e.g. numbers as in `int.real`.
679 type, # `type` handles a lot of generic cases, e.g. numbers as in `int.real`.
681 *NUMERICS,
680 *NUMERICS,
682 dict_keys,
681 dict_keys,
683 method_descriptor,
682 MethodDescriptorType,
684 module,
683 ModuleType,
685 }
684 }
686
685
687
686
688 BUILTIN_OPERATIONS = {*BUILTIN_GETATTR}
687 BUILTIN_OPERATIONS = {*BUILTIN_GETATTR}
689
688
690 EVALUATION_POLICIES = {
689 EVALUATION_POLICIES = {
691 "minimal": EvaluationPolicy(
690 "minimal": EvaluationPolicy(
692 allow_builtins_access=True,
691 allow_builtins_access=True,
693 allow_locals_access=False,
692 allow_locals_access=False,
694 allow_globals_access=False,
693 allow_globals_access=False,
695 allow_item_access=False,
694 allow_item_access=False,
696 allow_attr_access=False,
695 allow_attr_access=False,
697 allowed_calls=set(),
696 allowed_calls=set(),
698 allow_any_calls=False,
697 allow_any_calls=False,
699 allow_all_operations=False,
698 allow_all_operations=False,
700 ),
699 ),
701 "limited": SelectivePolicy(
700 "limited": SelectivePolicy(
702 allowed_getitem=BUILTIN_GETITEM,
701 allowed_getitem=BUILTIN_GETITEM,
703 allowed_getitem_external=SUPPORTED_EXTERNAL_GETITEM,
702 allowed_getitem_external=SUPPORTED_EXTERNAL_GETITEM,
704 allowed_getattr=BUILTIN_GETATTR,
703 allowed_getattr=BUILTIN_GETATTR,
705 allowed_getattr_external={
704 allowed_getattr_external={
706 # pandas Series/Frame implements custom `__getattr__`
705 # pandas Series/Frame implements custom `__getattr__`
707 ("pandas", "DataFrame"),
706 ("pandas", "DataFrame"),
708 ("pandas", "Series"),
707 ("pandas", "Series"),
709 },
708 },
710 allowed_operations=BUILTIN_OPERATIONS,
709 allowed_operations=BUILTIN_OPERATIONS,
711 allow_builtins_access=True,
710 allow_builtins_access=True,
712 allow_locals_access=True,
711 allow_locals_access=True,
713 allow_globals_access=True,
712 allow_globals_access=True,
714 allowed_calls=ALLOWED_CALLS,
713 allowed_calls=ALLOWED_CALLS,
715 ),
714 ),
716 "unsafe": EvaluationPolicy(
715 "unsafe": EvaluationPolicy(
717 allow_builtins_access=True,
716 allow_builtins_access=True,
718 allow_locals_access=True,
717 allow_locals_access=True,
719 allow_globals_access=True,
718 allow_globals_access=True,
720 allow_attr_access=True,
719 allow_attr_access=True,
721 allow_item_access=True,
720 allow_item_access=True,
722 allow_any_calls=True,
721 allow_any_calls=True,
723 allow_all_operations=True,
722 allow_all_operations=True,
724 ),
723 ),
725 }
724 }
726
725
727
726
728 __all__ = [
727 __all__ = [
729 "guarded_eval",
728 "guarded_eval",
730 "eval_node",
729 "eval_node",
731 "GuardRejection",
730 "GuardRejection",
732 "EvaluationContext",
731 "EvaluationContext",
733 "_unbind_method",
732 "_unbind_method",
734 ]
733 ]
General Comments 0
You need to be logged in to leave comments. Login now