##// END OF EJS Templates
Add more tests
krassowski -
Show More
@@ -1,736 +1,738 b''
1 from typing import (
1 from typing import (
2 Any,
2 Any,
3 Callable,
3 Callable,
4 Dict,
4 Dict,
5 Set,
5 Set,
6 Sequence,
6 Sequence,
7 Tuple,
7 Tuple,
8 NamedTuple,
8 NamedTuple,
9 Type,
9 Type,
10 Literal,
10 Literal,
11 Union,
11 Union,
12 TYPE_CHECKING,
12 TYPE_CHECKING,
13 )
13 )
14 import ast
14 import ast
15 import builtins
15 import builtins
16 import collections
16 import collections
17 import operator
17 import operator
18 import sys
18 import sys
19 from functools import cached_property
19 from functools import cached_property
20 from dataclasses import dataclass, field
20 from dataclasses import dataclass, field
21
21
22 from IPython.utils.docs import GENERATING_DOCUMENTATION
22 from IPython.utils.docs import GENERATING_DOCUMENTATION
23 from IPython.utils.decorators import undoc
23 from IPython.utils.decorators import undoc
24
24
25
25
26 if TYPE_CHECKING or GENERATING_DOCUMENTATION:
26 if TYPE_CHECKING or GENERATING_DOCUMENTATION:
27 from typing_extensions import Protocol
27 from typing_extensions import Protocol
28 else:
28 else:
29 # do not require on runtime
29 # do not require on runtime
30 Protocol = object # requires Python >=3.8
30 Protocol = object # requires Python >=3.8
31
31
32
32
33 @undoc
33 @undoc
34 class HasGetItem(Protocol):
34 class HasGetItem(Protocol):
35 def __getitem__(self, key) -> None:
35 def __getitem__(self, key) -> None:
36 ...
36 ...
37
37
38
38
39 @undoc
39 @undoc
40 class InstancesHaveGetItem(Protocol):
40 class InstancesHaveGetItem(Protocol):
41 def __call__(self, *args, **kwargs) -> HasGetItem:
41 def __call__(self, *args, **kwargs) -> HasGetItem:
42 ...
42 ...
43
43
44
44
45 @undoc
45 @undoc
46 class HasGetAttr(Protocol):
46 class HasGetAttr(Protocol):
47 def __getattr__(self, key) -> None:
47 def __getattr__(self, key) -> None:
48 ...
48 ...
49
49
50
50
51 @undoc
51 @undoc
52 class DoesNotHaveGetAttr(Protocol):
52 class DoesNotHaveGetAttr(Protocol):
53 pass
53 pass
54
54
55
55
56 # By default `__getattr__` is not explicitly implemented on most objects
56 # By default `__getattr__` is not explicitly implemented on most objects
57 MayHaveGetattr = Union[HasGetAttr, DoesNotHaveGetAttr]
57 MayHaveGetattr = Union[HasGetAttr, DoesNotHaveGetAttr]
58
58
59
59
60 def _unbind_method(func: Callable) -> Union[Callable, None]:
60 def _unbind_method(func: Callable) -> Union[Callable, None]:
61 """Get unbound method for given bound method.
61 """Get unbound method for given bound method.
62
62
63 Returns None if cannot get unbound method."""
63 Returns None if cannot get unbound method, or method is already unbound.
64 """
64 owner = getattr(func, "__self__", None)
65 owner = getattr(func, "__self__", None)
65 owner_class = type(owner)
66 owner_class = type(owner)
66 name = getattr(func, "__name__", None)
67 name = getattr(func, "__name__", None)
67 instance_dict_overrides = getattr(owner, "__dict__", None)
68 instance_dict_overrides = getattr(owner, "__dict__", None)
68 if (
69 if (
69 owner is not None
70 owner is not None
70 and name
71 and name
71 and (
72 and (
72 not instance_dict_overrides
73 not instance_dict_overrides
73 or (instance_dict_overrides and name not in instance_dict_overrides)
74 or (instance_dict_overrides and name not in instance_dict_overrides)
74 )
75 )
75 ):
76 ):
76 return getattr(owner_class, name)
77 return getattr(owner_class, name)
77 return None
78 return None
78
79
79
80
80 @undoc
81 @undoc
81 @dataclass
82 @dataclass
82 class EvaluationPolicy:
83 class EvaluationPolicy:
83 """Definition of evaluation policy."""
84 """Definition of evaluation policy."""
84
85
85 allow_locals_access: bool = False
86 allow_locals_access: bool = False
86 allow_globals_access: bool = False
87 allow_globals_access: bool = False
87 allow_item_access: bool = False
88 allow_item_access: bool = False
88 allow_attr_access: bool = False
89 allow_attr_access: bool = False
89 allow_builtins_access: bool = False
90 allow_builtins_access: bool = False
90 allow_all_operations: bool = False
91 allow_all_operations: bool = False
91 allow_any_calls: bool = False
92 allow_any_calls: bool = False
92 allowed_calls: Set[Callable] = field(default_factory=set)
93 allowed_calls: Set[Callable] = field(default_factory=set)
93
94
94 def can_get_item(self, value, item):
95 def can_get_item(self, value, item):
95 return self.allow_item_access
96 return self.allow_item_access
96
97
97 def can_get_attr(self, value, attr):
98 def can_get_attr(self, value, attr):
98 return self.allow_attr_access
99 return self.allow_attr_access
99
100
100 def can_operate(self, dunders: Tuple[str, ...], a, b=None):
101 def can_operate(self, dunders: Tuple[str, ...], a, b=None):
101 if self.allow_all_operations:
102 if self.allow_all_operations:
102 return True
103 return True
103
104
104 def can_call(self, func):
105 def can_call(self, func):
105 if self.allow_any_calls:
106 if self.allow_any_calls:
106 return True
107 return True
107
108
108 if func in self.allowed_calls:
109 if func in self.allowed_calls:
109 return True
110 return True
110
111
111 owner_method = _unbind_method(func)
112 owner_method = _unbind_method(func)
112
113
113 if owner_method and owner_method in self.allowed_calls:
114 if owner_method and owner_method in self.allowed_calls:
114 return True
115 return True
115
116
116
117
117 def _get_external(module_name: str, access_path: Sequence[str]):
118 def _get_external(module_name: str, access_path: Sequence[str]):
118 """Get value from external module given a dotted access path.
119 """Get value from external module given a dotted access path.
119
120
120 Raises:
121 Raises:
121 * `KeyError` if module is removed not found, and
122 * `KeyError` if module is removed not found, and
122 * `AttributeError` if acess path does not match an exported object
123 * `AttributeError` if acess path does not match an exported object
123 """
124 """
124 member_type = sys.modules[module_name]
125 member_type = sys.modules[module_name]
125 for attr in access_path:
126 for attr in access_path:
126 member_type = getattr(member_type, attr)
127 member_type = getattr(member_type, attr)
127 return member_type
128 return member_type
128
129
129
130
130 def _has_original_dunder_external(
131 def _has_original_dunder_external(
131 value,
132 value,
132 module_name: str,
133 module_name: str,
133 access_path: Sequence[str],
134 access_path: Sequence[str],
134 method_name: str,
135 method_name: str,
135 ):
136 ):
136 if module_name not in sys.modules:
137 if module_name not in sys.modules:
137 # LBYLB as it is faster
138 # LBYLB as it is faster
138 return False
139 return False
139 try:
140 try:
140 member_type = _get_external(module_name, access_path)
141 member_type = _get_external(module_name, access_path)
141 value_type = type(value)
142 value_type = type(value)
142 if type(value) == member_type:
143 if type(value) == member_type:
143 return True
144 return True
144 if method_name == "__getattribute__":
145 if method_name == "__getattribute__":
145 # we have to short-circuit here due to an unresolved issue in
146 # we have to short-circuit here due to an unresolved issue in
146 # `isinstance` implementation: https://bugs.python.org/issue32683
147 # `isinstance` implementation: https://bugs.python.org/issue32683
147 return False
148 return False
148 if isinstance(value, member_type):
149 if isinstance(value, member_type):
149 method = getattr(value_type, method_name, None)
150 method = getattr(value_type, method_name, None)
150 member_method = getattr(member_type, method_name, None)
151 member_method = getattr(member_type, method_name, None)
151 if member_method == method:
152 if member_method == method:
152 return True
153 return True
153 except (AttributeError, KeyError):
154 except (AttributeError, KeyError):
154 return False
155 return False
155
156
156
157
157 def _has_original_dunder(
158 def _has_original_dunder(
158 value, allowed_types, allowed_methods, allowed_external, method_name
159 value, allowed_types, allowed_methods, allowed_external, method_name
159 ):
160 ):
160 # note: Python ignores `__getattr__`/`__getitem__` on instances,
161 # note: Python ignores `__getattr__`/`__getitem__` on instances,
161 # we only need to check at class level
162 # we only need to check at class level
162 value_type = type(value)
163 value_type = type(value)
163
164
164 # strict type check passes β†’ no need to check method
165 # strict type check passes β†’ no need to check method
165 if value_type in allowed_types:
166 if value_type in allowed_types:
166 return True
167 return True
167
168
168 method = getattr(value_type, method_name, None)
169 method = getattr(value_type, method_name, None)
169
170
170 if method is None:
171 if method is None:
171 return None
172 return None
172
173
173 if method in allowed_methods:
174 if method in allowed_methods:
174 return True
175 return True
175
176
176 for module_name, *access_path in allowed_external:
177 for module_name, *access_path in allowed_external:
177 if _has_original_dunder_external(value, module_name, access_path, method_name):
178 if _has_original_dunder_external(value, module_name, access_path, method_name):
178 return True
179 return True
179
180
180 return False
181 return False
181
182
182
183
183 @undoc
184 @undoc
184 @dataclass
185 @dataclass
185 class SelectivePolicy(EvaluationPolicy):
186 class SelectivePolicy(EvaluationPolicy):
186 allowed_getitem: Set[InstancesHaveGetItem] = field(default_factory=set)
187 allowed_getitem: Set[InstancesHaveGetItem] = field(default_factory=set)
187 allowed_getitem_external: Set[Tuple[str, ...]] = field(default_factory=set)
188 allowed_getitem_external: Set[Tuple[str, ...]] = field(default_factory=set)
188
189
189 allowed_getattr: Set[MayHaveGetattr] = field(default_factory=set)
190 allowed_getattr: Set[MayHaveGetattr] = field(default_factory=set)
190 allowed_getattr_external: Set[Tuple[str, ...]] = field(default_factory=set)
191 allowed_getattr_external: Set[Tuple[str, ...]] = field(default_factory=set)
191
192
192 allowed_operations: Set = field(default_factory=set)
193 allowed_operations: Set = field(default_factory=set)
193 allowed_operations_external: Set[Tuple[str, ...]] = field(default_factory=set)
194 allowed_operations_external: Set[Tuple[str, ...]] = field(default_factory=set)
194
195
195 _operation_methods_cache: Dict[str, Set[Callable]] = field(
196 _operation_methods_cache: Dict[str, Set[Callable]] = field(
196 default_factory=dict, init=False
197 default_factory=dict, init=False
197 )
198 )
198
199
199 def can_get_attr(self, value, attr):
200 def can_get_attr(self, value, attr):
200 has_original_attribute = _has_original_dunder(
201 has_original_attribute = _has_original_dunder(
201 value,
202 value,
202 allowed_types=self.allowed_getattr,
203 allowed_types=self.allowed_getattr,
203 allowed_methods=self._getattribute_methods,
204 allowed_methods=self._getattribute_methods,
204 allowed_external=self.allowed_getattr_external,
205 allowed_external=self.allowed_getattr_external,
205 method_name="__getattribute__",
206 method_name="__getattribute__",
206 )
207 )
207 has_original_attr = _has_original_dunder(
208 has_original_attr = _has_original_dunder(
208 value,
209 value,
209 allowed_types=self.allowed_getattr,
210 allowed_types=self.allowed_getattr,
210 allowed_methods=self._getattr_methods,
211 allowed_methods=self._getattr_methods,
211 allowed_external=self.allowed_getattr_external,
212 allowed_external=self.allowed_getattr_external,
212 method_name="__getattr__",
213 method_name="__getattr__",
213 )
214 )
214
215
215 accept = False
216 accept = False
216
217
217 # Many objects do not have `__getattr__`, this is fine
218 # Many objects do not have `__getattr__`, this is fine.
218 if has_original_attr is None and has_original_attribute:
219 if has_original_attr is None and has_original_attribute:
219 accept = True
220 accept = True
220 else:
221 else:
221 # Accept objects without modifications to `__getattr__` and `__getattribute__`
222 # Accept objects without modifications to `__getattr__` and `__getattribute__`
222 accept = has_original_attr and has_original_attribute
223 accept = has_original_attr and has_original_attribute
223
224
224 if accept:
225 if accept:
225 # We still need to check for overriden properties.
226 # We still need to check for overriden properties.
226
227
227 value_class = type(value)
228 value_class = type(value)
228 if not hasattr(value_class, attr):
229 if not hasattr(value_class, attr):
229 return True
230 return True
230
231
231 class_attr_val = getattr(value_class, attr)
232 class_attr_val = getattr(value_class, attr)
232 is_property = isinstance(class_attr_val, property)
233 is_property = isinstance(class_attr_val, property)
233
234
234 if not is_property:
235 if not is_property:
235 return True
236 return True
236
237
237 # Properties in allowed types are ok
238 # Properties in allowed types are ok (although we do not include any
239 # properties in our default allow list currently).
238 if type(value) in self.allowed_getattr:
240 if type(value) in self.allowed_getattr:
239 return True
241 return True # pragma: no cover
240
242
241 # Properties in subclasses of allowed types may be ok if not changed
243 # Properties in subclasses of allowed types may be ok if not changed
242 for module_name, *access_path in self.allowed_getattr_external:
244 for module_name, *access_path in self.allowed_getattr_external:
243 try:
245 try:
244 external_class = _get_external(module_name, access_path)
246 external_class = _get_external(module_name, access_path)
245 external_class_attr_val = getattr(external_class, attr)
247 external_class_attr_val = getattr(external_class, attr)
246 except (KeyError, AttributeError):
248 except (KeyError, AttributeError):
247 return False # pragma: no cover
249 return False # pragma: no cover
248 return class_attr_val == external_class_attr_val
250 return class_attr_val == external_class_attr_val
249
251
250 return False
252 return False
251
253
252 def can_get_item(self, value, item):
254 def can_get_item(self, value, item):
253 """Allow accessing `__getiitem__` of allow-listed instances unless it was not modified."""
255 """Allow accessing `__getiitem__` of allow-listed instances unless it was not modified."""
254 return _has_original_dunder(
256 return _has_original_dunder(
255 value,
257 value,
256 allowed_types=self.allowed_getitem,
258 allowed_types=self.allowed_getitem,
257 allowed_methods=self._getitem_methods,
259 allowed_methods=self._getitem_methods,
258 allowed_external=self.allowed_getitem_external,
260 allowed_external=self.allowed_getitem_external,
259 method_name="__getitem__",
261 method_name="__getitem__",
260 )
262 )
261
263
262 def can_operate(self, dunders: Tuple[str, ...], a, b=None):
264 def can_operate(self, dunders: Tuple[str, ...], a, b=None):
263 objects = [a]
265 objects = [a]
264 if b is not None:
266 if b is not None:
265 objects.append(b)
267 objects.append(b)
266 return all(
268 return all(
267 [
269 [
268 _has_original_dunder(
270 _has_original_dunder(
269 obj,
271 obj,
270 allowed_types=self.allowed_operations,
272 allowed_types=self.allowed_operations,
271 allowed_methods=self._operator_dunder_methods(dunder),
273 allowed_methods=self._operator_dunder_methods(dunder),
272 allowed_external=self.allowed_operations_external,
274 allowed_external=self.allowed_operations_external,
273 method_name=dunder,
275 method_name=dunder,
274 )
276 )
275 for dunder in dunders
277 for dunder in dunders
276 for obj in objects
278 for obj in objects
277 ]
279 ]
278 )
280 )
279
281
280 def _operator_dunder_methods(self, dunder: str) -> Set[Callable]:
282 def _operator_dunder_methods(self, dunder: str) -> Set[Callable]:
281 if dunder not in self._operation_methods_cache:
283 if dunder not in self._operation_methods_cache:
282 self._operation_methods_cache[dunder] = self._safe_get_methods(
284 self._operation_methods_cache[dunder] = self._safe_get_methods(
283 self.allowed_operations, dunder
285 self.allowed_operations, dunder
284 )
286 )
285 return self._operation_methods_cache[dunder]
287 return self._operation_methods_cache[dunder]
286
288
287 @cached_property
289 @cached_property
288 def _getitem_methods(self) -> Set[Callable]:
290 def _getitem_methods(self) -> Set[Callable]:
289 return self._safe_get_methods(self.allowed_getitem, "__getitem__")
291 return self._safe_get_methods(self.allowed_getitem, "__getitem__")
290
292
291 @cached_property
293 @cached_property
292 def _getattr_methods(self) -> Set[Callable]:
294 def _getattr_methods(self) -> Set[Callable]:
293 return self._safe_get_methods(self.allowed_getattr, "__getattr__")
295 return self._safe_get_methods(self.allowed_getattr, "__getattr__")
294
296
295 @cached_property
297 @cached_property
296 def _getattribute_methods(self) -> Set[Callable]:
298 def _getattribute_methods(self) -> Set[Callable]:
297 return self._safe_get_methods(self.allowed_getattr, "__getattribute__")
299 return self._safe_get_methods(self.allowed_getattr, "__getattribute__")
298
300
299 def _safe_get_methods(self, classes, name) -> Set[Callable]:
301 def _safe_get_methods(self, classes, name) -> Set[Callable]:
300 return {
302 return {
301 method
303 method
302 for class_ in classes
304 for class_ in classes
303 for method in [getattr(class_, name, None)]
305 for method in [getattr(class_, name, None)]
304 if method
306 if method
305 }
307 }
306
308
307
309
308 class _DummyNamedTuple(NamedTuple):
310 class _DummyNamedTuple(NamedTuple):
309 """Used internally to retrieve methods of named tuple instance."""
311 """Used internally to retrieve methods of named tuple instance."""
310
312
311
313
312 class EvaluationContext(NamedTuple):
314 class EvaluationContext(NamedTuple):
313 #: Local namespace
315 #: Local namespace
314 locals: dict
316 locals: dict
315 #: Global namespace
317 #: Global namespace
316 globals: dict
318 globals: dict
317 #: Evaluation policy identifier
319 #: Evaluation policy identifier
318 evaluation: Literal[
320 evaluation: Literal[
319 "forbidden", "minimal", "limited", "unsafe", "dangerous"
321 "forbidden", "minimal", "limited", "unsafe", "dangerous"
320 ] = "forbidden"
322 ] = "forbidden"
321 #: Whether the evalution of code takes place inside of a subscript.
323 #: Whether the evalution of code takes place inside of a subscript.
322 #: Useful for evaluating ``:-1, 'col'`` in ``df[:-1, 'col']``.
324 #: Useful for evaluating ``:-1, 'col'`` in ``df[:-1, 'col']``.
323 in_subscript: bool = False
325 in_subscript: bool = False
324
326
325
327
326 class _IdentitySubscript:
328 class _IdentitySubscript:
327 """Returns the key itself when item is requested via subscript."""
329 """Returns the key itself when item is requested via subscript."""
328
330
329 def __getitem__(self, key):
331 def __getitem__(self, key):
330 return key
332 return key
331
333
332
334
333 IDENTITY_SUBSCRIPT = _IdentitySubscript()
335 IDENTITY_SUBSCRIPT = _IdentitySubscript()
334 SUBSCRIPT_MARKER = "__SUBSCRIPT_SENTINEL__"
336 SUBSCRIPT_MARKER = "__SUBSCRIPT_SENTINEL__"
335
337
336
338
337 class GuardRejection(Exception):
339 class GuardRejection(Exception):
338 """Exception raised when guard rejects evaluation attempt."""
340 """Exception raised when guard rejects evaluation attempt."""
339
341
340 pass
342 pass
341
343
342
344
343 def guarded_eval(code: str, context: EvaluationContext):
345 def guarded_eval(code: str, context: EvaluationContext):
344 """Evaluate provided code in the evaluation context.
346 """Evaluate provided code in the evaluation context.
345
347
346 If evaluation policy given by context is set to ``forbidden``
348 If evaluation policy given by context is set to ``forbidden``
347 no evaluation will be performed; if it is set to ``dangerous``
349 no evaluation will be performed; if it is set to ``dangerous``
348 standard :func:`eval` will be used; finally, for any other,
350 standard :func:`eval` will be used; finally, for any other,
349 policy :func:`eval_node` will be called on parsed AST.
351 policy :func:`eval_node` will be called on parsed AST.
350 """
352 """
351 locals_ = context.locals
353 locals_ = context.locals
352
354
353 if context.evaluation == "forbidden":
355 if context.evaluation == "forbidden":
354 raise GuardRejection("Forbidden mode")
356 raise GuardRejection("Forbidden mode")
355
357
356 # note: not using `ast.literal_eval` as it does not implement
358 # note: not using `ast.literal_eval` as it does not implement
357 # getitem at all, for example it fails on simple `[0][1]`
359 # getitem at all, for example it fails on simple `[0][1]`
358
360
359 if context.in_subscript:
361 if context.in_subscript:
360 # syntatic sugar for ellipsis (:) is only available in susbcripts
362 # syntatic sugar for ellipsis (:) is only available in susbcripts
361 # so we need to trick the ast parser into thinking that we have
363 # so we need to trick the ast parser into thinking that we have
362 # a subscript, but we need to be able to later recognise that we did
364 # a subscript, but we need to be able to later recognise that we did
363 # it so we can ignore the actual __getitem__ operation
365 # it so we can ignore the actual __getitem__ operation
364 if not code:
366 if not code:
365 return tuple()
367 return tuple()
366 locals_ = locals_.copy()
368 locals_ = locals_.copy()
367 locals_[SUBSCRIPT_MARKER] = IDENTITY_SUBSCRIPT
369 locals_[SUBSCRIPT_MARKER] = IDENTITY_SUBSCRIPT
368 code = SUBSCRIPT_MARKER + "[" + code + "]"
370 code = SUBSCRIPT_MARKER + "[" + code + "]"
369 context = EvaluationContext(**{**context._asdict(), **{"locals": locals_}})
371 context = EvaluationContext(**{**context._asdict(), **{"locals": locals_}})
370
372
371 if context.evaluation == "dangerous":
373 if context.evaluation == "dangerous":
372 return eval(code, context.globals, context.locals)
374 return eval(code, context.globals, context.locals)
373
375
374 expression = ast.parse(code, mode="eval")
376 expression = ast.parse(code, mode="eval")
375
377
376 return eval_node(expression, context)
378 return eval_node(expression, context)
377
379
378
380
379 BINARY_OP_DUNDERS: Dict[Type[ast.operator], Tuple[str]] = {
381 BINARY_OP_DUNDERS: Dict[Type[ast.operator], Tuple[str]] = {
380 ast.Add: ("__add__",),
382 ast.Add: ("__add__",),
381 ast.Sub: ("__sub__",),
383 ast.Sub: ("__sub__",),
382 ast.Mult: ("__mul__",),
384 ast.Mult: ("__mul__",),
383 ast.Div: ("__truediv__",),
385 ast.Div: ("__truediv__",),
384 ast.FloorDiv: ("__floordiv__",),
386 ast.FloorDiv: ("__floordiv__",),
385 ast.Mod: ("__mod__",),
387 ast.Mod: ("__mod__",),
386 ast.Pow: ("__pow__",),
388 ast.Pow: ("__pow__",),
387 ast.LShift: ("__lshift__",),
389 ast.LShift: ("__lshift__",),
388 ast.RShift: ("__rshift__",),
390 ast.RShift: ("__rshift__",),
389 ast.BitOr: ("__or__",),
391 ast.BitOr: ("__or__",),
390 ast.BitXor: ("__xor__",),
392 ast.BitXor: ("__xor__",),
391 ast.BitAnd: ("__and__",),
393 ast.BitAnd: ("__and__",),
392 ast.MatMult: ("__matmul__",),
394 ast.MatMult: ("__matmul__",),
393 }
395 }
394
396
395 COMP_OP_DUNDERS: Dict[Type[ast.cmpop], Tuple[str, ...]] = {
397 COMP_OP_DUNDERS: Dict[Type[ast.cmpop], Tuple[str, ...]] = {
396 ast.Eq: ("__eq__",),
398 ast.Eq: ("__eq__",),
397 ast.NotEq: ("__ne__", "__eq__"),
399 ast.NotEq: ("__ne__", "__eq__"),
398 ast.Lt: ("__lt__", "__gt__"),
400 ast.Lt: ("__lt__", "__gt__"),
399 ast.LtE: ("__le__", "__ge__"),
401 ast.LtE: ("__le__", "__ge__"),
400 ast.Gt: ("__gt__", "__lt__"),
402 ast.Gt: ("__gt__", "__lt__"),
401 ast.GtE: ("__ge__", "__le__"),
403 ast.GtE: ("__ge__", "__le__"),
402 ast.In: ("__contains__",),
404 ast.In: ("__contains__",),
403 # Note: ast.Is, ast.IsNot, ast.NotIn are handled specially
405 # Note: ast.Is, ast.IsNot, ast.NotIn are handled specially
404 }
406 }
405
407
406 UNARY_OP_DUNDERS: Dict[Type[ast.unaryop], Tuple[str, ...]] = {
408 UNARY_OP_DUNDERS: Dict[Type[ast.unaryop], Tuple[str, ...]] = {
407 ast.USub: ("__neg__",),
409 ast.USub: ("__neg__",),
408 ast.UAdd: ("__pos__",),
410 ast.UAdd: ("__pos__",),
409 # we have to check both __inv__ and __invert__!
411 # we have to check both __inv__ and __invert__!
410 ast.Invert: ("__invert__", "__inv__"),
412 ast.Invert: ("__invert__", "__inv__"),
411 ast.Not: ("__not__",),
413 ast.Not: ("__not__",),
412 }
414 }
413
415
414
416
415 def _find_dunder(node_op, dunders) -> Union[Tuple[str, ...], None]:
417 def _find_dunder(node_op, dunders) -> Union[Tuple[str, ...], None]:
416 dunder = None
418 dunder = None
417 for op, candidate_dunder in dunders.items():
419 for op, candidate_dunder in dunders.items():
418 if isinstance(node_op, op):
420 if isinstance(node_op, op):
419 dunder = candidate_dunder
421 dunder = candidate_dunder
420 return dunder
422 return dunder
421
423
422
424
423 def eval_node(node: Union[ast.AST, None], context: EvaluationContext):
425 def eval_node(node: Union[ast.AST, None], context: EvaluationContext):
424 """Evaluate AST node in provided context.
426 """Evaluate AST node in provided context.
425
427
426 Applies evaluation restrictions defined in the context. Currently does not support evaluation of functions with keyword arguments.
428 Applies evaluation restrictions defined in the context. Currently does not support evaluation of functions with keyword arguments.
427
429
428 Does not evaluate actions that always have side effects:
430 Does not evaluate actions that always have side effects:
429
431
430 - class definitions (``class sth: ...``)
432 - class definitions (``class sth: ...``)
431 - function definitions (``def sth: ...``)
433 - function definitions (``def sth: ...``)
432 - variable assignments (``x = 1``)
434 - variable assignments (``x = 1``)
433 - augmented assignments (``x += 1``)
435 - augmented assignments (``x += 1``)
434 - deletions (``del x``)
436 - deletions (``del x``)
435
437
436 Does not evaluate operations which do not return values:
438 Does not evaluate operations which do not return values:
437
439
438 - assertions (``assert x``)
440 - assertions (``assert x``)
439 - pass (``pass``)
441 - pass (``pass``)
440 - imports (``import x``)
442 - imports (``import x``)
441 - control flow:
443 - control flow:
442
444
443 - conditionals (``if x:``) except for ternary IfExp (``a if x else b``)
445 - conditionals (``if x:``) except for ternary IfExp (``a if x else b``)
444 - loops (``for`` and `while``)
446 - loops (``for`` and `while``)
445 - exception handling
447 - exception handling
446
448
447 The purpose of this function is to guard against unwanted side-effects;
449 The purpose of this function is to guard against unwanted side-effects;
448 it does not give guarantees on protection from malicious code execution.
450 it does not give guarantees on protection from malicious code execution.
449 """
451 """
450 policy = EVALUATION_POLICIES[context.evaluation]
452 policy = EVALUATION_POLICIES[context.evaluation]
451 if node is None:
453 if node is None:
452 return None
454 return None
453 if isinstance(node, ast.Expression):
455 if isinstance(node, ast.Expression):
454 return eval_node(node.body, context)
456 return eval_node(node.body, context)
455 if isinstance(node, ast.BinOp):
457 if isinstance(node, ast.BinOp):
456 left = eval_node(node.left, context)
458 left = eval_node(node.left, context)
457 right = eval_node(node.right, context)
459 right = eval_node(node.right, context)
458 dunders = _find_dunder(node.op, BINARY_OP_DUNDERS)
460 dunders = _find_dunder(node.op, BINARY_OP_DUNDERS)
459 if dunders:
461 if dunders:
460 if policy.can_operate(dunders, left, right):
462 if policy.can_operate(dunders, left, right):
461 return getattr(left, dunders[0])(right)
463 return getattr(left, dunders[0])(right)
462 else:
464 else:
463 raise GuardRejection(
465 raise GuardRejection(
464 f"Operation (`{dunders}`) for",
466 f"Operation (`{dunders}`) for",
465 type(left),
467 type(left),
466 f"not allowed in {context.evaluation} mode",
468 f"not allowed in {context.evaluation} mode",
467 )
469 )
468 if isinstance(node, ast.Compare):
470 if isinstance(node, ast.Compare):
469 left = eval_node(node.left, context)
471 left = eval_node(node.left, context)
470 all_true = True
472 all_true = True
471 negate = False
473 negate = False
472 for op, right in zip(node.ops, node.comparators):
474 for op, right in zip(node.ops, node.comparators):
473 right = eval_node(right, context)
475 right = eval_node(right, context)
474 dunder = None
476 dunder = None
475 dunders = _find_dunder(op, COMP_OP_DUNDERS)
477 dunders = _find_dunder(op, COMP_OP_DUNDERS)
476 if not dunders:
478 if not dunders:
477 if isinstance(op, ast.NotIn):
479 if isinstance(op, ast.NotIn):
478 dunders = COMP_OP_DUNDERS[ast.In]
480 dunders = COMP_OP_DUNDERS[ast.In]
479 negate = True
481 negate = True
480 if isinstance(op, ast.Is):
482 if isinstance(op, ast.Is):
481 dunder = "is_"
483 dunder = "is_"
482 if isinstance(op, ast.IsNot):
484 if isinstance(op, ast.IsNot):
483 dunder = "is_"
485 dunder = "is_"
484 negate = True
486 negate = True
485 if not dunder and dunders:
487 if not dunder and dunders:
486 dunder = dunders[0]
488 dunder = dunders[0]
487 if dunder:
489 if dunder:
488 a, b = (right, left) if dunder == "__contains__" else (left, right)
490 a, b = (right, left) if dunder == "__contains__" else (left, right)
489 if dunder == "is_" or dunders and policy.can_operate(dunders, a, b):
491 if dunder == "is_" or dunders and policy.can_operate(dunders, a, b):
490 result = getattr(operator, dunder)(a, b)
492 result = getattr(operator, dunder)(a, b)
491 if negate:
493 if negate:
492 result = not result
494 result = not result
493 if not result:
495 if not result:
494 all_true = False
496 all_true = False
495 left = right
497 left = right
496 else:
498 else:
497 raise GuardRejection(
499 raise GuardRejection(
498 f"Comparison (`{dunder}`) for",
500 f"Comparison (`{dunder}`) for",
499 type(left),
501 type(left),
500 f"not allowed in {context.evaluation} mode",
502 f"not allowed in {context.evaluation} mode",
501 )
503 )
502 else:
504 else:
503 raise ValueError(
505 raise ValueError(
504 f"Comparison `{dunder}` not supported"
506 f"Comparison `{dunder}` not supported"
505 ) # pragma: no cover
507 ) # pragma: no cover
506 return all_true
508 return all_true
507 if isinstance(node, ast.Constant):
509 if isinstance(node, ast.Constant):
508 return node.value
510 return node.value
509 if isinstance(node, ast.Index):
511 if isinstance(node, ast.Index):
510 # deprecated since Python 3.9
512 # deprecated since Python 3.9
511 return eval_node(node.value, context) # pragma: no cover
513 return eval_node(node.value, context) # pragma: no cover
512 if isinstance(node, ast.Tuple):
514 if isinstance(node, ast.Tuple):
513 return tuple(eval_node(e, context) for e in node.elts)
515 return tuple(eval_node(e, context) for e in node.elts)
514 if isinstance(node, ast.List):
516 if isinstance(node, ast.List):
515 return [eval_node(e, context) for e in node.elts]
517 return [eval_node(e, context) for e in node.elts]
516 if isinstance(node, ast.Set):
518 if isinstance(node, ast.Set):
517 return {eval_node(e, context) for e in node.elts}
519 return {eval_node(e, context) for e in node.elts}
518 if isinstance(node, ast.Dict):
520 if isinstance(node, ast.Dict):
519 return dict(
521 return dict(
520 zip(
522 zip(
521 [eval_node(k, context) for k in node.keys],
523 [eval_node(k, context) for k in node.keys],
522 [eval_node(v, context) for v in node.values],
524 [eval_node(v, context) for v in node.values],
523 )
525 )
524 )
526 )
525 if isinstance(node, ast.Slice):
527 if isinstance(node, ast.Slice):
526 return slice(
528 return slice(
527 eval_node(node.lower, context),
529 eval_node(node.lower, context),
528 eval_node(node.upper, context),
530 eval_node(node.upper, context),
529 eval_node(node.step, context),
531 eval_node(node.step, context),
530 )
532 )
531 if isinstance(node, ast.ExtSlice):
533 if isinstance(node, ast.ExtSlice):
532 # deprecated since Python 3.9
534 # deprecated since Python 3.9
533 return tuple([eval_node(dim, context) for dim in node.dims]) # pragma: no cover
535 return tuple([eval_node(dim, context) for dim in node.dims]) # pragma: no cover
534 if isinstance(node, ast.UnaryOp):
536 if isinstance(node, ast.UnaryOp):
535 value = eval_node(node.operand, context)
537 value = eval_node(node.operand, context)
536 dunders = _find_dunder(node.op, UNARY_OP_DUNDERS)
538 dunders = _find_dunder(node.op, UNARY_OP_DUNDERS)
537 if dunders:
539 if dunders:
538 if policy.can_operate(dunders, value):
540 if policy.can_operate(dunders, value):
539 return getattr(value, dunders[0])()
541 return getattr(value, dunders[0])()
540 else:
542 else:
541 raise GuardRejection(
543 raise GuardRejection(
542 f"Operation (`{dunders}`) for",
544 f"Operation (`{dunders}`) for",
543 type(value),
545 type(value),
544 f"not allowed in {context.evaluation} mode",
546 f"not allowed in {context.evaluation} mode",
545 )
547 )
546 if isinstance(node, ast.Subscript):
548 if isinstance(node, ast.Subscript):
547 value = eval_node(node.value, context)
549 value = eval_node(node.value, context)
548 slice_ = eval_node(node.slice, context)
550 slice_ = eval_node(node.slice, context)
549 if policy.can_get_item(value, slice_):
551 if policy.can_get_item(value, slice_):
550 return value[slice_]
552 return value[slice_]
551 raise GuardRejection(
553 raise GuardRejection(
552 "Subscript access (`__getitem__`) for",
554 "Subscript access (`__getitem__`) for",
553 type(value), # not joined to avoid calling `repr`
555 type(value), # not joined to avoid calling `repr`
554 f" not allowed in {context.evaluation} mode",
556 f" not allowed in {context.evaluation} mode",
555 )
557 )
556 if isinstance(node, ast.Name):
558 if isinstance(node, ast.Name):
557 if policy.allow_locals_access and node.id in context.locals:
559 if policy.allow_locals_access and node.id in context.locals:
558 return context.locals[node.id]
560 return context.locals[node.id]
559 if policy.allow_globals_access and node.id in context.globals:
561 if policy.allow_globals_access and node.id in context.globals:
560 return context.globals[node.id]
562 return context.globals[node.id]
561 if policy.allow_builtins_access and hasattr(builtins, node.id):
563 if policy.allow_builtins_access and hasattr(builtins, node.id):
562 # note: do not use __builtins__, it is implementation detail of cPython
564 # note: do not use __builtins__, it is implementation detail of cPython
563 return getattr(builtins, node.id)
565 return getattr(builtins, node.id)
564 if not policy.allow_globals_access and not policy.allow_locals_access:
566 if not policy.allow_globals_access and not policy.allow_locals_access:
565 raise GuardRejection(
567 raise GuardRejection(
566 f"Namespace access not allowed in {context.evaluation} mode"
568 f"Namespace access not allowed in {context.evaluation} mode"
567 )
569 )
568 else:
570 else:
569 raise NameError(f"{node.id} not found in locals, globals, nor builtins")
571 raise NameError(f"{node.id} not found in locals, globals, nor builtins")
570 if isinstance(node, ast.Attribute):
572 if isinstance(node, ast.Attribute):
571 value = eval_node(node.value, context)
573 value = eval_node(node.value, context)
572 if policy.can_get_attr(value, node.attr):
574 if policy.can_get_attr(value, node.attr):
573 return getattr(value, node.attr)
575 return getattr(value, node.attr)
574 raise GuardRejection(
576 raise GuardRejection(
575 "Attribute access (`__getattr__`) for",
577 "Attribute access (`__getattr__`) for",
576 type(value), # not joined to avoid calling `repr`
578 type(value), # not joined to avoid calling `repr`
577 f"not allowed in {context.evaluation} mode",
579 f"not allowed in {context.evaluation} mode",
578 )
580 )
579 if isinstance(node, ast.IfExp):
581 if isinstance(node, ast.IfExp):
580 test = eval_node(node.test, context)
582 test = eval_node(node.test, context)
581 if test:
583 if test:
582 return eval_node(node.body, context)
584 return eval_node(node.body, context)
583 else:
585 else:
584 return eval_node(node.orelse, context)
586 return eval_node(node.orelse, context)
585 if isinstance(node, ast.Call):
587 if isinstance(node, ast.Call):
586 func = eval_node(node.func, context)
588 func = eval_node(node.func, context)
587 if policy.can_call(func) and not node.keywords:
589 if policy.can_call(func) and not node.keywords:
588 args = [eval_node(arg, context) for arg in node.args]
590 args = [eval_node(arg, context) for arg in node.args]
589 return func(*args)
591 return func(*args)
590 raise GuardRejection(
592 raise GuardRejection(
591 "Call for",
593 "Call for",
592 func, # not joined to avoid calling `repr`
594 func, # not joined to avoid calling `repr`
593 f"not allowed in {context.evaluation} mode",
595 f"not allowed in {context.evaluation} mode",
594 )
596 )
595 raise ValueError("Unhandled node", ast.dump(node))
597 raise ValueError("Unhandled node", ast.dump(node))
596
598
597
599
598 SUPPORTED_EXTERNAL_GETITEM = {
600 SUPPORTED_EXTERNAL_GETITEM = {
599 ("pandas", "core", "indexing", "_iLocIndexer"),
601 ("pandas", "core", "indexing", "_iLocIndexer"),
600 ("pandas", "core", "indexing", "_LocIndexer"),
602 ("pandas", "core", "indexing", "_LocIndexer"),
601 ("pandas", "DataFrame"),
603 ("pandas", "DataFrame"),
602 ("pandas", "Series"),
604 ("pandas", "Series"),
603 ("numpy", "ndarray"),
605 ("numpy", "ndarray"),
604 ("numpy", "void"),
606 ("numpy", "void"),
605 }
607 }
606
608
607
609
608 BUILTIN_GETITEM: Set[InstancesHaveGetItem] = {
610 BUILTIN_GETITEM: Set[InstancesHaveGetItem] = {
609 dict,
611 dict,
610 str,
612 str,
611 bytes,
613 bytes,
612 list,
614 list,
613 tuple,
615 tuple,
614 collections.defaultdict,
616 collections.defaultdict,
615 collections.deque,
617 collections.deque,
616 collections.OrderedDict,
618 collections.OrderedDict,
617 collections.ChainMap,
619 collections.ChainMap,
618 collections.UserDict,
620 collections.UserDict,
619 collections.UserList,
621 collections.UserList,
620 collections.UserString,
622 collections.UserString,
621 _DummyNamedTuple,
623 _DummyNamedTuple,
622 _IdentitySubscript,
624 _IdentitySubscript,
623 }
625 }
624
626
625
627
626 def _list_methods(cls, source=None):
628 def _list_methods(cls, source=None):
627 """For use on immutable objects or with methods returning a copy"""
629 """For use on immutable objects or with methods returning a copy"""
628 return [getattr(cls, k) for k in (source if source else dir(cls))]
630 return [getattr(cls, k) for k in (source if source else dir(cls))]
629
631
630
632
631 dict_non_mutating_methods = ("copy", "keys", "values", "items")
633 dict_non_mutating_methods = ("copy", "keys", "values", "items")
632 list_non_mutating_methods = ("copy", "index", "count")
634 list_non_mutating_methods = ("copy", "index", "count")
633 set_non_mutating_methods = set(dir(set)) & set(dir(frozenset))
635 set_non_mutating_methods = set(dir(set)) & set(dir(frozenset))
634
636
635
637
636 dict_keys: Type[collections.abc.KeysView] = type({}.keys())
638 dict_keys: Type[collections.abc.KeysView] = type({}.keys())
637 method_descriptor: Any = type(list.copy)
639 method_descriptor: Any = type(list.copy)
638
640
639 NUMERICS = {int, float, complex}
641 NUMERICS = {int, float, complex}
640
642
641 ALLOWED_CALLS = {
643 ALLOWED_CALLS = {
642 bytes,
644 bytes,
643 *_list_methods(bytes),
645 *_list_methods(bytes),
644 dict,
646 dict,
645 *_list_methods(dict, dict_non_mutating_methods),
647 *_list_methods(dict, dict_non_mutating_methods),
646 dict_keys.isdisjoint,
648 dict_keys.isdisjoint,
647 list,
649 list,
648 *_list_methods(list, list_non_mutating_methods),
650 *_list_methods(list, list_non_mutating_methods),
649 set,
651 set,
650 *_list_methods(set, set_non_mutating_methods),
652 *_list_methods(set, set_non_mutating_methods),
651 frozenset,
653 frozenset,
652 *_list_methods(frozenset),
654 *_list_methods(frozenset),
653 range,
655 range,
654 str,
656 str,
655 *_list_methods(str),
657 *_list_methods(str),
656 tuple,
658 tuple,
657 *_list_methods(tuple),
659 *_list_methods(tuple),
658 *NUMERICS,
660 *NUMERICS,
659 *[method for numeric_cls in NUMERICS for method in _list_methods(numeric_cls)],
661 *[method for numeric_cls in NUMERICS for method in _list_methods(numeric_cls)],
660 collections.deque,
662 collections.deque,
661 *_list_methods(collections.deque, list_non_mutating_methods),
663 *_list_methods(collections.deque, list_non_mutating_methods),
662 collections.defaultdict,
664 collections.defaultdict,
663 *_list_methods(collections.defaultdict, dict_non_mutating_methods),
665 *_list_methods(collections.defaultdict, dict_non_mutating_methods),
664 collections.OrderedDict,
666 collections.OrderedDict,
665 *_list_methods(collections.OrderedDict, dict_non_mutating_methods),
667 *_list_methods(collections.OrderedDict, dict_non_mutating_methods),
666 collections.UserDict,
668 collections.UserDict,
667 *_list_methods(collections.UserDict, dict_non_mutating_methods),
669 *_list_methods(collections.UserDict, dict_non_mutating_methods),
668 collections.UserList,
670 collections.UserList,
669 *_list_methods(collections.UserList, list_non_mutating_methods),
671 *_list_methods(collections.UserList, list_non_mutating_methods),
670 collections.UserString,
672 collections.UserString,
671 *_list_methods(collections.UserString, dir(str)),
673 *_list_methods(collections.UserString, dir(str)),
672 collections.Counter,
674 collections.Counter,
673 *_list_methods(collections.Counter, dict_non_mutating_methods),
675 *_list_methods(collections.Counter, dict_non_mutating_methods),
674 collections.Counter.elements,
676 collections.Counter.elements,
675 collections.Counter.most_common,
677 collections.Counter.most_common,
676 }
678 }
677
679
678 BUILTIN_GETATTR: Set[MayHaveGetattr] = {
680 BUILTIN_GETATTR: Set[MayHaveGetattr] = {
679 *BUILTIN_GETITEM,
681 *BUILTIN_GETITEM,
680 set,
682 set,
681 frozenset,
683 frozenset,
682 object,
684 object,
683 type, # `type` handles a lot of generic cases, e.g. numbers as in `int.real`.
685 type, # `type` handles a lot of generic cases, e.g. numbers as in `int.real`.
684 *NUMERICS,
686 *NUMERICS,
685 dict_keys,
687 dict_keys,
686 method_descriptor,
688 method_descriptor,
687 }
689 }
688
690
689
691
690 BUILTIN_OPERATIONS = {*BUILTIN_GETATTR}
692 BUILTIN_OPERATIONS = {*BUILTIN_GETATTR}
691
693
692 EVALUATION_POLICIES = {
694 EVALUATION_POLICIES = {
693 "minimal": EvaluationPolicy(
695 "minimal": EvaluationPolicy(
694 allow_builtins_access=True,
696 allow_builtins_access=True,
695 allow_locals_access=False,
697 allow_locals_access=False,
696 allow_globals_access=False,
698 allow_globals_access=False,
697 allow_item_access=False,
699 allow_item_access=False,
698 allow_attr_access=False,
700 allow_attr_access=False,
699 allowed_calls=set(),
701 allowed_calls=set(),
700 allow_any_calls=False,
702 allow_any_calls=False,
701 allow_all_operations=False,
703 allow_all_operations=False,
702 ),
704 ),
703 "limited": SelectivePolicy(
705 "limited": SelectivePolicy(
704 allowed_getitem=BUILTIN_GETITEM,
706 allowed_getitem=BUILTIN_GETITEM,
705 allowed_getitem_external=SUPPORTED_EXTERNAL_GETITEM,
707 allowed_getitem_external=SUPPORTED_EXTERNAL_GETITEM,
706 allowed_getattr=BUILTIN_GETATTR,
708 allowed_getattr=BUILTIN_GETATTR,
707 allowed_getattr_external={
709 allowed_getattr_external={
708 # pandas Series/Frame implements custom `__getattr__`
710 # pandas Series/Frame implements custom `__getattr__`
709 ("pandas", "DataFrame"),
711 ("pandas", "DataFrame"),
710 ("pandas", "Series"),
712 ("pandas", "Series"),
711 },
713 },
712 allowed_operations=BUILTIN_OPERATIONS,
714 allowed_operations=BUILTIN_OPERATIONS,
713 allow_builtins_access=True,
715 allow_builtins_access=True,
714 allow_locals_access=True,
716 allow_locals_access=True,
715 allow_globals_access=True,
717 allow_globals_access=True,
716 allowed_calls=ALLOWED_CALLS,
718 allowed_calls=ALLOWED_CALLS,
717 ),
719 ),
718 "unsafe": EvaluationPolicy(
720 "unsafe": EvaluationPolicy(
719 allow_builtins_access=True,
721 allow_builtins_access=True,
720 allow_locals_access=True,
722 allow_locals_access=True,
721 allow_globals_access=True,
723 allow_globals_access=True,
722 allow_attr_access=True,
724 allow_attr_access=True,
723 allow_item_access=True,
725 allow_item_access=True,
724 allow_any_calls=True,
726 allow_any_calls=True,
725 allow_all_operations=True,
727 allow_all_operations=True,
726 ),
728 ),
727 }
729 }
728
730
729
731
730 __all__ = [
732 __all__ = [
731 "guarded_eval",
733 "guarded_eval",
732 "eval_node",
734 "eval_node",
733 "GuardRejection",
735 "GuardRejection",
734 "EvaluationContext",
736 "EvaluationContext",
735 "_unbind_method",
737 "_unbind_method",
736 ]
738 ]
@@ -1,537 +1,570 b''
1 from contextlib import contextmanager
1 from contextlib import contextmanager
2 from typing import NamedTuple
2 from typing import NamedTuple
3 from functools import partial
3 from functools import partial
4 from IPython.core.guarded_eval import (
4 from IPython.core.guarded_eval import (
5 EvaluationContext,
5 EvaluationContext,
6 GuardRejection,
6 GuardRejection,
7 guarded_eval,
7 guarded_eval,
8 _unbind_method,
8 _unbind_method,
9 )
9 )
10 from IPython.testing import decorators as dec
10 from IPython.testing import decorators as dec
11 import pytest
11 import pytest
12
12
13
13
14 def create_context(evaluation: str, **kwargs):
14 def create_context(evaluation: str, **kwargs):
15 return EvaluationContext(locals=kwargs, globals={}, evaluation=evaluation)
15 return EvaluationContext(locals=kwargs, globals={}, evaluation=evaluation)
16
16
17
17
18 forbidden = partial(create_context, "forbidden")
18 forbidden = partial(create_context, "forbidden")
19 minimal = partial(create_context, "minimal")
19 minimal = partial(create_context, "minimal")
20 limited = partial(create_context, "limited")
20 limited = partial(create_context, "limited")
21 unsafe = partial(create_context, "unsafe")
21 unsafe = partial(create_context, "unsafe")
22 dangerous = partial(create_context, "dangerous")
22 dangerous = partial(create_context, "dangerous")
23
23
24 LIMITED_OR_HIGHER = [limited, unsafe, dangerous]
24 LIMITED_OR_HIGHER = [limited, unsafe, dangerous]
25
26 MINIMAL_OR_HIGHER = [minimal, *LIMITED_OR_HIGHER]
25 MINIMAL_OR_HIGHER = [minimal, *LIMITED_OR_HIGHER]
27
26
28
27
29 @contextmanager
28 @contextmanager
30 def module_not_installed(module: str):
29 def module_not_installed(module: str):
31 import sys
30 import sys
32
31
33 try:
32 try:
34 to_restore = sys.modules[module]
33 to_restore = sys.modules[module]
35 del sys.modules[module]
34 del sys.modules[module]
36 except KeyError:
35 except KeyError:
37 to_restore = None
36 to_restore = None
38 try:
37 try:
39 yield
38 yield
40 finally:
39 finally:
41 sys.modules[module] = to_restore
40 sys.modules[module] = to_restore
42
41
43
42
43 def test_external_not_installed():
44 """
45 Because attribute check requires checking if object is not of allowed
46 external type, this tests logic for absence of external module.
47 """
48
49 class Custom:
50 def __init__(self):
51 self.test = 1
52
53 def __getattr__(self, key):
54 return key
55
56 with module_not_installed("pandas"):
57 context = limited(x=Custom())
58 with pytest.raises(GuardRejection):
59 guarded_eval("x.test", context)
60
61
62 @dec.skip_without("pandas")
63 def test_external_changed_api(monkeypatch):
64 """Check that the execution rejects if external API changed paths"""
65 import pandas as pd
66
67 series = pd.Series([1], index=["a"])
68
69 with monkeypatch.context() as m:
70 m.delattr(pd, "Series")
71 context = limited(data=series)
72 with pytest.raises(GuardRejection):
73 guarded_eval("data.iloc[0]", context)
74
75
44 @dec.skip_without("pandas")
76 @dec.skip_without("pandas")
45 def test_pandas_series_iloc():
77 def test_pandas_series_iloc():
46 import pandas as pd
78 import pandas as pd
47
79
48 series = pd.Series([1], index=["a"])
80 series = pd.Series([1], index=["a"])
49 context = limited(data=series)
81 context = limited(data=series)
50 assert guarded_eval("data.iloc[0]", context) == 1
82 assert guarded_eval("data.iloc[0]", context) == 1
51
83
52
84
53 def test_rejects_custom_properties():
85 def test_rejects_custom_properties():
54 class BadProperty:
86 class BadProperty:
55 @property
87 @property
56 def iloc(self):
88 def iloc(self):
57 return [None]
89 return [None]
58
90
59 series = BadProperty()
91 series = BadProperty()
60 context = limited(data=series)
92 context = limited(data=series)
61
93
62 with pytest.raises(GuardRejection):
94 with pytest.raises(GuardRejection):
63 guarded_eval("data.iloc[0]", context)
95 guarded_eval("data.iloc[0]", context)
64
96
65
97
66 @dec.skip_without("pandas")
98 @dec.skip_without("pandas")
67 def test_accepts_non_overriden_properties():
99 def test_accepts_non_overriden_properties():
68 import pandas as pd
100 import pandas as pd
69
101
70 class GoodProperty(pd.Series):
102 class GoodProperty(pd.Series):
71 pass
103 pass
72
104
73 series = GoodProperty([1], index=["a"])
105 series = GoodProperty([1], index=["a"])
74 context = limited(data=series)
106 context = limited(data=series)
75
107
76 assert guarded_eval("data.iloc[0]", context) == 1
108 assert guarded_eval("data.iloc[0]", context) == 1
77
109
78
110
79 @dec.skip_without("pandas")
111 @dec.skip_without("pandas")
80 def test_pandas_series():
112 def test_pandas_series():
81 import pandas as pd
113 import pandas as pd
82
114
83 context = limited(data=pd.Series([1], index=["a"]))
115 context = limited(data=pd.Series([1], index=["a"]))
84 assert guarded_eval('data["a"]', context) == 1
116 assert guarded_eval('data["a"]', context) == 1
85 with pytest.raises(KeyError):
117 with pytest.raises(KeyError):
86 guarded_eval('data["c"]', context)
118 guarded_eval('data["c"]', context)
87
119
88
120
89 @dec.skip_without("pandas")
121 @dec.skip_without("pandas")
90 def test_pandas_bad_series():
122 def test_pandas_bad_series():
91 import pandas as pd
123 import pandas as pd
92
124
93 class BadItemSeries(pd.Series):
125 class BadItemSeries(pd.Series):
94 def __getitem__(self, key):
126 def __getitem__(self, key):
95 return "CUSTOM_ITEM"
127 return "CUSTOM_ITEM"
96
128
97 class BadAttrSeries(pd.Series):
129 class BadAttrSeries(pd.Series):
98 def __getattr__(self, key):
130 def __getattr__(self, key):
99 return "CUSTOM_ATTR"
131 return "CUSTOM_ATTR"
100
132
101 bad_series = BadItemSeries([1], index=["a"])
133 bad_series = BadItemSeries([1], index=["a"])
102 context = limited(data=bad_series)
134 context = limited(data=bad_series)
103
135
104 with pytest.raises(GuardRejection):
136 with pytest.raises(GuardRejection):
105 guarded_eval('data["a"]', context)
137 guarded_eval('data["a"]', context)
106 with pytest.raises(GuardRejection):
138 with pytest.raises(GuardRejection):
107 guarded_eval('data["c"]', context)
139 guarded_eval('data["c"]', context)
108
140
109 # note: here result is a bit unexpected because
141 # note: here result is a bit unexpected because
110 # pandas `__getattr__` calls `__getitem__`;
142 # pandas `__getattr__` calls `__getitem__`;
111 # FIXME - special case to handle it?
143 # FIXME - special case to handle it?
112 assert guarded_eval("data.a", context) == "CUSTOM_ITEM"
144 assert guarded_eval("data.a", context) == "CUSTOM_ITEM"
113
145
114 context = unsafe(data=bad_series)
146 context = unsafe(data=bad_series)
115 assert guarded_eval('data["a"]', context) == "CUSTOM_ITEM"
147 assert guarded_eval('data["a"]', context) == "CUSTOM_ITEM"
116
148
117 bad_attr_series = BadAttrSeries([1], index=["a"])
149 bad_attr_series = BadAttrSeries([1], index=["a"])
118 context = limited(data=bad_attr_series)
150 context = limited(data=bad_attr_series)
119 assert guarded_eval('data["a"]', context) == 1
151 assert guarded_eval('data["a"]', context) == 1
120 with pytest.raises(GuardRejection):
152 with pytest.raises(GuardRejection):
121 guarded_eval("data.a", context)
153 guarded_eval("data.a", context)
122
154
123
155
124 @dec.skip_without("pandas")
156 @dec.skip_without("pandas")
125 def test_pandas_dataframe_loc():
157 def test_pandas_dataframe_loc():
126 import pandas as pd
158 import pandas as pd
127 from pandas.testing import assert_series_equal
159 from pandas.testing import assert_series_equal
128
160
129 data = pd.DataFrame([{"a": 1}])
161 data = pd.DataFrame([{"a": 1}])
130 context = limited(data=data)
162 context = limited(data=data)
131 assert_series_equal(guarded_eval('data.loc[:, "a"]', context), data["a"])
163 assert_series_equal(guarded_eval('data.loc[:, "a"]', context), data["a"])
132
164
133
165
134 def test_named_tuple():
166 def test_named_tuple():
135 class GoodNamedTuple(NamedTuple):
167 class GoodNamedTuple(NamedTuple):
136 a: str
168 a: str
137 pass
169 pass
138
170
139 class BadNamedTuple(NamedTuple):
171 class BadNamedTuple(NamedTuple):
140 a: str
172 a: str
141
173
142 def __getitem__(self, key):
174 def __getitem__(self, key):
143 return None
175 return None
144
176
145 good = GoodNamedTuple(a="x")
177 good = GoodNamedTuple(a="x")
146 bad = BadNamedTuple(a="x")
178 bad = BadNamedTuple(a="x")
147
179
148 context = limited(data=good)
180 context = limited(data=good)
149 assert guarded_eval("data[0]", context) == "x"
181 assert guarded_eval("data[0]", context) == "x"
150
182
151 context = limited(data=bad)
183 context = limited(data=bad)
152 with pytest.raises(GuardRejection):
184 with pytest.raises(GuardRejection):
153 guarded_eval("data[0]", context)
185 guarded_eval("data[0]", context)
154
186
155
187
156 def test_dict():
188 def test_dict():
157 context = limited(data={"a": 1, "b": {"x": 2}, ("x", "y"): 3})
189 context = limited(data={"a": 1, "b": {"x": 2}, ("x", "y"): 3})
158 assert guarded_eval('data["a"]', context) == 1
190 assert guarded_eval('data["a"]', context) == 1
159 assert guarded_eval('data["b"]', context) == {"x": 2}
191 assert guarded_eval('data["b"]', context) == {"x": 2}
160 assert guarded_eval('data["b"]["x"]', context) == 2
192 assert guarded_eval('data["b"]["x"]', context) == 2
161 assert guarded_eval('data["x", "y"]', context) == 3
193 assert guarded_eval('data["x", "y"]', context) == 3
162
194
163 assert guarded_eval("data.keys", context)
195 assert guarded_eval("data.keys", context)
164
196
165
197
166 def test_set():
198 def test_set():
167 context = limited(data={"a", "b"})
199 context = limited(data={"a", "b"})
168 assert guarded_eval("data.difference", context)
200 assert guarded_eval("data.difference", context)
169
201
170
202
171 def test_list():
203 def test_list():
172 context = limited(data=[1, 2, 3])
204 context = limited(data=[1, 2, 3])
173 assert guarded_eval("data[1]", context) == 2
205 assert guarded_eval("data[1]", context) == 2
174 assert guarded_eval("data.copy", context)
206 assert guarded_eval("data.copy", context)
175
207
176
208
177 def test_dict_literal():
209 def test_dict_literal():
178 context = limited()
210 context = limited()
179 assert guarded_eval("{}", context) == {}
211 assert guarded_eval("{}", context) == {}
180 assert guarded_eval('{"a": 1}', context) == {"a": 1}
212 assert guarded_eval('{"a": 1}', context) == {"a": 1}
181
213
182
214
183 def test_list_literal():
215 def test_list_literal():
184 context = limited()
216 context = limited()
185 assert guarded_eval("[]", context) == []
217 assert guarded_eval("[]", context) == []
186 assert guarded_eval('[1, "a"]', context) == [1, "a"]
218 assert guarded_eval('[1, "a"]', context) == [1, "a"]
187
219
188
220
189 def test_set_literal():
221 def test_set_literal():
190 context = limited()
222 context = limited()
191 assert guarded_eval("set()", context) == set()
223 assert guarded_eval("set()", context) == set()
192 assert guarded_eval('{"a"}', context) == {"a"}
224 assert guarded_eval('{"a"}', context) == {"a"}
193
225
194
226
195 def test_evaluates_if_expression():
227 def test_evaluates_if_expression():
196 context = limited()
228 context = limited()
197 assert guarded_eval("2 if True else 3", context) == 2
229 assert guarded_eval("2 if True else 3", context) == 2
198 assert guarded_eval("4 if False else 5", context) == 5
230 assert guarded_eval("4 if False else 5", context) == 5
199
231
200
232
201 def test_object():
233 def test_object():
202 obj = object()
234 obj = object()
203 context = limited(obj=obj)
235 context = limited(obj=obj)
204 assert guarded_eval("obj.__dir__", context) == obj.__dir__
236 assert guarded_eval("obj.__dir__", context) == obj.__dir__
205
237
206
238
207 @pytest.mark.parametrize(
239 @pytest.mark.parametrize(
208 "code,expected",
240 "code,expected",
209 [
241 [
210 ["int.numerator", int.numerator],
242 ["int.numerator", int.numerator],
211 ["float.is_integer", float.is_integer],
243 ["float.is_integer", float.is_integer],
212 ["complex.real", complex.real],
244 ["complex.real", complex.real],
213 ],
245 ],
214 )
246 )
215 def test_number_attributes(code, expected):
247 def test_number_attributes(code, expected):
216 assert guarded_eval(code, limited()) == expected
248 assert guarded_eval(code, limited()) == expected
217
249
218
250
219 def test_method_descriptor():
251 def test_method_descriptor():
220 context = limited()
252 context = limited()
221 assert guarded_eval("list.copy.__name__", context) == "copy"
253 assert guarded_eval("list.copy.__name__", context) == "copy"
222
254
223
255
224 @pytest.mark.parametrize(
256 @pytest.mark.parametrize(
225 "data,good,bad,expected",
257 "data,good,bad,expected",
226 [
258 [
227 [[1, 2, 3], "data.index(2)", "data.append(4)", 1],
259 [[1, 2, 3], "data.index(2)", "data.append(4)", 1],
228 [{"a": 1}, "data.keys().isdisjoint({})", "data.update()", True],
260 [{"a": 1}, "data.keys().isdisjoint({})", "data.update()", True],
229 ],
261 ],
230 )
262 )
231 def test_evaluates_calls(data, good, bad, expected):
263 def test_evaluates_calls(data, good, bad, expected):
232 context = limited(data=data)
264 context = limited(data=data)
233 assert guarded_eval(good, context) == expected
265 assert guarded_eval(good, context) == expected
234
266
235 with pytest.raises(GuardRejection):
267 with pytest.raises(GuardRejection):
236 guarded_eval(bad, context)
268 guarded_eval(bad, context)
237
269
238
270
239 @pytest.mark.parametrize(
271 @pytest.mark.parametrize(
240 "code,expected",
272 "code,expected",
241 [
273 [
242 ["(1\n+\n1)", 2],
274 ["(1\n+\n1)", 2],
243 ["list(range(10))[-1:]", [9]],
275 ["list(range(10))[-1:]", [9]],
244 ["list(range(20))[3:-2:3]", [3, 6, 9, 12, 15]],
276 ["list(range(20))[3:-2:3]", [3, 6, 9, 12, 15]],
245 ],
277 ],
246 )
278 )
247 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
279 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
248 def test_evaluates_complex_cases(code, expected, context):
280 def test_evaluates_complex_cases(code, expected, context):
249 assert guarded_eval(code, context()) == expected
281 assert guarded_eval(code, context()) == expected
250
282
251
283
252 @pytest.mark.parametrize(
284 @pytest.mark.parametrize(
253 "code,expected",
285 "code,expected",
254 [
286 [
255 ["1", 1],
287 ["1", 1],
256 ["1.0", 1.0],
288 ["1.0", 1.0],
257 ["0xdeedbeef", 0xDEEDBEEF],
289 ["0xdeedbeef", 0xDEEDBEEF],
258 ["True", True],
290 ["True", True],
259 ["None", None],
291 ["None", None],
260 ["{}", {}],
292 ["{}", {}],
261 ["[]", []],
293 ["[]", []],
262 ],
294 ],
263 )
295 )
264 @pytest.mark.parametrize("context", MINIMAL_OR_HIGHER)
296 @pytest.mark.parametrize("context", MINIMAL_OR_HIGHER)
265 def test_evaluates_literals(code, expected, context):
297 def test_evaluates_literals(code, expected, context):
266 assert guarded_eval(code, context()) == expected
298 assert guarded_eval(code, context()) == expected
267
299
268
300
269 @pytest.mark.parametrize(
301 @pytest.mark.parametrize(
270 "code,expected",
302 "code,expected",
271 [
303 [
272 ["-5", -5],
304 ["-5", -5],
273 ["+5", +5],
305 ["+5", +5],
274 ["~5", -6],
306 ["~5", -6],
275 ],
307 ],
276 )
308 )
277 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
309 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
278 def test_evaluates_unary_operations(code, expected, context):
310 def test_evaluates_unary_operations(code, expected, context):
279 assert guarded_eval(code, context()) == expected
311 assert guarded_eval(code, context()) == expected
280
312
281
313
282 @pytest.mark.parametrize(
314 @pytest.mark.parametrize(
283 "code,expected",
315 "code,expected",
284 [
316 [
285 ["1 + 1", 2],
317 ["1 + 1", 2],
286 ["3 - 1", 2],
318 ["3 - 1", 2],
287 ["2 * 3", 6],
319 ["2 * 3", 6],
288 ["5 // 2", 2],
320 ["5 // 2", 2],
289 ["5 / 2", 2.5],
321 ["5 / 2", 2.5],
290 ["5**2", 25],
322 ["5**2", 25],
291 ["2 >> 1", 1],
323 ["2 >> 1", 1],
292 ["2 << 1", 4],
324 ["2 << 1", 4],
293 ["1 | 2", 3],
325 ["1 | 2", 3],
294 ["1 & 1", 1],
326 ["1 & 1", 1],
295 ["1 & 2", 0],
327 ["1 & 2", 0],
296 ],
328 ],
297 )
329 )
298 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
330 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
299 def test_evaluates_binary_operations(code, expected, context):
331 def test_evaluates_binary_operations(code, expected, context):
300 assert guarded_eval(code, context()) == expected
332 assert guarded_eval(code, context()) == expected
301
333
302
334
303 @pytest.mark.parametrize(
335 @pytest.mark.parametrize(
304 "code,expected",
336 "code,expected",
305 [
337 [
306 ["2 > 1", True],
338 ["2 > 1", True],
307 ["2 < 1", False],
339 ["2 < 1", False],
308 ["2 <= 1", False],
340 ["2 <= 1", False],
309 ["2 <= 2", True],
341 ["2 <= 2", True],
310 ["1 >= 2", False],
342 ["1 >= 2", False],
311 ["2 >= 2", True],
343 ["2 >= 2", True],
312 ["2 == 2", True],
344 ["2 == 2", True],
313 ["1 == 2", False],
345 ["1 == 2", False],
314 ["1 != 2", True],
346 ["1 != 2", True],
315 ["1 != 1", False],
347 ["1 != 1", False],
316 ["1 < 4 < 3", False],
348 ["1 < 4 < 3", False],
317 ["(1 < 4) < 3", True],
349 ["(1 < 4) < 3", True],
318 ["4 > 3 > 2 > 1", True],
350 ["4 > 3 > 2 > 1", True],
319 ["4 > 3 > 2 > 9", False],
351 ["4 > 3 > 2 > 9", False],
320 ["1 < 2 < 3 < 4", True],
352 ["1 < 2 < 3 < 4", True],
321 ["9 < 2 < 3 < 4", False],
353 ["9 < 2 < 3 < 4", False],
322 ["1 < 2 > 1 > 0 > -1 < 1", True],
354 ["1 < 2 > 1 > 0 > -1 < 1", True],
323 ["1 in [1] in [[1]]", True],
355 ["1 in [1] in [[1]]", True],
324 ["1 in [1] in [[2]]", False],
356 ["1 in [1] in [[2]]", False],
325 ["1 in [1]", True],
357 ["1 in [1]", True],
326 ["0 in [1]", False],
358 ["0 in [1]", False],
327 ["1 not in [1]", False],
359 ["1 not in [1]", False],
328 ["0 not in [1]", True],
360 ["0 not in [1]", True],
329 ["True is True", True],
361 ["True is True", True],
330 ["False is False", True],
362 ["False is False", True],
331 ["True is False", False],
363 ["True is False", False],
332 ["True is not True", False],
364 ["True is not True", False],
333 ["False is not True", True],
365 ["False is not True", True],
334 ],
366 ],
335 )
367 )
336 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
368 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
337 def test_evaluates_comparisons(code, expected, context):
369 def test_evaluates_comparisons(code, expected, context):
338 assert guarded_eval(code, context()) == expected
370 assert guarded_eval(code, context()) == expected
339
371
340
372
341 def test_guards_comparisons():
373 def test_guards_comparisons():
342 class GoodEq(int):
374 class GoodEq(int):
343 pass
375 pass
344
376
345 class BadEq(int):
377 class BadEq(int):
346 def __eq__(self, other):
378 def __eq__(self, other):
347 assert False
379 assert False
348
380
349 context = limited(bad=BadEq(1), good=GoodEq(1))
381 context = limited(bad=BadEq(1), good=GoodEq(1))
350
382
351 with pytest.raises(GuardRejection):
383 with pytest.raises(GuardRejection):
352 guarded_eval("bad == 1", context)
384 guarded_eval("bad == 1", context)
353
385
354 with pytest.raises(GuardRejection):
386 with pytest.raises(GuardRejection):
355 guarded_eval("bad != 1", context)
387 guarded_eval("bad != 1", context)
356
388
357 with pytest.raises(GuardRejection):
389 with pytest.raises(GuardRejection):
358 guarded_eval("1 == bad", context)
390 guarded_eval("1 == bad", context)
359
391
360 with pytest.raises(GuardRejection):
392 with pytest.raises(GuardRejection):
361 guarded_eval("1 != bad", context)
393 guarded_eval("1 != bad", context)
362
394
363 assert guarded_eval("good == 1", context) is True
395 assert guarded_eval("good == 1", context) is True
364 assert guarded_eval("good != 1", context) is False
396 assert guarded_eval("good != 1", context) is False
365 assert guarded_eval("1 == good", context) is True
397 assert guarded_eval("1 == good", context) is True
366 assert guarded_eval("1 != good", context) is False
398 assert guarded_eval("1 != good", context) is False
367
399
368
400
369 def test_guards_unary_operations():
401 def test_guards_unary_operations():
370 class GoodOp(int):
402 class GoodOp(int):
371 pass
403 pass
372
404
373 class BadOpInv(int):
405 class BadOpInv(int):
374 def __inv__(self, other):
406 def __inv__(self, other):
375 assert False
407 assert False
376
408
377 class BadOpInverse(int):
409 class BadOpInverse(int):
378 def __inv__(self, other):
410 def __inv__(self, other):
379 assert False
411 assert False
380
412
381 context = limited(good=GoodOp(1), bad1=BadOpInv(1), bad2=BadOpInverse(1))
413 context = limited(good=GoodOp(1), bad1=BadOpInv(1), bad2=BadOpInverse(1))
382
414
383 with pytest.raises(GuardRejection):
415 with pytest.raises(GuardRejection):
384 guarded_eval("~bad1", context)
416 guarded_eval("~bad1", context)
385
417
386 with pytest.raises(GuardRejection):
418 with pytest.raises(GuardRejection):
387 guarded_eval("~bad2", context)
419 guarded_eval("~bad2", context)
388
420
389
421
390 def test_guards_binary_operations():
422 def test_guards_binary_operations():
391 class GoodOp(int):
423 class GoodOp(int):
392 pass
424 pass
393
425
394 class BadOp(int):
426 class BadOp(int):
395 def __add__(self, other):
427 def __add__(self, other):
396 assert False
428 assert False
397
429
398 context = limited(good=GoodOp(1), bad=BadOp(1))
430 context = limited(good=GoodOp(1), bad=BadOp(1))
399
431
400 with pytest.raises(GuardRejection):
432 with pytest.raises(GuardRejection):
401 guarded_eval("1 + bad", context)
433 guarded_eval("1 + bad", context)
402
434
403 with pytest.raises(GuardRejection):
435 with pytest.raises(GuardRejection):
404 guarded_eval("bad + 1", context)
436 guarded_eval("bad + 1", context)
405
437
406 assert guarded_eval("good + 1", context) == 2
438 assert guarded_eval("good + 1", context) == 2
407 assert guarded_eval("1 + good", context) == 2
439 assert guarded_eval("1 + good", context) == 2
408
440
409
441
410 def test_guards_attributes():
442 def test_guards_attributes():
411 class GoodAttr(float):
443 class GoodAttr(float):
412 pass
444 pass
413
445
414 class BadAttr1(float):
446 class BadAttr1(float):
415 def __getattr__(self, key):
447 def __getattr__(self, key):
416 assert False
448 assert False
417
449
418 class BadAttr2(float):
450 class BadAttr2(float):
419 def __getattribute__(self, key):
451 def __getattribute__(self, key):
420 assert False
452 assert False
421
453
422 context = limited(good=GoodAttr(0.5), bad1=BadAttr1(0.5), bad2=BadAttr2(0.5))
454 context = limited(good=GoodAttr(0.5), bad1=BadAttr1(0.5), bad2=BadAttr2(0.5))
423
455
424 with pytest.raises(GuardRejection):
456 with pytest.raises(GuardRejection):
425 guarded_eval("bad1.as_integer_ratio", context)
457 guarded_eval("bad1.as_integer_ratio", context)
426
458
427 with pytest.raises(GuardRejection):
459 with pytest.raises(GuardRejection):
428 guarded_eval("bad2.as_integer_ratio", context)
460 guarded_eval("bad2.as_integer_ratio", context)
429
461
430 assert guarded_eval("good.as_integer_ratio()", context) == (1, 2)
462 assert guarded_eval("good.as_integer_ratio()", context) == (1, 2)
431
463
432
464
433 @pytest.mark.parametrize("context", MINIMAL_OR_HIGHER)
465 @pytest.mark.parametrize("context", MINIMAL_OR_HIGHER)
434 def test_access_builtins(context):
466 def test_access_builtins(context):
435 assert guarded_eval("round", context()) == round
467 assert guarded_eval("round", context()) == round
436
468
437
469
438 def test_access_builtins_fails():
470 def test_access_builtins_fails():
439 context = limited()
471 context = limited()
440 with pytest.raises(NameError):
472 with pytest.raises(NameError):
441 guarded_eval("this_is_not_builtin", context)
473 guarded_eval("this_is_not_builtin", context)
442
474
443
475
444 def test_rejects_forbidden():
476 def test_rejects_forbidden():
445 context = forbidden()
477 context = forbidden()
446 with pytest.raises(GuardRejection):
478 with pytest.raises(GuardRejection):
447 guarded_eval("1", context)
479 guarded_eval("1", context)
448
480
449
481
450 def test_guards_locals_and_globals():
482 def test_guards_locals_and_globals():
451 context = EvaluationContext(
483 context = EvaluationContext(
452 locals={"local_a": "a"}, globals={"global_b": "b"}, evaluation="minimal"
484 locals={"local_a": "a"}, globals={"global_b": "b"}, evaluation="minimal"
453 )
485 )
454
486
455 with pytest.raises(GuardRejection):
487 with pytest.raises(GuardRejection):
456 guarded_eval("local_a", context)
488 guarded_eval("local_a", context)
457
489
458 with pytest.raises(GuardRejection):
490 with pytest.raises(GuardRejection):
459 guarded_eval("global_b", context)
491 guarded_eval("global_b", context)
460
492
461
493
462 def test_access_locals_and_globals():
494 def test_access_locals_and_globals():
463 context = EvaluationContext(
495 context = EvaluationContext(
464 locals={"local_a": "a"}, globals={"global_b": "b"}, evaluation="limited"
496 locals={"local_a": "a"}, globals={"global_b": "b"}, evaluation="limited"
465 )
497 )
466 assert guarded_eval("local_a", context) == "a"
498 assert guarded_eval("local_a", context) == "a"
467 assert guarded_eval("global_b", context) == "b"
499 assert guarded_eval("global_b", context) == "b"
468
500
469
501
470 @pytest.mark.parametrize(
502 @pytest.mark.parametrize(
471 "code",
503 "code",
472 ["def func(): pass", "class C: pass", "x = 1", "x += 1", "del x", "import ast"],
504 ["def func(): pass", "class C: pass", "x = 1", "x += 1", "del x", "import ast"],
473 )
505 )
474 @pytest.mark.parametrize("context", [minimal(), limited(), unsafe()])
506 @pytest.mark.parametrize("context", [minimal(), limited(), unsafe()])
475 def test_rejects_side_effect_syntax(code, context):
507 def test_rejects_side_effect_syntax(code, context):
476 with pytest.raises(SyntaxError):
508 with pytest.raises(SyntaxError):
477 guarded_eval(code, context)
509 guarded_eval(code, context)
478
510
479
511
480 def test_subscript():
512 def test_subscript():
481 context = EvaluationContext(
513 context = EvaluationContext(
482 locals={}, globals={}, evaluation="limited", in_subscript=True
514 locals={}, globals={}, evaluation="limited", in_subscript=True
483 )
515 )
484 empty_slice = slice(None, None, None)
516 empty_slice = slice(None, None, None)
485 assert guarded_eval("", context) == tuple()
517 assert guarded_eval("", context) == tuple()
486 assert guarded_eval(":", context) == empty_slice
518 assert guarded_eval(":", context) == empty_slice
487 assert guarded_eval("1:2:3", context) == slice(1, 2, 3)
519 assert guarded_eval("1:2:3", context) == slice(1, 2, 3)
488 assert guarded_eval(':, "a"', context) == (empty_slice, "a")
520 assert guarded_eval(':, "a"', context) == (empty_slice, "a")
489
521
490
522
491 def test_unbind_method():
523 def test_unbind_method():
492 class X(list):
524 class X(list):
493 def index(self, k):
525 def index(self, k):
494 return "CUSTOM"
526 return "CUSTOM"
495
527
496 x = X()
528 x = X()
497 assert _unbind_method(x.index) is X.index
529 assert _unbind_method(x.index) is X.index
498 assert _unbind_method([].index) is list.index
530 assert _unbind_method([].index) is list.index
531 assert _unbind_method(list.index) is None
499
532
500
533
501 def test_assumption_instance_attr_do_not_matter():
534 def test_assumption_instance_attr_do_not_matter():
502 """This is semi-specified in Python documentation.
535 """This is semi-specified in Python documentation.
503
536
504 However, since the specification says 'not guaranted
537 However, since the specification says 'not guaranted
505 to work' rather than 'is forbidden to work', future
538 to work' rather than 'is forbidden to work', future
506 versions could invalidate this assumptions. This test
539 versions could invalidate this assumptions. This test
507 is meant to catch such a change if it ever comes true.
540 is meant to catch such a change if it ever comes true.
508 """
541 """
509
542
510 class T:
543 class T:
511 def __getitem__(self, k):
544 def __getitem__(self, k):
512 return "a"
545 return "a"
513
546
514 def __getattr__(self, k):
547 def __getattr__(self, k):
515 return "a"
548 return "a"
516
549
517 def f(self):
550 def f(self):
518 return "b"
551 return "b"
519
552
520 t = T()
553 t = T()
521 t.__getitem__ = f
554 t.__getitem__ = f
522 t.__getattr__ = f
555 t.__getattr__ = f
523 assert t[1] == "a"
556 assert t[1] == "a"
524 assert t[1] == "a"
557 assert t[1] == "a"
525
558
526
559
527 def test_assumption_named_tuples_share_getitem():
560 def test_assumption_named_tuples_share_getitem():
528 """Check assumption on named tuples sharing __getitem__"""
561 """Check assumption on named tuples sharing __getitem__"""
529 from typing import NamedTuple
562 from typing import NamedTuple
530
563
531 class A(NamedTuple):
564 class A(NamedTuple):
532 pass
565 pass
533
566
534 class B(NamedTuple):
567 class B(NamedTuple):
535 pass
568 pass
536
569
537 assert A.__getitem__ == B.__getitem__
570 assert A.__getitem__ == B.__getitem__
General Comments 0
You need to be logged in to leave comments. Login now