##// END OF EJS Templates
MAINT: fix typing mypy 1.0
Matthias Bussonnier -
Show More
@@ -1,738 +1,738 b''
1 from typing import (
1 from typing import (
2 Any,
2 Any,
3 Callable,
3 Callable,
4 Dict,
4 Dict,
5 Set,
5 Set,
6 Sequence,
6 Sequence,
7 Tuple,
7 Tuple,
8 NamedTuple,
8 NamedTuple,
9 Type,
9 Type,
10 Literal,
10 Literal,
11 Union,
11 Union,
12 TYPE_CHECKING,
12 TYPE_CHECKING,
13 )
13 )
14 import ast
14 import ast
15 import builtins
15 import builtins
16 import collections
16 import collections
17 import operator
17 import operator
18 import sys
18 import sys
19 from functools import cached_property
19 from functools import cached_property
20 from dataclasses import dataclass, field
20 from dataclasses import dataclass, field
21
21
22 from IPython.utils.docs import GENERATING_DOCUMENTATION
22 from IPython.utils.docs import GENERATING_DOCUMENTATION
23 from IPython.utils.decorators import undoc
23 from IPython.utils.decorators import undoc
24
24
25
25
26 if TYPE_CHECKING or GENERATING_DOCUMENTATION:
26 if TYPE_CHECKING or GENERATING_DOCUMENTATION:
27 from typing_extensions import Protocol
27 from typing_extensions import Protocol
28 else:
28 else:
29 # do not require on runtime
29 # do not require on runtime
30 Protocol = object # requires Python >=3.8
30 Protocol = object # requires Python >=3.8
31
31
32
32
33 @undoc
33 @undoc
34 class HasGetItem(Protocol):
34 class HasGetItem(Protocol):
35 def __getitem__(self, key) -> None:
35 def __getitem__(self, key) -> None:
36 ...
36 ...
37
37
38
38
39 @undoc
39 @undoc
40 class InstancesHaveGetItem(Protocol):
40 class InstancesHaveGetItem(Protocol):
41 def __call__(self, *args, **kwargs) -> HasGetItem:
41 def __call__(self, *args, **kwargs) -> HasGetItem:
42 ...
42 ...
43
43
44
44
45 @undoc
45 @undoc
46 class HasGetAttr(Protocol):
46 class HasGetAttr(Protocol):
47 def __getattr__(self, key) -> None:
47 def __getattr__(self, key) -> None:
48 ...
48 ...
49
49
50
50
51 @undoc
51 @undoc
52 class DoesNotHaveGetAttr(Protocol):
52 class DoesNotHaveGetAttr(Protocol):
53 pass
53 pass
54
54
55
55
56 # By default `__getattr__` is not explicitly implemented on most objects
56 # By default `__getattr__` is not explicitly implemented on most objects
57 MayHaveGetattr = Union[HasGetAttr, DoesNotHaveGetAttr]
57 MayHaveGetattr = Union[HasGetAttr, DoesNotHaveGetAttr]
58
58
59
59
60 def _unbind_method(func: Callable) -> Union[Callable, None]:
60 def _unbind_method(func: Callable) -> Union[Callable, None]:
61 """Get unbound method for given bound method.
61 """Get unbound method for given bound method.
62
62
63 Returns None if cannot get unbound method, or method is already unbound.
63 Returns None if cannot get unbound method, or method is already unbound.
64 """
64 """
65 owner = getattr(func, "__self__", None)
65 owner = getattr(func, "__self__", None)
66 owner_class = type(owner)
66 owner_class = type(owner)
67 name = getattr(func, "__name__", None)
67 name = getattr(func, "__name__", None)
68 instance_dict_overrides = getattr(owner, "__dict__", None)
68 instance_dict_overrides = getattr(owner, "__dict__", None)
69 if (
69 if (
70 owner is not None
70 owner is not None
71 and name
71 and name
72 and (
72 and (
73 not instance_dict_overrides
73 not instance_dict_overrides
74 or (instance_dict_overrides and name not in instance_dict_overrides)
74 or (instance_dict_overrides and name not in instance_dict_overrides)
75 )
75 )
76 ):
76 ):
77 return getattr(owner_class, name)
77 return getattr(owner_class, name)
78 return None
78 return None
79
79
80
80
81 @undoc
81 @undoc
82 @dataclass
82 @dataclass
83 class EvaluationPolicy:
83 class EvaluationPolicy:
84 """Definition of evaluation policy."""
84 """Definition of evaluation policy."""
85
85
86 allow_locals_access: bool = False
86 allow_locals_access: bool = False
87 allow_globals_access: bool = False
87 allow_globals_access: bool = False
88 allow_item_access: bool = False
88 allow_item_access: bool = False
89 allow_attr_access: bool = False
89 allow_attr_access: bool = False
90 allow_builtins_access: bool = False
90 allow_builtins_access: bool = False
91 allow_all_operations: bool = False
91 allow_all_operations: bool = False
92 allow_any_calls: bool = False
92 allow_any_calls: bool = False
93 allowed_calls: Set[Callable] = field(default_factory=set)
93 allowed_calls: Set[Callable] = field(default_factory=set)
94
94
95 def can_get_item(self, value, item):
95 def can_get_item(self, value, item):
96 return self.allow_item_access
96 return self.allow_item_access
97
97
98 def can_get_attr(self, value, attr):
98 def can_get_attr(self, value, attr):
99 return self.allow_attr_access
99 return self.allow_attr_access
100
100
101 def can_operate(self, dunders: Tuple[str, ...], a, b=None):
101 def can_operate(self, dunders: Tuple[str, ...], a, b=None):
102 if self.allow_all_operations:
102 if self.allow_all_operations:
103 return True
103 return True
104
104
105 def can_call(self, func):
105 def can_call(self, func):
106 if self.allow_any_calls:
106 if self.allow_any_calls:
107 return True
107 return True
108
108
109 if func in self.allowed_calls:
109 if func in self.allowed_calls:
110 return True
110 return True
111
111
112 owner_method = _unbind_method(func)
112 owner_method = _unbind_method(func)
113
113
114 if owner_method and owner_method in self.allowed_calls:
114 if owner_method and owner_method in self.allowed_calls:
115 return True
115 return True
116
116
117
117
118 def _get_external(module_name: str, access_path: Sequence[str]):
118 def _get_external(module_name: str, access_path: Sequence[str]):
119 """Get value from external module given a dotted access path.
119 """Get value from external module given a dotted access path.
120
120
121 Raises:
121 Raises:
122 * `KeyError` if module is removed not found, and
122 * `KeyError` if module is removed not found, and
123 * `AttributeError` if acess path does not match an exported object
123 * `AttributeError` if acess path does not match an exported object
124 """
124 """
125 member_type = sys.modules[module_name]
125 member_type = sys.modules[module_name]
126 for attr in access_path:
126 for attr in access_path:
127 member_type = getattr(member_type, attr)
127 member_type = getattr(member_type, attr)
128 return member_type
128 return member_type
129
129
130
130
131 def _has_original_dunder_external(
131 def _has_original_dunder_external(
132 value,
132 value,
133 module_name: str,
133 module_name: str,
134 access_path: Sequence[str],
134 access_path: Sequence[str],
135 method_name: str,
135 method_name: str,
136 ):
136 ):
137 if module_name not in sys.modules:
137 if module_name not in sys.modules:
138 # LBYLB as it is faster
138 # LBYLB as it is faster
139 return False
139 return False
140 try:
140 try:
141 member_type = _get_external(module_name, access_path)
141 member_type = _get_external(module_name, access_path)
142 value_type = type(value)
142 value_type = type(value)
143 if type(value) == member_type:
143 if type(value) == member_type:
144 return True
144 return True
145 if method_name == "__getattribute__":
145 if method_name == "__getattribute__":
146 # we have to short-circuit here due to an unresolved issue in
146 # we have to short-circuit here due to an unresolved issue in
147 # `isinstance` implementation: https://bugs.python.org/issue32683
147 # `isinstance` implementation: https://bugs.python.org/issue32683
148 return False
148 return False
149 if isinstance(value, member_type):
149 if isinstance(value, member_type):
150 method = getattr(value_type, method_name, None)
150 method = getattr(value_type, method_name, None)
151 member_method = getattr(member_type, method_name, None)
151 member_method = getattr(member_type, method_name, None)
152 if member_method == method:
152 if member_method == method:
153 return True
153 return True
154 except (AttributeError, KeyError):
154 except (AttributeError, KeyError):
155 return False
155 return False
156
156
157
157
158 def _has_original_dunder(
158 def _has_original_dunder(
159 value, allowed_types, allowed_methods, allowed_external, method_name
159 value, allowed_types, allowed_methods, allowed_external, method_name
160 ):
160 ):
161 # note: Python ignores `__getattr__`/`__getitem__` on instances,
161 # note: Python ignores `__getattr__`/`__getitem__` on instances,
162 # we only need to check at class level
162 # we only need to check at class level
163 value_type = type(value)
163 value_type = type(value)
164
164
165 # strict type check passes β†’ no need to check method
165 # strict type check passes β†’ no need to check method
166 if value_type in allowed_types:
166 if value_type in allowed_types:
167 return True
167 return True
168
168
169 method = getattr(value_type, method_name, None)
169 method = getattr(value_type, method_name, None)
170
170
171 if method is None:
171 if method is None:
172 return None
172 return None
173
173
174 if method in allowed_methods:
174 if method in allowed_methods:
175 return True
175 return True
176
176
177 for module_name, *access_path in allowed_external:
177 for module_name, *access_path in allowed_external:
178 if _has_original_dunder_external(value, module_name, access_path, method_name):
178 if _has_original_dunder_external(value, module_name, access_path, method_name):
179 return True
179 return True
180
180
181 return False
181 return False
182
182
183
183
184 @undoc
184 @undoc
185 @dataclass
185 @dataclass
186 class SelectivePolicy(EvaluationPolicy):
186 class SelectivePolicy(EvaluationPolicy):
187 allowed_getitem: Set[InstancesHaveGetItem] = field(default_factory=set)
187 allowed_getitem: Set[InstancesHaveGetItem] = field(default_factory=set)
188 allowed_getitem_external: Set[Tuple[str, ...]] = field(default_factory=set)
188 allowed_getitem_external: Set[Tuple[str, ...]] = field(default_factory=set)
189
189
190 allowed_getattr: Set[MayHaveGetattr] = field(default_factory=set)
190 allowed_getattr: Set[MayHaveGetattr] = field(default_factory=set)
191 allowed_getattr_external: Set[Tuple[str, ...]] = field(default_factory=set)
191 allowed_getattr_external: Set[Tuple[str, ...]] = field(default_factory=set)
192
192
193 allowed_operations: Set = field(default_factory=set)
193 allowed_operations: Set = field(default_factory=set)
194 allowed_operations_external: Set[Tuple[str, ...]] = field(default_factory=set)
194 allowed_operations_external: Set[Tuple[str, ...]] = field(default_factory=set)
195
195
196 _operation_methods_cache: Dict[str, Set[Callable]] = field(
196 _operation_methods_cache: Dict[str, Set[Callable]] = field(
197 default_factory=dict, init=False
197 default_factory=dict, init=False
198 )
198 )
199
199
200 def can_get_attr(self, value, attr):
200 def can_get_attr(self, value, attr):
201 has_original_attribute = _has_original_dunder(
201 has_original_attribute = _has_original_dunder(
202 value,
202 value,
203 allowed_types=self.allowed_getattr,
203 allowed_types=self.allowed_getattr,
204 allowed_methods=self._getattribute_methods,
204 allowed_methods=self._getattribute_methods,
205 allowed_external=self.allowed_getattr_external,
205 allowed_external=self.allowed_getattr_external,
206 method_name="__getattribute__",
206 method_name="__getattribute__",
207 )
207 )
208 has_original_attr = _has_original_dunder(
208 has_original_attr = _has_original_dunder(
209 value,
209 value,
210 allowed_types=self.allowed_getattr,
210 allowed_types=self.allowed_getattr,
211 allowed_methods=self._getattr_methods,
211 allowed_methods=self._getattr_methods,
212 allowed_external=self.allowed_getattr_external,
212 allowed_external=self.allowed_getattr_external,
213 method_name="__getattr__",
213 method_name="__getattr__",
214 )
214 )
215
215
216 accept = False
216 accept = False
217
217
218 # Many objects do not have `__getattr__`, this is fine.
218 # Many objects do not have `__getattr__`, this is fine.
219 if has_original_attr is None and has_original_attribute:
219 if has_original_attr is None and has_original_attribute:
220 accept = True
220 accept = True
221 else:
221 else:
222 # Accept objects without modifications to `__getattr__` and `__getattribute__`
222 # Accept objects without modifications to `__getattr__` and `__getattribute__`
223 accept = has_original_attr and has_original_attribute
223 accept = has_original_attr and has_original_attribute
224
224
225 if accept:
225 if accept:
226 # We still need to check for overriden properties.
226 # We still need to check for overriden properties.
227
227
228 value_class = type(value)
228 value_class = type(value)
229 if not hasattr(value_class, attr):
229 if not hasattr(value_class, attr):
230 return True
230 return True
231
231
232 class_attr_val = getattr(value_class, attr)
232 class_attr_val = getattr(value_class, attr)
233 is_property = isinstance(class_attr_val, property)
233 is_property = isinstance(class_attr_val, property)
234
234
235 if not is_property:
235 if not is_property:
236 return True
236 return True
237
237
238 # Properties in allowed types are ok (although we do not include any
238 # Properties in allowed types are ok (although we do not include any
239 # properties in our default allow list currently).
239 # properties in our default allow list currently).
240 if type(value) in self.allowed_getattr:
240 if type(value) in self.allowed_getattr:
241 return True # pragma: no cover
241 return True # pragma: no cover
242
242
243 # Properties in subclasses of allowed types may be ok if not changed
243 # Properties in subclasses of allowed types may be ok if not changed
244 for module_name, *access_path in self.allowed_getattr_external:
244 for module_name, *access_path in self.allowed_getattr_external:
245 try:
245 try:
246 external_class = _get_external(module_name, access_path)
246 external_class = _get_external(module_name, access_path)
247 external_class_attr_val = getattr(external_class, attr)
247 external_class_attr_val = getattr(external_class, attr)
248 except (KeyError, AttributeError):
248 except (KeyError, AttributeError):
249 return False # pragma: no cover
249 return False # pragma: no cover
250 return class_attr_val == external_class_attr_val
250 return class_attr_val == external_class_attr_val
251
251
252 return False
252 return False
253
253
254 def can_get_item(self, value, item):
254 def can_get_item(self, value, item):
255 """Allow accessing `__getiitem__` of allow-listed instances unless it was not modified."""
255 """Allow accessing `__getiitem__` of allow-listed instances unless it was not modified."""
256 return _has_original_dunder(
256 return _has_original_dunder(
257 value,
257 value,
258 allowed_types=self.allowed_getitem,
258 allowed_types=self.allowed_getitem,
259 allowed_methods=self._getitem_methods,
259 allowed_methods=self._getitem_methods,
260 allowed_external=self.allowed_getitem_external,
260 allowed_external=self.allowed_getitem_external,
261 method_name="__getitem__",
261 method_name="__getitem__",
262 )
262 )
263
263
264 def can_operate(self, dunders: Tuple[str, ...], a, b=None):
264 def can_operate(self, dunders: Tuple[str, ...], a, b=None):
265 objects = [a]
265 objects = [a]
266 if b is not None:
266 if b is not None:
267 objects.append(b)
267 objects.append(b)
268 return all(
268 return all(
269 [
269 [
270 _has_original_dunder(
270 _has_original_dunder(
271 obj,
271 obj,
272 allowed_types=self.allowed_operations,
272 allowed_types=self.allowed_operations,
273 allowed_methods=self._operator_dunder_methods(dunder),
273 allowed_methods=self._operator_dunder_methods(dunder),
274 allowed_external=self.allowed_operations_external,
274 allowed_external=self.allowed_operations_external,
275 method_name=dunder,
275 method_name=dunder,
276 )
276 )
277 for dunder in dunders
277 for dunder in dunders
278 for obj in objects
278 for obj in objects
279 ]
279 ]
280 )
280 )
281
281
282 def _operator_dunder_methods(self, dunder: str) -> Set[Callable]:
282 def _operator_dunder_methods(self, dunder: str) -> Set[Callable]:
283 if dunder not in self._operation_methods_cache:
283 if dunder not in self._operation_methods_cache:
284 self._operation_methods_cache[dunder] = self._safe_get_methods(
284 self._operation_methods_cache[dunder] = self._safe_get_methods(
285 self.allowed_operations, dunder
285 self.allowed_operations, dunder
286 )
286 )
287 return self._operation_methods_cache[dunder]
287 return self._operation_methods_cache[dunder]
288
288
289 @cached_property
289 @cached_property
290 def _getitem_methods(self) -> Set[Callable]:
290 def _getitem_methods(self) -> Set[Callable]:
291 return self._safe_get_methods(self.allowed_getitem, "__getitem__")
291 return self._safe_get_methods(self.allowed_getitem, "__getitem__")
292
292
293 @cached_property
293 @cached_property
294 def _getattr_methods(self) -> Set[Callable]:
294 def _getattr_methods(self) -> Set[Callable]:
295 return self._safe_get_methods(self.allowed_getattr, "__getattr__")
295 return self._safe_get_methods(self.allowed_getattr, "__getattr__")
296
296
297 @cached_property
297 @cached_property
298 def _getattribute_methods(self) -> Set[Callable]:
298 def _getattribute_methods(self) -> Set[Callable]:
299 return self._safe_get_methods(self.allowed_getattr, "__getattribute__")
299 return self._safe_get_methods(self.allowed_getattr, "__getattribute__")
300
300
301 def _safe_get_methods(self, classes, name) -> Set[Callable]:
301 def _safe_get_methods(self, classes, name) -> Set[Callable]:
302 return {
302 return {
303 method
303 method
304 for class_ in classes
304 for class_ in classes
305 for method in [getattr(class_, name, None)]
305 for method in [getattr(class_, name, None)]
306 if method
306 if method
307 }
307 }
308
308
309
309
310 class _DummyNamedTuple(NamedTuple):
310 class _DummyNamedTuple(NamedTuple):
311 """Used internally to retrieve methods of named tuple instance."""
311 """Used internally to retrieve methods of named tuple instance."""
312
312
313
313
314 class EvaluationContext(NamedTuple):
314 class EvaluationContext(NamedTuple):
315 #: Local namespace
315 #: Local namespace
316 locals: dict
316 locals: dict
317 #: Global namespace
317 #: Global namespace
318 globals: dict
318 globals: dict
319 #: Evaluation policy identifier
319 #: Evaluation policy identifier
320 evaluation: Literal[
320 evaluation: Literal[
321 "forbidden", "minimal", "limited", "unsafe", "dangerous"
321 "forbidden", "minimal", "limited", "unsafe", "dangerous"
322 ] = "forbidden"
322 ] = "forbidden"
323 #: Whether the evalution of code takes place inside of a subscript.
323 #: Whether the evalution of code takes place inside of a subscript.
324 #: Useful for evaluating ``:-1, 'col'`` in ``df[:-1, 'col']``.
324 #: Useful for evaluating ``:-1, 'col'`` in ``df[:-1, 'col']``.
325 in_subscript: bool = False
325 in_subscript: bool = False
326
326
327
327
328 class _IdentitySubscript:
328 class _IdentitySubscript:
329 """Returns the key itself when item is requested via subscript."""
329 """Returns the key itself when item is requested via subscript."""
330
330
331 def __getitem__(self, key):
331 def __getitem__(self, key):
332 return key
332 return key
333
333
334
334
335 IDENTITY_SUBSCRIPT = _IdentitySubscript()
335 IDENTITY_SUBSCRIPT = _IdentitySubscript()
336 SUBSCRIPT_MARKER = "__SUBSCRIPT_SENTINEL__"
336 SUBSCRIPT_MARKER = "__SUBSCRIPT_SENTINEL__"
337
337
338
338
339 class GuardRejection(Exception):
339 class GuardRejection(Exception):
340 """Exception raised when guard rejects evaluation attempt."""
340 """Exception raised when guard rejects evaluation attempt."""
341
341
342 pass
342 pass
343
343
344
344
345 def guarded_eval(code: str, context: EvaluationContext):
345 def guarded_eval(code: str, context: EvaluationContext):
346 """Evaluate provided code in the evaluation context.
346 """Evaluate provided code in the evaluation context.
347
347
348 If evaluation policy given by context is set to ``forbidden``
348 If evaluation policy given by context is set to ``forbidden``
349 no evaluation will be performed; if it is set to ``dangerous``
349 no evaluation will be performed; if it is set to ``dangerous``
350 standard :func:`eval` will be used; finally, for any other,
350 standard :func:`eval` will be used; finally, for any other,
351 policy :func:`eval_node` will be called on parsed AST.
351 policy :func:`eval_node` will be called on parsed AST.
352 """
352 """
353 locals_ = context.locals
353 locals_ = context.locals
354
354
355 if context.evaluation == "forbidden":
355 if context.evaluation == "forbidden":
356 raise GuardRejection("Forbidden mode")
356 raise GuardRejection("Forbidden mode")
357
357
358 # note: not using `ast.literal_eval` as it does not implement
358 # note: not using `ast.literal_eval` as it does not implement
359 # getitem at all, for example it fails on simple `[0][1]`
359 # getitem at all, for example it fails on simple `[0][1]`
360
360
361 if context.in_subscript:
361 if context.in_subscript:
362 # syntatic sugar for ellipsis (:) is only available in susbcripts
362 # syntatic sugar for ellipsis (:) is only available in susbcripts
363 # so we need to trick the ast parser into thinking that we have
363 # so we need to trick the ast parser into thinking that we have
364 # a subscript, but we need to be able to later recognise that we did
364 # a subscript, but we need to be able to later recognise that we did
365 # it so we can ignore the actual __getitem__ operation
365 # it so we can ignore the actual __getitem__ operation
366 if not code:
366 if not code:
367 return tuple()
367 return tuple()
368 locals_ = locals_.copy()
368 locals_ = locals_.copy()
369 locals_[SUBSCRIPT_MARKER] = IDENTITY_SUBSCRIPT
369 locals_[SUBSCRIPT_MARKER] = IDENTITY_SUBSCRIPT
370 code = SUBSCRIPT_MARKER + "[" + code + "]"
370 code = SUBSCRIPT_MARKER + "[" + code + "]"
371 context = EvaluationContext(**{**context._asdict(), **{"locals": locals_}})
371 context = EvaluationContext(**{**context._asdict(), **{"locals": locals_}})
372
372
373 if context.evaluation == "dangerous":
373 if context.evaluation == "dangerous":
374 return eval(code, context.globals, context.locals)
374 return eval(code, context.globals, context.locals)
375
375
376 expression = ast.parse(code, mode="eval")
376 expression = ast.parse(code, mode="eval")
377
377
378 return eval_node(expression, context)
378 return eval_node(expression, context)
379
379
380
380
381 BINARY_OP_DUNDERS: Dict[Type[ast.operator], Tuple[str]] = {
381 BINARY_OP_DUNDERS: Dict[Type[ast.operator], Tuple[str]] = {
382 ast.Add: ("__add__",),
382 ast.Add: ("__add__",),
383 ast.Sub: ("__sub__",),
383 ast.Sub: ("__sub__",),
384 ast.Mult: ("__mul__",),
384 ast.Mult: ("__mul__",),
385 ast.Div: ("__truediv__",),
385 ast.Div: ("__truediv__",),
386 ast.FloorDiv: ("__floordiv__",),
386 ast.FloorDiv: ("__floordiv__",),
387 ast.Mod: ("__mod__",),
387 ast.Mod: ("__mod__",),
388 ast.Pow: ("__pow__",),
388 ast.Pow: ("__pow__",),
389 ast.LShift: ("__lshift__",),
389 ast.LShift: ("__lshift__",),
390 ast.RShift: ("__rshift__",),
390 ast.RShift: ("__rshift__",),
391 ast.BitOr: ("__or__",),
391 ast.BitOr: ("__or__",),
392 ast.BitXor: ("__xor__",),
392 ast.BitXor: ("__xor__",),
393 ast.BitAnd: ("__and__",),
393 ast.BitAnd: ("__and__",),
394 ast.MatMult: ("__matmul__",),
394 ast.MatMult: ("__matmul__",),
395 }
395 }
396
396
397 COMP_OP_DUNDERS: Dict[Type[ast.cmpop], Tuple[str, ...]] = {
397 COMP_OP_DUNDERS: Dict[Type[ast.cmpop], Tuple[str, ...]] = {
398 ast.Eq: ("__eq__",),
398 ast.Eq: ("__eq__",),
399 ast.NotEq: ("__ne__", "__eq__"),
399 ast.NotEq: ("__ne__", "__eq__"),
400 ast.Lt: ("__lt__", "__gt__"),
400 ast.Lt: ("__lt__", "__gt__"),
401 ast.LtE: ("__le__", "__ge__"),
401 ast.LtE: ("__le__", "__ge__"),
402 ast.Gt: ("__gt__", "__lt__"),
402 ast.Gt: ("__gt__", "__lt__"),
403 ast.GtE: ("__ge__", "__le__"),
403 ast.GtE: ("__ge__", "__le__"),
404 ast.In: ("__contains__",),
404 ast.In: ("__contains__",),
405 # Note: ast.Is, ast.IsNot, ast.NotIn are handled specially
405 # Note: ast.Is, ast.IsNot, ast.NotIn are handled specially
406 }
406 }
407
407
408 UNARY_OP_DUNDERS: Dict[Type[ast.unaryop], Tuple[str, ...]] = {
408 UNARY_OP_DUNDERS: Dict[Type[ast.unaryop], Tuple[str, ...]] = {
409 ast.USub: ("__neg__",),
409 ast.USub: ("__neg__",),
410 ast.UAdd: ("__pos__",),
410 ast.UAdd: ("__pos__",),
411 # we have to check both __inv__ and __invert__!
411 # we have to check both __inv__ and __invert__!
412 ast.Invert: ("__invert__", "__inv__"),
412 ast.Invert: ("__invert__", "__inv__"),
413 ast.Not: ("__not__",),
413 ast.Not: ("__not__",),
414 }
414 }
415
415
416
416
417 def _find_dunder(node_op, dunders) -> Union[Tuple[str, ...], None]:
417 def _find_dunder(node_op, dunders) -> Union[Tuple[str, ...], None]:
418 dunder = None
418 dunder = None
419 for op, candidate_dunder in dunders.items():
419 for op, candidate_dunder in dunders.items():
420 if isinstance(node_op, op):
420 if isinstance(node_op, op):
421 dunder = candidate_dunder
421 dunder = candidate_dunder
422 return dunder
422 return dunder
423
423
424
424
425 def eval_node(node: Union[ast.AST, None], context: EvaluationContext):
425 def eval_node(node: Union[ast.AST, None], context: EvaluationContext):
426 """Evaluate AST node in provided context.
426 """Evaluate AST node in provided context.
427
427
428 Applies evaluation restrictions defined in the context. Currently does not support evaluation of functions with keyword arguments.
428 Applies evaluation restrictions defined in the context. Currently does not support evaluation of functions with keyword arguments.
429
429
430 Does not evaluate actions that always have side effects:
430 Does not evaluate actions that always have side effects:
431
431
432 - class definitions (``class sth: ...``)
432 - class definitions (``class sth: ...``)
433 - function definitions (``def sth: ...``)
433 - function definitions (``def sth: ...``)
434 - variable assignments (``x = 1``)
434 - variable assignments (``x = 1``)
435 - augmented assignments (``x += 1``)
435 - augmented assignments (``x += 1``)
436 - deletions (``del x``)
436 - deletions (``del x``)
437
437
438 Does not evaluate operations which do not return values:
438 Does not evaluate operations which do not return values:
439
439
440 - assertions (``assert x``)
440 - assertions (``assert x``)
441 - pass (``pass``)
441 - pass (``pass``)
442 - imports (``import x``)
442 - imports (``import x``)
443 - control flow:
443 - control flow:
444
444
445 - conditionals (``if x:``) except for ternary IfExp (``a if x else b``)
445 - conditionals (``if x:``) except for ternary IfExp (``a if x else b``)
446 - loops (``for`` and `while``)
446 - loops (``for`` and `while``)
447 - exception handling
447 - exception handling
448
448
449 The purpose of this function is to guard against unwanted side-effects;
449 The purpose of this function is to guard against unwanted side-effects;
450 it does not give guarantees on protection from malicious code execution.
450 it does not give guarantees on protection from malicious code execution.
451 """
451 """
452 policy = EVALUATION_POLICIES[context.evaluation]
452 policy = EVALUATION_POLICIES[context.evaluation]
453 if node is None:
453 if node is None:
454 return None
454 return None
455 if isinstance(node, ast.Expression):
455 if isinstance(node, ast.Expression):
456 return eval_node(node.body, context)
456 return eval_node(node.body, context)
457 if isinstance(node, ast.BinOp):
457 if isinstance(node, ast.BinOp):
458 left = eval_node(node.left, context)
458 left = eval_node(node.left, context)
459 right = eval_node(node.right, context)
459 right = eval_node(node.right, context)
460 dunders = _find_dunder(node.op, BINARY_OP_DUNDERS)
460 dunders = _find_dunder(node.op, BINARY_OP_DUNDERS)
461 if dunders:
461 if dunders:
462 if policy.can_operate(dunders, left, right):
462 if policy.can_operate(dunders, left, right):
463 return getattr(left, dunders[0])(right)
463 return getattr(left, dunders[0])(right)
464 else:
464 else:
465 raise GuardRejection(
465 raise GuardRejection(
466 f"Operation (`{dunders}`) for",
466 f"Operation (`{dunders}`) for",
467 type(left),
467 type(left),
468 f"not allowed in {context.evaluation} mode",
468 f"not allowed in {context.evaluation} mode",
469 )
469 )
470 if isinstance(node, ast.Compare):
470 if isinstance(node, ast.Compare):
471 left = eval_node(node.left, context)
471 left = eval_node(node.left, context)
472 all_true = True
472 all_true = True
473 negate = False
473 negate = False
474 for op, right in zip(node.ops, node.comparators):
474 for op, right in zip(node.ops, node.comparators):
475 right = eval_node(right, context)
475 right = eval_node(right, context)
476 dunder = None
476 dunder = None
477 dunders = _find_dunder(op, COMP_OP_DUNDERS)
477 dunders = _find_dunder(op, COMP_OP_DUNDERS)
478 if not dunders:
478 if not dunders:
479 if isinstance(op, ast.NotIn):
479 if isinstance(op, ast.NotIn):
480 dunders = COMP_OP_DUNDERS[ast.In]
480 dunders = COMP_OP_DUNDERS[ast.In]
481 negate = True
481 negate = True
482 if isinstance(op, ast.Is):
482 if isinstance(op, ast.Is):
483 dunder = "is_"
483 dunder = "is_"
484 if isinstance(op, ast.IsNot):
484 if isinstance(op, ast.IsNot):
485 dunder = "is_"
485 dunder = "is_"
486 negate = True
486 negate = True
487 if not dunder and dunders:
487 if not dunder and dunders:
488 dunder = dunders[0]
488 dunder = dunders[0]
489 if dunder:
489 if dunder:
490 a, b = (right, left) if dunder == "__contains__" else (left, right)
490 a, b = (right, left) if dunder == "__contains__" else (left, right)
491 if dunder == "is_" or dunders and policy.can_operate(dunders, a, b):
491 if dunder == "is_" or dunders and policy.can_operate(dunders, a, b):
492 result = getattr(operator, dunder)(a, b)
492 result = getattr(operator, dunder)(a, b)
493 if negate:
493 if negate:
494 result = not result
494 result = not result
495 if not result:
495 if not result:
496 all_true = False
496 all_true = False
497 left = right
497 left = right
498 else:
498 else:
499 raise GuardRejection(
499 raise GuardRejection(
500 f"Comparison (`{dunder}`) for",
500 f"Comparison (`{dunder}`) for",
501 type(left),
501 type(left),
502 f"not allowed in {context.evaluation} mode",
502 f"not allowed in {context.evaluation} mode",
503 )
503 )
504 else:
504 else:
505 raise ValueError(
505 raise ValueError(
506 f"Comparison `{dunder}` not supported"
506 f"Comparison `{dunder}` not supported"
507 ) # pragma: no cover
507 ) # pragma: no cover
508 return all_true
508 return all_true
509 if isinstance(node, ast.Constant):
509 if isinstance(node, ast.Constant):
510 return node.value
510 return node.value
511 if isinstance(node, ast.Index):
511 if isinstance(node, ast.Index):
512 # deprecated since Python 3.9
512 # deprecated since Python 3.9
513 return eval_node(node.value, context) # pragma: no cover
513 return eval_node(node.value, context) # pragma: no cover
514 if isinstance(node, ast.Tuple):
514 if isinstance(node, ast.Tuple):
515 return tuple(eval_node(e, context) for e in node.elts)
515 return tuple(eval_node(e, context) for e in node.elts)
516 if isinstance(node, ast.List):
516 if isinstance(node, ast.List):
517 return [eval_node(e, context) for e in node.elts]
517 return [eval_node(e, context) for e in node.elts]
518 if isinstance(node, ast.Set):
518 if isinstance(node, ast.Set):
519 return {eval_node(e, context) for e in node.elts}
519 return {eval_node(e, context) for e in node.elts}
520 if isinstance(node, ast.Dict):
520 if isinstance(node, ast.Dict):
521 return dict(
521 return dict(
522 zip(
522 zip(
523 [eval_node(k, context) for k in node.keys],
523 [eval_node(k, context) for k in node.keys],
524 [eval_node(v, context) for v in node.values],
524 [eval_node(v, context) for v in node.values],
525 )
525 )
526 )
526 )
527 if isinstance(node, ast.Slice):
527 if isinstance(node, ast.Slice):
528 return slice(
528 return slice(
529 eval_node(node.lower, context),
529 eval_node(node.lower, context),
530 eval_node(node.upper, context),
530 eval_node(node.upper, context),
531 eval_node(node.step, context),
531 eval_node(node.step, context),
532 )
532 )
533 if isinstance(node, ast.ExtSlice):
533 if isinstance(node, ast.ExtSlice):
534 # deprecated since Python 3.9
534 # deprecated since Python 3.9
535 return tuple([eval_node(dim, context) for dim in node.dims]) # pragma: no cover
535 return tuple([eval_node(dim, context) for dim in node.dims]) # pragma: no cover
536 if isinstance(node, ast.UnaryOp):
536 if isinstance(node, ast.UnaryOp):
537 value = eval_node(node.operand, context)
537 value = eval_node(node.operand, context)
538 dunders = _find_dunder(node.op, UNARY_OP_DUNDERS)
538 dunders = _find_dunder(node.op, UNARY_OP_DUNDERS)
539 if dunders:
539 if dunders:
540 if policy.can_operate(dunders, value):
540 if policy.can_operate(dunders, value):
541 return getattr(value, dunders[0])()
541 return getattr(value, dunders[0])()
542 else:
542 else:
543 raise GuardRejection(
543 raise GuardRejection(
544 f"Operation (`{dunders}`) for",
544 f"Operation (`{dunders}`) for",
545 type(value),
545 type(value),
546 f"not allowed in {context.evaluation} mode",
546 f"not allowed in {context.evaluation} mode",
547 )
547 )
548 if isinstance(node, ast.Subscript):
548 if isinstance(node, ast.Subscript):
549 value = eval_node(node.value, context)
549 value = eval_node(node.value, context)
550 slice_ = eval_node(node.slice, context)
550 slice_ = eval_node(node.slice, context)
551 if policy.can_get_item(value, slice_):
551 if policy.can_get_item(value, slice_):
552 return value[slice_]
552 return value[slice_]
553 raise GuardRejection(
553 raise GuardRejection(
554 "Subscript access (`__getitem__`) for",
554 "Subscript access (`__getitem__`) for",
555 type(value), # not joined to avoid calling `repr`
555 type(value), # not joined to avoid calling `repr`
556 f" not allowed in {context.evaluation} mode",
556 f" not allowed in {context.evaluation} mode",
557 )
557 )
558 if isinstance(node, ast.Name):
558 if isinstance(node, ast.Name):
559 if policy.allow_locals_access and node.id in context.locals:
559 if policy.allow_locals_access and node.id in context.locals:
560 return context.locals[node.id]
560 return context.locals[node.id]
561 if policy.allow_globals_access and node.id in context.globals:
561 if policy.allow_globals_access and node.id in context.globals:
562 return context.globals[node.id]
562 return context.globals[node.id]
563 if policy.allow_builtins_access and hasattr(builtins, node.id):
563 if policy.allow_builtins_access and hasattr(builtins, node.id):
564 # note: do not use __builtins__, it is implementation detail of cPython
564 # note: do not use __builtins__, it is implementation detail of cPython
565 return getattr(builtins, node.id)
565 return getattr(builtins, node.id)
566 if not policy.allow_globals_access and not policy.allow_locals_access:
566 if not policy.allow_globals_access and not policy.allow_locals_access:
567 raise GuardRejection(
567 raise GuardRejection(
568 f"Namespace access not allowed in {context.evaluation} mode"
568 f"Namespace access not allowed in {context.evaluation} mode"
569 )
569 )
570 else:
570 else:
571 raise NameError(f"{node.id} not found in locals, globals, nor builtins")
571 raise NameError(f"{node.id} not found in locals, globals, nor builtins")
572 if isinstance(node, ast.Attribute):
572 if isinstance(node, ast.Attribute):
573 value = eval_node(node.value, context)
573 value = eval_node(node.value, context)
574 if policy.can_get_attr(value, node.attr):
574 if policy.can_get_attr(value, node.attr):
575 return getattr(value, node.attr)
575 return getattr(value, node.attr)
576 raise GuardRejection(
576 raise GuardRejection(
577 "Attribute access (`__getattr__`) for",
577 "Attribute access (`__getattr__`) for",
578 type(value), # not joined to avoid calling `repr`
578 type(value), # not joined to avoid calling `repr`
579 f"not allowed in {context.evaluation} mode",
579 f"not allowed in {context.evaluation} mode",
580 )
580 )
581 if isinstance(node, ast.IfExp):
581 if isinstance(node, ast.IfExp):
582 test = eval_node(node.test, context)
582 test = eval_node(node.test, context)
583 if test:
583 if test:
584 return eval_node(node.body, context)
584 return eval_node(node.body, context)
585 else:
585 else:
586 return eval_node(node.orelse, context)
586 return eval_node(node.orelse, context)
587 if isinstance(node, ast.Call):
587 if isinstance(node, ast.Call):
588 func = eval_node(node.func, context)
588 func = eval_node(node.func, context)
589 if policy.can_call(func) and not node.keywords:
589 if policy.can_call(func) and not node.keywords:
590 args = [eval_node(arg, context) for arg in node.args]
590 args = [eval_node(arg, context) for arg in node.args]
591 return func(*args)
591 return func(*args)
592 raise GuardRejection(
592 raise GuardRejection(
593 "Call for",
593 "Call for",
594 func, # not joined to avoid calling `repr`
594 func, # not joined to avoid calling `repr`
595 f"not allowed in {context.evaluation} mode",
595 f"not allowed in {context.evaluation} mode",
596 )
596 )
597 raise ValueError("Unhandled node", ast.dump(node))
597 raise ValueError("Unhandled node", ast.dump(node))
598
598
599
599
600 SUPPORTED_EXTERNAL_GETITEM = {
600 SUPPORTED_EXTERNAL_GETITEM = {
601 ("pandas", "core", "indexing", "_iLocIndexer"),
601 ("pandas", "core", "indexing", "_iLocIndexer"),
602 ("pandas", "core", "indexing", "_LocIndexer"),
602 ("pandas", "core", "indexing", "_LocIndexer"),
603 ("pandas", "DataFrame"),
603 ("pandas", "DataFrame"),
604 ("pandas", "Series"),
604 ("pandas", "Series"),
605 ("numpy", "ndarray"),
605 ("numpy", "ndarray"),
606 ("numpy", "void"),
606 ("numpy", "void"),
607 }
607 }
608
608
609
609
610 BUILTIN_GETITEM: Set[InstancesHaveGetItem] = {
610 BUILTIN_GETITEM: Set[InstancesHaveGetItem] = {
611 dict,
611 dict,
612 str,
612 str, # type: ignore[arg-type]
613 bytes,
613 bytes, # type: ignore[arg-type]
614 list,
614 list,
615 tuple,
615 tuple,
616 collections.defaultdict,
616 collections.defaultdict,
617 collections.deque,
617 collections.deque,
618 collections.OrderedDict,
618 collections.OrderedDict,
619 collections.ChainMap,
619 collections.ChainMap,
620 collections.UserDict,
620 collections.UserDict,
621 collections.UserList,
621 collections.UserList,
622 collections.UserString,
622 collections.UserString, # type: ignore[arg-type]
623 _DummyNamedTuple,
623 _DummyNamedTuple,
624 _IdentitySubscript,
624 _IdentitySubscript,
625 }
625 }
626
626
627
627
628 def _list_methods(cls, source=None):
628 def _list_methods(cls, source=None):
629 """For use on immutable objects or with methods returning a copy"""
629 """For use on immutable objects or with methods returning a copy"""
630 return [getattr(cls, k) for k in (source if source else dir(cls))]
630 return [getattr(cls, k) for k in (source if source else dir(cls))]
631
631
632
632
633 dict_non_mutating_methods = ("copy", "keys", "values", "items")
633 dict_non_mutating_methods = ("copy", "keys", "values", "items")
634 list_non_mutating_methods = ("copy", "index", "count")
634 list_non_mutating_methods = ("copy", "index", "count")
635 set_non_mutating_methods = set(dir(set)) & set(dir(frozenset))
635 set_non_mutating_methods = set(dir(set)) & set(dir(frozenset))
636
636
637
637
638 dict_keys: Type[collections.abc.KeysView] = type({}.keys())
638 dict_keys: Type[collections.abc.KeysView] = type({}.keys())
639 method_descriptor: Any = type(list.copy)
639 method_descriptor: Any = type(list.copy)
640
640
641 NUMERICS = {int, float, complex}
641 NUMERICS = {int, float, complex}
642
642
643 ALLOWED_CALLS = {
643 ALLOWED_CALLS = {
644 bytes,
644 bytes,
645 *_list_methods(bytes),
645 *_list_methods(bytes),
646 dict,
646 dict,
647 *_list_methods(dict, dict_non_mutating_methods),
647 *_list_methods(dict, dict_non_mutating_methods),
648 dict_keys.isdisjoint,
648 dict_keys.isdisjoint,
649 list,
649 list,
650 *_list_methods(list, list_non_mutating_methods),
650 *_list_methods(list, list_non_mutating_methods),
651 set,
651 set,
652 *_list_methods(set, set_non_mutating_methods),
652 *_list_methods(set, set_non_mutating_methods),
653 frozenset,
653 frozenset,
654 *_list_methods(frozenset),
654 *_list_methods(frozenset),
655 range,
655 range,
656 str,
656 str,
657 *_list_methods(str),
657 *_list_methods(str),
658 tuple,
658 tuple,
659 *_list_methods(tuple),
659 *_list_methods(tuple),
660 *NUMERICS,
660 *NUMERICS,
661 *[method for numeric_cls in NUMERICS for method in _list_methods(numeric_cls)],
661 *[method for numeric_cls in NUMERICS for method in _list_methods(numeric_cls)],
662 collections.deque,
662 collections.deque,
663 *_list_methods(collections.deque, list_non_mutating_methods),
663 *_list_methods(collections.deque, list_non_mutating_methods),
664 collections.defaultdict,
664 collections.defaultdict,
665 *_list_methods(collections.defaultdict, dict_non_mutating_methods),
665 *_list_methods(collections.defaultdict, dict_non_mutating_methods),
666 collections.OrderedDict,
666 collections.OrderedDict,
667 *_list_methods(collections.OrderedDict, dict_non_mutating_methods),
667 *_list_methods(collections.OrderedDict, dict_non_mutating_methods),
668 collections.UserDict,
668 collections.UserDict,
669 *_list_methods(collections.UserDict, dict_non_mutating_methods),
669 *_list_methods(collections.UserDict, dict_non_mutating_methods),
670 collections.UserList,
670 collections.UserList,
671 *_list_methods(collections.UserList, list_non_mutating_methods),
671 *_list_methods(collections.UserList, list_non_mutating_methods),
672 collections.UserString,
672 collections.UserString,
673 *_list_methods(collections.UserString, dir(str)),
673 *_list_methods(collections.UserString, dir(str)),
674 collections.Counter,
674 collections.Counter,
675 *_list_methods(collections.Counter, dict_non_mutating_methods),
675 *_list_methods(collections.Counter, dict_non_mutating_methods),
676 collections.Counter.elements,
676 collections.Counter.elements,
677 collections.Counter.most_common,
677 collections.Counter.most_common,
678 }
678 }
679
679
680 BUILTIN_GETATTR: Set[MayHaveGetattr] = {
680 BUILTIN_GETATTR: Set[MayHaveGetattr] = {
681 *BUILTIN_GETITEM,
681 *BUILTIN_GETITEM,
682 set,
682 set,
683 frozenset,
683 frozenset,
684 object,
684 object,
685 type, # `type` handles a lot of generic cases, e.g. numbers as in `int.real`.
685 type, # `type` handles a lot of generic cases, e.g. numbers as in `int.real`.
686 *NUMERICS,
686 *NUMERICS,
687 dict_keys,
687 dict_keys,
688 method_descriptor,
688 method_descriptor,
689 }
689 }
690
690
691
691
692 BUILTIN_OPERATIONS = {*BUILTIN_GETATTR}
692 BUILTIN_OPERATIONS = {*BUILTIN_GETATTR}
693
693
694 EVALUATION_POLICIES = {
694 EVALUATION_POLICIES = {
695 "minimal": EvaluationPolicy(
695 "minimal": EvaluationPolicy(
696 allow_builtins_access=True,
696 allow_builtins_access=True,
697 allow_locals_access=False,
697 allow_locals_access=False,
698 allow_globals_access=False,
698 allow_globals_access=False,
699 allow_item_access=False,
699 allow_item_access=False,
700 allow_attr_access=False,
700 allow_attr_access=False,
701 allowed_calls=set(),
701 allowed_calls=set(),
702 allow_any_calls=False,
702 allow_any_calls=False,
703 allow_all_operations=False,
703 allow_all_operations=False,
704 ),
704 ),
705 "limited": SelectivePolicy(
705 "limited": SelectivePolicy(
706 allowed_getitem=BUILTIN_GETITEM,
706 allowed_getitem=BUILTIN_GETITEM,
707 allowed_getitem_external=SUPPORTED_EXTERNAL_GETITEM,
707 allowed_getitem_external=SUPPORTED_EXTERNAL_GETITEM,
708 allowed_getattr=BUILTIN_GETATTR,
708 allowed_getattr=BUILTIN_GETATTR,
709 allowed_getattr_external={
709 allowed_getattr_external={
710 # pandas Series/Frame implements custom `__getattr__`
710 # pandas Series/Frame implements custom `__getattr__`
711 ("pandas", "DataFrame"),
711 ("pandas", "DataFrame"),
712 ("pandas", "Series"),
712 ("pandas", "Series"),
713 },
713 },
714 allowed_operations=BUILTIN_OPERATIONS,
714 allowed_operations=BUILTIN_OPERATIONS,
715 allow_builtins_access=True,
715 allow_builtins_access=True,
716 allow_locals_access=True,
716 allow_locals_access=True,
717 allow_globals_access=True,
717 allow_globals_access=True,
718 allowed_calls=ALLOWED_CALLS,
718 allowed_calls=ALLOWED_CALLS,
719 ),
719 ),
720 "unsafe": EvaluationPolicy(
720 "unsafe": EvaluationPolicy(
721 allow_builtins_access=True,
721 allow_builtins_access=True,
722 allow_locals_access=True,
722 allow_locals_access=True,
723 allow_globals_access=True,
723 allow_globals_access=True,
724 allow_attr_access=True,
724 allow_attr_access=True,
725 allow_item_access=True,
725 allow_item_access=True,
726 allow_any_calls=True,
726 allow_any_calls=True,
727 allow_all_operations=True,
727 allow_all_operations=True,
728 ),
728 ),
729 }
729 }
730
730
731
731
732 __all__ = [
732 __all__ = [
733 "guarded_eval",
733 "guarded_eval",
734 "eval_node",
734 "eval_node",
735 "GuardRejection",
735 "GuardRejection",
736 "EvaluationContext",
736 "EvaluationContext",
737 "_unbind_method",
737 "_unbind_method",
738 ]
738 ]
General Comments 0
You need to be logged in to leave comments. Login now