##// END OF EJS Templates
Backport PR #14029: Allow safe access to the `__getattribute__` method of modules
Matthias Bussonnier -
Show More
@@ -1,738 +1,740 b''
1 from typing import (
1 from typing import (
2 Any,
2 Any,
3 Callable,
3 Callable,
4 Dict,
4 Dict,
5 Set,
5 Set,
6 Sequence,
6 Sequence,
7 Tuple,
7 Tuple,
8 NamedTuple,
8 NamedTuple,
9 Type,
9 Type,
10 Literal,
10 Literal,
11 Union,
11 Union,
12 TYPE_CHECKING,
12 TYPE_CHECKING,
13 )
13 )
14 import ast
14 import ast
15 import builtins
15 import builtins
16 import collections
16 import collections
17 import operator
17 import operator
18 import sys
18 import sys
19 from functools import cached_property
19 from functools import cached_property
20 from dataclasses import dataclass, field
20 from dataclasses import dataclass, field
21
21
22 from IPython.utils.docs import GENERATING_DOCUMENTATION
22 from IPython.utils.docs import GENERATING_DOCUMENTATION
23 from IPython.utils.decorators import undoc
23 from IPython.utils.decorators import undoc
24
24
25
25
26 if TYPE_CHECKING or GENERATING_DOCUMENTATION:
26 if TYPE_CHECKING or GENERATING_DOCUMENTATION:
27 from typing_extensions import Protocol
27 from typing_extensions import Protocol
28 else:
28 else:
29 # do not require on runtime
29 # do not require on runtime
30 Protocol = object # requires Python >=3.8
30 Protocol = object # requires Python >=3.8
31
31
32
32
33 @undoc
33 @undoc
34 class HasGetItem(Protocol):
34 class HasGetItem(Protocol):
35 def __getitem__(self, key) -> None:
35 def __getitem__(self, key) -> None:
36 ...
36 ...
37
37
38
38
39 @undoc
39 @undoc
40 class InstancesHaveGetItem(Protocol):
40 class InstancesHaveGetItem(Protocol):
41 def __call__(self, *args, **kwargs) -> HasGetItem:
41 def __call__(self, *args, **kwargs) -> HasGetItem:
42 ...
42 ...
43
43
44
44
45 @undoc
45 @undoc
46 class HasGetAttr(Protocol):
46 class HasGetAttr(Protocol):
47 def __getattr__(self, key) -> None:
47 def __getattr__(self, key) -> None:
48 ...
48 ...
49
49
50
50
51 @undoc
51 @undoc
52 class DoesNotHaveGetAttr(Protocol):
52 class DoesNotHaveGetAttr(Protocol):
53 pass
53 pass
54
54
55
55
56 # By default `__getattr__` is not explicitly implemented on most objects
56 # By default `__getattr__` is not explicitly implemented on most objects
57 MayHaveGetattr = Union[HasGetAttr, DoesNotHaveGetAttr]
57 MayHaveGetattr = Union[HasGetAttr, DoesNotHaveGetAttr]
58
58
59
59
60 def _unbind_method(func: Callable) -> Union[Callable, None]:
60 def _unbind_method(func: Callable) -> Union[Callable, None]:
61 """Get unbound method for given bound method.
61 """Get unbound method for given bound method.
62
62
63 Returns None if cannot get unbound method, or method is already unbound.
63 Returns None if cannot get unbound method, or method is already unbound.
64 """
64 """
65 owner = getattr(func, "__self__", None)
65 owner = getattr(func, "__self__", None)
66 owner_class = type(owner)
66 owner_class = type(owner)
67 name = getattr(func, "__name__", None)
67 name = getattr(func, "__name__", None)
68 instance_dict_overrides = getattr(owner, "__dict__", None)
68 instance_dict_overrides = getattr(owner, "__dict__", None)
69 if (
69 if (
70 owner is not None
70 owner is not None
71 and name
71 and name
72 and (
72 and (
73 not instance_dict_overrides
73 not instance_dict_overrides
74 or (instance_dict_overrides and name not in instance_dict_overrides)
74 or (instance_dict_overrides and name not in instance_dict_overrides)
75 )
75 )
76 ):
76 ):
77 return getattr(owner_class, name)
77 return getattr(owner_class, name)
78 return None
78 return None
79
79
80
80
81 @undoc
81 @undoc
82 @dataclass
82 @dataclass
83 class EvaluationPolicy:
83 class EvaluationPolicy:
84 """Definition of evaluation policy."""
84 """Definition of evaluation policy."""
85
85
86 allow_locals_access: bool = False
86 allow_locals_access: bool = False
87 allow_globals_access: bool = False
87 allow_globals_access: bool = False
88 allow_item_access: bool = False
88 allow_item_access: bool = False
89 allow_attr_access: bool = False
89 allow_attr_access: bool = False
90 allow_builtins_access: bool = False
90 allow_builtins_access: bool = False
91 allow_all_operations: bool = False
91 allow_all_operations: bool = False
92 allow_any_calls: bool = False
92 allow_any_calls: bool = False
93 allowed_calls: Set[Callable] = field(default_factory=set)
93 allowed_calls: Set[Callable] = field(default_factory=set)
94
94
95 def can_get_item(self, value, item):
95 def can_get_item(self, value, item):
96 return self.allow_item_access
96 return self.allow_item_access
97
97
98 def can_get_attr(self, value, attr):
98 def can_get_attr(self, value, attr):
99 return self.allow_attr_access
99 return self.allow_attr_access
100
100
101 def can_operate(self, dunders: Tuple[str, ...], a, b=None):
101 def can_operate(self, dunders: Tuple[str, ...], a, b=None):
102 if self.allow_all_operations:
102 if self.allow_all_operations:
103 return True
103 return True
104
104
105 def can_call(self, func):
105 def can_call(self, func):
106 if self.allow_any_calls:
106 if self.allow_any_calls:
107 return True
107 return True
108
108
109 if func in self.allowed_calls:
109 if func in self.allowed_calls:
110 return True
110 return True
111
111
112 owner_method = _unbind_method(func)
112 owner_method = _unbind_method(func)
113
113
114 if owner_method and owner_method in self.allowed_calls:
114 if owner_method and owner_method in self.allowed_calls:
115 return True
115 return True
116
116
117
117
118 def _get_external(module_name: str, access_path: Sequence[str]):
118 def _get_external(module_name: str, access_path: Sequence[str]):
119 """Get value from external module given a dotted access path.
119 """Get value from external module given a dotted access path.
120
120
121 Raises:
121 Raises:
122 * `KeyError` if module is removed not found, and
122 * `KeyError` if module is removed not found, and
123 * `AttributeError` if acess path does not match an exported object
123 * `AttributeError` if acess path does not match an exported object
124 """
124 """
125 member_type = sys.modules[module_name]
125 member_type = sys.modules[module_name]
126 for attr in access_path:
126 for attr in access_path:
127 member_type = getattr(member_type, attr)
127 member_type = getattr(member_type, attr)
128 return member_type
128 return member_type
129
129
130
130
131 def _has_original_dunder_external(
131 def _has_original_dunder_external(
132 value,
132 value,
133 module_name: str,
133 module_name: str,
134 access_path: Sequence[str],
134 access_path: Sequence[str],
135 method_name: str,
135 method_name: str,
136 ):
136 ):
137 if module_name not in sys.modules:
137 if module_name not in sys.modules:
138 # LBYLB as it is faster
138 # LBYLB as it is faster
139 return False
139 return False
140 try:
140 try:
141 member_type = _get_external(module_name, access_path)
141 member_type = _get_external(module_name, access_path)
142 value_type = type(value)
142 value_type = type(value)
143 if type(value) == member_type:
143 if type(value) == member_type:
144 return True
144 return True
145 if method_name == "__getattribute__":
145 if method_name == "__getattribute__":
146 # we have to short-circuit here due to an unresolved issue in
146 # we have to short-circuit here due to an unresolved issue in
147 # `isinstance` implementation: https://bugs.python.org/issue32683
147 # `isinstance` implementation: https://bugs.python.org/issue32683
148 return False
148 return False
149 if isinstance(value, member_type):
149 if isinstance(value, member_type):
150 method = getattr(value_type, method_name, None)
150 method = getattr(value_type, method_name, None)
151 member_method = getattr(member_type, method_name, None)
151 member_method = getattr(member_type, method_name, None)
152 if member_method == method:
152 if member_method == method:
153 return True
153 return True
154 except (AttributeError, KeyError):
154 except (AttributeError, KeyError):
155 return False
155 return False
156
156
157
157
158 def _has_original_dunder(
158 def _has_original_dunder(
159 value, allowed_types, allowed_methods, allowed_external, method_name
159 value, allowed_types, allowed_methods, allowed_external, method_name
160 ):
160 ):
161 # note: Python ignores `__getattr__`/`__getitem__` on instances,
161 # note: Python ignores `__getattr__`/`__getitem__` on instances,
162 # we only need to check at class level
162 # we only need to check at class level
163 value_type = type(value)
163 value_type = type(value)
164
164
165 # strict type check passes β†’ no need to check method
165 # strict type check passes β†’ no need to check method
166 if value_type in allowed_types:
166 if value_type in allowed_types:
167 return True
167 return True
168
168
169 method = getattr(value_type, method_name, None)
169 method = getattr(value_type, method_name, None)
170
170
171 if method is None:
171 if method is None:
172 return None
172 return None
173
173
174 if method in allowed_methods:
174 if method in allowed_methods:
175 return True
175 return True
176
176
177 for module_name, *access_path in allowed_external:
177 for module_name, *access_path in allowed_external:
178 if _has_original_dunder_external(value, module_name, access_path, method_name):
178 if _has_original_dunder_external(value, module_name, access_path, method_name):
179 return True
179 return True
180
180
181 return False
181 return False
182
182
183
183
184 @undoc
184 @undoc
185 @dataclass
185 @dataclass
186 class SelectivePolicy(EvaluationPolicy):
186 class SelectivePolicy(EvaluationPolicy):
187 allowed_getitem: Set[InstancesHaveGetItem] = field(default_factory=set)
187 allowed_getitem: Set[InstancesHaveGetItem] = field(default_factory=set)
188 allowed_getitem_external: Set[Tuple[str, ...]] = field(default_factory=set)
188 allowed_getitem_external: Set[Tuple[str, ...]] = field(default_factory=set)
189
189
190 allowed_getattr: Set[MayHaveGetattr] = field(default_factory=set)
190 allowed_getattr: Set[MayHaveGetattr] = field(default_factory=set)
191 allowed_getattr_external: Set[Tuple[str, ...]] = field(default_factory=set)
191 allowed_getattr_external: Set[Tuple[str, ...]] = field(default_factory=set)
192
192
193 allowed_operations: Set = field(default_factory=set)
193 allowed_operations: Set = field(default_factory=set)
194 allowed_operations_external: Set[Tuple[str, ...]] = field(default_factory=set)
194 allowed_operations_external: Set[Tuple[str, ...]] = field(default_factory=set)
195
195
196 _operation_methods_cache: Dict[str, Set[Callable]] = field(
196 _operation_methods_cache: Dict[str, Set[Callable]] = field(
197 default_factory=dict, init=False
197 default_factory=dict, init=False
198 )
198 )
199
199
200 def can_get_attr(self, value, attr):
200 def can_get_attr(self, value, attr):
201 has_original_attribute = _has_original_dunder(
201 has_original_attribute = _has_original_dunder(
202 value,
202 value,
203 allowed_types=self.allowed_getattr,
203 allowed_types=self.allowed_getattr,
204 allowed_methods=self._getattribute_methods,
204 allowed_methods=self._getattribute_methods,
205 allowed_external=self.allowed_getattr_external,
205 allowed_external=self.allowed_getattr_external,
206 method_name="__getattribute__",
206 method_name="__getattribute__",
207 )
207 )
208 has_original_attr = _has_original_dunder(
208 has_original_attr = _has_original_dunder(
209 value,
209 value,
210 allowed_types=self.allowed_getattr,
210 allowed_types=self.allowed_getattr,
211 allowed_methods=self._getattr_methods,
211 allowed_methods=self._getattr_methods,
212 allowed_external=self.allowed_getattr_external,
212 allowed_external=self.allowed_getattr_external,
213 method_name="__getattr__",
213 method_name="__getattr__",
214 )
214 )
215
215
216 accept = False
216 accept = False
217
217
218 # Many objects do not have `__getattr__`, this is fine.
218 # Many objects do not have `__getattr__`, this is fine.
219 if has_original_attr is None and has_original_attribute:
219 if has_original_attr is None and has_original_attribute:
220 accept = True
220 accept = True
221 else:
221 else:
222 # Accept objects without modifications to `__getattr__` and `__getattribute__`
222 # Accept objects without modifications to `__getattr__` and `__getattribute__`
223 accept = has_original_attr and has_original_attribute
223 accept = has_original_attr and has_original_attribute
224
224
225 if accept:
225 if accept:
226 # We still need to check for overriden properties.
226 # We still need to check for overriden properties.
227
227
228 value_class = type(value)
228 value_class = type(value)
229 if not hasattr(value_class, attr):
229 if not hasattr(value_class, attr):
230 return True
230 return True
231
231
232 class_attr_val = getattr(value_class, attr)
232 class_attr_val = getattr(value_class, attr)
233 is_property = isinstance(class_attr_val, property)
233 is_property = isinstance(class_attr_val, property)
234
234
235 if not is_property:
235 if not is_property:
236 return True
236 return True
237
237
238 # Properties in allowed types are ok (although we do not include any
238 # Properties in allowed types are ok (although we do not include any
239 # properties in our default allow list currently).
239 # properties in our default allow list currently).
240 if type(value) in self.allowed_getattr:
240 if type(value) in self.allowed_getattr:
241 return True # pragma: no cover
241 return True # pragma: no cover
242
242
243 # Properties in subclasses of allowed types may be ok if not changed
243 # Properties in subclasses of allowed types may be ok if not changed
244 for module_name, *access_path in self.allowed_getattr_external:
244 for module_name, *access_path in self.allowed_getattr_external:
245 try:
245 try:
246 external_class = _get_external(module_name, access_path)
246 external_class = _get_external(module_name, access_path)
247 external_class_attr_val = getattr(external_class, attr)
247 external_class_attr_val = getattr(external_class, attr)
248 except (KeyError, AttributeError):
248 except (KeyError, AttributeError):
249 return False # pragma: no cover
249 return False # pragma: no cover
250 return class_attr_val == external_class_attr_val
250 return class_attr_val == external_class_attr_val
251
251
252 return False
252 return False
253
253
254 def can_get_item(self, value, item):
254 def can_get_item(self, value, item):
255 """Allow accessing `__getiitem__` of allow-listed instances unless it was not modified."""
255 """Allow accessing `__getiitem__` of allow-listed instances unless it was not modified."""
256 return _has_original_dunder(
256 return _has_original_dunder(
257 value,
257 value,
258 allowed_types=self.allowed_getitem,
258 allowed_types=self.allowed_getitem,
259 allowed_methods=self._getitem_methods,
259 allowed_methods=self._getitem_methods,
260 allowed_external=self.allowed_getitem_external,
260 allowed_external=self.allowed_getitem_external,
261 method_name="__getitem__",
261 method_name="__getitem__",
262 )
262 )
263
263
264 def can_operate(self, dunders: Tuple[str, ...], a, b=None):
264 def can_operate(self, dunders: Tuple[str, ...], a, b=None):
265 objects = [a]
265 objects = [a]
266 if b is not None:
266 if b is not None:
267 objects.append(b)
267 objects.append(b)
268 return all(
268 return all(
269 [
269 [
270 _has_original_dunder(
270 _has_original_dunder(
271 obj,
271 obj,
272 allowed_types=self.allowed_operations,
272 allowed_types=self.allowed_operations,
273 allowed_methods=self._operator_dunder_methods(dunder),
273 allowed_methods=self._operator_dunder_methods(dunder),
274 allowed_external=self.allowed_operations_external,
274 allowed_external=self.allowed_operations_external,
275 method_name=dunder,
275 method_name=dunder,
276 )
276 )
277 for dunder in dunders
277 for dunder in dunders
278 for obj in objects
278 for obj in objects
279 ]
279 ]
280 )
280 )
281
281
282 def _operator_dunder_methods(self, dunder: str) -> Set[Callable]:
282 def _operator_dunder_methods(self, dunder: str) -> Set[Callable]:
283 if dunder not in self._operation_methods_cache:
283 if dunder not in self._operation_methods_cache:
284 self._operation_methods_cache[dunder] = self._safe_get_methods(
284 self._operation_methods_cache[dunder] = self._safe_get_methods(
285 self.allowed_operations, dunder
285 self.allowed_operations, dunder
286 )
286 )
287 return self._operation_methods_cache[dunder]
287 return self._operation_methods_cache[dunder]
288
288
289 @cached_property
289 @cached_property
290 def _getitem_methods(self) -> Set[Callable]:
290 def _getitem_methods(self) -> Set[Callable]:
291 return self._safe_get_methods(self.allowed_getitem, "__getitem__")
291 return self._safe_get_methods(self.allowed_getitem, "__getitem__")
292
292
293 @cached_property
293 @cached_property
294 def _getattr_methods(self) -> Set[Callable]:
294 def _getattr_methods(self) -> Set[Callable]:
295 return self._safe_get_methods(self.allowed_getattr, "__getattr__")
295 return self._safe_get_methods(self.allowed_getattr, "__getattr__")
296
296
297 @cached_property
297 @cached_property
298 def _getattribute_methods(self) -> Set[Callable]:
298 def _getattribute_methods(self) -> Set[Callable]:
299 return self._safe_get_methods(self.allowed_getattr, "__getattribute__")
299 return self._safe_get_methods(self.allowed_getattr, "__getattribute__")
300
300
301 def _safe_get_methods(self, classes, name) -> Set[Callable]:
301 def _safe_get_methods(self, classes, name) -> Set[Callable]:
302 return {
302 return {
303 method
303 method
304 for class_ in classes
304 for class_ in classes
305 for method in [getattr(class_, name, None)]
305 for method in [getattr(class_, name, None)]
306 if method
306 if method
307 }
307 }
308
308
309
309
310 class _DummyNamedTuple(NamedTuple):
310 class _DummyNamedTuple(NamedTuple):
311 """Used internally to retrieve methods of named tuple instance."""
311 """Used internally to retrieve methods of named tuple instance."""
312
312
313
313
314 class EvaluationContext(NamedTuple):
314 class EvaluationContext(NamedTuple):
315 #: Local namespace
315 #: Local namespace
316 locals: dict
316 locals: dict
317 #: Global namespace
317 #: Global namespace
318 globals: dict
318 globals: dict
319 #: Evaluation policy identifier
319 #: Evaluation policy identifier
320 evaluation: Literal[
320 evaluation: Literal[
321 "forbidden", "minimal", "limited", "unsafe", "dangerous"
321 "forbidden", "minimal", "limited", "unsafe", "dangerous"
322 ] = "forbidden"
322 ] = "forbidden"
323 #: Whether the evalution of code takes place inside of a subscript.
323 #: Whether the evalution of code takes place inside of a subscript.
324 #: Useful for evaluating ``:-1, 'col'`` in ``df[:-1, 'col']``.
324 #: Useful for evaluating ``:-1, 'col'`` in ``df[:-1, 'col']``.
325 in_subscript: bool = False
325 in_subscript: bool = False
326
326
327
327
328 class _IdentitySubscript:
328 class _IdentitySubscript:
329 """Returns the key itself when item is requested via subscript."""
329 """Returns the key itself when item is requested via subscript."""
330
330
331 def __getitem__(self, key):
331 def __getitem__(self, key):
332 return key
332 return key
333
333
334
334
335 IDENTITY_SUBSCRIPT = _IdentitySubscript()
335 IDENTITY_SUBSCRIPT = _IdentitySubscript()
336 SUBSCRIPT_MARKER = "__SUBSCRIPT_SENTINEL__"
336 SUBSCRIPT_MARKER = "__SUBSCRIPT_SENTINEL__"
337
337
338
338
339 class GuardRejection(Exception):
339 class GuardRejection(Exception):
340 """Exception raised when guard rejects evaluation attempt."""
340 """Exception raised when guard rejects evaluation attempt."""
341
341
342 pass
342 pass
343
343
344
344
345 def guarded_eval(code: str, context: EvaluationContext):
345 def guarded_eval(code: str, context: EvaluationContext):
346 """Evaluate provided code in the evaluation context.
346 """Evaluate provided code in the evaluation context.
347
347
348 If evaluation policy given by context is set to ``forbidden``
348 If evaluation policy given by context is set to ``forbidden``
349 no evaluation will be performed; if it is set to ``dangerous``
349 no evaluation will be performed; if it is set to ``dangerous``
350 standard :func:`eval` will be used; finally, for any other,
350 standard :func:`eval` will be used; finally, for any other,
351 policy :func:`eval_node` will be called on parsed AST.
351 policy :func:`eval_node` will be called on parsed AST.
352 """
352 """
353 locals_ = context.locals
353 locals_ = context.locals
354
354
355 if context.evaluation == "forbidden":
355 if context.evaluation == "forbidden":
356 raise GuardRejection("Forbidden mode")
356 raise GuardRejection("Forbidden mode")
357
357
358 # note: not using `ast.literal_eval` as it does not implement
358 # note: not using `ast.literal_eval` as it does not implement
359 # getitem at all, for example it fails on simple `[0][1]`
359 # getitem at all, for example it fails on simple `[0][1]`
360
360
361 if context.in_subscript:
361 if context.in_subscript:
362 # syntatic sugar for ellipsis (:) is only available in susbcripts
362 # syntatic sugar for ellipsis (:) is only available in susbcripts
363 # so we need to trick the ast parser into thinking that we have
363 # so we need to trick the ast parser into thinking that we have
364 # a subscript, but we need to be able to later recognise that we did
364 # a subscript, but we need to be able to later recognise that we did
365 # it so we can ignore the actual __getitem__ operation
365 # it so we can ignore the actual __getitem__ operation
366 if not code:
366 if not code:
367 return tuple()
367 return tuple()
368 locals_ = locals_.copy()
368 locals_ = locals_.copy()
369 locals_[SUBSCRIPT_MARKER] = IDENTITY_SUBSCRIPT
369 locals_[SUBSCRIPT_MARKER] = IDENTITY_SUBSCRIPT
370 code = SUBSCRIPT_MARKER + "[" + code + "]"
370 code = SUBSCRIPT_MARKER + "[" + code + "]"
371 context = EvaluationContext(**{**context._asdict(), **{"locals": locals_}})
371 context = EvaluationContext(**{**context._asdict(), **{"locals": locals_}})
372
372
373 if context.evaluation == "dangerous":
373 if context.evaluation == "dangerous":
374 return eval(code, context.globals, context.locals)
374 return eval(code, context.globals, context.locals)
375
375
376 expression = ast.parse(code, mode="eval")
376 expression = ast.parse(code, mode="eval")
377
377
378 return eval_node(expression, context)
378 return eval_node(expression, context)
379
379
380
380
381 BINARY_OP_DUNDERS: Dict[Type[ast.operator], Tuple[str]] = {
381 BINARY_OP_DUNDERS: Dict[Type[ast.operator], Tuple[str]] = {
382 ast.Add: ("__add__",),
382 ast.Add: ("__add__",),
383 ast.Sub: ("__sub__",),
383 ast.Sub: ("__sub__",),
384 ast.Mult: ("__mul__",),
384 ast.Mult: ("__mul__",),
385 ast.Div: ("__truediv__",),
385 ast.Div: ("__truediv__",),
386 ast.FloorDiv: ("__floordiv__",),
386 ast.FloorDiv: ("__floordiv__",),
387 ast.Mod: ("__mod__",),
387 ast.Mod: ("__mod__",),
388 ast.Pow: ("__pow__",),
388 ast.Pow: ("__pow__",),
389 ast.LShift: ("__lshift__",),
389 ast.LShift: ("__lshift__",),
390 ast.RShift: ("__rshift__",),
390 ast.RShift: ("__rshift__",),
391 ast.BitOr: ("__or__",),
391 ast.BitOr: ("__or__",),
392 ast.BitXor: ("__xor__",),
392 ast.BitXor: ("__xor__",),
393 ast.BitAnd: ("__and__",),
393 ast.BitAnd: ("__and__",),
394 ast.MatMult: ("__matmul__",),
394 ast.MatMult: ("__matmul__",),
395 }
395 }
396
396
397 COMP_OP_DUNDERS: Dict[Type[ast.cmpop], Tuple[str, ...]] = {
397 COMP_OP_DUNDERS: Dict[Type[ast.cmpop], Tuple[str, ...]] = {
398 ast.Eq: ("__eq__",),
398 ast.Eq: ("__eq__",),
399 ast.NotEq: ("__ne__", "__eq__"),
399 ast.NotEq: ("__ne__", "__eq__"),
400 ast.Lt: ("__lt__", "__gt__"),
400 ast.Lt: ("__lt__", "__gt__"),
401 ast.LtE: ("__le__", "__ge__"),
401 ast.LtE: ("__le__", "__ge__"),
402 ast.Gt: ("__gt__", "__lt__"),
402 ast.Gt: ("__gt__", "__lt__"),
403 ast.GtE: ("__ge__", "__le__"),
403 ast.GtE: ("__ge__", "__le__"),
404 ast.In: ("__contains__",),
404 ast.In: ("__contains__",),
405 # Note: ast.Is, ast.IsNot, ast.NotIn are handled specially
405 # Note: ast.Is, ast.IsNot, ast.NotIn are handled specially
406 }
406 }
407
407
408 UNARY_OP_DUNDERS: Dict[Type[ast.unaryop], Tuple[str, ...]] = {
408 UNARY_OP_DUNDERS: Dict[Type[ast.unaryop], Tuple[str, ...]] = {
409 ast.USub: ("__neg__",),
409 ast.USub: ("__neg__",),
410 ast.UAdd: ("__pos__",),
410 ast.UAdd: ("__pos__",),
411 # we have to check both __inv__ and __invert__!
411 # we have to check both __inv__ and __invert__!
412 ast.Invert: ("__invert__", "__inv__"),
412 ast.Invert: ("__invert__", "__inv__"),
413 ast.Not: ("__not__",),
413 ast.Not: ("__not__",),
414 }
414 }
415
415
416
416
417 def _find_dunder(node_op, dunders) -> Union[Tuple[str, ...], None]:
417 def _find_dunder(node_op, dunders) -> Union[Tuple[str, ...], None]:
418 dunder = None
418 dunder = None
419 for op, candidate_dunder in dunders.items():
419 for op, candidate_dunder in dunders.items():
420 if isinstance(node_op, op):
420 if isinstance(node_op, op):
421 dunder = candidate_dunder
421 dunder = candidate_dunder
422 return dunder
422 return dunder
423
423
424
424
425 def eval_node(node: Union[ast.AST, None], context: EvaluationContext):
425 def eval_node(node: Union[ast.AST, None], context: EvaluationContext):
426 """Evaluate AST node in provided context.
426 """Evaluate AST node in provided context.
427
427
428 Applies evaluation restrictions defined in the context. Currently does not support evaluation of functions with keyword arguments.
428 Applies evaluation restrictions defined in the context. Currently does not support evaluation of functions with keyword arguments.
429
429
430 Does not evaluate actions that always have side effects:
430 Does not evaluate actions that always have side effects:
431
431
432 - class definitions (``class sth: ...``)
432 - class definitions (``class sth: ...``)
433 - function definitions (``def sth: ...``)
433 - function definitions (``def sth: ...``)
434 - variable assignments (``x = 1``)
434 - variable assignments (``x = 1``)
435 - augmented assignments (``x += 1``)
435 - augmented assignments (``x += 1``)
436 - deletions (``del x``)
436 - deletions (``del x``)
437
437
438 Does not evaluate operations which do not return values:
438 Does not evaluate operations which do not return values:
439
439
440 - assertions (``assert x``)
440 - assertions (``assert x``)
441 - pass (``pass``)
441 - pass (``pass``)
442 - imports (``import x``)
442 - imports (``import x``)
443 - control flow:
443 - control flow:
444
444
445 - conditionals (``if x:``) except for ternary IfExp (``a if x else b``)
445 - conditionals (``if x:``) except for ternary IfExp (``a if x else b``)
446 - loops (``for`` and `while``)
446 - loops (``for`` and `while``)
447 - exception handling
447 - exception handling
448
448
449 The purpose of this function is to guard against unwanted side-effects;
449 The purpose of this function is to guard against unwanted side-effects;
450 it does not give guarantees on protection from malicious code execution.
450 it does not give guarantees on protection from malicious code execution.
451 """
451 """
452 policy = EVALUATION_POLICIES[context.evaluation]
452 policy = EVALUATION_POLICIES[context.evaluation]
453 if node is None:
453 if node is None:
454 return None
454 return None
455 if isinstance(node, ast.Expression):
455 if isinstance(node, ast.Expression):
456 return eval_node(node.body, context)
456 return eval_node(node.body, context)
457 if isinstance(node, ast.BinOp):
457 if isinstance(node, ast.BinOp):
458 left = eval_node(node.left, context)
458 left = eval_node(node.left, context)
459 right = eval_node(node.right, context)
459 right = eval_node(node.right, context)
460 dunders = _find_dunder(node.op, BINARY_OP_DUNDERS)
460 dunders = _find_dunder(node.op, BINARY_OP_DUNDERS)
461 if dunders:
461 if dunders:
462 if policy.can_operate(dunders, left, right):
462 if policy.can_operate(dunders, left, right):
463 return getattr(left, dunders[0])(right)
463 return getattr(left, dunders[0])(right)
464 else:
464 else:
465 raise GuardRejection(
465 raise GuardRejection(
466 f"Operation (`{dunders}`) for",
466 f"Operation (`{dunders}`) for",
467 type(left),
467 type(left),
468 f"not allowed in {context.evaluation} mode",
468 f"not allowed in {context.evaluation} mode",
469 )
469 )
470 if isinstance(node, ast.Compare):
470 if isinstance(node, ast.Compare):
471 left = eval_node(node.left, context)
471 left = eval_node(node.left, context)
472 all_true = True
472 all_true = True
473 negate = False
473 negate = False
474 for op, right in zip(node.ops, node.comparators):
474 for op, right in zip(node.ops, node.comparators):
475 right = eval_node(right, context)
475 right = eval_node(right, context)
476 dunder = None
476 dunder = None
477 dunders = _find_dunder(op, COMP_OP_DUNDERS)
477 dunders = _find_dunder(op, COMP_OP_DUNDERS)
478 if not dunders:
478 if not dunders:
479 if isinstance(op, ast.NotIn):
479 if isinstance(op, ast.NotIn):
480 dunders = COMP_OP_DUNDERS[ast.In]
480 dunders = COMP_OP_DUNDERS[ast.In]
481 negate = True
481 negate = True
482 if isinstance(op, ast.Is):
482 if isinstance(op, ast.Is):
483 dunder = "is_"
483 dunder = "is_"
484 if isinstance(op, ast.IsNot):
484 if isinstance(op, ast.IsNot):
485 dunder = "is_"
485 dunder = "is_"
486 negate = True
486 negate = True
487 if not dunder and dunders:
487 if not dunder and dunders:
488 dunder = dunders[0]
488 dunder = dunders[0]
489 if dunder:
489 if dunder:
490 a, b = (right, left) if dunder == "__contains__" else (left, right)
490 a, b = (right, left) if dunder == "__contains__" else (left, right)
491 if dunder == "is_" or dunders and policy.can_operate(dunders, a, b):
491 if dunder == "is_" or dunders and policy.can_operate(dunders, a, b):
492 result = getattr(operator, dunder)(a, b)
492 result = getattr(operator, dunder)(a, b)
493 if negate:
493 if negate:
494 result = not result
494 result = not result
495 if not result:
495 if not result:
496 all_true = False
496 all_true = False
497 left = right
497 left = right
498 else:
498 else:
499 raise GuardRejection(
499 raise GuardRejection(
500 f"Comparison (`{dunder}`) for",
500 f"Comparison (`{dunder}`) for",
501 type(left),
501 type(left),
502 f"not allowed in {context.evaluation} mode",
502 f"not allowed in {context.evaluation} mode",
503 )
503 )
504 else:
504 else:
505 raise ValueError(
505 raise ValueError(
506 f"Comparison `{dunder}` not supported"
506 f"Comparison `{dunder}` not supported"
507 ) # pragma: no cover
507 ) # pragma: no cover
508 return all_true
508 return all_true
509 if isinstance(node, ast.Constant):
509 if isinstance(node, ast.Constant):
510 return node.value
510 return node.value
511 if isinstance(node, ast.Index):
511 if isinstance(node, ast.Index):
512 # deprecated since Python 3.9
512 # deprecated since Python 3.9
513 return eval_node(node.value, context) # pragma: no cover
513 return eval_node(node.value, context) # pragma: no cover
514 if isinstance(node, ast.Tuple):
514 if isinstance(node, ast.Tuple):
515 return tuple(eval_node(e, context) for e in node.elts)
515 return tuple(eval_node(e, context) for e in node.elts)
516 if isinstance(node, ast.List):
516 if isinstance(node, ast.List):
517 return [eval_node(e, context) for e in node.elts]
517 return [eval_node(e, context) for e in node.elts]
518 if isinstance(node, ast.Set):
518 if isinstance(node, ast.Set):
519 return {eval_node(e, context) for e in node.elts}
519 return {eval_node(e, context) for e in node.elts}
520 if isinstance(node, ast.Dict):
520 if isinstance(node, ast.Dict):
521 return dict(
521 return dict(
522 zip(
522 zip(
523 [eval_node(k, context) for k in node.keys],
523 [eval_node(k, context) for k in node.keys],
524 [eval_node(v, context) for v in node.values],
524 [eval_node(v, context) for v in node.values],
525 )
525 )
526 )
526 )
527 if isinstance(node, ast.Slice):
527 if isinstance(node, ast.Slice):
528 return slice(
528 return slice(
529 eval_node(node.lower, context),
529 eval_node(node.lower, context),
530 eval_node(node.upper, context),
530 eval_node(node.upper, context),
531 eval_node(node.step, context),
531 eval_node(node.step, context),
532 )
532 )
533 if isinstance(node, ast.ExtSlice):
533 if isinstance(node, ast.ExtSlice):
534 # deprecated since Python 3.9
534 # deprecated since Python 3.9
535 return tuple([eval_node(dim, context) for dim in node.dims]) # pragma: no cover
535 return tuple([eval_node(dim, context) for dim in node.dims]) # pragma: no cover
536 if isinstance(node, ast.UnaryOp):
536 if isinstance(node, ast.UnaryOp):
537 value = eval_node(node.operand, context)
537 value = eval_node(node.operand, context)
538 dunders = _find_dunder(node.op, UNARY_OP_DUNDERS)
538 dunders = _find_dunder(node.op, UNARY_OP_DUNDERS)
539 if dunders:
539 if dunders:
540 if policy.can_operate(dunders, value):
540 if policy.can_operate(dunders, value):
541 return getattr(value, dunders[0])()
541 return getattr(value, dunders[0])()
542 else:
542 else:
543 raise GuardRejection(
543 raise GuardRejection(
544 f"Operation (`{dunders}`) for",
544 f"Operation (`{dunders}`) for",
545 type(value),
545 type(value),
546 f"not allowed in {context.evaluation} mode",
546 f"not allowed in {context.evaluation} mode",
547 )
547 )
548 if isinstance(node, ast.Subscript):
548 if isinstance(node, ast.Subscript):
549 value = eval_node(node.value, context)
549 value = eval_node(node.value, context)
550 slice_ = eval_node(node.slice, context)
550 slice_ = eval_node(node.slice, context)
551 if policy.can_get_item(value, slice_):
551 if policy.can_get_item(value, slice_):
552 return value[slice_]
552 return value[slice_]
553 raise GuardRejection(
553 raise GuardRejection(
554 "Subscript access (`__getitem__`) for",
554 "Subscript access (`__getitem__`) for",
555 type(value), # not joined to avoid calling `repr`
555 type(value), # not joined to avoid calling `repr`
556 f" not allowed in {context.evaluation} mode",
556 f" not allowed in {context.evaluation} mode",
557 )
557 )
558 if isinstance(node, ast.Name):
558 if isinstance(node, ast.Name):
559 if policy.allow_locals_access and node.id in context.locals:
559 if policy.allow_locals_access and node.id in context.locals:
560 return context.locals[node.id]
560 return context.locals[node.id]
561 if policy.allow_globals_access and node.id in context.globals:
561 if policy.allow_globals_access and node.id in context.globals:
562 return context.globals[node.id]
562 return context.globals[node.id]
563 if policy.allow_builtins_access and hasattr(builtins, node.id):
563 if policy.allow_builtins_access and hasattr(builtins, node.id):
564 # note: do not use __builtins__, it is implementation detail of cPython
564 # note: do not use __builtins__, it is implementation detail of cPython
565 return getattr(builtins, node.id)
565 return getattr(builtins, node.id)
566 if not policy.allow_globals_access and not policy.allow_locals_access:
566 if not policy.allow_globals_access and not policy.allow_locals_access:
567 raise GuardRejection(
567 raise GuardRejection(
568 f"Namespace access not allowed in {context.evaluation} mode"
568 f"Namespace access not allowed in {context.evaluation} mode"
569 )
569 )
570 else:
570 else:
571 raise NameError(f"{node.id} not found in locals, globals, nor builtins")
571 raise NameError(f"{node.id} not found in locals, globals, nor builtins")
572 if isinstance(node, ast.Attribute):
572 if isinstance(node, ast.Attribute):
573 value = eval_node(node.value, context)
573 value = eval_node(node.value, context)
574 if policy.can_get_attr(value, node.attr):
574 if policy.can_get_attr(value, node.attr):
575 return getattr(value, node.attr)
575 return getattr(value, node.attr)
576 raise GuardRejection(
576 raise GuardRejection(
577 "Attribute access (`__getattr__`) for",
577 "Attribute access (`__getattr__`) for",
578 type(value), # not joined to avoid calling `repr`
578 type(value), # not joined to avoid calling `repr`
579 f"not allowed in {context.evaluation} mode",
579 f"not allowed in {context.evaluation} mode",
580 )
580 )
581 if isinstance(node, ast.IfExp):
581 if isinstance(node, ast.IfExp):
582 test = eval_node(node.test, context)
582 test = eval_node(node.test, context)
583 if test:
583 if test:
584 return eval_node(node.body, context)
584 return eval_node(node.body, context)
585 else:
585 else:
586 return eval_node(node.orelse, context)
586 return eval_node(node.orelse, context)
587 if isinstance(node, ast.Call):
587 if isinstance(node, ast.Call):
588 func = eval_node(node.func, context)
588 func = eval_node(node.func, context)
589 if policy.can_call(func) and not node.keywords:
589 if policy.can_call(func) and not node.keywords:
590 args = [eval_node(arg, context) for arg in node.args]
590 args = [eval_node(arg, context) for arg in node.args]
591 return func(*args)
591 return func(*args)
592 raise GuardRejection(
592 raise GuardRejection(
593 "Call for",
593 "Call for",
594 func, # not joined to avoid calling `repr`
594 func, # not joined to avoid calling `repr`
595 f"not allowed in {context.evaluation} mode",
595 f"not allowed in {context.evaluation} mode",
596 )
596 )
597 raise ValueError("Unhandled node", ast.dump(node))
597 raise ValueError("Unhandled node", ast.dump(node))
598
598
599
599
600 SUPPORTED_EXTERNAL_GETITEM = {
600 SUPPORTED_EXTERNAL_GETITEM = {
601 ("pandas", "core", "indexing", "_iLocIndexer"),
601 ("pandas", "core", "indexing", "_iLocIndexer"),
602 ("pandas", "core", "indexing", "_LocIndexer"),
602 ("pandas", "core", "indexing", "_LocIndexer"),
603 ("pandas", "DataFrame"),
603 ("pandas", "DataFrame"),
604 ("pandas", "Series"),
604 ("pandas", "Series"),
605 ("numpy", "ndarray"),
605 ("numpy", "ndarray"),
606 ("numpy", "void"),
606 ("numpy", "void"),
607 }
607 }
608
608
609
609
610 BUILTIN_GETITEM: Set[InstancesHaveGetItem] = {
610 BUILTIN_GETITEM: Set[InstancesHaveGetItem] = {
611 dict,
611 dict,
612 str, # type: ignore[arg-type]
612 str, # type: ignore[arg-type]
613 bytes, # type: ignore[arg-type]
613 bytes, # type: ignore[arg-type]
614 list,
614 list,
615 tuple,
615 tuple,
616 collections.defaultdict,
616 collections.defaultdict,
617 collections.deque,
617 collections.deque,
618 collections.OrderedDict,
618 collections.OrderedDict,
619 collections.ChainMap,
619 collections.ChainMap,
620 collections.UserDict,
620 collections.UserDict,
621 collections.UserList,
621 collections.UserList,
622 collections.UserString, # type: ignore[arg-type]
622 collections.UserString, # type: ignore[arg-type]
623 _DummyNamedTuple,
623 _DummyNamedTuple,
624 _IdentitySubscript,
624 _IdentitySubscript,
625 }
625 }
626
626
627
627
628 def _list_methods(cls, source=None):
628 def _list_methods(cls, source=None):
629 """For use on immutable objects or with methods returning a copy"""
629 """For use on immutable objects or with methods returning a copy"""
630 return [getattr(cls, k) for k in (source if source else dir(cls))]
630 return [getattr(cls, k) for k in (source if source else dir(cls))]
631
631
632
632
633 dict_non_mutating_methods = ("copy", "keys", "values", "items")
633 dict_non_mutating_methods = ("copy", "keys", "values", "items")
634 list_non_mutating_methods = ("copy", "index", "count")
634 list_non_mutating_methods = ("copy", "index", "count")
635 set_non_mutating_methods = set(dir(set)) & set(dir(frozenset))
635 set_non_mutating_methods = set(dir(set)) & set(dir(frozenset))
636
636
637
637
638 dict_keys: Type[collections.abc.KeysView] = type({}.keys())
638 dict_keys: Type[collections.abc.KeysView] = type({}.keys())
639 method_descriptor: Any = type(list.copy)
639 method_descriptor: Any = type(list.copy)
640 module = type(builtins)
640
641
641 NUMERICS = {int, float, complex}
642 NUMERICS = {int, float, complex}
642
643
643 ALLOWED_CALLS = {
644 ALLOWED_CALLS = {
644 bytes,
645 bytes,
645 *_list_methods(bytes),
646 *_list_methods(bytes),
646 dict,
647 dict,
647 *_list_methods(dict, dict_non_mutating_methods),
648 *_list_methods(dict, dict_non_mutating_methods),
648 dict_keys.isdisjoint,
649 dict_keys.isdisjoint,
649 list,
650 list,
650 *_list_methods(list, list_non_mutating_methods),
651 *_list_methods(list, list_non_mutating_methods),
651 set,
652 set,
652 *_list_methods(set, set_non_mutating_methods),
653 *_list_methods(set, set_non_mutating_methods),
653 frozenset,
654 frozenset,
654 *_list_methods(frozenset),
655 *_list_methods(frozenset),
655 range,
656 range,
656 str,
657 str,
657 *_list_methods(str),
658 *_list_methods(str),
658 tuple,
659 tuple,
659 *_list_methods(tuple),
660 *_list_methods(tuple),
660 *NUMERICS,
661 *NUMERICS,
661 *[method for numeric_cls in NUMERICS for method in _list_methods(numeric_cls)],
662 *[method for numeric_cls in NUMERICS for method in _list_methods(numeric_cls)],
662 collections.deque,
663 collections.deque,
663 *_list_methods(collections.deque, list_non_mutating_methods),
664 *_list_methods(collections.deque, list_non_mutating_methods),
664 collections.defaultdict,
665 collections.defaultdict,
665 *_list_methods(collections.defaultdict, dict_non_mutating_methods),
666 *_list_methods(collections.defaultdict, dict_non_mutating_methods),
666 collections.OrderedDict,
667 collections.OrderedDict,
667 *_list_methods(collections.OrderedDict, dict_non_mutating_methods),
668 *_list_methods(collections.OrderedDict, dict_non_mutating_methods),
668 collections.UserDict,
669 collections.UserDict,
669 *_list_methods(collections.UserDict, dict_non_mutating_methods),
670 *_list_methods(collections.UserDict, dict_non_mutating_methods),
670 collections.UserList,
671 collections.UserList,
671 *_list_methods(collections.UserList, list_non_mutating_methods),
672 *_list_methods(collections.UserList, list_non_mutating_methods),
672 collections.UserString,
673 collections.UserString,
673 *_list_methods(collections.UserString, dir(str)),
674 *_list_methods(collections.UserString, dir(str)),
674 collections.Counter,
675 collections.Counter,
675 *_list_methods(collections.Counter, dict_non_mutating_methods),
676 *_list_methods(collections.Counter, dict_non_mutating_methods),
676 collections.Counter.elements,
677 collections.Counter.elements,
677 collections.Counter.most_common,
678 collections.Counter.most_common,
678 }
679 }
679
680
680 BUILTIN_GETATTR: Set[MayHaveGetattr] = {
681 BUILTIN_GETATTR: Set[MayHaveGetattr] = {
681 *BUILTIN_GETITEM,
682 *BUILTIN_GETITEM,
682 set,
683 set,
683 frozenset,
684 frozenset,
684 object,
685 object,
685 type, # `type` handles a lot of generic cases, e.g. numbers as in `int.real`.
686 type, # `type` handles a lot of generic cases, e.g. numbers as in `int.real`.
686 *NUMERICS,
687 *NUMERICS,
687 dict_keys,
688 dict_keys,
688 method_descriptor,
689 method_descriptor,
690 module,
689 }
691 }
690
692
691
693
692 BUILTIN_OPERATIONS = {*BUILTIN_GETATTR}
694 BUILTIN_OPERATIONS = {*BUILTIN_GETATTR}
693
695
694 EVALUATION_POLICIES = {
696 EVALUATION_POLICIES = {
695 "minimal": EvaluationPolicy(
697 "minimal": EvaluationPolicy(
696 allow_builtins_access=True,
698 allow_builtins_access=True,
697 allow_locals_access=False,
699 allow_locals_access=False,
698 allow_globals_access=False,
700 allow_globals_access=False,
699 allow_item_access=False,
701 allow_item_access=False,
700 allow_attr_access=False,
702 allow_attr_access=False,
701 allowed_calls=set(),
703 allowed_calls=set(),
702 allow_any_calls=False,
704 allow_any_calls=False,
703 allow_all_operations=False,
705 allow_all_operations=False,
704 ),
706 ),
705 "limited": SelectivePolicy(
707 "limited": SelectivePolicy(
706 allowed_getitem=BUILTIN_GETITEM,
708 allowed_getitem=BUILTIN_GETITEM,
707 allowed_getitem_external=SUPPORTED_EXTERNAL_GETITEM,
709 allowed_getitem_external=SUPPORTED_EXTERNAL_GETITEM,
708 allowed_getattr=BUILTIN_GETATTR,
710 allowed_getattr=BUILTIN_GETATTR,
709 allowed_getattr_external={
711 allowed_getattr_external={
710 # pandas Series/Frame implements custom `__getattr__`
712 # pandas Series/Frame implements custom `__getattr__`
711 ("pandas", "DataFrame"),
713 ("pandas", "DataFrame"),
712 ("pandas", "Series"),
714 ("pandas", "Series"),
713 },
715 },
714 allowed_operations=BUILTIN_OPERATIONS,
716 allowed_operations=BUILTIN_OPERATIONS,
715 allow_builtins_access=True,
717 allow_builtins_access=True,
716 allow_locals_access=True,
718 allow_locals_access=True,
717 allow_globals_access=True,
719 allow_globals_access=True,
718 allowed_calls=ALLOWED_CALLS,
720 allowed_calls=ALLOWED_CALLS,
719 ),
721 ),
720 "unsafe": EvaluationPolicy(
722 "unsafe": EvaluationPolicy(
721 allow_builtins_access=True,
723 allow_builtins_access=True,
722 allow_locals_access=True,
724 allow_locals_access=True,
723 allow_globals_access=True,
725 allow_globals_access=True,
724 allow_attr_access=True,
726 allow_attr_access=True,
725 allow_item_access=True,
727 allow_item_access=True,
726 allow_any_calls=True,
728 allow_any_calls=True,
727 allow_all_operations=True,
729 allow_all_operations=True,
728 ),
730 ),
729 }
731 }
730
732
731
733
732 __all__ = [
734 __all__ = [
733 "guarded_eval",
735 "guarded_eval",
734 "eval_node",
736 "eval_node",
735 "GuardRejection",
737 "GuardRejection",
736 "EvaluationContext",
738 "EvaluationContext",
737 "_unbind_method",
739 "_unbind_method",
738 ]
740 ]
@@ -1,570 +1,582 b''
1 from contextlib import contextmanager
1 from contextlib import contextmanager
2 from typing import NamedTuple
2 from typing import NamedTuple
3 from functools import partial
3 from functools import partial
4 from IPython.core.guarded_eval import (
4 from IPython.core.guarded_eval import (
5 EvaluationContext,
5 EvaluationContext,
6 GuardRejection,
6 GuardRejection,
7 guarded_eval,
7 guarded_eval,
8 _unbind_method,
8 _unbind_method,
9 )
9 )
10 from IPython.testing import decorators as dec
10 from IPython.testing import decorators as dec
11 import pytest
11 import pytest
12
12
13
13
14 def create_context(evaluation: str, **kwargs):
14 def create_context(evaluation: str, **kwargs):
15 return EvaluationContext(locals=kwargs, globals={}, evaluation=evaluation)
15 return EvaluationContext(locals=kwargs, globals={}, evaluation=evaluation)
16
16
17
17
18 forbidden = partial(create_context, "forbidden")
18 forbidden = partial(create_context, "forbidden")
19 minimal = partial(create_context, "minimal")
19 minimal = partial(create_context, "minimal")
20 limited = partial(create_context, "limited")
20 limited = partial(create_context, "limited")
21 unsafe = partial(create_context, "unsafe")
21 unsafe = partial(create_context, "unsafe")
22 dangerous = partial(create_context, "dangerous")
22 dangerous = partial(create_context, "dangerous")
23
23
24 LIMITED_OR_HIGHER = [limited, unsafe, dangerous]
24 LIMITED_OR_HIGHER = [limited, unsafe, dangerous]
25 MINIMAL_OR_HIGHER = [minimal, *LIMITED_OR_HIGHER]
25 MINIMAL_OR_HIGHER = [minimal, *LIMITED_OR_HIGHER]
26
26
27
27
28 @contextmanager
28 @contextmanager
29 def module_not_installed(module: str):
29 def module_not_installed(module: str):
30 import sys
30 import sys
31
31
32 try:
32 try:
33 to_restore = sys.modules[module]
33 to_restore = sys.modules[module]
34 del sys.modules[module]
34 del sys.modules[module]
35 except KeyError:
35 except KeyError:
36 to_restore = None
36 to_restore = None
37 try:
37 try:
38 yield
38 yield
39 finally:
39 finally:
40 sys.modules[module] = to_restore
40 sys.modules[module] = to_restore
41
41
42
42
43 def test_external_not_installed():
43 def test_external_not_installed():
44 """
44 """
45 Because attribute check requires checking if object is not of allowed
45 Because attribute check requires checking if object is not of allowed
46 external type, this tests logic for absence of external module.
46 external type, this tests logic for absence of external module.
47 """
47 """
48
48
49 class Custom:
49 class Custom:
50 def __init__(self):
50 def __init__(self):
51 self.test = 1
51 self.test = 1
52
52
53 def __getattr__(self, key):
53 def __getattr__(self, key):
54 return key
54 return key
55
55
56 with module_not_installed("pandas"):
56 with module_not_installed("pandas"):
57 context = limited(x=Custom())
57 context = limited(x=Custom())
58 with pytest.raises(GuardRejection):
58 with pytest.raises(GuardRejection):
59 guarded_eval("x.test", context)
59 guarded_eval("x.test", context)
60
60
61
61
62 @dec.skip_without("pandas")
62 @dec.skip_without("pandas")
63 def test_external_changed_api(monkeypatch):
63 def test_external_changed_api(monkeypatch):
64 """Check that the execution rejects if external API changed paths"""
64 """Check that the execution rejects if external API changed paths"""
65 import pandas as pd
65 import pandas as pd
66
66
67 series = pd.Series([1], index=["a"])
67 series = pd.Series([1], index=["a"])
68
68
69 with monkeypatch.context() as m:
69 with monkeypatch.context() as m:
70 m.delattr(pd, "Series")
70 m.delattr(pd, "Series")
71 context = limited(data=series)
71 context = limited(data=series)
72 with pytest.raises(GuardRejection):
72 with pytest.raises(GuardRejection):
73 guarded_eval("data.iloc[0]", context)
73 guarded_eval("data.iloc[0]", context)
74
74
75
75
76 @dec.skip_without("pandas")
76 @dec.skip_without("pandas")
77 def test_pandas_series_iloc():
77 def test_pandas_series_iloc():
78 import pandas as pd
78 import pandas as pd
79
79
80 series = pd.Series([1], index=["a"])
80 series = pd.Series([1], index=["a"])
81 context = limited(data=series)
81 context = limited(data=series)
82 assert guarded_eval("data.iloc[0]", context) == 1
82 assert guarded_eval("data.iloc[0]", context) == 1
83
83
84
84
85 def test_rejects_custom_properties():
85 def test_rejects_custom_properties():
86 class BadProperty:
86 class BadProperty:
87 @property
87 @property
88 def iloc(self):
88 def iloc(self):
89 return [None]
89 return [None]
90
90
91 series = BadProperty()
91 series = BadProperty()
92 context = limited(data=series)
92 context = limited(data=series)
93
93
94 with pytest.raises(GuardRejection):
94 with pytest.raises(GuardRejection):
95 guarded_eval("data.iloc[0]", context)
95 guarded_eval("data.iloc[0]", context)
96
96
97
97
98 @dec.skip_without("pandas")
98 @dec.skip_without("pandas")
99 def test_accepts_non_overriden_properties():
99 def test_accepts_non_overriden_properties():
100 import pandas as pd
100 import pandas as pd
101
101
102 class GoodProperty(pd.Series):
102 class GoodProperty(pd.Series):
103 pass
103 pass
104
104
105 series = GoodProperty([1], index=["a"])
105 series = GoodProperty([1], index=["a"])
106 context = limited(data=series)
106 context = limited(data=series)
107
107
108 assert guarded_eval("data.iloc[0]", context) == 1
108 assert guarded_eval("data.iloc[0]", context) == 1
109
109
110
110
111 @dec.skip_without("pandas")
111 @dec.skip_without("pandas")
112 def test_pandas_series():
112 def test_pandas_series():
113 import pandas as pd
113 import pandas as pd
114
114
115 context = limited(data=pd.Series([1], index=["a"]))
115 context = limited(data=pd.Series([1], index=["a"]))
116 assert guarded_eval('data["a"]', context) == 1
116 assert guarded_eval('data["a"]', context) == 1
117 with pytest.raises(KeyError):
117 with pytest.raises(KeyError):
118 guarded_eval('data["c"]', context)
118 guarded_eval('data["c"]', context)
119
119
120
120
121 @dec.skip_without("pandas")
121 @dec.skip_without("pandas")
122 def test_pandas_bad_series():
122 def test_pandas_bad_series():
123 import pandas as pd
123 import pandas as pd
124
124
125 class BadItemSeries(pd.Series):
125 class BadItemSeries(pd.Series):
126 def __getitem__(self, key):
126 def __getitem__(self, key):
127 return "CUSTOM_ITEM"
127 return "CUSTOM_ITEM"
128
128
129 class BadAttrSeries(pd.Series):
129 class BadAttrSeries(pd.Series):
130 def __getattr__(self, key):
130 def __getattr__(self, key):
131 return "CUSTOM_ATTR"
131 return "CUSTOM_ATTR"
132
132
133 bad_series = BadItemSeries([1], index=["a"])
133 bad_series = BadItemSeries([1], index=["a"])
134 context = limited(data=bad_series)
134 context = limited(data=bad_series)
135
135
136 with pytest.raises(GuardRejection):
136 with pytest.raises(GuardRejection):
137 guarded_eval('data["a"]', context)
137 guarded_eval('data["a"]', context)
138 with pytest.raises(GuardRejection):
138 with pytest.raises(GuardRejection):
139 guarded_eval('data["c"]', context)
139 guarded_eval('data["c"]', context)
140
140
141 # note: here result is a bit unexpected because
141 # note: here result is a bit unexpected because
142 # pandas `__getattr__` calls `__getitem__`;
142 # pandas `__getattr__` calls `__getitem__`;
143 # FIXME - special case to handle it?
143 # FIXME - special case to handle it?
144 assert guarded_eval("data.a", context) == "CUSTOM_ITEM"
144 assert guarded_eval("data.a", context) == "CUSTOM_ITEM"
145
145
146 context = unsafe(data=bad_series)
146 context = unsafe(data=bad_series)
147 assert guarded_eval('data["a"]', context) == "CUSTOM_ITEM"
147 assert guarded_eval('data["a"]', context) == "CUSTOM_ITEM"
148
148
149 bad_attr_series = BadAttrSeries([1], index=["a"])
149 bad_attr_series = BadAttrSeries([1], index=["a"])
150 context = limited(data=bad_attr_series)
150 context = limited(data=bad_attr_series)
151 assert guarded_eval('data["a"]', context) == 1
151 assert guarded_eval('data["a"]', context) == 1
152 with pytest.raises(GuardRejection):
152 with pytest.raises(GuardRejection):
153 guarded_eval("data.a", context)
153 guarded_eval("data.a", context)
154
154
155
155
156 @dec.skip_without("pandas")
156 @dec.skip_without("pandas")
157 def test_pandas_dataframe_loc():
157 def test_pandas_dataframe_loc():
158 import pandas as pd
158 import pandas as pd
159 from pandas.testing import assert_series_equal
159 from pandas.testing import assert_series_equal
160
160
161 data = pd.DataFrame([{"a": 1}])
161 data = pd.DataFrame([{"a": 1}])
162 context = limited(data=data)
162 context = limited(data=data)
163 assert_series_equal(guarded_eval('data.loc[:, "a"]', context), data["a"])
163 assert_series_equal(guarded_eval('data.loc[:, "a"]', context), data["a"])
164
164
165
165
166 def test_named_tuple():
166 def test_named_tuple():
167 class GoodNamedTuple(NamedTuple):
167 class GoodNamedTuple(NamedTuple):
168 a: str
168 a: str
169 pass
169 pass
170
170
171 class BadNamedTuple(NamedTuple):
171 class BadNamedTuple(NamedTuple):
172 a: str
172 a: str
173
173
174 def __getitem__(self, key):
174 def __getitem__(self, key):
175 return None
175 return None
176
176
177 good = GoodNamedTuple(a="x")
177 good = GoodNamedTuple(a="x")
178 bad = BadNamedTuple(a="x")
178 bad = BadNamedTuple(a="x")
179
179
180 context = limited(data=good)
180 context = limited(data=good)
181 assert guarded_eval("data[0]", context) == "x"
181 assert guarded_eval("data[0]", context) == "x"
182
182
183 context = limited(data=bad)
183 context = limited(data=bad)
184 with pytest.raises(GuardRejection):
184 with pytest.raises(GuardRejection):
185 guarded_eval("data[0]", context)
185 guarded_eval("data[0]", context)
186
186
187
187
188 def test_dict():
188 def test_dict():
189 context = limited(data={"a": 1, "b": {"x": 2}, ("x", "y"): 3})
189 context = limited(data={"a": 1, "b": {"x": 2}, ("x", "y"): 3})
190 assert guarded_eval('data["a"]', context) == 1
190 assert guarded_eval('data["a"]', context) == 1
191 assert guarded_eval('data["b"]', context) == {"x": 2}
191 assert guarded_eval('data["b"]', context) == {"x": 2}
192 assert guarded_eval('data["b"]["x"]', context) == 2
192 assert guarded_eval('data["b"]["x"]', context) == 2
193 assert guarded_eval('data["x", "y"]', context) == 3
193 assert guarded_eval('data["x", "y"]', context) == 3
194
194
195 assert guarded_eval("data.keys", context)
195 assert guarded_eval("data.keys", context)
196
196
197
197
198 def test_set():
198 def test_set():
199 context = limited(data={"a", "b"})
199 context = limited(data={"a", "b"})
200 assert guarded_eval("data.difference", context)
200 assert guarded_eval("data.difference", context)
201
201
202
202
203 def test_list():
203 def test_list():
204 context = limited(data=[1, 2, 3])
204 context = limited(data=[1, 2, 3])
205 assert guarded_eval("data[1]", context) == 2
205 assert guarded_eval("data[1]", context) == 2
206 assert guarded_eval("data.copy", context)
206 assert guarded_eval("data.copy", context)
207
207
208
208
209 def test_dict_literal():
209 def test_dict_literal():
210 context = limited()
210 context = limited()
211 assert guarded_eval("{}", context) == {}
211 assert guarded_eval("{}", context) == {}
212 assert guarded_eval('{"a": 1}', context) == {"a": 1}
212 assert guarded_eval('{"a": 1}', context) == {"a": 1}
213
213
214
214
215 def test_list_literal():
215 def test_list_literal():
216 context = limited()
216 context = limited()
217 assert guarded_eval("[]", context) == []
217 assert guarded_eval("[]", context) == []
218 assert guarded_eval('[1, "a"]', context) == [1, "a"]
218 assert guarded_eval('[1, "a"]', context) == [1, "a"]
219
219
220
220
221 def test_set_literal():
221 def test_set_literal():
222 context = limited()
222 context = limited()
223 assert guarded_eval("set()", context) == set()
223 assert guarded_eval("set()", context) == set()
224 assert guarded_eval('{"a"}', context) == {"a"}
224 assert guarded_eval('{"a"}', context) == {"a"}
225
225
226
226
227 def test_evaluates_if_expression():
227 def test_evaluates_if_expression():
228 context = limited()
228 context = limited()
229 assert guarded_eval("2 if True else 3", context) == 2
229 assert guarded_eval("2 if True else 3", context) == 2
230 assert guarded_eval("4 if False else 5", context) == 5
230 assert guarded_eval("4 if False else 5", context) == 5
231
231
232
232
233 def test_object():
233 def test_object():
234 obj = object()
234 obj = object()
235 context = limited(obj=obj)
235 context = limited(obj=obj)
236 assert guarded_eval("obj.__dir__", context) == obj.__dir__
236 assert guarded_eval("obj.__dir__", context) == obj.__dir__
237
237
238
238
239 @pytest.mark.parametrize(
239 @pytest.mark.parametrize(
240 "code,expected",
240 "code,expected",
241 [
241 [
242 ["int.numerator", int.numerator],
242 ["int.numerator", int.numerator],
243 ["float.is_integer", float.is_integer],
243 ["float.is_integer", float.is_integer],
244 ["complex.real", complex.real],
244 ["complex.real", complex.real],
245 ],
245 ],
246 )
246 )
247 def test_number_attributes(code, expected):
247 def test_number_attributes(code, expected):
248 assert guarded_eval(code, limited()) == expected
248 assert guarded_eval(code, limited()) == expected
249
249
250
250
251 def test_method_descriptor():
251 def test_method_descriptor():
252 context = limited()
252 context = limited()
253 assert guarded_eval("list.copy.__name__", context) == "copy"
253 assert guarded_eval("list.copy.__name__", context) == "copy"
254
254
255
255
256 @pytest.mark.parametrize(
256 @pytest.mark.parametrize(
257 "data,good,bad,expected",
257 "data,good,bad,expected",
258 [
258 [
259 [[1, 2, 3], "data.index(2)", "data.append(4)", 1],
259 [[1, 2, 3], "data.index(2)", "data.append(4)", 1],
260 [{"a": 1}, "data.keys().isdisjoint({})", "data.update()", True],
260 [{"a": 1}, "data.keys().isdisjoint({})", "data.update()", True],
261 ],
261 ],
262 )
262 )
263 def test_evaluates_calls(data, good, bad, expected):
263 def test_evaluates_calls(data, good, bad, expected):
264 context = limited(data=data)
264 context = limited(data=data)
265 assert guarded_eval(good, context) == expected
265 assert guarded_eval(good, context) == expected
266
266
267 with pytest.raises(GuardRejection):
267 with pytest.raises(GuardRejection):
268 guarded_eval(bad, context)
268 guarded_eval(bad, context)
269
269
270
270
271 @pytest.mark.parametrize(
271 @pytest.mark.parametrize(
272 "code,expected",
272 "code,expected",
273 [
273 [
274 ["(1\n+\n1)", 2],
274 ["(1\n+\n1)", 2],
275 ["list(range(10))[-1:]", [9]],
275 ["list(range(10))[-1:]", [9]],
276 ["list(range(20))[3:-2:3]", [3, 6, 9, 12, 15]],
276 ["list(range(20))[3:-2:3]", [3, 6, 9, 12, 15]],
277 ],
277 ],
278 )
278 )
279 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
279 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
280 def test_evaluates_complex_cases(code, expected, context):
280 def test_evaluates_complex_cases(code, expected, context):
281 assert guarded_eval(code, context()) == expected
281 assert guarded_eval(code, context()) == expected
282
282
283
283
284 @pytest.mark.parametrize(
284 @pytest.mark.parametrize(
285 "code,expected",
285 "code,expected",
286 [
286 [
287 ["1", 1],
287 ["1", 1],
288 ["1.0", 1.0],
288 ["1.0", 1.0],
289 ["0xdeedbeef", 0xDEEDBEEF],
289 ["0xdeedbeef", 0xDEEDBEEF],
290 ["True", True],
290 ["True", True],
291 ["None", None],
291 ["None", None],
292 ["{}", {}],
292 ["{}", {}],
293 ["[]", []],
293 ["[]", []],
294 ],
294 ],
295 )
295 )
296 @pytest.mark.parametrize("context", MINIMAL_OR_HIGHER)
296 @pytest.mark.parametrize("context", MINIMAL_OR_HIGHER)
297 def test_evaluates_literals(code, expected, context):
297 def test_evaluates_literals(code, expected, context):
298 assert guarded_eval(code, context()) == expected
298 assert guarded_eval(code, context()) == expected
299
299
300
300
301 @pytest.mark.parametrize(
301 @pytest.mark.parametrize(
302 "code,expected",
302 "code,expected",
303 [
303 [
304 ["-5", -5],
304 ["-5", -5],
305 ["+5", +5],
305 ["+5", +5],
306 ["~5", -6],
306 ["~5", -6],
307 ],
307 ],
308 )
308 )
309 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
309 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
310 def test_evaluates_unary_operations(code, expected, context):
310 def test_evaluates_unary_operations(code, expected, context):
311 assert guarded_eval(code, context()) == expected
311 assert guarded_eval(code, context()) == expected
312
312
313
313
314 @pytest.mark.parametrize(
314 @pytest.mark.parametrize(
315 "code,expected",
315 "code,expected",
316 [
316 [
317 ["1 + 1", 2],
317 ["1 + 1", 2],
318 ["3 - 1", 2],
318 ["3 - 1", 2],
319 ["2 * 3", 6],
319 ["2 * 3", 6],
320 ["5 // 2", 2],
320 ["5 // 2", 2],
321 ["5 / 2", 2.5],
321 ["5 / 2", 2.5],
322 ["5**2", 25],
322 ["5**2", 25],
323 ["2 >> 1", 1],
323 ["2 >> 1", 1],
324 ["2 << 1", 4],
324 ["2 << 1", 4],
325 ["1 | 2", 3],
325 ["1 | 2", 3],
326 ["1 & 1", 1],
326 ["1 & 1", 1],
327 ["1 & 2", 0],
327 ["1 & 2", 0],
328 ],
328 ],
329 )
329 )
330 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
330 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
331 def test_evaluates_binary_operations(code, expected, context):
331 def test_evaluates_binary_operations(code, expected, context):
332 assert guarded_eval(code, context()) == expected
332 assert guarded_eval(code, context()) == expected
333
333
334
334
335 @pytest.mark.parametrize(
335 @pytest.mark.parametrize(
336 "code,expected",
336 "code,expected",
337 [
337 [
338 ["2 > 1", True],
338 ["2 > 1", True],
339 ["2 < 1", False],
339 ["2 < 1", False],
340 ["2 <= 1", False],
340 ["2 <= 1", False],
341 ["2 <= 2", True],
341 ["2 <= 2", True],
342 ["1 >= 2", False],
342 ["1 >= 2", False],
343 ["2 >= 2", True],
343 ["2 >= 2", True],
344 ["2 == 2", True],
344 ["2 == 2", True],
345 ["1 == 2", False],
345 ["1 == 2", False],
346 ["1 != 2", True],
346 ["1 != 2", True],
347 ["1 != 1", False],
347 ["1 != 1", False],
348 ["1 < 4 < 3", False],
348 ["1 < 4 < 3", False],
349 ["(1 < 4) < 3", True],
349 ["(1 < 4) < 3", True],
350 ["4 > 3 > 2 > 1", True],
350 ["4 > 3 > 2 > 1", True],
351 ["4 > 3 > 2 > 9", False],
351 ["4 > 3 > 2 > 9", False],
352 ["1 < 2 < 3 < 4", True],
352 ["1 < 2 < 3 < 4", True],
353 ["9 < 2 < 3 < 4", False],
353 ["9 < 2 < 3 < 4", False],
354 ["1 < 2 > 1 > 0 > -1 < 1", True],
354 ["1 < 2 > 1 > 0 > -1 < 1", True],
355 ["1 in [1] in [[1]]", True],
355 ["1 in [1] in [[1]]", True],
356 ["1 in [1] in [[2]]", False],
356 ["1 in [1] in [[2]]", False],
357 ["1 in [1]", True],
357 ["1 in [1]", True],
358 ["0 in [1]", False],
358 ["0 in [1]", False],
359 ["1 not in [1]", False],
359 ["1 not in [1]", False],
360 ["0 not in [1]", True],
360 ["0 not in [1]", True],
361 ["True is True", True],
361 ["True is True", True],
362 ["False is False", True],
362 ["False is False", True],
363 ["True is False", False],
363 ["True is False", False],
364 ["True is not True", False],
364 ["True is not True", False],
365 ["False is not True", True],
365 ["False is not True", True],
366 ],
366 ],
367 )
367 )
368 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
368 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
369 def test_evaluates_comparisons(code, expected, context):
369 def test_evaluates_comparisons(code, expected, context):
370 assert guarded_eval(code, context()) == expected
370 assert guarded_eval(code, context()) == expected
371
371
372
372
373 def test_guards_comparisons():
373 def test_guards_comparisons():
374 class GoodEq(int):
374 class GoodEq(int):
375 pass
375 pass
376
376
377 class BadEq(int):
377 class BadEq(int):
378 def __eq__(self, other):
378 def __eq__(self, other):
379 assert False
379 assert False
380
380
381 context = limited(bad=BadEq(1), good=GoodEq(1))
381 context = limited(bad=BadEq(1), good=GoodEq(1))
382
382
383 with pytest.raises(GuardRejection):
383 with pytest.raises(GuardRejection):
384 guarded_eval("bad == 1", context)
384 guarded_eval("bad == 1", context)
385
385
386 with pytest.raises(GuardRejection):
386 with pytest.raises(GuardRejection):
387 guarded_eval("bad != 1", context)
387 guarded_eval("bad != 1", context)
388
388
389 with pytest.raises(GuardRejection):
389 with pytest.raises(GuardRejection):
390 guarded_eval("1 == bad", context)
390 guarded_eval("1 == bad", context)
391
391
392 with pytest.raises(GuardRejection):
392 with pytest.raises(GuardRejection):
393 guarded_eval("1 != bad", context)
393 guarded_eval("1 != bad", context)
394
394
395 assert guarded_eval("good == 1", context) is True
395 assert guarded_eval("good == 1", context) is True
396 assert guarded_eval("good != 1", context) is False
396 assert guarded_eval("good != 1", context) is False
397 assert guarded_eval("1 == good", context) is True
397 assert guarded_eval("1 == good", context) is True
398 assert guarded_eval("1 != good", context) is False
398 assert guarded_eval("1 != good", context) is False
399
399
400
400
401 def test_guards_unary_operations():
401 def test_guards_unary_operations():
402 class GoodOp(int):
402 class GoodOp(int):
403 pass
403 pass
404
404
405 class BadOpInv(int):
405 class BadOpInv(int):
406 def __inv__(self, other):
406 def __inv__(self, other):
407 assert False
407 assert False
408
408
409 class BadOpInverse(int):
409 class BadOpInverse(int):
410 def __inv__(self, other):
410 def __inv__(self, other):
411 assert False
411 assert False
412
412
413 context = limited(good=GoodOp(1), bad1=BadOpInv(1), bad2=BadOpInverse(1))
413 context = limited(good=GoodOp(1), bad1=BadOpInv(1), bad2=BadOpInverse(1))
414
414
415 with pytest.raises(GuardRejection):
415 with pytest.raises(GuardRejection):
416 guarded_eval("~bad1", context)
416 guarded_eval("~bad1", context)
417
417
418 with pytest.raises(GuardRejection):
418 with pytest.raises(GuardRejection):
419 guarded_eval("~bad2", context)
419 guarded_eval("~bad2", context)
420
420
421
421
422 def test_guards_binary_operations():
422 def test_guards_binary_operations():
423 class GoodOp(int):
423 class GoodOp(int):
424 pass
424 pass
425
425
426 class BadOp(int):
426 class BadOp(int):
427 def __add__(self, other):
427 def __add__(self, other):
428 assert False
428 assert False
429
429
430 context = limited(good=GoodOp(1), bad=BadOp(1))
430 context = limited(good=GoodOp(1), bad=BadOp(1))
431
431
432 with pytest.raises(GuardRejection):
432 with pytest.raises(GuardRejection):
433 guarded_eval("1 + bad", context)
433 guarded_eval("1 + bad", context)
434
434
435 with pytest.raises(GuardRejection):
435 with pytest.raises(GuardRejection):
436 guarded_eval("bad + 1", context)
436 guarded_eval("bad + 1", context)
437
437
438 assert guarded_eval("good + 1", context) == 2
438 assert guarded_eval("good + 1", context) == 2
439 assert guarded_eval("1 + good", context) == 2
439 assert guarded_eval("1 + good", context) == 2
440
440
441
441
442 def test_guards_attributes():
442 def test_guards_attributes():
443 class GoodAttr(float):
443 class GoodAttr(float):
444 pass
444 pass
445
445
446 class BadAttr1(float):
446 class BadAttr1(float):
447 def __getattr__(self, key):
447 def __getattr__(self, key):
448 assert False
448 assert False
449
449
450 class BadAttr2(float):
450 class BadAttr2(float):
451 def __getattribute__(self, key):
451 def __getattribute__(self, key):
452 assert False
452 assert False
453
453
454 context = limited(good=GoodAttr(0.5), bad1=BadAttr1(0.5), bad2=BadAttr2(0.5))
454 context = limited(good=GoodAttr(0.5), bad1=BadAttr1(0.5), bad2=BadAttr2(0.5))
455
455
456 with pytest.raises(GuardRejection):
456 with pytest.raises(GuardRejection):
457 guarded_eval("bad1.as_integer_ratio", context)
457 guarded_eval("bad1.as_integer_ratio", context)
458
458
459 with pytest.raises(GuardRejection):
459 with pytest.raises(GuardRejection):
460 guarded_eval("bad2.as_integer_ratio", context)
460 guarded_eval("bad2.as_integer_ratio", context)
461
461
462 assert guarded_eval("good.as_integer_ratio()", context) == (1, 2)
462 assert guarded_eval("good.as_integer_ratio()", context) == (1, 2)
463
463
464
464
465 @pytest.mark.parametrize("context", MINIMAL_OR_HIGHER)
465 @pytest.mark.parametrize("context", MINIMAL_OR_HIGHER)
466 def test_access_builtins(context):
466 def test_access_builtins(context):
467 assert guarded_eval("round", context()) == round
467 assert guarded_eval("round", context()) == round
468
468
469
469
470 def test_access_builtins_fails():
470 def test_access_builtins_fails():
471 context = limited()
471 context = limited()
472 with pytest.raises(NameError):
472 with pytest.raises(NameError):
473 guarded_eval("this_is_not_builtin", context)
473 guarded_eval("this_is_not_builtin", context)
474
474
475
475
476 def test_rejects_forbidden():
476 def test_rejects_forbidden():
477 context = forbidden()
477 context = forbidden()
478 with pytest.raises(GuardRejection):
478 with pytest.raises(GuardRejection):
479 guarded_eval("1", context)
479 guarded_eval("1", context)
480
480
481
481
482 def test_guards_locals_and_globals():
482 def test_guards_locals_and_globals():
483 context = EvaluationContext(
483 context = EvaluationContext(
484 locals={"local_a": "a"}, globals={"global_b": "b"}, evaluation="minimal"
484 locals={"local_a": "a"}, globals={"global_b": "b"}, evaluation="minimal"
485 )
485 )
486
486
487 with pytest.raises(GuardRejection):
487 with pytest.raises(GuardRejection):
488 guarded_eval("local_a", context)
488 guarded_eval("local_a", context)
489
489
490 with pytest.raises(GuardRejection):
490 with pytest.raises(GuardRejection):
491 guarded_eval("global_b", context)
491 guarded_eval("global_b", context)
492
492
493
493
494 def test_access_locals_and_globals():
494 def test_access_locals_and_globals():
495 context = EvaluationContext(
495 context = EvaluationContext(
496 locals={"local_a": "a"}, globals={"global_b": "b"}, evaluation="limited"
496 locals={"local_a": "a"}, globals={"global_b": "b"}, evaluation="limited"
497 )
497 )
498 assert guarded_eval("local_a", context) == "a"
498 assert guarded_eval("local_a", context) == "a"
499 assert guarded_eval("global_b", context) == "b"
499 assert guarded_eval("global_b", context) == "b"
500
500
501
501
502 @pytest.mark.parametrize(
502 @pytest.mark.parametrize(
503 "code",
503 "code",
504 ["def func(): pass", "class C: pass", "x = 1", "x += 1", "del x", "import ast"],
504 ["def func(): pass", "class C: pass", "x = 1", "x += 1", "del x", "import ast"],
505 )
505 )
506 @pytest.mark.parametrize("context", [minimal(), limited(), unsafe()])
506 @pytest.mark.parametrize("context", [minimal(), limited(), unsafe()])
507 def test_rejects_side_effect_syntax(code, context):
507 def test_rejects_side_effect_syntax(code, context):
508 with pytest.raises(SyntaxError):
508 with pytest.raises(SyntaxError):
509 guarded_eval(code, context)
509 guarded_eval(code, context)
510
510
511
511
512 def test_subscript():
512 def test_subscript():
513 context = EvaluationContext(
513 context = EvaluationContext(
514 locals={}, globals={}, evaluation="limited", in_subscript=True
514 locals={}, globals={}, evaluation="limited", in_subscript=True
515 )
515 )
516 empty_slice = slice(None, None, None)
516 empty_slice = slice(None, None, None)
517 assert guarded_eval("", context) == tuple()
517 assert guarded_eval("", context) == tuple()
518 assert guarded_eval(":", context) == empty_slice
518 assert guarded_eval(":", context) == empty_slice
519 assert guarded_eval("1:2:3", context) == slice(1, 2, 3)
519 assert guarded_eval("1:2:3", context) == slice(1, 2, 3)
520 assert guarded_eval(':, "a"', context) == (empty_slice, "a")
520 assert guarded_eval(':, "a"', context) == (empty_slice, "a")
521
521
522
522
523 def test_unbind_method():
523 def test_unbind_method():
524 class X(list):
524 class X(list):
525 def index(self, k):
525 def index(self, k):
526 return "CUSTOM"
526 return "CUSTOM"
527
527
528 x = X()
528 x = X()
529 assert _unbind_method(x.index) is X.index
529 assert _unbind_method(x.index) is X.index
530 assert _unbind_method([].index) is list.index
530 assert _unbind_method([].index) is list.index
531 assert _unbind_method(list.index) is None
531 assert _unbind_method(list.index) is None
532
532
533
533
534 def test_assumption_instance_attr_do_not_matter():
534 def test_assumption_instance_attr_do_not_matter():
535 """This is semi-specified in Python documentation.
535 """This is semi-specified in Python documentation.
536
536
537 However, since the specification says 'not guaranted
537 However, since the specification says 'not guaranted
538 to work' rather than 'is forbidden to work', future
538 to work' rather than 'is forbidden to work', future
539 versions could invalidate this assumptions. This test
539 versions could invalidate this assumptions. This test
540 is meant to catch such a change if it ever comes true.
540 is meant to catch such a change if it ever comes true.
541 """
541 """
542
542
543 class T:
543 class T:
544 def __getitem__(self, k):
544 def __getitem__(self, k):
545 return "a"
545 return "a"
546
546
547 def __getattr__(self, k):
547 def __getattr__(self, k):
548 return "a"
548 return "a"
549
549
550 def f(self):
550 def f(self):
551 return "b"
551 return "b"
552
552
553 t = T()
553 t = T()
554 t.__getitem__ = f
554 t.__getitem__ = f
555 t.__getattr__ = f
555 t.__getattr__ = f
556 assert t[1] == "a"
556 assert t[1] == "a"
557 assert t[1] == "a"
557 assert t[1] == "a"
558
558
559
559
560 def test_assumption_named_tuples_share_getitem():
560 def test_assumption_named_tuples_share_getitem():
561 """Check assumption on named tuples sharing __getitem__"""
561 """Check assumption on named tuples sharing __getitem__"""
562 from typing import NamedTuple
562 from typing import NamedTuple
563
563
564 class A(NamedTuple):
564 class A(NamedTuple):
565 pass
565 pass
566
566
567 class B(NamedTuple):
567 class B(NamedTuple):
568 pass
568 pass
569
569
570 assert A.__getitem__ == B.__getitem__
570 assert A.__getitem__ == B.__getitem__
571
572
573 @dec.skip_without("numpy")
574 def test_module_access():
575 import numpy
576
577 context = limited(numpy=numpy)
578 assert guarded_eval("numpy.linalg.norm", context) == numpy.linalg.norm
579
580 context = minimal(numpy=numpy)
581 with pytest.raises(GuardRejection):
582 guarded_eval("np.linalg.norm", context)
General Comments 0
You need to be logged in to leave comments. Login now