##// END OF EJS Templates
Add version guards
krassowski -
Show More
@@ -1,816 +1,822 b''
1 from inspect import isclass, signature, Signature
1 from inspect import isclass, signature, Signature
2 from typing import (
2 from typing import (
3 Callable,
3 Callable,
4 Dict,
4 Dict,
5 Literal,
5 Literal,
6 NamedTuple,
6 NamedTuple,
7 NewType,
7 NewType,
8 Set,
8 Set,
9 Sequence,
9 Sequence,
10 Tuple,
10 Tuple,
11 Type,
11 Type,
12 Protocol,
12 Protocol,
13 Union,
13 Union,
14 get_args,
14 get_args,
15 get_origin,
15 get_origin,
16 )
16 )
17 from typing_extensions import (
18 Self, # Python >=3.10
19 TypeAliasType, # Python >=3.12
20 )
21 import ast
17 import ast
22 import builtins
18 import builtins
23 import collections
19 import collections
24 import operator
20 import operator
25 import sys
21 import sys
26 from functools import cached_property
22 from functools import cached_property
27 from dataclasses import dataclass, field
23 from dataclasses import dataclass, field
28 from types import MethodDescriptorType, ModuleType
24 from types import MethodDescriptorType, ModuleType
29
25
30
31 from IPython.utils.decorators import undoc
26 from IPython.utils.decorators import undoc
32
27
33
28
29 if sys.version_info < (3, 11):
30 from typing_extensions import Self
31 else:
32 from typing import Self
33
34 if sys.version_info < (3, 12):
35 from typing_extensions import TypeAliasType
36 else:
37 from typing import TypeAliasType
38
39
34 @undoc
40 @undoc
35 class HasGetItem(Protocol):
41 class HasGetItem(Protocol):
36 def __getitem__(self, key) -> None:
42 def __getitem__(self, key) -> None:
37 ...
43 ...
38
44
39
45
40 @undoc
46 @undoc
41 class InstancesHaveGetItem(Protocol):
47 class InstancesHaveGetItem(Protocol):
42 def __call__(self, *args, **kwargs) -> HasGetItem:
48 def __call__(self, *args, **kwargs) -> HasGetItem:
43 ...
49 ...
44
50
45
51
46 @undoc
52 @undoc
47 class HasGetAttr(Protocol):
53 class HasGetAttr(Protocol):
48 def __getattr__(self, key) -> None:
54 def __getattr__(self, key) -> None:
49 ...
55 ...
50
56
51
57
52 @undoc
58 @undoc
53 class DoesNotHaveGetAttr(Protocol):
59 class DoesNotHaveGetAttr(Protocol):
54 pass
60 pass
55
61
56
62
57 # By default `__getattr__` is not explicitly implemented on most objects
63 # By default `__getattr__` is not explicitly implemented on most objects
58 MayHaveGetattr = Union[HasGetAttr, DoesNotHaveGetAttr]
64 MayHaveGetattr = Union[HasGetAttr, DoesNotHaveGetAttr]
59
65
60
66
61 def _unbind_method(func: Callable) -> Union[Callable, None]:
67 def _unbind_method(func: Callable) -> Union[Callable, None]:
62 """Get unbound method for given bound method.
68 """Get unbound method for given bound method.
63
69
64 Returns None if cannot get unbound method, or method is already unbound.
70 Returns None if cannot get unbound method, or method is already unbound.
65 """
71 """
66 owner = getattr(func, "__self__", None)
72 owner = getattr(func, "__self__", None)
67 owner_class = type(owner)
73 owner_class = type(owner)
68 name = getattr(func, "__name__", None)
74 name = getattr(func, "__name__", None)
69 instance_dict_overrides = getattr(owner, "__dict__", None)
75 instance_dict_overrides = getattr(owner, "__dict__", None)
70 if (
76 if (
71 owner is not None
77 owner is not None
72 and name
78 and name
73 and (
79 and (
74 not instance_dict_overrides
80 not instance_dict_overrides
75 or (instance_dict_overrides and name not in instance_dict_overrides)
81 or (instance_dict_overrides and name not in instance_dict_overrides)
76 )
82 )
77 ):
83 ):
78 return getattr(owner_class, name)
84 return getattr(owner_class, name)
79 return None
85 return None
80
86
81
87
82 @undoc
88 @undoc
83 @dataclass
89 @dataclass
84 class EvaluationPolicy:
90 class EvaluationPolicy:
85 """Definition of evaluation policy."""
91 """Definition of evaluation policy."""
86
92
87 allow_locals_access: bool = False
93 allow_locals_access: bool = False
88 allow_globals_access: bool = False
94 allow_globals_access: bool = False
89 allow_item_access: bool = False
95 allow_item_access: bool = False
90 allow_attr_access: bool = False
96 allow_attr_access: bool = False
91 allow_builtins_access: bool = False
97 allow_builtins_access: bool = False
92 allow_all_operations: bool = False
98 allow_all_operations: bool = False
93 allow_any_calls: bool = False
99 allow_any_calls: bool = False
94 allowed_calls: Set[Callable] = field(default_factory=set)
100 allowed_calls: Set[Callable] = field(default_factory=set)
95
101
96 def can_get_item(self, value, item):
102 def can_get_item(self, value, item):
97 return self.allow_item_access
103 return self.allow_item_access
98
104
99 def can_get_attr(self, value, attr):
105 def can_get_attr(self, value, attr):
100 return self.allow_attr_access
106 return self.allow_attr_access
101
107
102 def can_operate(self, dunders: Tuple[str, ...], a, b=None):
108 def can_operate(self, dunders: Tuple[str, ...], a, b=None):
103 if self.allow_all_operations:
109 if self.allow_all_operations:
104 return True
110 return True
105
111
106 def can_call(self, func):
112 def can_call(self, func):
107 if self.allow_any_calls:
113 if self.allow_any_calls:
108 return True
114 return True
109
115
110 if func in self.allowed_calls:
116 if func in self.allowed_calls:
111 return True
117 return True
112
118
113 owner_method = _unbind_method(func)
119 owner_method = _unbind_method(func)
114
120
115 if owner_method and owner_method in self.allowed_calls:
121 if owner_method and owner_method in self.allowed_calls:
116 return True
122 return True
117
123
118
124
119 def _get_external(module_name: str, access_path: Sequence[str]):
125 def _get_external(module_name: str, access_path: Sequence[str]):
120 """Get value from external module given a dotted access path.
126 """Get value from external module given a dotted access path.
121
127
122 Raises:
128 Raises:
123 * `KeyError` if module is removed not found, and
129 * `KeyError` if module is removed not found, and
124 * `AttributeError` if acess path does not match an exported object
130 * `AttributeError` if acess path does not match an exported object
125 """
131 """
126 member_type = sys.modules[module_name]
132 member_type = sys.modules[module_name]
127 for attr in access_path:
133 for attr in access_path:
128 member_type = getattr(member_type, attr)
134 member_type = getattr(member_type, attr)
129 return member_type
135 return member_type
130
136
131
137
132 def _has_original_dunder_external(
138 def _has_original_dunder_external(
133 value,
139 value,
134 module_name: str,
140 module_name: str,
135 access_path: Sequence[str],
141 access_path: Sequence[str],
136 method_name: str,
142 method_name: str,
137 ):
143 ):
138 if module_name not in sys.modules:
144 if module_name not in sys.modules:
139 # LBYLB as it is faster
145 # LBYLB as it is faster
140 return False
146 return False
141 try:
147 try:
142 member_type = _get_external(module_name, access_path)
148 member_type = _get_external(module_name, access_path)
143 value_type = type(value)
149 value_type = type(value)
144 if type(value) == member_type:
150 if type(value) == member_type:
145 return True
151 return True
146 if method_name == "__getattribute__":
152 if method_name == "__getattribute__":
147 # we have to short-circuit here due to an unresolved issue in
153 # we have to short-circuit here due to an unresolved issue in
148 # `isinstance` implementation: https://bugs.python.org/issue32683
154 # `isinstance` implementation: https://bugs.python.org/issue32683
149 return False
155 return False
150 if isinstance(value, member_type):
156 if isinstance(value, member_type):
151 method = getattr(value_type, method_name, None)
157 method = getattr(value_type, method_name, None)
152 member_method = getattr(member_type, method_name, None)
158 member_method = getattr(member_type, method_name, None)
153 if member_method == method:
159 if member_method == method:
154 return True
160 return True
155 except (AttributeError, KeyError):
161 except (AttributeError, KeyError):
156 return False
162 return False
157
163
158
164
159 def _has_original_dunder(
165 def _has_original_dunder(
160 value, allowed_types, allowed_methods, allowed_external, method_name
166 value, allowed_types, allowed_methods, allowed_external, method_name
161 ):
167 ):
162 # note: Python ignores `__getattr__`/`__getitem__` on instances,
168 # note: Python ignores `__getattr__`/`__getitem__` on instances,
163 # we only need to check at class level
169 # we only need to check at class level
164 value_type = type(value)
170 value_type = type(value)
165
171
166 # strict type check passes β†’ no need to check method
172 # strict type check passes β†’ no need to check method
167 if value_type in allowed_types:
173 if value_type in allowed_types:
168 return True
174 return True
169
175
170 method = getattr(value_type, method_name, None)
176 method = getattr(value_type, method_name, None)
171
177
172 if method is None:
178 if method is None:
173 return None
179 return None
174
180
175 if method in allowed_methods:
181 if method in allowed_methods:
176 return True
182 return True
177
183
178 for module_name, *access_path in allowed_external:
184 for module_name, *access_path in allowed_external:
179 if _has_original_dunder_external(value, module_name, access_path, method_name):
185 if _has_original_dunder_external(value, module_name, access_path, method_name):
180 return True
186 return True
181
187
182 return False
188 return False
183
189
184
190
185 @undoc
191 @undoc
186 @dataclass
192 @dataclass
187 class SelectivePolicy(EvaluationPolicy):
193 class SelectivePolicy(EvaluationPolicy):
188 allowed_getitem: Set[InstancesHaveGetItem] = field(default_factory=set)
194 allowed_getitem: Set[InstancesHaveGetItem] = field(default_factory=set)
189 allowed_getitem_external: Set[Tuple[str, ...]] = field(default_factory=set)
195 allowed_getitem_external: Set[Tuple[str, ...]] = field(default_factory=set)
190
196
191 allowed_getattr: Set[MayHaveGetattr] = field(default_factory=set)
197 allowed_getattr: Set[MayHaveGetattr] = field(default_factory=set)
192 allowed_getattr_external: Set[Tuple[str, ...]] = field(default_factory=set)
198 allowed_getattr_external: Set[Tuple[str, ...]] = field(default_factory=set)
193
199
194 allowed_operations: Set = field(default_factory=set)
200 allowed_operations: Set = field(default_factory=set)
195 allowed_operations_external: Set[Tuple[str, ...]] = field(default_factory=set)
201 allowed_operations_external: Set[Tuple[str, ...]] = field(default_factory=set)
196
202
197 _operation_methods_cache: Dict[str, Set[Callable]] = field(
203 _operation_methods_cache: Dict[str, Set[Callable]] = field(
198 default_factory=dict, init=False
204 default_factory=dict, init=False
199 )
205 )
200
206
201 def can_get_attr(self, value, attr):
207 def can_get_attr(self, value, attr):
202 has_original_attribute = _has_original_dunder(
208 has_original_attribute = _has_original_dunder(
203 value,
209 value,
204 allowed_types=self.allowed_getattr,
210 allowed_types=self.allowed_getattr,
205 allowed_methods=self._getattribute_methods,
211 allowed_methods=self._getattribute_methods,
206 allowed_external=self.allowed_getattr_external,
212 allowed_external=self.allowed_getattr_external,
207 method_name="__getattribute__",
213 method_name="__getattribute__",
208 )
214 )
209 has_original_attr = _has_original_dunder(
215 has_original_attr = _has_original_dunder(
210 value,
216 value,
211 allowed_types=self.allowed_getattr,
217 allowed_types=self.allowed_getattr,
212 allowed_methods=self._getattr_methods,
218 allowed_methods=self._getattr_methods,
213 allowed_external=self.allowed_getattr_external,
219 allowed_external=self.allowed_getattr_external,
214 method_name="__getattr__",
220 method_name="__getattr__",
215 )
221 )
216
222
217 accept = False
223 accept = False
218
224
219 # Many objects do not have `__getattr__`, this is fine.
225 # Many objects do not have `__getattr__`, this is fine.
220 if has_original_attr is None and has_original_attribute:
226 if has_original_attr is None and has_original_attribute:
221 accept = True
227 accept = True
222 else:
228 else:
223 # Accept objects without modifications to `__getattr__` and `__getattribute__`
229 # Accept objects without modifications to `__getattr__` and `__getattribute__`
224 accept = has_original_attr and has_original_attribute
230 accept = has_original_attr and has_original_attribute
225
231
226 if accept:
232 if accept:
227 # We still need to check for overriden properties.
233 # We still need to check for overriden properties.
228
234
229 value_class = type(value)
235 value_class = type(value)
230 if not hasattr(value_class, attr):
236 if not hasattr(value_class, attr):
231 return True
237 return True
232
238
233 class_attr_val = getattr(value_class, attr)
239 class_attr_val = getattr(value_class, attr)
234 is_property = isinstance(class_attr_val, property)
240 is_property = isinstance(class_attr_val, property)
235
241
236 if not is_property:
242 if not is_property:
237 return True
243 return True
238
244
239 # Properties in allowed types are ok (although we do not include any
245 # Properties in allowed types are ok (although we do not include any
240 # properties in our default allow list currently).
246 # properties in our default allow list currently).
241 if type(value) in self.allowed_getattr:
247 if type(value) in self.allowed_getattr:
242 return True # pragma: no cover
248 return True # pragma: no cover
243
249
244 # Properties in subclasses of allowed types may be ok if not changed
250 # Properties in subclasses of allowed types may be ok if not changed
245 for module_name, *access_path in self.allowed_getattr_external:
251 for module_name, *access_path in self.allowed_getattr_external:
246 try:
252 try:
247 external_class = _get_external(module_name, access_path)
253 external_class = _get_external(module_name, access_path)
248 external_class_attr_val = getattr(external_class, attr)
254 external_class_attr_val = getattr(external_class, attr)
249 except (KeyError, AttributeError):
255 except (KeyError, AttributeError):
250 return False # pragma: no cover
256 return False # pragma: no cover
251 return class_attr_val == external_class_attr_val
257 return class_attr_val == external_class_attr_val
252
258
253 return False
259 return False
254
260
255 def can_get_item(self, value, item):
261 def can_get_item(self, value, item):
256 """Allow accessing `__getiitem__` of allow-listed instances unless it was not modified."""
262 """Allow accessing `__getiitem__` of allow-listed instances unless it was not modified."""
257 return _has_original_dunder(
263 return _has_original_dunder(
258 value,
264 value,
259 allowed_types=self.allowed_getitem,
265 allowed_types=self.allowed_getitem,
260 allowed_methods=self._getitem_methods,
266 allowed_methods=self._getitem_methods,
261 allowed_external=self.allowed_getitem_external,
267 allowed_external=self.allowed_getitem_external,
262 method_name="__getitem__",
268 method_name="__getitem__",
263 )
269 )
264
270
265 def can_operate(self, dunders: Tuple[str, ...], a, b=None):
271 def can_operate(self, dunders: Tuple[str, ...], a, b=None):
266 objects = [a]
272 objects = [a]
267 if b is not None:
273 if b is not None:
268 objects.append(b)
274 objects.append(b)
269 return all(
275 return all(
270 [
276 [
271 _has_original_dunder(
277 _has_original_dunder(
272 obj,
278 obj,
273 allowed_types=self.allowed_operations,
279 allowed_types=self.allowed_operations,
274 allowed_methods=self._operator_dunder_methods(dunder),
280 allowed_methods=self._operator_dunder_methods(dunder),
275 allowed_external=self.allowed_operations_external,
281 allowed_external=self.allowed_operations_external,
276 method_name=dunder,
282 method_name=dunder,
277 )
283 )
278 for dunder in dunders
284 for dunder in dunders
279 for obj in objects
285 for obj in objects
280 ]
286 ]
281 )
287 )
282
288
283 def _operator_dunder_methods(self, dunder: str) -> Set[Callable]:
289 def _operator_dunder_methods(self, dunder: str) -> Set[Callable]:
284 if dunder not in self._operation_methods_cache:
290 if dunder not in self._operation_methods_cache:
285 self._operation_methods_cache[dunder] = self._safe_get_methods(
291 self._operation_methods_cache[dunder] = self._safe_get_methods(
286 self.allowed_operations, dunder
292 self.allowed_operations, dunder
287 )
293 )
288 return self._operation_methods_cache[dunder]
294 return self._operation_methods_cache[dunder]
289
295
290 @cached_property
296 @cached_property
291 def _getitem_methods(self) -> Set[Callable]:
297 def _getitem_methods(self) -> Set[Callable]:
292 return self._safe_get_methods(self.allowed_getitem, "__getitem__")
298 return self._safe_get_methods(self.allowed_getitem, "__getitem__")
293
299
294 @cached_property
300 @cached_property
295 def _getattr_methods(self) -> Set[Callable]:
301 def _getattr_methods(self) -> Set[Callable]:
296 return self._safe_get_methods(self.allowed_getattr, "__getattr__")
302 return self._safe_get_methods(self.allowed_getattr, "__getattr__")
297
303
298 @cached_property
304 @cached_property
299 def _getattribute_methods(self) -> Set[Callable]:
305 def _getattribute_methods(self) -> Set[Callable]:
300 return self._safe_get_methods(self.allowed_getattr, "__getattribute__")
306 return self._safe_get_methods(self.allowed_getattr, "__getattribute__")
301
307
302 def _safe_get_methods(self, classes, name) -> Set[Callable]:
308 def _safe_get_methods(self, classes, name) -> Set[Callable]:
303 return {
309 return {
304 method
310 method
305 for class_ in classes
311 for class_ in classes
306 for method in [getattr(class_, name, None)]
312 for method in [getattr(class_, name, None)]
307 if method
313 if method
308 }
314 }
309
315
310
316
311 class _DummyNamedTuple(NamedTuple):
317 class _DummyNamedTuple(NamedTuple):
312 """Used internally to retrieve methods of named tuple instance."""
318 """Used internally to retrieve methods of named tuple instance."""
313
319
314
320
315 class EvaluationContext(NamedTuple):
321 class EvaluationContext(NamedTuple):
316 #: Local namespace
322 #: Local namespace
317 locals: dict
323 locals: dict
318 #: Global namespace
324 #: Global namespace
319 globals: dict
325 globals: dict
320 #: Evaluation policy identifier
326 #: Evaluation policy identifier
321 evaluation: Literal[
327 evaluation: Literal[
322 "forbidden", "minimal", "limited", "unsafe", "dangerous"
328 "forbidden", "minimal", "limited", "unsafe", "dangerous"
323 ] = "forbidden"
329 ] = "forbidden"
324 #: Whether the evalution of code takes place inside of a subscript.
330 #: Whether the evalution of code takes place inside of a subscript.
325 #: Useful for evaluating ``:-1, 'col'`` in ``df[:-1, 'col']``.
331 #: Useful for evaluating ``:-1, 'col'`` in ``df[:-1, 'col']``.
326 in_subscript: bool = False
332 in_subscript: bool = False
327
333
328
334
329 class _IdentitySubscript:
335 class _IdentitySubscript:
330 """Returns the key itself when item is requested via subscript."""
336 """Returns the key itself when item is requested via subscript."""
331
337
332 def __getitem__(self, key):
338 def __getitem__(self, key):
333 return key
339 return key
334
340
335
341
336 IDENTITY_SUBSCRIPT = _IdentitySubscript()
342 IDENTITY_SUBSCRIPT = _IdentitySubscript()
337 SUBSCRIPT_MARKER = "__SUBSCRIPT_SENTINEL__"
343 SUBSCRIPT_MARKER = "__SUBSCRIPT_SENTINEL__"
338 UNKNOWN_SIGNATURE = Signature()
344 UNKNOWN_SIGNATURE = Signature()
339 NOT_EVALUATED = object()
345 NOT_EVALUATED = object()
340
346
341
347
342 class GuardRejection(Exception):
348 class GuardRejection(Exception):
343 """Exception raised when guard rejects evaluation attempt."""
349 """Exception raised when guard rejects evaluation attempt."""
344
350
345 pass
351 pass
346
352
347
353
348 def guarded_eval(code: str, context: EvaluationContext):
354 def guarded_eval(code: str, context: EvaluationContext):
349 """Evaluate provided code in the evaluation context.
355 """Evaluate provided code in the evaluation context.
350
356
351 If evaluation policy given by context is set to ``forbidden``
357 If evaluation policy given by context is set to ``forbidden``
352 no evaluation will be performed; if it is set to ``dangerous``
358 no evaluation will be performed; if it is set to ``dangerous``
353 standard :func:`eval` will be used; finally, for any other,
359 standard :func:`eval` will be used; finally, for any other,
354 policy :func:`eval_node` will be called on parsed AST.
360 policy :func:`eval_node` will be called on parsed AST.
355 """
361 """
356 locals_ = context.locals
362 locals_ = context.locals
357
363
358 if context.evaluation == "forbidden":
364 if context.evaluation == "forbidden":
359 raise GuardRejection("Forbidden mode")
365 raise GuardRejection("Forbidden mode")
360
366
361 # note: not using `ast.literal_eval` as it does not implement
367 # note: not using `ast.literal_eval` as it does not implement
362 # getitem at all, for example it fails on simple `[0][1]`
368 # getitem at all, for example it fails on simple `[0][1]`
363
369
364 if context.in_subscript:
370 if context.in_subscript:
365 # syntatic sugar for ellipsis (:) is only available in susbcripts
371 # syntatic sugar for ellipsis (:) is only available in susbcripts
366 # so we need to trick the ast parser into thinking that we have
372 # so we need to trick the ast parser into thinking that we have
367 # a subscript, but we need to be able to later recognise that we did
373 # a subscript, but we need to be able to later recognise that we did
368 # it so we can ignore the actual __getitem__ operation
374 # it so we can ignore the actual __getitem__ operation
369 if not code:
375 if not code:
370 return tuple()
376 return tuple()
371 locals_ = locals_.copy()
377 locals_ = locals_.copy()
372 locals_[SUBSCRIPT_MARKER] = IDENTITY_SUBSCRIPT
378 locals_[SUBSCRIPT_MARKER] = IDENTITY_SUBSCRIPT
373 code = SUBSCRIPT_MARKER + "[" + code + "]"
379 code = SUBSCRIPT_MARKER + "[" + code + "]"
374 context = EvaluationContext(**{**context._asdict(), **{"locals": locals_}})
380 context = EvaluationContext(**{**context._asdict(), **{"locals": locals_}})
375
381
376 if context.evaluation == "dangerous":
382 if context.evaluation == "dangerous":
377 return eval(code, context.globals, context.locals)
383 return eval(code, context.globals, context.locals)
378
384
379 expression = ast.parse(code, mode="eval")
385 expression = ast.parse(code, mode="eval")
380
386
381 return eval_node(expression, context)
387 return eval_node(expression, context)
382
388
383
389
384 BINARY_OP_DUNDERS: Dict[Type[ast.operator], Tuple[str]] = {
390 BINARY_OP_DUNDERS: Dict[Type[ast.operator], Tuple[str]] = {
385 ast.Add: ("__add__",),
391 ast.Add: ("__add__",),
386 ast.Sub: ("__sub__",),
392 ast.Sub: ("__sub__",),
387 ast.Mult: ("__mul__",),
393 ast.Mult: ("__mul__",),
388 ast.Div: ("__truediv__",),
394 ast.Div: ("__truediv__",),
389 ast.FloorDiv: ("__floordiv__",),
395 ast.FloorDiv: ("__floordiv__",),
390 ast.Mod: ("__mod__",),
396 ast.Mod: ("__mod__",),
391 ast.Pow: ("__pow__",),
397 ast.Pow: ("__pow__",),
392 ast.LShift: ("__lshift__",),
398 ast.LShift: ("__lshift__",),
393 ast.RShift: ("__rshift__",),
399 ast.RShift: ("__rshift__",),
394 ast.BitOr: ("__or__",),
400 ast.BitOr: ("__or__",),
395 ast.BitXor: ("__xor__",),
401 ast.BitXor: ("__xor__",),
396 ast.BitAnd: ("__and__",),
402 ast.BitAnd: ("__and__",),
397 ast.MatMult: ("__matmul__",),
403 ast.MatMult: ("__matmul__",),
398 }
404 }
399
405
400 COMP_OP_DUNDERS: Dict[Type[ast.cmpop], Tuple[str, ...]] = {
406 COMP_OP_DUNDERS: Dict[Type[ast.cmpop], Tuple[str, ...]] = {
401 ast.Eq: ("__eq__",),
407 ast.Eq: ("__eq__",),
402 ast.NotEq: ("__ne__", "__eq__"),
408 ast.NotEq: ("__ne__", "__eq__"),
403 ast.Lt: ("__lt__", "__gt__"),
409 ast.Lt: ("__lt__", "__gt__"),
404 ast.LtE: ("__le__", "__ge__"),
410 ast.LtE: ("__le__", "__ge__"),
405 ast.Gt: ("__gt__", "__lt__"),
411 ast.Gt: ("__gt__", "__lt__"),
406 ast.GtE: ("__ge__", "__le__"),
412 ast.GtE: ("__ge__", "__le__"),
407 ast.In: ("__contains__",),
413 ast.In: ("__contains__",),
408 # Note: ast.Is, ast.IsNot, ast.NotIn are handled specially
414 # Note: ast.Is, ast.IsNot, ast.NotIn are handled specially
409 }
415 }
410
416
411 UNARY_OP_DUNDERS: Dict[Type[ast.unaryop], Tuple[str, ...]] = {
417 UNARY_OP_DUNDERS: Dict[Type[ast.unaryop], Tuple[str, ...]] = {
412 ast.USub: ("__neg__",),
418 ast.USub: ("__neg__",),
413 ast.UAdd: ("__pos__",),
419 ast.UAdd: ("__pos__",),
414 # we have to check both __inv__ and __invert__!
420 # we have to check both __inv__ and __invert__!
415 ast.Invert: ("__invert__", "__inv__"),
421 ast.Invert: ("__invert__", "__inv__"),
416 ast.Not: ("__not__",),
422 ast.Not: ("__not__",),
417 }
423 }
418
424
419
425
420 class Duck:
426 class Duck:
421 """A dummy class used to create objects of other classes without calling their ``__init__``"""
427 """A dummy class used to create objects of other classes without calling their ``__init__``"""
422
428
423
429
424 def _find_dunder(node_op, dunders) -> Union[Tuple[str, ...], None]:
430 def _find_dunder(node_op, dunders) -> Union[Tuple[str, ...], None]:
425 dunder = None
431 dunder = None
426 for op, candidate_dunder in dunders.items():
432 for op, candidate_dunder in dunders.items():
427 if isinstance(node_op, op):
433 if isinstance(node_op, op):
428 dunder = candidate_dunder
434 dunder = candidate_dunder
429 return dunder
435 return dunder
430
436
431
437
432 def eval_node(node: Union[ast.AST, None], context: EvaluationContext):
438 def eval_node(node: Union[ast.AST, None], context: EvaluationContext):
433 """Evaluate AST node in provided context.
439 """Evaluate AST node in provided context.
434
440
435 Applies evaluation restrictions defined in the context. Currently does not support evaluation of functions with keyword arguments.
441 Applies evaluation restrictions defined in the context. Currently does not support evaluation of functions with keyword arguments.
436
442
437 Does not evaluate actions that always have side effects:
443 Does not evaluate actions that always have side effects:
438
444
439 - class definitions (``class sth: ...``)
445 - class definitions (``class sth: ...``)
440 - function definitions (``def sth: ...``)
446 - function definitions (``def sth: ...``)
441 - variable assignments (``x = 1``)
447 - variable assignments (``x = 1``)
442 - augmented assignments (``x += 1``)
448 - augmented assignments (``x += 1``)
443 - deletions (``del x``)
449 - deletions (``del x``)
444
450
445 Does not evaluate operations which do not return values:
451 Does not evaluate operations which do not return values:
446
452
447 - assertions (``assert x``)
453 - assertions (``assert x``)
448 - pass (``pass``)
454 - pass (``pass``)
449 - imports (``import x``)
455 - imports (``import x``)
450 - control flow:
456 - control flow:
451
457
452 - conditionals (``if x:``) except for ternary IfExp (``a if x else b``)
458 - conditionals (``if x:``) except for ternary IfExp (``a if x else b``)
453 - loops (``for`` and ``while``)
459 - loops (``for`` and ``while``)
454 - exception handling
460 - exception handling
455
461
456 The purpose of this function is to guard against unwanted side-effects;
462 The purpose of this function is to guard against unwanted side-effects;
457 it does not give guarantees on protection from malicious code execution.
463 it does not give guarantees on protection from malicious code execution.
458 """
464 """
459 policy = EVALUATION_POLICIES[context.evaluation]
465 policy = EVALUATION_POLICIES[context.evaluation]
460 if node is None:
466 if node is None:
461 return None
467 return None
462 if isinstance(node, ast.Expression):
468 if isinstance(node, ast.Expression):
463 return eval_node(node.body, context)
469 return eval_node(node.body, context)
464 if isinstance(node, ast.BinOp):
470 if isinstance(node, ast.BinOp):
465 left = eval_node(node.left, context)
471 left = eval_node(node.left, context)
466 right = eval_node(node.right, context)
472 right = eval_node(node.right, context)
467 dunders = _find_dunder(node.op, BINARY_OP_DUNDERS)
473 dunders = _find_dunder(node.op, BINARY_OP_DUNDERS)
468 if dunders:
474 if dunders:
469 if policy.can_operate(dunders, left, right):
475 if policy.can_operate(dunders, left, right):
470 return getattr(left, dunders[0])(right)
476 return getattr(left, dunders[0])(right)
471 else:
477 else:
472 raise GuardRejection(
478 raise GuardRejection(
473 f"Operation (`{dunders}`) for",
479 f"Operation (`{dunders}`) for",
474 type(left),
480 type(left),
475 f"not allowed in {context.evaluation} mode",
481 f"not allowed in {context.evaluation} mode",
476 )
482 )
477 if isinstance(node, ast.Compare):
483 if isinstance(node, ast.Compare):
478 left = eval_node(node.left, context)
484 left = eval_node(node.left, context)
479 all_true = True
485 all_true = True
480 negate = False
486 negate = False
481 for op, right in zip(node.ops, node.comparators):
487 for op, right in zip(node.ops, node.comparators):
482 right = eval_node(right, context)
488 right = eval_node(right, context)
483 dunder = None
489 dunder = None
484 dunders = _find_dunder(op, COMP_OP_DUNDERS)
490 dunders = _find_dunder(op, COMP_OP_DUNDERS)
485 if not dunders:
491 if not dunders:
486 if isinstance(op, ast.NotIn):
492 if isinstance(op, ast.NotIn):
487 dunders = COMP_OP_DUNDERS[ast.In]
493 dunders = COMP_OP_DUNDERS[ast.In]
488 negate = True
494 negate = True
489 if isinstance(op, ast.Is):
495 if isinstance(op, ast.Is):
490 dunder = "is_"
496 dunder = "is_"
491 if isinstance(op, ast.IsNot):
497 if isinstance(op, ast.IsNot):
492 dunder = "is_"
498 dunder = "is_"
493 negate = True
499 negate = True
494 if not dunder and dunders:
500 if not dunder and dunders:
495 dunder = dunders[0]
501 dunder = dunders[0]
496 if dunder:
502 if dunder:
497 a, b = (right, left) if dunder == "__contains__" else (left, right)
503 a, b = (right, left) if dunder == "__contains__" else (left, right)
498 if dunder == "is_" or dunders and policy.can_operate(dunders, a, b):
504 if dunder == "is_" or dunders and policy.can_operate(dunders, a, b):
499 result = getattr(operator, dunder)(a, b)
505 result = getattr(operator, dunder)(a, b)
500 if negate:
506 if negate:
501 result = not result
507 result = not result
502 if not result:
508 if not result:
503 all_true = False
509 all_true = False
504 left = right
510 left = right
505 else:
511 else:
506 raise GuardRejection(
512 raise GuardRejection(
507 f"Comparison (`{dunder}`) for",
513 f"Comparison (`{dunder}`) for",
508 type(left),
514 type(left),
509 f"not allowed in {context.evaluation} mode",
515 f"not allowed in {context.evaluation} mode",
510 )
516 )
511 else:
517 else:
512 raise ValueError(
518 raise ValueError(
513 f"Comparison `{dunder}` not supported"
519 f"Comparison `{dunder}` not supported"
514 ) # pragma: no cover
520 ) # pragma: no cover
515 return all_true
521 return all_true
516 if isinstance(node, ast.Constant):
522 if isinstance(node, ast.Constant):
517 return node.value
523 return node.value
518 if isinstance(node, ast.Tuple):
524 if isinstance(node, ast.Tuple):
519 return tuple(eval_node(e, context) for e in node.elts)
525 return tuple(eval_node(e, context) for e in node.elts)
520 if isinstance(node, ast.List):
526 if isinstance(node, ast.List):
521 return [eval_node(e, context) for e in node.elts]
527 return [eval_node(e, context) for e in node.elts]
522 if isinstance(node, ast.Set):
528 if isinstance(node, ast.Set):
523 return {eval_node(e, context) for e in node.elts}
529 return {eval_node(e, context) for e in node.elts}
524 if isinstance(node, ast.Dict):
530 if isinstance(node, ast.Dict):
525 return dict(
531 return dict(
526 zip(
532 zip(
527 [eval_node(k, context) for k in node.keys],
533 [eval_node(k, context) for k in node.keys],
528 [eval_node(v, context) for v in node.values],
534 [eval_node(v, context) for v in node.values],
529 )
535 )
530 )
536 )
531 if isinstance(node, ast.Slice):
537 if isinstance(node, ast.Slice):
532 return slice(
538 return slice(
533 eval_node(node.lower, context),
539 eval_node(node.lower, context),
534 eval_node(node.upper, context),
540 eval_node(node.upper, context),
535 eval_node(node.step, context),
541 eval_node(node.step, context),
536 )
542 )
537 if isinstance(node, ast.UnaryOp):
543 if isinstance(node, ast.UnaryOp):
538 value = eval_node(node.operand, context)
544 value = eval_node(node.operand, context)
539 dunders = _find_dunder(node.op, UNARY_OP_DUNDERS)
545 dunders = _find_dunder(node.op, UNARY_OP_DUNDERS)
540 if dunders:
546 if dunders:
541 if policy.can_operate(dunders, value):
547 if policy.can_operate(dunders, value):
542 return getattr(value, dunders[0])()
548 return getattr(value, dunders[0])()
543 else:
549 else:
544 raise GuardRejection(
550 raise GuardRejection(
545 f"Operation (`{dunders}`) for",
551 f"Operation (`{dunders}`) for",
546 type(value),
552 type(value),
547 f"not allowed in {context.evaluation} mode",
553 f"not allowed in {context.evaluation} mode",
548 )
554 )
549 if isinstance(node, ast.Subscript):
555 if isinstance(node, ast.Subscript):
550 value = eval_node(node.value, context)
556 value = eval_node(node.value, context)
551 slice_ = eval_node(node.slice, context)
557 slice_ = eval_node(node.slice, context)
552 if policy.can_get_item(value, slice_):
558 if policy.can_get_item(value, slice_):
553 return value[slice_]
559 return value[slice_]
554 raise GuardRejection(
560 raise GuardRejection(
555 "Subscript access (`__getitem__`) for",
561 "Subscript access (`__getitem__`) for",
556 type(value), # not joined to avoid calling `repr`
562 type(value), # not joined to avoid calling `repr`
557 f" not allowed in {context.evaluation} mode",
563 f" not allowed in {context.evaluation} mode",
558 )
564 )
559 if isinstance(node, ast.Name):
565 if isinstance(node, ast.Name):
560 return _eval_node_name(node.id, context)
566 return _eval_node_name(node.id, context)
561 if isinstance(node, ast.Attribute):
567 if isinstance(node, ast.Attribute):
562 value = eval_node(node.value, context)
568 value = eval_node(node.value, context)
563 if policy.can_get_attr(value, node.attr):
569 if policy.can_get_attr(value, node.attr):
564 return getattr(value, node.attr)
570 return getattr(value, node.attr)
565 raise GuardRejection(
571 raise GuardRejection(
566 "Attribute access (`__getattr__`) for",
572 "Attribute access (`__getattr__`) for",
567 type(value), # not joined to avoid calling `repr`
573 type(value), # not joined to avoid calling `repr`
568 f"not allowed in {context.evaluation} mode",
574 f"not allowed in {context.evaluation} mode",
569 )
575 )
570 if isinstance(node, ast.IfExp):
576 if isinstance(node, ast.IfExp):
571 test = eval_node(node.test, context)
577 test = eval_node(node.test, context)
572 if test:
578 if test:
573 return eval_node(node.body, context)
579 return eval_node(node.body, context)
574 else:
580 else:
575 return eval_node(node.orelse, context)
581 return eval_node(node.orelse, context)
576 if isinstance(node, ast.Call):
582 if isinstance(node, ast.Call):
577 func = eval_node(node.func, context)
583 func = eval_node(node.func, context)
578 if policy.can_call(func) and not node.keywords:
584 if policy.can_call(func) and not node.keywords:
579 args = [eval_node(arg, context) for arg in node.args]
585 args = [eval_node(arg, context) for arg in node.args]
580 return func(*args)
586 return func(*args)
581 if isclass(func):
587 if isclass(func):
582 # this code path gets entered when calling class e.g. `MyClass()`
588 # this code path gets entered when calling class e.g. `MyClass()`
583 # or `my_instance.__class__()` - in both cases `func` is `MyClass`.
589 # or `my_instance.__class__()` - in both cases `func` is `MyClass`.
584 # Should return `MyClass` if `__new__` is not overridden,
590 # Should return `MyClass` if `__new__` is not overridden,
585 # otherwise whatever `__new__` return type is.
591 # otherwise whatever `__new__` return type is.
586 overridden_return_type = _eval_return_type(func.__new__, node, context)
592 overridden_return_type = _eval_return_type(func.__new__, node, context)
587 if overridden_return_type is not NOT_EVALUATED:
593 if overridden_return_type is not NOT_EVALUATED:
588 return overridden_return_type
594 return overridden_return_type
589 return _create_duck_for_heap_type(func)
595 return _create_duck_for_heap_type(func)
590 else:
596 else:
591 return_type = _eval_return_type(func, node, context)
597 return_type = _eval_return_type(func, node, context)
592 if return_type is not NOT_EVALUATED:
598 if return_type is not NOT_EVALUATED:
593 return return_type
599 return return_type
594 raise GuardRejection(
600 raise GuardRejection(
595 "Call for",
601 "Call for",
596 func, # not joined to avoid calling `repr`
602 func, # not joined to avoid calling `repr`
597 f"not allowed in {context.evaluation} mode",
603 f"not allowed in {context.evaluation} mode",
598 )
604 )
599 raise ValueError("Unhandled node", ast.dump(node))
605 raise ValueError("Unhandled node", ast.dump(node))
600
606
601
607
602 def _eval_return_type(func: Callable, node: ast.Call, context: EvaluationContext):
608 def _eval_return_type(func: Callable, node: ast.Call, context: EvaluationContext):
603 """Evaluate return type of a given callable function.
609 """Evaluate return type of a given callable function.
604
610
605 Returns the built-in type, a duck or NOT_EVALUATED sentinel.
611 Returns the built-in type, a duck or NOT_EVALUATED sentinel.
606 """
612 """
607 try:
613 try:
608 sig = signature(func)
614 sig = signature(func)
609 except ValueError:
615 except ValueError:
610 sig = UNKNOWN_SIGNATURE
616 sig = UNKNOWN_SIGNATURE
611 # if annotation was not stringized, or it was stringized
617 # if annotation was not stringized, or it was stringized
612 # but resolved by signature call we know the return type
618 # but resolved by signature call we know the return type
613 not_empty = sig.return_annotation is not Signature.empty
619 not_empty = sig.return_annotation is not Signature.empty
614 stringized = isinstance(sig.return_annotation, str)
620 stringized = isinstance(sig.return_annotation, str)
615 if not_empty:
621 if not_empty:
616 return_type = (
622 return_type = (
617 _eval_node_name(sig.return_annotation, context)
623 _eval_node_name(sig.return_annotation, context)
618 if stringized
624 if stringized
619 else sig.return_annotation
625 else sig.return_annotation
620 )
626 )
621 if return_type is Self and hasattr(func, "__self__"):
627 if return_type is Self and hasattr(func, "__self__"):
622 return func.__self__
628 return func.__self__
623 elif get_origin(return_type) is Literal:
629 elif get_origin(return_type) is Literal:
624 type_args = get_args(return_type)
630 type_args = get_args(return_type)
625 if len(type_args) == 1:
631 if len(type_args) == 1:
626 return type_args[0]
632 return type_args[0]
627 elif isinstance(return_type, NewType):
633 elif isinstance(return_type, NewType):
628 return _eval_or_create_duck(return_type.__supertype__, node, context)
634 return _eval_or_create_duck(return_type.__supertype__, node, context)
629 elif isinstance(return_type, TypeAliasType):
635 elif isinstance(return_type, TypeAliasType):
630 return _eval_or_create_duck(return_type.__value__, node, context)
636 return _eval_or_create_duck(return_type.__value__, node, context)
631 else:
637 else:
632 return _eval_or_create_duck(return_type, node, context)
638 return _eval_or_create_duck(return_type, node, context)
633 return NOT_EVALUATED
639 return NOT_EVALUATED
634
640
635
641
636 def _eval_node_name(node_id: str, context: EvaluationContext):
642 def _eval_node_name(node_id: str, context: EvaluationContext):
637 policy = EVALUATION_POLICIES[context.evaluation]
643 policy = EVALUATION_POLICIES[context.evaluation]
638 if policy.allow_locals_access and node_id in context.locals:
644 if policy.allow_locals_access and node_id in context.locals:
639 return context.locals[node_id]
645 return context.locals[node_id]
640 if policy.allow_globals_access and node_id in context.globals:
646 if policy.allow_globals_access and node_id in context.globals:
641 return context.globals[node_id]
647 return context.globals[node_id]
642 if policy.allow_builtins_access and hasattr(builtins, node_id):
648 if policy.allow_builtins_access and hasattr(builtins, node_id):
643 # note: do not use __builtins__, it is implementation detail of cPython
649 # note: do not use __builtins__, it is implementation detail of cPython
644 return getattr(builtins, node_id)
650 return getattr(builtins, node_id)
645 if not policy.allow_globals_access and not policy.allow_locals_access:
651 if not policy.allow_globals_access and not policy.allow_locals_access:
646 raise GuardRejection(
652 raise GuardRejection(
647 f"Namespace access not allowed in {context.evaluation} mode"
653 f"Namespace access not allowed in {context.evaluation} mode"
648 )
654 )
649 else:
655 else:
650 raise NameError(f"{node_id} not found in locals, globals, nor builtins")
656 raise NameError(f"{node_id} not found in locals, globals, nor builtins")
651
657
652
658
653 def _eval_or_create_duck(duck_type, node: ast.Call, context: EvaluationContext):
659 def _eval_or_create_duck(duck_type, node: ast.Call, context: EvaluationContext):
654 policy = EVALUATION_POLICIES[context.evaluation]
660 policy = EVALUATION_POLICIES[context.evaluation]
655 # if allow-listed builtin is on type annotation, instantiate it
661 # if allow-listed builtin is on type annotation, instantiate it
656 if policy.can_call(duck_type) and not node.keywords:
662 if policy.can_call(duck_type) and not node.keywords:
657 args = [eval_node(arg, context) for arg in node.args]
663 args = [eval_node(arg, context) for arg in node.args]
658 return duck_type(*args)
664 return duck_type(*args)
659 # if custom class is in type annotation, mock it
665 # if custom class is in type annotation, mock it
660 return _create_duck_for_heap_type(duck_type)
666 return _create_duck_for_heap_type(duck_type)
661
667
662
668
663 def _create_duck_for_heap_type(duck_type):
669 def _create_duck_for_heap_type(duck_type):
664 """Create an imitation of an object of a given type (a duck).
670 """Create an imitation of an object of a given type (a duck).
665
671
666 Returns the duck or NOT_EVALUATED sentinel if duck could not be created.
672 Returns the duck or NOT_EVALUATED sentinel if duck could not be created.
667 """
673 """
668 duck = Duck()
674 duck = Duck()
669 try:
675 try:
670 # this only works for heap types, not builtins
676 # this only works for heap types, not builtins
671 duck.__class__ = duck_type
677 duck.__class__ = duck_type
672 return duck
678 return duck
673 except TypeError:
679 except TypeError:
674 pass
680 pass
675 return NOT_EVALUATED
681 return NOT_EVALUATED
676
682
677
683
678 SUPPORTED_EXTERNAL_GETITEM = {
684 SUPPORTED_EXTERNAL_GETITEM = {
679 ("pandas", "core", "indexing", "_iLocIndexer"),
685 ("pandas", "core", "indexing", "_iLocIndexer"),
680 ("pandas", "core", "indexing", "_LocIndexer"),
686 ("pandas", "core", "indexing", "_LocIndexer"),
681 ("pandas", "DataFrame"),
687 ("pandas", "DataFrame"),
682 ("pandas", "Series"),
688 ("pandas", "Series"),
683 ("numpy", "ndarray"),
689 ("numpy", "ndarray"),
684 ("numpy", "void"),
690 ("numpy", "void"),
685 }
691 }
686
692
687
693
688 BUILTIN_GETITEM: Set[InstancesHaveGetItem] = {
694 BUILTIN_GETITEM: Set[InstancesHaveGetItem] = {
689 dict,
695 dict,
690 str, # type: ignore[arg-type]
696 str, # type: ignore[arg-type]
691 bytes, # type: ignore[arg-type]
697 bytes, # type: ignore[arg-type]
692 list,
698 list,
693 tuple,
699 tuple,
694 collections.defaultdict,
700 collections.defaultdict,
695 collections.deque,
701 collections.deque,
696 collections.OrderedDict,
702 collections.OrderedDict,
697 collections.ChainMap,
703 collections.ChainMap,
698 collections.UserDict,
704 collections.UserDict,
699 collections.UserList,
705 collections.UserList,
700 collections.UserString, # type: ignore[arg-type]
706 collections.UserString, # type: ignore[arg-type]
701 _DummyNamedTuple,
707 _DummyNamedTuple,
702 _IdentitySubscript,
708 _IdentitySubscript,
703 }
709 }
704
710
705
711
706 def _list_methods(cls, source=None):
712 def _list_methods(cls, source=None):
707 """For use on immutable objects or with methods returning a copy"""
713 """For use on immutable objects or with methods returning a copy"""
708 return [getattr(cls, k) for k in (source if source else dir(cls))]
714 return [getattr(cls, k) for k in (source if source else dir(cls))]
709
715
710
716
711 dict_non_mutating_methods = ("copy", "keys", "values", "items")
717 dict_non_mutating_methods = ("copy", "keys", "values", "items")
712 list_non_mutating_methods = ("copy", "index", "count")
718 list_non_mutating_methods = ("copy", "index", "count")
713 set_non_mutating_methods = set(dir(set)) & set(dir(frozenset))
719 set_non_mutating_methods = set(dir(set)) & set(dir(frozenset))
714
720
715
721
716 dict_keys: Type[collections.abc.KeysView] = type({}.keys())
722 dict_keys: Type[collections.abc.KeysView] = type({}.keys())
717
723
718 NUMERICS = {int, float, complex}
724 NUMERICS = {int, float, complex}
719
725
720 ALLOWED_CALLS = {
726 ALLOWED_CALLS = {
721 bytes,
727 bytes,
722 *_list_methods(bytes),
728 *_list_methods(bytes),
723 dict,
729 dict,
724 *_list_methods(dict, dict_non_mutating_methods),
730 *_list_methods(dict, dict_non_mutating_methods),
725 dict_keys.isdisjoint,
731 dict_keys.isdisjoint,
726 list,
732 list,
727 *_list_methods(list, list_non_mutating_methods),
733 *_list_methods(list, list_non_mutating_methods),
728 set,
734 set,
729 *_list_methods(set, set_non_mutating_methods),
735 *_list_methods(set, set_non_mutating_methods),
730 frozenset,
736 frozenset,
731 *_list_methods(frozenset),
737 *_list_methods(frozenset),
732 range,
738 range,
733 str,
739 str,
734 *_list_methods(str),
740 *_list_methods(str),
735 tuple,
741 tuple,
736 *_list_methods(tuple),
742 *_list_methods(tuple),
737 *NUMERICS,
743 *NUMERICS,
738 *[method for numeric_cls in NUMERICS for method in _list_methods(numeric_cls)],
744 *[method for numeric_cls in NUMERICS for method in _list_methods(numeric_cls)],
739 collections.deque,
745 collections.deque,
740 *_list_methods(collections.deque, list_non_mutating_methods),
746 *_list_methods(collections.deque, list_non_mutating_methods),
741 collections.defaultdict,
747 collections.defaultdict,
742 *_list_methods(collections.defaultdict, dict_non_mutating_methods),
748 *_list_methods(collections.defaultdict, dict_non_mutating_methods),
743 collections.OrderedDict,
749 collections.OrderedDict,
744 *_list_methods(collections.OrderedDict, dict_non_mutating_methods),
750 *_list_methods(collections.OrderedDict, dict_non_mutating_methods),
745 collections.UserDict,
751 collections.UserDict,
746 *_list_methods(collections.UserDict, dict_non_mutating_methods),
752 *_list_methods(collections.UserDict, dict_non_mutating_methods),
747 collections.UserList,
753 collections.UserList,
748 *_list_methods(collections.UserList, list_non_mutating_methods),
754 *_list_methods(collections.UserList, list_non_mutating_methods),
749 collections.UserString,
755 collections.UserString,
750 *_list_methods(collections.UserString, dir(str)),
756 *_list_methods(collections.UserString, dir(str)),
751 collections.Counter,
757 collections.Counter,
752 *_list_methods(collections.Counter, dict_non_mutating_methods),
758 *_list_methods(collections.Counter, dict_non_mutating_methods),
753 collections.Counter.elements,
759 collections.Counter.elements,
754 collections.Counter.most_common,
760 collections.Counter.most_common,
755 }
761 }
756
762
757 BUILTIN_GETATTR: Set[MayHaveGetattr] = {
763 BUILTIN_GETATTR: Set[MayHaveGetattr] = {
758 *BUILTIN_GETITEM,
764 *BUILTIN_GETITEM,
759 set,
765 set,
760 frozenset,
766 frozenset,
761 object,
767 object,
762 type, # `type` handles a lot of generic cases, e.g. numbers as in `int.real`.
768 type, # `type` handles a lot of generic cases, e.g. numbers as in `int.real`.
763 *NUMERICS,
769 *NUMERICS,
764 dict_keys,
770 dict_keys,
765 MethodDescriptorType,
771 MethodDescriptorType,
766 ModuleType,
772 ModuleType,
767 }
773 }
768
774
769
775
770 BUILTIN_OPERATIONS = {*BUILTIN_GETATTR}
776 BUILTIN_OPERATIONS = {*BUILTIN_GETATTR}
771
777
772 EVALUATION_POLICIES = {
778 EVALUATION_POLICIES = {
773 "minimal": EvaluationPolicy(
779 "minimal": EvaluationPolicy(
774 allow_builtins_access=True,
780 allow_builtins_access=True,
775 allow_locals_access=False,
781 allow_locals_access=False,
776 allow_globals_access=False,
782 allow_globals_access=False,
777 allow_item_access=False,
783 allow_item_access=False,
778 allow_attr_access=False,
784 allow_attr_access=False,
779 allowed_calls=set(),
785 allowed_calls=set(),
780 allow_any_calls=False,
786 allow_any_calls=False,
781 allow_all_operations=False,
787 allow_all_operations=False,
782 ),
788 ),
783 "limited": SelectivePolicy(
789 "limited": SelectivePolicy(
784 allowed_getitem=BUILTIN_GETITEM,
790 allowed_getitem=BUILTIN_GETITEM,
785 allowed_getitem_external=SUPPORTED_EXTERNAL_GETITEM,
791 allowed_getitem_external=SUPPORTED_EXTERNAL_GETITEM,
786 allowed_getattr=BUILTIN_GETATTR,
792 allowed_getattr=BUILTIN_GETATTR,
787 allowed_getattr_external={
793 allowed_getattr_external={
788 # pandas Series/Frame implements custom `__getattr__`
794 # pandas Series/Frame implements custom `__getattr__`
789 ("pandas", "DataFrame"),
795 ("pandas", "DataFrame"),
790 ("pandas", "Series"),
796 ("pandas", "Series"),
791 },
797 },
792 allowed_operations=BUILTIN_OPERATIONS,
798 allowed_operations=BUILTIN_OPERATIONS,
793 allow_builtins_access=True,
799 allow_builtins_access=True,
794 allow_locals_access=True,
800 allow_locals_access=True,
795 allow_globals_access=True,
801 allow_globals_access=True,
796 allowed_calls=ALLOWED_CALLS,
802 allowed_calls=ALLOWED_CALLS,
797 ),
803 ),
798 "unsafe": EvaluationPolicy(
804 "unsafe": EvaluationPolicy(
799 allow_builtins_access=True,
805 allow_builtins_access=True,
800 allow_locals_access=True,
806 allow_locals_access=True,
801 allow_globals_access=True,
807 allow_globals_access=True,
802 allow_attr_access=True,
808 allow_attr_access=True,
803 allow_item_access=True,
809 allow_item_access=True,
804 allow_any_calls=True,
810 allow_any_calls=True,
805 allow_all_operations=True,
811 allow_all_operations=True,
806 ),
812 ),
807 }
813 }
808
814
809
815
810 __all__ = [
816 __all__ = [
811 "guarded_eval",
817 "guarded_eval",
812 "eval_node",
818 "eval_node",
813 "GuardRejection",
819 "GuardRejection",
814 "EvaluationContext",
820 "EvaluationContext",
815 "_unbind_method",
821 "_unbind_method",
816 ]
822 ]
@@ -1,681 +1,689 b''
1 import sys
1 from contextlib import contextmanager
2 from contextlib import contextmanager
2 from typing import NamedTuple, Literal, NewType
3 from typing import NamedTuple, Literal, NewType
3 from functools import partial
4 from functools import partial
4 from IPython.core.guarded_eval import (
5 from IPython.core.guarded_eval import (
5 EvaluationContext,
6 EvaluationContext,
6 GuardRejection,
7 GuardRejection,
7 guarded_eval,
8 guarded_eval,
8 _unbind_method,
9 _unbind_method,
9 )
10 )
10 from typing_extensions import (
11 Self, # Python >=3.10
12 TypeAliasType, # Python >=3.12
13 )
14 from IPython.testing import decorators as dec
11 from IPython.testing import decorators as dec
15 import pytest
12 import pytest
16
13
17
14
15 if sys.version_info < (3, 11):
16 from typing_extensions import Self
17 else:
18 from typing import Self
19
20 if sys.version_info < (3, 12):
21 from typing_extensions import TypeAliasType
22 else:
23 from typing import TypeAliasType
24
25
18 def create_context(evaluation: str, **kwargs):
26 def create_context(evaluation: str, **kwargs):
19 return EvaluationContext(locals=kwargs, globals={}, evaluation=evaluation)
27 return EvaluationContext(locals=kwargs, globals={}, evaluation=evaluation)
20
28
21
29
22 forbidden = partial(create_context, "forbidden")
30 forbidden = partial(create_context, "forbidden")
23 minimal = partial(create_context, "minimal")
31 minimal = partial(create_context, "minimal")
24 limited = partial(create_context, "limited")
32 limited = partial(create_context, "limited")
25 unsafe = partial(create_context, "unsafe")
33 unsafe = partial(create_context, "unsafe")
26 dangerous = partial(create_context, "dangerous")
34 dangerous = partial(create_context, "dangerous")
27
35
28 LIMITED_OR_HIGHER = [limited, unsafe, dangerous]
36 LIMITED_OR_HIGHER = [limited, unsafe, dangerous]
29 MINIMAL_OR_HIGHER = [minimal, *LIMITED_OR_HIGHER]
37 MINIMAL_OR_HIGHER = [minimal, *LIMITED_OR_HIGHER]
30
38
31
39
32 @contextmanager
40 @contextmanager
33 def module_not_installed(module: str):
41 def module_not_installed(module: str):
34 import sys
42 import sys
35
43
36 try:
44 try:
37 to_restore = sys.modules[module]
45 to_restore = sys.modules[module]
38 del sys.modules[module]
46 del sys.modules[module]
39 except KeyError:
47 except KeyError:
40 to_restore = None
48 to_restore = None
41 try:
49 try:
42 yield
50 yield
43 finally:
51 finally:
44 sys.modules[module] = to_restore
52 sys.modules[module] = to_restore
45
53
46
54
47 def test_external_not_installed():
55 def test_external_not_installed():
48 """
56 """
49 Because attribute check requires checking if object is not of allowed
57 Because attribute check requires checking if object is not of allowed
50 external type, this tests logic for absence of external module.
58 external type, this tests logic for absence of external module.
51 """
59 """
52
60
53 class Custom:
61 class Custom:
54 def __init__(self):
62 def __init__(self):
55 self.test = 1
63 self.test = 1
56
64
57 def __getattr__(self, key):
65 def __getattr__(self, key):
58 return key
66 return key
59
67
60 with module_not_installed("pandas"):
68 with module_not_installed("pandas"):
61 context = limited(x=Custom())
69 context = limited(x=Custom())
62 with pytest.raises(GuardRejection):
70 with pytest.raises(GuardRejection):
63 guarded_eval("x.test", context)
71 guarded_eval("x.test", context)
64
72
65
73
66 @dec.skip_without("pandas")
74 @dec.skip_without("pandas")
67 def test_external_changed_api(monkeypatch):
75 def test_external_changed_api(monkeypatch):
68 """Check that the execution rejects if external API changed paths"""
76 """Check that the execution rejects if external API changed paths"""
69 import pandas as pd
77 import pandas as pd
70
78
71 series = pd.Series([1], index=["a"])
79 series = pd.Series([1], index=["a"])
72
80
73 with monkeypatch.context() as m:
81 with monkeypatch.context() as m:
74 m.delattr(pd, "Series")
82 m.delattr(pd, "Series")
75 context = limited(data=series)
83 context = limited(data=series)
76 with pytest.raises(GuardRejection):
84 with pytest.raises(GuardRejection):
77 guarded_eval("data.iloc[0]", context)
85 guarded_eval("data.iloc[0]", context)
78
86
79
87
80 @dec.skip_without("pandas")
88 @dec.skip_without("pandas")
81 def test_pandas_series_iloc():
89 def test_pandas_series_iloc():
82 import pandas as pd
90 import pandas as pd
83
91
84 series = pd.Series([1], index=["a"])
92 series = pd.Series([1], index=["a"])
85 context = limited(data=series)
93 context = limited(data=series)
86 assert guarded_eval("data.iloc[0]", context) == 1
94 assert guarded_eval("data.iloc[0]", context) == 1
87
95
88
96
89 def test_rejects_custom_properties():
97 def test_rejects_custom_properties():
90 class BadProperty:
98 class BadProperty:
91 @property
99 @property
92 def iloc(self):
100 def iloc(self):
93 return [None]
101 return [None]
94
102
95 series = BadProperty()
103 series = BadProperty()
96 context = limited(data=series)
104 context = limited(data=series)
97
105
98 with pytest.raises(GuardRejection):
106 with pytest.raises(GuardRejection):
99 guarded_eval("data.iloc[0]", context)
107 guarded_eval("data.iloc[0]", context)
100
108
101
109
102 @dec.skip_without("pandas")
110 @dec.skip_without("pandas")
103 def test_accepts_non_overriden_properties():
111 def test_accepts_non_overriden_properties():
104 import pandas as pd
112 import pandas as pd
105
113
106 class GoodProperty(pd.Series):
114 class GoodProperty(pd.Series):
107 pass
115 pass
108
116
109 series = GoodProperty([1], index=["a"])
117 series = GoodProperty([1], index=["a"])
110 context = limited(data=series)
118 context = limited(data=series)
111
119
112 assert guarded_eval("data.iloc[0]", context) == 1
120 assert guarded_eval("data.iloc[0]", context) == 1
113
121
114
122
115 @dec.skip_without("pandas")
123 @dec.skip_without("pandas")
116 def test_pandas_series():
124 def test_pandas_series():
117 import pandas as pd
125 import pandas as pd
118
126
119 context = limited(data=pd.Series([1], index=["a"]))
127 context = limited(data=pd.Series([1], index=["a"]))
120 assert guarded_eval('data["a"]', context) == 1
128 assert guarded_eval('data["a"]', context) == 1
121 with pytest.raises(KeyError):
129 with pytest.raises(KeyError):
122 guarded_eval('data["c"]', context)
130 guarded_eval('data["c"]', context)
123
131
124
132
125 @dec.skip_without("pandas")
133 @dec.skip_without("pandas")
126 def test_pandas_bad_series():
134 def test_pandas_bad_series():
127 import pandas as pd
135 import pandas as pd
128
136
129 class BadItemSeries(pd.Series):
137 class BadItemSeries(pd.Series):
130 def __getitem__(self, key):
138 def __getitem__(self, key):
131 return "CUSTOM_ITEM"
139 return "CUSTOM_ITEM"
132
140
133 class BadAttrSeries(pd.Series):
141 class BadAttrSeries(pd.Series):
134 def __getattr__(self, key):
142 def __getattr__(self, key):
135 return "CUSTOM_ATTR"
143 return "CUSTOM_ATTR"
136
144
137 bad_series = BadItemSeries([1], index=["a"])
145 bad_series = BadItemSeries([1], index=["a"])
138 context = limited(data=bad_series)
146 context = limited(data=bad_series)
139
147
140 with pytest.raises(GuardRejection):
148 with pytest.raises(GuardRejection):
141 guarded_eval('data["a"]', context)
149 guarded_eval('data["a"]', context)
142 with pytest.raises(GuardRejection):
150 with pytest.raises(GuardRejection):
143 guarded_eval('data["c"]', context)
151 guarded_eval('data["c"]', context)
144
152
145 # note: here result is a bit unexpected because
153 # note: here result is a bit unexpected because
146 # pandas `__getattr__` calls `__getitem__`;
154 # pandas `__getattr__` calls `__getitem__`;
147 # FIXME - special case to handle it?
155 # FIXME - special case to handle it?
148 assert guarded_eval("data.a", context) == "CUSTOM_ITEM"
156 assert guarded_eval("data.a", context) == "CUSTOM_ITEM"
149
157
150 context = unsafe(data=bad_series)
158 context = unsafe(data=bad_series)
151 assert guarded_eval('data["a"]', context) == "CUSTOM_ITEM"
159 assert guarded_eval('data["a"]', context) == "CUSTOM_ITEM"
152
160
153 bad_attr_series = BadAttrSeries([1], index=["a"])
161 bad_attr_series = BadAttrSeries([1], index=["a"])
154 context = limited(data=bad_attr_series)
162 context = limited(data=bad_attr_series)
155 assert guarded_eval('data["a"]', context) == 1
163 assert guarded_eval('data["a"]', context) == 1
156 with pytest.raises(GuardRejection):
164 with pytest.raises(GuardRejection):
157 guarded_eval("data.a", context)
165 guarded_eval("data.a", context)
158
166
159
167
160 @dec.skip_without("pandas")
168 @dec.skip_without("pandas")
161 def test_pandas_dataframe_loc():
169 def test_pandas_dataframe_loc():
162 import pandas as pd
170 import pandas as pd
163 from pandas.testing import assert_series_equal
171 from pandas.testing import assert_series_equal
164
172
165 data = pd.DataFrame([{"a": 1}])
173 data = pd.DataFrame([{"a": 1}])
166 context = limited(data=data)
174 context = limited(data=data)
167 assert_series_equal(guarded_eval('data.loc[:, "a"]', context), data["a"])
175 assert_series_equal(guarded_eval('data.loc[:, "a"]', context), data["a"])
168
176
169
177
170 def test_named_tuple():
178 def test_named_tuple():
171 class GoodNamedTuple(NamedTuple):
179 class GoodNamedTuple(NamedTuple):
172 a: str
180 a: str
173 pass
181 pass
174
182
175 class BadNamedTuple(NamedTuple):
183 class BadNamedTuple(NamedTuple):
176 a: str
184 a: str
177
185
178 def __getitem__(self, key):
186 def __getitem__(self, key):
179 return None
187 return None
180
188
181 good = GoodNamedTuple(a="x")
189 good = GoodNamedTuple(a="x")
182 bad = BadNamedTuple(a="x")
190 bad = BadNamedTuple(a="x")
183
191
184 context = limited(data=good)
192 context = limited(data=good)
185 assert guarded_eval("data[0]", context) == "x"
193 assert guarded_eval("data[0]", context) == "x"
186
194
187 context = limited(data=bad)
195 context = limited(data=bad)
188 with pytest.raises(GuardRejection):
196 with pytest.raises(GuardRejection):
189 guarded_eval("data[0]", context)
197 guarded_eval("data[0]", context)
190
198
191
199
192 def test_dict():
200 def test_dict():
193 context = limited(data={"a": 1, "b": {"x": 2}, ("x", "y"): 3})
201 context = limited(data={"a": 1, "b": {"x": 2}, ("x", "y"): 3})
194 assert guarded_eval('data["a"]', context) == 1
202 assert guarded_eval('data["a"]', context) == 1
195 assert guarded_eval('data["b"]', context) == {"x": 2}
203 assert guarded_eval('data["b"]', context) == {"x": 2}
196 assert guarded_eval('data["b"]["x"]', context) == 2
204 assert guarded_eval('data["b"]["x"]', context) == 2
197 assert guarded_eval('data["x", "y"]', context) == 3
205 assert guarded_eval('data["x", "y"]', context) == 3
198
206
199 assert guarded_eval("data.keys", context)
207 assert guarded_eval("data.keys", context)
200
208
201
209
202 def test_set():
210 def test_set():
203 context = limited(data={"a", "b"})
211 context = limited(data={"a", "b"})
204 assert guarded_eval("data.difference", context)
212 assert guarded_eval("data.difference", context)
205
213
206
214
207 def test_list():
215 def test_list():
208 context = limited(data=[1, 2, 3])
216 context = limited(data=[1, 2, 3])
209 assert guarded_eval("data[1]", context) == 2
217 assert guarded_eval("data[1]", context) == 2
210 assert guarded_eval("data.copy", context)
218 assert guarded_eval("data.copy", context)
211
219
212
220
213 def test_dict_literal():
221 def test_dict_literal():
214 context = limited()
222 context = limited()
215 assert guarded_eval("{}", context) == {}
223 assert guarded_eval("{}", context) == {}
216 assert guarded_eval('{"a": 1}', context) == {"a": 1}
224 assert guarded_eval('{"a": 1}', context) == {"a": 1}
217
225
218
226
219 def test_list_literal():
227 def test_list_literal():
220 context = limited()
228 context = limited()
221 assert guarded_eval("[]", context) == []
229 assert guarded_eval("[]", context) == []
222 assert guarded_eval('[1, "a"]', context) == [1, "a"]
230 assert guarded_eval('[1, "a"]', context) == [1, "a"]
223
231
224
232
225 def test_set_literal():
233 def test_set_literal():
226 context = limited()
234 context = limited()
227 assert guarded_eval("set()", context) == set()
235 assert guarded_eval("set()", context) == set()
228 assert guarded_eval('{"a"}', context) == {"a"}
236 assert guarded_eval('{"a"}', context) == {"a"}
229
237
230
238
231 def test_evaluates_if_expression():
239 def test_evaluates_if_expression():
232 context = limited()
240 context = limited()
233 assert guarded_eval("2 if True else 3", context) == 2
241 assert guarded_eval("2 if True else 3", context) == 2
234 assert guarded_eval("4 if False else 5", context) == 5
242 assert guarded_eval("4 if False else 5", context) == 5
235
243
236
244
237 def test_object():
245 def test_object():
238 obj = object()
246 obj = object()
239 context = limited(obj=obj)
247 context = limited(obj=obj)
240 assert guarded_eval("obj.__dir__", context) == obj.__dir__
248 assert guarded_eval("obj.__dir__", context) == obj.__dir__
241
249
242
250
243 @pytest.mark.parametrize(
251 @pytest.mark.parametrize(
244 "code,expected",
252 "code,expected",
245 [
253 [
246 ["int.numerator", int.numerator],
254 ["int.numerator", int.numerator],
247 ["float.is_integer", float.is_integer],
255 ["float.is_integer", float.is_integer],
248 ["complex.real", complex.real],
256 ["complex.real", complex.real],
249 ],
257 ],
250 )
258 )
251 def test_number_attributes(code, expected):
259 def test_number_attributes(code, expected):
252 assert guarded_eval(code, limited()) == expected
260 assert guarded_eval(code, limited()) == expected
253
261
254
262
255 def test_method_descriptor():
263 def test_method_descriptor():
256 context = limited()
264 context = limited()
257 assert guarded_eval("list.copy.__name__", context) == "copy"
265 assert guarded_eval("list.copy.__name__", context) == "copy"
258
266
259
267
260 class HeapType:
268 class HeapType:
261 pass
269 pass
262
270
263
271
264 class CallCreatesHeapType:
272 class CallCreatesHeapType:
265 def __call__(self) -> HeapType:
273 def __call__(self) -> HeapType:
266 return HeapType()
274 return HeapType()
267
275
268
276
269 class CallCreatesBuiltin:
277 class CallCreatesBuiltin:
270 def __call__(self) -> frozenset:
278 def __call__(self) -> frozenset:
271 return frozenset()
279 return frozenset()
272
280
273
281
274 class HasStaticMethod:
282 class HasStaticMethod:
275 @staticmethod
283 @staticmethod
276 def static_method() -> HeapType:
284 def static_method() -> HeapType:
277 return HeapType()
285 return HeapType()
278
286
279
287
280 class InitReturnsFrozenset:
288 class InitReturnsFrozenset:
281 def __new__(self) -> frozenset: # type:ignore[misc]
289 def __new__(self) -> frozenset: # type:ignore[misc]
282 return frozenset()
290 return frozenset()
283
291
284
292
285 class StringAnnotation:
293 class StringAnnotation:
286 def heap(self) -> "HeapType":
294 def heap(self) -> "HeapType":
287 return HeapType()
295 return HeapType()
288
296
289 def copy(self) -> "StringAnnotation":
297 def copy(self) -> "StringAnnotation":
290 return StringAnnotation()
298 return StringAnnotation()
291
299
292
300
293 CustomIntType = NewType("CustomIntType", int)
301 CustomIntType = NewType("CustomIntType", int)
294 CustomHeapType = NewType("CustomHeapType", HeapType)
302 CustomHeapType = NewType("CustomHeapType", HeapType)
295 IntTypeAlias = TypeAliasType("IntTypeAlias", int)
303 IntTypeAlias = TypeAliasType("IntTypeAlias", int)
296 HeapTypeAlias = TypeAliasType("HeapTypeAlias", HeapType)
304 HeapTypeAlias = TypeAliasType("HeapTypeAlias", HeapType)
297
305
298
306
299 class SpecialTyping:
307 class SpecialTyping:
300 def custom_int_type(self) -> CustomIntType:
308 def custom_int_type(self) -> CustomIntType:
301 return CustomIntType(1)
309 return CustomIntType(1)
302
310
303 def custom_heap_type(self) -> CustomHeapType:
311 def custom_heap_type(self) -> CustomHeapType:
304 return CustomHeapType(HeapType())
312 return CustomHeapType(HeapType())
305
313
306 # TODO: remove type:ignore comment once mypy
314 # TODO: remove type:ignore comment once mypy
307 # supports explicit calls to `TypeAliasType`, see:
315 # supports explicit calls to `TypeAliasType`, see:
308 # https://github.com/python/mypy/issues/16614
316 # https://github.com/python/mypy/issues/16614
309 def int_type_alias(self) -> IntTypeAlias: # type:ignore[valid-type]
317 def int_type_alias(self) -> IntTypeAlias: # type:ignore[valid-type]
310 return 1
318 return 1
311
319
312 def heap_type_alias(self) -> HeapTypeAlias: # type:ignore[valid-type]
320 def heap_type_alias(self) -> HeapTypeAlias: # type:ignore[valid-type]
313 return 1
321 return 1
314
322
315 def literal(self) -> Literal[False]:
323 def literal(self) -> Literal[False]:
316 return False
324 return False
317
325
318 def self(self) -> Self:
326 def self(self) -> Self:
319 return self
327 return self
320
328
321
329
322 @pytest.mark.parametrize(
330 @pytest.mark.parametrize(
323 "data,good,expected,equality",
331 "data,good,expected,equality",
324 [
332 [
325 [[1, 2, 3], "data.index(2)", 1, True],
333 [[1, 2, 3], "data.index(2)", 1, True],
326 [{"a": 1}, "data.keys().isdisjoint({})", True, True],
334 [{"a": 1}, "data.keys().isdisjoint({})", True, True],
327 [StringAnnotation(), "data.heap()", HeapType, False],
335 [StringAnnotation(), "data.heap()", HeapType, False],
328 [StringAnnotation(), "data.copy()", StringAnnotation, False],
336 [StringAnnotation(), "data.copy()", StringAnnotation, False],
329 # test cases for `__call__`
337 # test cases for `__call__`
330 [CallCreatesHeapType(), "data()", HeapType, False],
338 [CallCreatesHeapType(), "data()", HeapType, False],
331 [CallCreatesBuiltin(), "data()", frozenset, False],
339 [CallCreatesBuiltin(), "data()", frozenset, False],
332 # Test cases for `__init__`
340 # Test cases for `__init__`
333 [HeapType, "data()", HeapType, False],
341 [HeapType, "data()", HeapType, False],
334 [InitReturnsFrozenset, "data()", frozenset, False],
342 [InitReturnsFrozenset, "data()", frozenset, False],
335 [HeapType(), "data.__class__()", HeapType, False],
343 [HeapType(), "data.__class__()", HeapType, False],
336 # supported special cases for typing
344 # supported special cases for typing
337 [SpecialTyping(), "data.custom_int_type()", int, False],
345 [SpecialTyping(), "data.custom_int_type()", int, False],
338 [SpecialTyping(), "data.custom_heap_type()", HeapType, False],
346 [SpecialTyping(), "data.custom_heap_type()", HeapType, False],
339 [SpecialTyping(), "data.int_type_alias()", int, False],
347 [SpecialTyping(), "data.int_type_alias()", int, False],
340 [SpecialTyping(), "data.heap_type_alias()", HeapType, False],
348 [SpecialTyping(), "data.heap_type_alias()", HeapType, False],
341 [SpecialTyping(), "data.self()", SpecialTyping, False],
349 [SpecialTyping(), "data.self()", SpecialTyping, False],
342 [SpecialTyping(), "data.literal()", False, True],
350 [SpecialTyping(), "data.literal()", False, True],
343 # test cases for static methods
351 # test cases for static methods
344 [HasStaticMethod, "data.static_method()", HeapType, False],
352 [HasStaticMethod, "data.static_method()", HeapType, False],
345 ],
353 ],
346 )
354 )
347 def test_evaluates_calls(data, good, expected, equality):
355 def test_evaluates_calls(data, good, expected, equality):
348 context = limited(data=data, HeapType=HeapType, StringAnnotation=StringAnnotation)
356 context = limited(data=data, HeapType=HeapType, StringAnnotation=StringAnnotation)
349 value = guarded_eval(good, context)
357 value = guarded_eval(good, context)
350 if equality:
358 if equality:
351 assert value == expected
359 assert value == expected
352 else:
360 else:
353 assert isinstance(value, expected)
361 assert isinstance(value, expected)
354
362
355
363
356 @pytest.mark.parametrize(
364 @pytest.mark.parametrize(
357 "data,bad",
365 "data,bad",
358 [
366 [
359 [[1, 2, 3], "data.append(4)"],
367 [[1, 2, 3], "data.append(4)"],
360 [{"a": 1}, "data.update()"],
368 [{"a": 1}, "data.update()"],
361 ],
369 ],
362 )
370 )
363 def test_rejects_calls_with_side_effects(data, bad):
371 def test_rejects_calls_with_side_effects(data, bad):
364 context = limited(data=data)
372 context = limited(data=data)
365
373
366 with pytest.raises(GuardRejection):
374 with pytest.raises(GuardRejection):
367 guarded_eval(bad, context)
375 guarded_eval(bad, context)
368
376
369
377
370 @pytest.mark.parametrize(
378 @pytest.mark.parametrize(
371 "code,expected",
379 "code,expected",
372 [
380 [
373 ["(1\n+\n1)", 2],
381 ["(1\n+\n1)", 2],
374 ["list(range(10))[-1:]", [9]],
382 ["list(range(10))[-1:]", [9]],
375 ["list(range(20))[3:-2:3]", [3, 6, 9, 12, 15]],
383 ["list(range(20))[3:-2:3]", [3, 6, 9, 12, 15]],
376 ],
384 ],
377 )
385 )
378 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
386 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
379 def test_evaluates_complex_cases(code, expected, context):
387 def test_evaluates_complex_cases(code, expected, context):
380 assert guarded_eval(code, context()) == expected
388 assert guarded_eval(code, context()) == expected
381
389
382
390
383 @pytest.mark.parametrize(
391 @pytest.mark.parametrize(
384 "code,expected",
392 "code,expected",
385 [
393 [
386 ["1", 1],
394 ["1", 1],
387 ["1.0", 1.0],
395 ["1.0", 1.0],
388 ["0xdeedbeef", 0xDEEDBEEF],
396 ["0xdeedbeef", 0xDEEDBEEF],
389 ["True", True],
397 ["True", True],
390 ["None", None],
398 ["None", None],
391 ["{}", {}],
399 ["{}", {}],
392 ["[]", []],
400 ["[]", []],
393 ],
401 ],
394 )
402 )
395 @pytest.mark.parametrize("context", MINIMAL_OR_HIGHER)
403 @pytest.mark.parametrize("context", MINIMAL_OR_HIGHER)
396 def test_evaluates_literals(code, expected, context):
404 def test_evaluates_literals(code, expected, context):
397 assert guarded_eval(code, context()) == expected
405 assert guarded_eval(code, context()) == expected
398
406
399
407
400 @pytest.mark.parametrize(
408 @pytest.mark.parametrize(
401 "code,expected",
409 "code,expected",
402 [
410 [
403 ["-5", -5],
411 ["-5", -5],
404 ["+5", +5],
412 ["+5", +5],
405 ["~5", -6],
413 ["~5", -6],
406 ],
414 ],
407 )
415 )
408 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
416 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
409 def test_evaluates_unary_operations(code, expected, context):
417 def test_evaluates_unary_operations(code, expected, context):
410 assert guarded_eval(code, context()) == expected
418 assert guarded_eval(code, context()) == expected
411
419
412
420
413 @pytest.mark.parametrize(
421 @pytest.mark.parametrize(
414 "code,expected",
422 "code,expected",
415 [
423 [
416 ["1 + 1", 2],
424 ["1 + 1", 2],
417 ["3 - 1", 2],
425 ["3 - 1", 2],
418 ["2 * 3", 6],
426 ["2 * 3", 6],
419 ["5 // 2", 2],
427 ["5 // 2", 2],
420 ["5 / 2", 2.5],
428 ["5 / 2", 2.5],
421 ["5**2", 25],
429 ["5**2", 25],
422 ["2 >> 1", 1],
430 ["2 >> 1", 1],
423 ["2 << 1", 4],
431 ["2 << 1", 4],
424 ["1 | 2", 3],
432 ["1 | 2", 3],
425 ["1 & 1", 1],
433 ["1 & 1", 1],
426 ["1 & 2", 0],
434 ["1 & 2", 0],
427 ],
435 ],
428 )
436 )
429 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
437 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
430 def test_evaluates_binary_operations(code, expected, context):
438 def test_evaluates_binary_operations(code, expected, context):
431 assert guarded_eval(code, context()) == expected
439 assert guarded_eval(code, context()) == expected
432
440
433
441
434 @pytest.mark.parametrize(
442 @pytest.mark.parametrize(
435 "code,expected",
443 "code,expected",
436 [
444 [
437 ["2 > 1", True],
445 ["2 > 1", True],
438 ["2 < 1", False],
446 ["2 < 1", False],
439 ["2 <= 1", False],
447 ["2 <= 1", False],
440 ["2 <= 2", True],
448 ["2 <= 2", True],
441 ["1 >= 2", False],
449 ["1 >= 2", False],
442 ["2 >= 2", True],
450 ["2 >= 2", True],
443 ["2 == 2", True],
451 ["2 == 2", True],
444 ["1 == 2", False],
452 ["1 == 2", False],
445 ["1 != 2", True],
453 ["1 != 2", True],
446 ["1 != 1", False],
454 ["1 != 1", False],
447 ["1 < 4 < 3", False],
455 ["1 < 4 < 3", False],
448 ["(1 < 4) < 3", True],
456 ["(1 < 4) < 3", True],
449 ["4 > 3 > 2 > 1", True],
457 ["4 > 3 > 2 > 1", True],
450 ["4 > 3 > 2 > 9", False],
458 ["4 > 3 > 2 > 9", False],
451 ["1 < 2 < 3 < 4", True],
459 ["1 < 2 < 3 < 4", True],
452 ["9 < 2 < 3 < 4", False],
460 ["9 < 2 < 3 < 4", False],
453 ["1 < 2 > 1 > 0 > -1 < 1", True],
461 ["1 < 2 > 1 > 0 > -1 < 1", True],
454 ["1 in [1] in [[1]]", True],
462 ["1 in [1] in [[1]]", True],
455 ["1 in [1] in [[2]]", False],
463 ["1 in [1] in [[2]]", False],
456 ["1 in [1]", True],
464 ["1 in [1]", True],
457 ["0 in [1]", False],
465 ["0 in [1]", False],
458 ["1 not in [1]", False],
466 ["1 not in [1]", False],
459 ["0 not in [1]", True],
467 ["0 not in [1]", True],
460 ["True is True", True],
468 ["True is True", True],
461 ["False is False", True],
469 ["False is False", True],
462 ["True is False", False],
470 ["True is False", False],
463 ["True is not True", False],
471 ["True is not True", False],
464 ["False is not True", True],
472 ["False is not True", True],
465 ],
473 ],
466 )
474 )
467 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
475 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
468 def test_evaluates_comparisons(code, expected, context):
476 def test_evaluates_comparisons(code, expected, context):
469 assert guarded_eval(code, context()) == expected
477 assert guarded_eval(code, context()) == expected
470
478
471
479
472 def test_guards_comparisons():
480 def test_guards_comparisons():
473 class GoodEq(int):
481 class GoodEq(int):
474 pass
482 pass
475
483
476 class BadEq(int):
484 class BadEq(int):
477 def __eq__(self, other):
485 def __eq__(self, other):
478 assert False
486 assert False
479
487
480 context = limited(bad=BadEq(1), good=GoodEq(1))
488 context = limited(bad=BadEq(1), good=GoodEq(1))
481
489
482 with pytest.raises(GuardRejection):
490 with pytest.raises(GuardRejection):
483 guarded_eval("bad == 1", context)
491 guarded_eval("bad == 1", context)
484
492
485 with pytest.raises(GuardRejection):
493 with pytest.raises(GuardRejection):
486 guarded_eval("bad != 1", context)
494 guarded_eval("bad != 1", context)
487
495
488 with pytest.raises(GuardRejection):
496 with pytest.raises(GuardRejection):
489 guarded_eval("1 == bad", context)
497 guarded_eval("1 == bad", context)
490
498
491 with pytest.raises(GuardRejection):
499 with pytest.raises(GuardRejection):
492 guarded_eval("1 != bad", context)
500 guarded_eval("1 != bad", context)
493
501
494 assert guarded_eval("good == 1", context) is True
502 assert guarded_eval("good == 1", context) is True
495 assert guarded_eval("good != 1", context) is False
503 assert guarded_eval("good != 1", context) is False
496 assert guarded_eval("1 == good", context) is True
504 assert guarded_eval("1 == good", context) is True
497 assert guarded_eval("1 != good", context) is False
505 assert guarded_eval("1 != good", context) is False
498
506
499
507
500 def test_guards_unary_operations():
508 def test_guards_unary_operations():
501 class GoodOp(int):
509 class GoodOp(int):
502 pass
510 pass
503
511
504 class BadOpInv(int):
512 class BadOpInv(int):
505 def __inv__(self, other):
513 def __inv__(self, other):
506 assert False
514 assert False
507
515
508 class BadOpInverse(int):
516 class BadOpInverse(int):
509 def __inv__(self, other):
517 def __inv__(self, other):
510 assert False
518 assert False
511
519
512 context = limited(good=GoodOp(1), bad1=BadOpInv(1), bad2=BadOpInverse(1))
520 context = limited(good=GoodOp(1), bad1=BadOpInv(1), bad2=BadOpInverse(1))
513
521
514 with pytest.raises(GuardRejection):
522 with pytest.raises(GuardRejection):
515 guarded_eval("~bad1", context)
523 guarded_eval("~bad1", context)
516
524
517 with pytest.raises(GuardRejection):
525 with pytest.raises(GuardRejection):
518 guarded_eval("~bad2", context)
526 guarded_eval("~bad2", context)
519
527
520
528
521 def test_guards_binary_operations():
529 def test_guards_binary_operations():
522 class GoodOp(int):
530 class GoodOp(int):
523 pass
531 pass
524
532
525 class BadOp(int):
533 class BadOp(int):
526 def __add__(self, other):
534 def __add__(self, other):
527 assert False
535 assert False
528
536
529 context = limited(good=GoodOp(1), bad=BadOp(1))
537 context = limited(good=GoodOp(1), bad=BadOp(1))
530
538
531 with pytest.raises(GuardRejection):
539 with pytest.raises(GuardRejection):
532 guarded_eval("1 + bad", context)
540 guarded_eval("1 + bad", context)
533
541
534 with pytest.raises(GuardRejection):
542 with pytest.raises(GuardRejection):
535 guarded_eval("bad + 1", context)
543 guarded_eval("bad + 1", context)
536
544
537 assert guarded_eval("good + 1", context) == 2
545 assert guarded_eval("good + 1", context) == 2
538 assert guarded_eval("1 + good", context) == 2
546 assert guarded_eval("1 + good", context) == 2
539
547
540
548
541 def test_guards_attributes():
549 def test_guards_attributes():
542 class GoodAttr(float):
550 class GoodAttr(float):
543 pass
551 pass
544
552
545 class BadAttr1(float):
553 class BadAttr1(float):
546 def __getattr__(self, key):
554 def __getattr__(self, key):
547 assert False
555 assert False
548
556
549 class BadAttr2(float):
557 class BadAttr2(float):
550 def __getattribute__(self, key):
558 def __getattribute__(self, key):
551 assert False
559 assert False
552
560
553 context = limited(good=GoodAttr(0.5), bad1=BadAttr1(0.5), bad2=BadAttr2(0.5))
561 context = limited(good=GoodAttr(0.5), bad1=BadAttr1(0.5), bad2=BadAttr2(0.5))
554
562
555 with pytest.raises(GuardRejection):
563 with pytest.raises(GuardRejection):
556 guarded_eval("bad1.as_integer_ratio", context)
564 guarded_eval("bad1.as_integer_ratio", context)
557
565
558 with pytest.raises(GuardRejection):
566 with pytest.raises(GuardRejection):
559 guarded_eval("bad2.as_integer_ratio", context)
567 guarded_eval("bad2.as_integer_ratio", context)
560
568
561 assert guarded_eval("good.as_integer_ratio()", context) == (1, 2)
569 assert guarded_eval("good.as_integer_ratio()", context) == (1, 2)
562
570
563
571
564 @pytest.mark.parametrize("context", MINIMAL_OR_HIGHER)
572 @pytest.mark.parametrize("context", MINIMAL_OR_HIGHER)
565 def test_access_builtins(context):
573 def test_access_builtins(context):
566 assert guarded_eval("round", context()) == round
574 assert guarded_eval("round", context()) == round
567
575
568
576
569 def test_access_builtins_fails():
577 def test_access_builtins_fails():
570 context = limited()
578 context = limited()
571 with pytest.raises(NameError):
579 with pytest.raises(NameError):
572 guarded_eval("this_is_not_builtin", context)
580 guarded_eval("this_is_not_builtin", context)
573
581
574
582
575 def test_rejects_forbidden():
583 def test_rejects_forbidden():
576 context = forbidden()
584 context = forbidden()
577 with pytest.raises(GuardRejection):
585 with pytest.raises(GuardRejection):
578 guarded_eval("1", context)
586 guarded_eval("1", context)
579
587
580
588
581 def test_guards_locals_and_globals():
589 def test_guards_locals_and_globals():
582 context = EvaluationContext(
590 context = EvaluationContext(
583 locals={"local_a": "a"}, globals={"global_b": "b"}, evaluation="minimal"
591 locals={"local_a": "a"}, globals={"global_b": "b"}, evaluation="minimal"
584 )
592 )
585
593
586 with pytest.raises(GuardRejection):
594 with pytest.raises(GuardRejection):
587 guarded_eval("local_a", context)
595 guarded_eval("local_a", context)
588
596
589 with pytest.raises(GuardRejection):
597 with pytest.raises(GuardRejection):
590 guarded_eval("global_b", context)
598 guarded_eval("global_b", context)
591
599
592
600
593 def test_access_locals_and_globals():
601 def test_access_locals_and_globals():
594 context = EvaluationContext(
602 context = EvaluationContext(
595 locals={"local_a": "a"}, globals={"global_b": "b"}, evaluation="limited"
603 locals={"local_a": "a"}, globals={"global_b": "b"}, evaluation="limited"
596 )
604 )
597 assert guarded_eval("local_a", context) == "a"
605 assert guarded_eval("local_a", context) == "a"
598 assert guarded_eval("global_b", context) == "b"
606 assert guarded_eval("global_b", context) == "b"
599
607
600
608
601 @pytest.mark.parametrize(
609 @pytest.mark.parametrize(
602 "code",
610 "code",
603 ["def func(): pass", "class C: pass", "x = 1", "x += 1", "del x", "import ast"],
611 ["def func(): pass", "class C: pass", "x = 1", "x += 1", "del x", "import ast"],
604 )
612 )
605 @pytest.mark.parametrize("context", [minimal(), limited(), unsafe()])
613 @pytest.mark.parametrize("context", [minimal(), limited(), unsafe()])
606 def test_rejects_side_effect_syntax(code, context):
614 def test_rejects_side_effect_syntax(code, context):
607 with pytest.raises(SyntaxError):
615 with pytest.raises(SyntaxError):
608 guarded_eval(code, context)
616 guarded_eval(code, context)
609
617
610
618
611 def test_subscript():
619 def test_subscript():
612 context = EvaluationContext(
620 context = EvaluationContext(
613 locals={}, globals={}, evaluation="limited", in_subscript=True
621 locals={}, globals={}, evaluation="limited", in_subscript=True
614 )
622 )
615 empty_slice = slice(None, None, None)
623 empty_slice = slice(None, None, None)
616 assert guarded_eval("", context) == tuple()
624 assert guarded_eval("", context) == tuple()
617 assert guarded_eval(":", context) == empty_slice
625 assert guarded_eval(":", context) == empty_slice
618 assert guarded_eval("1:2:3", context) == slice(1, 2, 3)
626 assert guarded_eval("1:2:3", context) == slice(1, 2, 3)
619 assert guarded_eval(':, "a"', context) == (empty_slice, "a")
627 assert guarded_eval(':, "a"', context) == (empty_slice, "a")
620
628
621
629
622 def test_unbind_method():
630 def test_unbind_method():
623 class X(list):
631 class X(list):
624 def index(self, k):
632 def index(self, k):
625 return "CUSTOM"
633 return "CUSTOM"
626
634
627 x = X()
635 x = X()
628 assert _unbind_method(x.index) is X.index
636 assert _unbind_method(x.index) is X.index
629 assert _unbind_method([].index) is list.index
637 assert _unbind_method([].index) is list.index
630 assert _unbind_method(list.index) is None
638 assert _unbind_method(list.index) is None
631
639
632
640
633 def test_assumption_instance_attr_do_not_matter():
641 def test_assumption_instance_attr_do_not_matter():
634 """This is semi-specified in Python documentation.
642 """This is semi-specified in Python documentation.
635
643
636 However, since the specification says 'not guaranteed
644 However, since the specification says 'not guaranteed
637 to work' rather than 'is forbidden to work', future
645 to work' rather than 'is forbidden to work', future
638 versions could invalidate this assumptions. This test
646 versions could invalidate this assumptions. This test
639 is meant to catch such a change if it ever comes true.
647 is meant to catch such a change if it ever comes true.
640 """
648 """
641
649
642 class T:
650 class T:
643 def __getitem__(self, k):
651 def __getitem__(self, k):
644 return "a"
652 return "a"
645
653
646 def __getattr__(self, k):
654 def __getattr__(self, k):
647 return "a"
655 return "a"
648
656
649 def f(self):
657 def f(self):
650 return "b"
658 return "b"
651
659
652 t = T()
660 t = T()
653 t.__getitem__ = f
661 t.__getitem__ = f
654 t.__getattr__ = f
662 t.__getattr__ = f
655 assert t[1] == "a"
663 assert t[1] == "a"
656 assert t[1] == "a"
664 assert t[1] == "a"
657
665
658
666
659 def test_assumption_named_tuples_share_getitem():
667 def test_assumption_named_tuples_share_getitem():
660 """Check assumption on named tuples sharing __getitem__"""
668 """Check assumption on named tuples sharing __getitem__"""
661 from typing import NamedTuple
669 from typing import NamedTuple
662
670
663 class A(NamedTuple):
671 class A(NamedTuple):
664 pass
672 pass
665
673
666 class B(NamedTuple):
674 class B(NamedTuple):
667 pass
675 pass
668
676
669 assert A.__getitem__ == B.__getitem__
677 assert A.__getitem__ == B.__getitem__
670
678
671
679
672 @dec.skip_without("numpy")
680 @dec.skip_without("numpy")
673 def test_module_access():
681 def test_module_access():
674 import numpy
682 import numpy
675
683
676 context = limited(numpy=numpy)
684 context = limited(numpy=numpy)
677 assert guarded_eval("numpy.linalg.norm", context) == numpy.linalg.norm
685 assert guarded_eval("numpy.linalg.norm", context) == numpy.linalg.norm
678
686
679 context = minimal(numpy=numpy)
687 context = minimal(numpy=numpy)
680 with pytest.raises(GuardRejection):
688 with pytest.raises(GuardRejection):
681 guarded_eval("np.linalg.norm", context)
689 guarded_eval("np.linalg.norm", context)
General Comments 0
You need to be logged in to leave comments. Login now