##// END OF EJS Templates
Infer type for `__init__` calls (including `__new__` mods)
krassowski -
Show More
@@ -1,760 +1,793 b''
1 from inspect import signature, Signature
1 from inspect import isclass, signature, Signature
2 from typing import (
2 from typing import (
3 Any,
3 Any,
4 Callable,
4 Callable,
5 Dict,
5 Dict,
6 Set,
6 Set,
7 Sequence,
7 Sequence,
8 Tuple,
8 Tuple,
9 NamedTuple,
9 NamedTuple,
10 Type,
10 Type,
11 Literal,
11 Literal,
12 Union,
12 Union,
13 TYPE_CHECKING,
13 TYPE_CHECKING,
14 )
14 )
15 import ast
15 import ast
16 import builtins
16 import builtins
17 import collections
17 import collections
18 import operator
18 import operator
19 import sys
19 import sys
20 from functools import cached_property
20 from functools import cached_property
21 from dataclasses import dataclass, field
21 from dataclasses import dataclass, field
22 from types import MethodDescriptorType, ModuleType
22 from types import MethodDescriptorType, ModuleType
23
23
24 from IPython.utils.docs import GENERATING_DOCUMENTATION
24 from IPython.utils.docs import GENERATING_DOCUMENTATION
25 from IPython.utils.decorators import undoc
25 from IPython.utils.decorators import undoc
26
26
27
27
28 if TYPE_CHECKING or GENERATING_DOCUMENTATION:
28 if TYPE_CHECKING or GENERATING_DOCUMENTATION:
29 from typing_extensions import Protocol
29 from typing_extensions import Protocol
30 else:
30 else:
31 # do not require on runtime
31 # do not require on runtime
32 Protocol = object # requires Python >=3.8
32 Protocol = object # requires Python >=3.8
33
33
34
34
35 @undoc
35 @undoc
36 class HasGetItem(Protocol):
36 class HasGetItem(Protocol):
37 def __getitem__(self, key) -> None:
37 def __getitem__(self, key) -> None:
38 ...
38 ...
39
39
40
40
41 @undoc
41 @undoc
42 class InstancesHaveGetItem(Protocol):
42 class InstancesHaveGetItem(Protocol):
43 def __call__(self, *args, **kwargs) -> HasGetItem:
43 def __call__(self, *args, **kwargs) -> HasGetItem:
44 ...
44 ...
45
45
46
46
47 @undoc
47 @undoc
48 class HasGetAttr(Protocol):
48 class HasGetAttr(Protocol):
49 def __getattr__(self, key) -> None:
49 def __getattr__(self, key) -> None:
50 ...
50 ...
51
51
52
52
53 @undoc
53 @undoc
54 class DoesNotHaveGetAttr(Protocol):
54 class DoesNotHaveGetAttr(Protocol):
55 pass
55 pass
56
56
57
57
58 # By default `__getattr__` is not explicitly implemented on most objects
58 # By default `__getattr__` is not explicitly implemented on most objects
59 MayHaveGetattr = Union[HasGetAttr, DoesNotHaveGetAttr]
59 MayHaveGetattr = Union[HasGetAttr, DoesNotHaveGetAttr]
60
60
61
61
62 def _unbind_method(func: Callable) -> Union[Callable, None]:
62 def _unbind_method(func: Callable) -> Union[Callable, None]:
63 """Get unbound method for given bound method.
63 """Get unbound method for given bound method.
64
64
65 Returns None if cannot get unbound method, or method is already unbound.
65 Returns None if cannot get unbound method, or method is already unbound.
66 """
66 """
67 owner = getattr(func, "__self__", None)
67 owner = getattr(func, "__self__", None)
68 owner_class = type(owner)
68 owner_class = type(owner)
69 name = getattr(func, "__name__", None)
69 name = getattr(func, "__name__", None)
70 instance_dict_overrides = getattr(owner, "__dict__", None)
70 instance_dict_overrides = getattr(owner, "__dict__", None)
71 if (
71 if (
72 owner is not None
72 owner is not None
73 and name
73 and name
74 and (
74 and (
75 not instance_dict_overrides
75 not instance_dict_overrides
76 or (instance_dict_overrides and name not in instance_dict_overrides)
76 or (instance_dict_overrides and name not in instance_dict_overrides)
77 )
77 )
78 ):
78 ):
79 return getattr(owner_class, name)
79 return getattr(owner_class, name)
80 return None
80 return None
81
81
82
82
83 @undoc
83 @undoc
84 @dataclass
84 @dataclass
85 class EvaluationPolicy:
85 class EvaluationPolicy:
86 """Definition of evaluation policy."""
86 """Definition of evaluation policy."""
87
87
88 allow_locals_access: bool = False
88 allow_locals_access: bool = False
89 allow_globals_access: bool = False
89 allow_globals_access: bool = False
90 allow_item_access: bool = False
90 allow_item_access: bool = False
91 allow_attr_access: bool = False
91 allow_attr_access: bool = False
92 allow_builtins_access: bool = False
92 allow_builtins_access: bool = False
93 allow_all_operations: bool = False
93 allow_all_operations: bool = False
94 allow_any_calls: bool = False
94 allow_any_calls: bool = False
95 allowed_calls: Set[Callable] = field(default_factory=set)
95 allowed_calls: Set[Callable] = field(default_factory=set)
96
96
97 def can_get_item(self, value, item):
97 def can_get_item(self, value, item):
98 return self.allow_item_access
98 return self.allow_item_access
99
99
100 def can_get_attr(self, value, attr):
100 def can_get_attr(self, value, attr):
101 return self.allow_attr_access
101 return self.allow_attr_access
102
102
103 def can_operate(self, dunders: Tuple[str, ...], a, b=None):
103 def can_operate(self, dunders: Tuple[str, ...], a, b=None):
104 if self.allow_all_operations:
104 if self.allow_all_operations:
105 return True
105 return True
106
106
107 def can_call(self, func):
107 def can_call(self, func):
108 if self.allow_any_calls:
108 if self.allow_any_calls:
109 return True
109 return True
110
110
111 if func in self.allowed_calls:
111 if func in self.allowed_calls:
112 return True
112 return True
113
113
114 owner_method = _unbind_method(func)
114 owner_method = _unbind_method(func)
115
115
116 if owner_method and owner_method in self.allowed_calls:
116 if owner_method and owner_method in self.allowed_calls:
117 return True
117 return True
118
118
119
119
120 def _get_external(module_name: str, access_path: Sequence[str]):
120 def _get_external(module_name: str, access_path: Sequence[str]):
121 """Get value from external module given a dotted access path.
121 """Get value from external module given a dotted access path.
122
122
123 Raises:
123 Raises:
124 * `KeyError` if module is removed not found, and
124 * `KeyError` if module is removed not found, and
125 * `AttributeError` if acess path does not match an exported object
125 * `AttributeError` if acess path does not match an exported object
126 """
126 """
127 member_type = sys.modules[module_name]
127 member_type = sys.modules[module_name]
128 for attr in access_path:
128 for attr in access_path:
129 member_type = getattr(member_type, attr)
129 member_type = getattr(member_type, attr)
130 return member_type
130 return member_type
131
131
132
132
133 def _has_original_dunder_external(
133 def _has_original_dunder_external(
134 value,
134 value,
135 module_name: str,
135 module_name: str,
136 access_path: Sequence[str],
136 access_path: Sequence[str],
137 method_name: str,
137 method_name: str,
138 ):
138 ):
139 if module_name not in sys.modules:
139 if module_name not in sys.modules:
140 # LBYLB as it is faster
140 # LBYLB as it is faster
141 return False
141 return False
142 try:
142 try:
143 member_type = _get_external(module_name, access_path)
143 member_type = _get_external(module_name, access_path)
144 value_type = type(value)
144 value_type = type(value)
145 if type(value) == member_type:
145 if type(value) == member_type:
146 return True
146 return True
147 if method_name == "__getattribute__":
147 if method_name == "__getattribute__":
148 # we have to short-circuit here due to an unresolved issue in
148 # we have to short-circuit here due to an unresolved issue in
149 # `isinstance` implementation: https://bugs.python.org/issue32683
149 # `isinstance` implementation: https://bugs.python.org/issue32683
150 return False
150 return False
151 if isinstance(value, member_type):
151 if isinstance(value, member_type):
152 method = getattr(value_type, method_name, None)
152 method = getattr(value_type, method_name, None)
153 member_method = getattr(member_type, method_name, None)
153 member_method = getattr(member_type, method_name, None)
154 if member_method == method:
154 if member_method == method:
155 return True
155 return True
156 except (AttributeError, KeyError):
156 except (AttributeError, KeyError):
157 return False
157 return False
158
158
159
159
160 def _has_original_dunder(
160 def _has_original_dunder(
161 value, allowed_types, allowed_methods, allowed_external, method_name
161 value, allowed_types, allowed_methods, allowed_external, method_name
162 ):
162 ):
163 # note: Python ignores `__getattr__`/`__getitem__` on instances,
163 # note: Python ignores `__getattr__`/`__getitem__` on instances,
164 # we only need to check at class level
164 # we only need to check at class level
165 value_type = type(value)
165 value_type = type(value)
166
166
167 # strict type check passes β†’ no need to check method
167 # strict type check passes β†’ no need to check method
168 if value_type in allowed_types:
168 if value_type in allowed_types:
169 return True
169 return True
170
170
171 method = getattr(value_type, method_name, None)
171 method = getattr(value_type, method_name, None)
172
172
173 if method is None:
173 if method is None:
174 return None
174 return None
175
175
176 if method in allowed_methods:
176 if method in allowed_methods:
177 return True
177 return True
178
178
179 for module_name, *access_path in allowed_external:
179 for module_name, *access_path in allowed_external:
180 if _has_original_dunder_external(value, module_name, access_path, method_name):
180 if _has_original_dunder_external(value, module_name, access_path, method_name):
181 return True
181 return True
182
182
183 return False
183 return False
184
184
185
185
186 @undoc
186 @undoc
187 @dataclass
187 @dataclass
188 class SelectivePolicy(EvaluationPolicy):
188 class SelectivePolicy(EvaluationPolicy):
189 allowed_getitem: Set[InstancesHaveGetItem] = field(default_factory=set)
189 allowed_getitem: Set[InstancesHaveGetItem] = field(default_factory=set)
190 allowed_getitem_external: Set[Tuple[str, ...]] = field(default_factory=set)
190 allowed_getitem_external: Set[Tuple[str, ...]] = field(default_factory=set)
191
191
192 allowed_getattr: Set[MayHaveGetattr] = field(default_factory=set)
192 allowed_getattr: Set[MayHaveGetattr] = field(default_factory=set)
193 allowed_getattr_external: Set[Tuple[str, ...]] = field(default_factory=set)
193 allowed_getattr_external: Set[Tuple[str, ...]] = field(default_factory=set)
194
194
195 allowed_operations: Set = field(default_factory=set)
195 allowed_operations: Set = field(default_factory=set)
196 allowed_operations_external: Set[Tuple[str, ...]] = field(default_factory=set)
196 allowed_operations_external: Set[Tuple[str, ...]] = field(default_factory=set)
197
197
198 _operation_methods_cache: Dict[str, Set[Callable]] = field(
198 _operation_methods_cache: Dict[str, Set[Callable]] = field(
199 default_factory=dict, init=False
199 default_factory=dict, init=False
200 )
200 )
201
201
202 def can_get_attr(self, value, attr):
202 def can_get_attr(self, value, attr):
203 has_original_attribute = _has_original_dunder(
203 has_original_attribute = _has_original_dunder(
204 value,
204 value,
205 allowed_types=self.allowed_getattr,
205 allowed_types=self.allowed_getattr,
206 allowed_methods=self._getattribute_methods,
206 allowed_methods=self._getattribute_methods,
207 allowed_external=self.allowed_getattr_external,
207 allowed_external=self.allowed_getattr_external,
208 method_name="__getattribute__",
208 method_name="__getattribute__",
209 )
209 )
210 has_original_attr = _has_original_dunder(
210 has_original_attr = _has_original_dunder(
211 value,
211 value,
212 allowed_types=self.allowed_getattr,
212 allowed_types=self.allowed_getattr,
213 allowed_methods=self._getattr_methods,
213 allowed_methods=self._getattr_methods,
214 allowed_external=self.allowed_getattr_external,
214 allowed_external=self.allowed_getattr_external,
215 method_name="__getattr__",
215 method_name="__getattr__",
216 )
216 )
217
217
218 accept = False
218 accept = False
219
219
220 # Many objects do not have `__getattr__`, this is fine.
220 # Many objects do not have `__getattr__`, this is fine.
221 if has_original_attr is None and has_original_attribute:
221 if has_original_attr is None and has_original_attribute:
222 accept = True
222 accept = True
223 else:
223 else:
224 # Accept objects without modifications to `__getattr__` and `__getattribute__`
224 # Accept objects without modifications to `__getattr__` and `__getattribute__`
225 accept = has_original_attr and has_original_attribute
225 accept = has_original_attr and has_original_attribute
226
226
227 if accept:
227 if accept:
228 # We still need to check for overriden properties.
228 # We still need to check for overriden properties.
229
229
230 value_class = type(value)
230 value_class = type(value)
231 if not hasattr(value_class, attr):
231 if not hasattr(value_class, attr):
232 return True
232 return True
233
233
234 class_attr_val = getattr(value_class, attr)
234 class_attr_val = getattr(value_class, attr)
235 is_property = isinstance(class_attr_val, property)
235 is_property = isinstance(class_attr_val, property)
236
236
237 if not is_property:
237 if not is_property:
238 return True
238 return True
239
239
240 # Properties in allowed types are ok (although we do not include any
240 # Properties in allowed types are ok (although we do not include any
241 # properties in our default allow list currently).
241 # properties in our default allow list currently).
242 if type(value) in self.allowed_getattr:
242 if type(value) in self.allowed_getattr:
243 return True # pragma: no cover
243 return True # pragma: no cover
244
244
245 # Properties in subclasses of allowed types may be ok if not changed
245 # Properties in subclasses of allowed types may be ok if not changed
246 for module_name, *access_path in self.allowed_getattr_external:
246 for module_name, *access_path in self.allowed_getattr_external:
247 try:
247 try:
248 external_class = _get_external(module_name, access_path)
248 external_class = _get_external(module_name, access_path)
249 external_class_attr_val = getattr(external_class, attr)
249 external_class_attr_val = getattr(external_class, attr)
250 except (KeyError, AttributeError):
250 except (KeyError, AttributeError):
251 return False # pragma: no cover
251 return False # pragma: no cover
252 return class_attr_val == external_class_attr_val
252 return class_attr_val == external_class_attr_val
253
253
254 return False
254 return False
255
255
256 def can_get_item(self, value, item):
256 def can_get_item(self, value, item):
257 """Allow accessing `__getiitem__` of allow-listed instances unless it was not modified."""
257 """Allow accessing `__getiitem__` of allow-listed instances unless it was not modified."""
258 return _has_original_dunder(
258 return _has_original_dunder(
259 value,
259 value,
260 allowed_types=self.allowed_getitem,
260 allowed_types=self.allowed_getitem,
261 allowed_methods=self._getitem_methods,
261 allowed_methods=self._getitem_methods,
262 allowed_external=self.allowed_getitem_external,
262 allowed_external=self.allowed_getitem_external,
263 method_name="__getitem__",
263 method_name="__getitem__",
264 )
264 )
265
265
266 def can_operate(self, dunders: Tuple[str, ...], a, b=None):
266 def can_operate(self, dunders: Tuple[str, ...], a, b=None):
267 objects = [a]
267 objects = [a]
268 if b is not None:
268 if b is not None:
269 objects.append(b)
269 objects.append(b)
270 return all(
270 return all(
271 [
271 [
272 _has_original_dunder(
272 _has_original_dunder(
273 obj,
273 obj,
274 allowed_types=self.allowed_operations,
274 allowed_types=self.allowed_operations,
275 allowed_methods=self._operator_dunder_methods(dunder),
275 allowed_methods=self._operator_dunder_methods(dunder),
276 allowed_external=self.allowed_operations_external,
276 allowed_external=self.allowed_operations_external,
277 method_name=dunder,
277 method_name=dunder,
278 )
278 )
279 for dunder in dunders
279 for dunder in dunders
280 for obj in objects
280 for obj in objects
281 ]
281 ]
282 )
282 )
283
283
284 def _operator_dunder_methods(self, dunder: str) -> Set[Callable]:
284 def _operator_dunder_methods(self, dunder: str) -> Set[Callable]:
285 if dunder not in self._operation_methods_cache:
285 if dunder not in self._operation_methods_cache:
286 self._operation_methods_cache[dunder] = self._safe_get_methods(
286 self._operation_methods_cache[dunder] = self._safe_get_methods(
287 self.allowed_operations, dunder
287 self.allowed_operations, dunder
288 )
288 )
289 return self._operation_methods_cache[dunder]
289 return self._operation_methods_cache[dunder]
290
290
291 @cached_property
291 @cached_property
292 def _getitem_methods(self) -> Set[Callable]:
292 def _getitem_methods(self) -> Set[Callable]:
293 return self._safe_get_methods(self.allowed_getitem, "__getitem__")
293 return self._safe_get_methods(self.allowed_getitem, "__getitem__")
294
294
295 @cached_property
295 @cached_property
296 def _getattr_methods(self) -> Set[Callable]:
296 def _getattr_methods(self) -> Set[Callable]:
297 return self._safe_get_methods(self.allowed_getattr, "__getattr__")
297 return self._safe_get_methods(self.allowed_getattr, "__getattr__")
298
298
299 @cached_property
299 @cached_property
300 def _getattribute_methods(self) -> Set[Callable]:
300 def _getattribute_methods(self) -> Set[Callable]:
301 return self._safe_get_methods(self.allowed_getattr, "__getattribute__")
301 return self._safe_get_methods(self.allowed_getattr, "__getattribute__")
302
302
303 def _safe_get_methods(self, classes, name) -> Set[Callable]:
303 def _safe_get_methods(self, classes, name) -> Set[Callable]:
304 return {
304 return {
305 method
305 method
306 for class_ in classes
306 for class_ in classes
307 for method in [getattr(class_, name, None)]
307 for method in [getattr(class_, name, None)]
308 if method
308 if method
309 }
309 }
310
310
311
311
312 class _DummyNamedTuple(NamedTuple):
312 class _DummyNamedTuple(NamedTuple):
313 """Used internally to retrieve methods of named tuple instance."""
313 """Used internally to retrieve methods of named tuple instance."""
314
314
315
315
316 class EvaluationContext(NamedTuple):
316 class EvaluationContext(NamedTuple):
317 #: Local namespace
317 #: Local namespace
318 locals: dict
318 locals: dict
319 #: Global namespace
319 #: Global namespace
320 globals: dict
320 globals: dict
321 #: Evaluation policy identifier
321 #: Evaluation policy identifier
322 evaluation: Literal[
322 evaluation: Literal[
323 "forbidden", "minimal", "limited", "unsafe", "dangerous"
323 "forbidden", "minimal", "limited", "unsafe", "dangerous"
324 ] = "forbidden"
324 ] = "forbidden"
325 #: Whether the evalution of code takes place inside of a subscript.
325 #: Whether the evalution of code takes place inside of a subscript.
326 #: Useful for evaluating ``:-1, 'col'`` in ``df[:-1, 'col']``.
326 #: Useful for evaluating ``:-1, 'col'`` in ``df[:-1, 'col']``.
327 in_subscript: bool = False
327 in_subscript: bool = False
328
328
329
329
330 class _IdentitySubscript:
330 class _IdentitySubscript:
331 """Returns the key itself when item is requested via subscript."""
331 """Returns the key itself when item is requested via subscript."""
332
332
333 def __getitem__(self, key):
333 def __getitem__(self, key):
334 return key
334 return key
335
335
336
336
337 IDENTITY_SUBSCRIPT = _IdentitySubscript()
337 IDENTITY_SUBSCRIPT = _IdentitySubscript()
338 SUBSCRIPT_MARKER = "__SUBSCRIPT_SENTINEL__"
338 SUBSCRIPT_MARKER = "__SUBSCRIPT_SENTINEL__"
339 UNKNOWN_SIGNATURE = Signature()
339 UNKNOWN_SIGNATURE = Signature()
340 NOT_EVALUATED = object()
340
341
341
342
342 class GuardRejection(Exception):
343 class GuardRejection(Exception):
343 """Exception raised when guard rejects evaluation attempt."""
344 """Exception raised when guard rejects evaluation attempt."""
344
345
345 pass
346 pass
346
347
347
348
348 def guarded_eval(code: str, context: EvaluationContext):
349 def guarded_eval(code: str, context: EvaluationContext):
349 """Evaluate provided code in the evaluation context.
350 """Evaluate provided code in the evaluation context.
350
351
351 If evaluation policy given by context is set to ``forbidden``
352 If evaluation policy given by context is set to ``forbidden``
352 no evaluation will be performed; if it is set to ``dangerous``
353 no evaluation will be performed; if it is set to ``dangerous``
353 standard :func:`eval` will be used; finally, for any other,
354 standard :func:`eval` will be used; finally, for any other,
354 policy :func:`eval_node` will be called on parsed AST.
355 policy :func:`eval_node` will be called on parsed AST.
355 """
356 """
356 locals_ = context.locals
357 locals_ = context.locals
357
358
358 if context.evaluation == "forbidden":
359 if context.evaluation == "forbidden":
359 raise GuardRejection("Forbidden mode")
360 raise GuardRejection("Forbidden mode")
360
361
361 # note: not using `ast.literal_eval` as it does not implement
362 # note: not using `ast.literal_eval` as it does not implement
362 # getitem at all, for example it fails on simple `[0][1]`
363 # getitem at all, for example it fails on simple `[0][1]`
363
364
364 if context.in_subscript:
365 if context.in_subscript:
365 # syntatic sugar for ellipsis (:) is only available in susbcripts
366 # syntatic sugar for ellipsis (:) is only available in susbcripts
366 # so we need to trick the ast parser into thinking that we have
367 # so we need to trick the ast parser into thinking that we have
367 # a subscript, but we need to be able to later recognise that we did
368 # a subscript, but we need to be able to later recognise that we did
368 # it so we can ignore the actual __getitem__ operation
369 # it so we can ignore the actual __getitem__ operation
369 if not code:
370 if not code:
370 return tuple()
371 return tuple()
371 locals_ = locals_.copy()
372 locals_ = locals_.copy()
372 locals_[SUBSCRIPT_MARKER] = IDENTITY_SUBSCRIPT
373 locals_[SUBSCRIPT_MARKER] = IDENTITY_SUBSCRIPT
373 code = SUBSCRIPT_MARKER + "[" + code + "]"
374 code = SUBSCRIPT_MARKER + "[" + code + "]"
374 context = EvaluationContext(**{**context._asdict(), **{"locals": locals_}})
375 context = EvaluationContext(**{**context._asdict(), **{"locals": locals_}})
375
376
376 if context.evaluation == "dangerous":
377 if context.evaluation == "dangerous":
377 return eval(code, context.globals, context.locals)
378 return eval(code, context.globals, context.locals)
378
379
379 expression = ast.parse(code, mode="eval")
380 expression = ast.parse(code, mode="eval")
380
381
381 return eval_node(expression, context)
382 return eval_node(expression, context)
382
383
383
384
384 BINARY_OP_DUNDERS: Dict[Type[ast.operator], Tuple[str]] = {
385 BINARY_OP_DUNDERS: Dict[Type[ast.operator], Tuple[str]] = {
385 ast.Add: ("__add__",),
386 ast.Add: ("__add__",),
386 ast.Sub: ("__sub__",),
387 ast.Sub: ("__sub__",),
387 ast.Mult: ("__mul__",),
388 ast.Mult: ("__mul__",),
388 ast.Div: ("__truediv__",),
389 ast.Div: ("__truediv__",),
389 ast.FloorDiv: ("__floordiv__",),
390 ast.FloorDiv: ("__floordiv__",),
390 ast.Mod: ("__mod__",),
391 ast.Mod: ("__mod__",),
391 ast.Pow: ("__pow__",),
392 ast.Pow: ("__pow__",),
392 ast.LShift: ("__lshift__",),
393 ast.LShift: ("__lshift__",),
393 ast.RShift: ("__rshift__",),
394 ast.RShift: ("__rshift__",),
394 ast.BitOr: ("__or__",),
395 ast.BitOr: ("__or__",),
395 ast.BitXor: ("__xor__",),
396 ast.BitXor: ("__xor__",),
396 ast.BitAnd: ("__and__",),
397 ast.BitAnd: ("__and__",),
397 ast.MatMult: ("__matmul__",),
398 ast.MatMult: ("__matmul__",),
398 }
399 }
399
400
400 COMP_OP_DUNDERS: Dict[Type[ast.cmpop], Tuple[str, ...]] = {
401 COMP_OP_DUNDERS: Dict[Type[ast.cmpop], Tuple[str, ...]] = {
401 ast.Eq: ("__eq__",),
402 ast.Eq: ("__eq__",),
402 ast.NotEq: ("__ne__", "__eq__"),
403 ast.NotEq: ("__ne__", "__eq__"),
403 ast.Lt: ("__lt__", "__gt__"),
404 ast.Lt: ("__lt__", "__gt__"),
404 ast.LtE: ("__le__", "__ge__"),
405 ast.LtE: ("__le__", "__ge__"),
405 ast.Gt: ("__gt__", "__lt__"),
406 ast.Gt: ("__gt__", "__lt__"),
406 ast.GtE: ("__ge__", "__le__"),
407 ast.GtE: ("__ge__", "__le__"),
407 ast.In: ("__contains__",),
408 ast.In: ("__contains__",),
408 # Note: ast.Is, ast.IsNot, ast.NotIn are handled specially
409 # Note: ast.Is, ast.IsNot, ast.NotIn are handled specially
409 }
410 }
410
411
411 UNARY_OP_DUNDERS: Dict[Type[ast.unaryop], Tuple[str, ...]] = {
412 UNARY_OP_DUNDERS: Dict[Type[ast.unaryop], Tuple[str, ...]] = {
412 ast.USub: ("__neg__",),
413 ast.USub: ("__neg__",),
413 ast.UAdd: ("__pos__",),
414 ast.UAdd: ("__pos__",),
414 # we have to check both __inv__ and __invert__!
415 # we have to check both __inv__ and __invert__!
415 ast.Invert: ("__invert__", "__inv__"),
416 ast.Invert: ("__invert__", "__inv__"),
416 ast.Not: ("__not__",),
417 ast.Not: ("__not__",),
417 }
418 }
418
419
419
420
420 class Duck:
421 class Duck:
421 """A dummy class used to create objects of other classes without calling their ``__init__``"""
422 """A dummy class used to create objects of other classes without calling their ``__init__``"""
422
423
423
424
424 def _find_dunder(node_op, dunders) -> Union[Tuple[str, ...], None]:
425 def _find_dunder(node_op, dunders) -> Union[Tuple[str, ...], None]:
425 dunder = None
426 dunder = None
426 for op, candidate_dunder in dunders.items():
427 for op, candidate_dunder in dunders.items():
427 if isinstance(node_op, op):
428 if isinstance(node_op, op):
428 dunder = candidate_dunder
429 dunder = candidate_dunder
429 return dunder
430 return dunder
430
431
431
432
432 def eval_node(node: Union[ast.AST, None], context: EvaluationContext):
433 def eval_node(node: Union[ast.AST, None], context: EvaluationContext):
433 """Evaluate AST node in provided context.
434 """Evaluate AST node in provided context.
434
435
435 Applies evaluation restrictions defined in the context. Currently does not support evaluation of functions with keyword arguments.
436 Applies evaluation restrictions defined in the context. Currently does not support evaluation of functions with keyword arguments.
436
437
437 Does not evaluate actions that always have side effects:
438 Does not evaluate actions that always have side effects:
438
439
439 - class definitions (``class sth: ...``)
440 - class definitions (``class sth: ...``)
440 - function definitions (``def sth: ...``)
441 - function definitions (``def sth: ...``)
441 - variable assignments (``x = 1``)
442 - variable assignments (``x = 1``)
442 - augmented assignments (``x += 1``)
443 - augmented assignments (``x += 1``)
443 - deletions (``del x``)
444 - deletions (``del x``)
444
445
445 Does not evaluate operations which do not return values:
446 Does not evaluate operations which do not return values:
446
447
447 - assertions (``assert x``)
448 - assertions (``assert x``)
448 - pass (``pass``)
449 - pass (``pass``)
449 - imports (``import x``)
450 - imports (``import x``)
450 - control flow:
451 - control flow:
451
452
452 - conditionals (``if x:``) except for ternary IfExp (``a if x else b``)
453 - conditionals (``if x:``) except for ternary IfExp (``a if x else b``)
453 - loops (``for`` and ``while``)
454 - loops (``for`` and ``while``)
454 - exception handling
455 - exception handling
455
456
456 The purpose of this function is to guard against unwanted side-effects;
457 The purpose of this function is to guard against unwanted side-effects;
457 it does not give guarantees on protection from malicious code execution.
458 it does not give guarantees on protection from malicious code execution.
458 """
459 """
459 policy = EVALUATION_POLICIES[context.evaluation]
460 policy = EVALUATION_POLICIES[context.evaluation]
460 if node is None:
461 if node is None:
461 return None
462 return None
462 if isinstance(node, ast.Expression):
463 if isinstance(node, ast.Expression):
463 return eval_node(node.body, context)
464 return eval_node(node.body, context)
464 if isinstance(node, ast.BinOp):
465 if isinstance(node, ast.BinOp):
465 left = eval_node(node.left, context)
466 left = eval_node(node.left, context)
466 right = eval_node(node.right, context)
467 right = eval_node(node.right, context)
467 dunders = _find_dunder(node.op, BINARY_OP_DUNDERS)
468 dunders = _find_dunder(node.op, BINARY_OP_DUNDERS)
468 if dunders:
469 if dunders:
469 if policy.can_operate(dunders, left, right):
470 if policy.can_operate(dunders, left, right):
470 return getattr(left, dunders[0])(right)
471 return getattr(left, dunders[0])(right)
471 else:
472 else:
472 raise GuardRejection(
473 raise GuardRejection(
473 f"Operation (`{dunders}`) for",
474 f"Operation (`{dunders}`) for",
474 type(left),
475 type(left),
475 f"not allowed in {context.evaluation} mode",
476 f"not allowed in {context.evaluation} mode",
476 )
477 )
477 if isinstance(node, ast.Compare):
478 if isinstance(node, ast.Compare):
478 left = eval_node(node.left, context)
479 left = eval_node(node.left, context)
479 all_true = True
480 all_true = True
480 negate = False
481 negate = False
481 for op, right in zip(node.ops, node.comparators):
482 for op, right in zip(node.ops, node.comparators):
482 right = eval_node(right, context)
483 right = eval_node(right, context)
483 dunder = None
484 dunder = None
484 dunders = _find_dunder(op, COMP_OP_DUNDERS)
485 dunders = _find_dunder(op, COMP_OP_DUNDERS)
485 if not dunders:
486 if not dunders:
486 if isinstance(op, ast.NotIn):
487 if isinstance(op, ast.NotIn):
487 dunders = COMP_OP_DUNDERS[ast.In]
488 dunders = COMP_OP_DUNDERS[ast.In]
488 negate = True
489 negate = True
489 if isinstance(op, ast.Is):
490 if isinstance(op, ast.Is):
490 dunder = "is_"
491 dunder = "is_"
491 if isinstance(op, ast.IsNot):
492 if isinstance(op, ast.IsNot):
492 dunder = "is_"
493 dunder = "is_"
493 negate = True
494 negate = True
494 if not dunder and dunders:
495 if not dunder and dunders:
495 dunder = dunders[0]
496 dunder = dunders[0]
496 if dunder:
497 if dunder:
497 a, b = (right, left) if dunder == "__contains__" else (left, right)
498 a, b = (right, left) if dunder == "__contains__" else (left, right)
498 if dunder == "is_" or dunders and policy.can_operate(dunders, a, b):
499 if dunder == "is_" or dunders and policy.can_operate(dunders, a, b):
499 result = getattr(operator, dunder)(a, b)
500 result = getattr(operator, dunder)(a, b)
500 if negate:
501 if negate:
501 result = not result
502 result = not result
502 if not result:
503 if not result:
503 all_true = False
504 all_true = False
504 left = right
505 left = right
505 else:
506 else:
506 raise GuardRejection(
507 raise GuardRejection(
507 f"Comparison (`{dunder}`) for",
508 f"Comparison (`{dunder}`) for",
508 type(left),
509 type(left),
509 f"not allowed in {context.evaluation} mode",
510 f"not allowed in {context.evaluation} mode",
510 )
511 )
511 else:
512 else:
512 raise ValueError(
513 raise ValueError(
513 f"Comparison `{dunder}` not supported"
514 f"Comparison `{dunder}` not supported"
514 ) # pragma: no cover
515 ) # pragma: no cover
515 return all_true
516 return all_true
516 if isinstance(node, ast.Constant):
517 if isinstance(node, ast.Constant):
517 return node.value
518 return node.value
518 if isinstance(node, ast.Tuple):
519 if isinstance(node, ast.Tuple):
519 return tuple(eval_node(e, context) for e in node.elts)
520 return tuple(eval_node(e, context) for e in node.elts)
520 if isinstance(node, ast.List):
521 if isinstance(node, ast.List):
521 return [eval_node(e, context) for e in node.elts]
522 return [eval_node(e, context) for e in node.elts]
522 if isinstance(node, ast.Set):
523 if isinstance(node, ast.Set):
523 return {eval_node(e, context) for e in node.elts}
524 return {eval_node(e, context) for e in node.elts}
524 if isinstance(node, ast.Dict):
525 if isinstance(node, ast.Dict):
525 return dict(
526 return dict(
526 zip(
527 zip(
527 [eval_node(k, context) for k in node.keys],
528 [eval_node(k, context) for k in node.keys],
528 [eval_node(v, context) for v in node.values],
529 [eval_node(v, context) for v in node.values],
529 )
530 )
530 )
531 )
531 if isinstance(node, ast.Slice):
532 if isinstance(node, ast.Slice):
532 return slice(
533 return slice(
533 eval_node(node.lower, context),
534 eval_node(node.lower, context),
534 eval_node(node.upper, context),
535 eval_node(node.upper, context),
535 eval_node(node.step, context),
536 eval_node(node.step, context),
536 )
537 )
537 if isinstance(node, ast.UnaryOp):
538 if isinstance(node, ast.UnaryOp):
538 value = eval_node(node.operand, context)
539 value = eval_node(node.operand, context)
539 dunders = _find_dunder(node.op, UNARY_OP_DUNDERS)
540 dunders = _find_dunder(node.op, UNARY_OP_DUNDERS)
540 if dunders:
541 if dunders:
541 if policy.can_operate(dunders, value):
542 if policy.can_operate(dunders, value):
542 return getattr(value, dunders[0])()
543 return getattr(value, dunders[0])()
543 else:
544 else:
544 raise GuardRejection(
545 raise GuardRejection(
545 f"Operation (`{dunders}`) for",
546 f"Operation (`{dunders}`) for",
546 type(value),
547 type(value),
547 f"not allowed in {context.evaluation} mode",
548 f"not allowed in {context.evaluation} mode",
548 )
549 )
549 if isinstance(node, ast.Subscript):
550 if isinstance(node, ast.Subscript):
550 value = eval_node(node.value, context)
551 value = eval_node(node.value, context)
551 slice_ = eval_node(node.slice, context)
552 slice_ = eval_node(node.slice, context)
552 if policy.can_get_item(value, slice_):
553 if policy.can_get_item(value, slice_):
553 return value[slice_]
554 return value[slice_]
554 raise GuardRejection(
555 raise GuardRejection(
555 "Subscript access (`__getitem__`) for",
556 "Subscript access (`__getitem__`) for",
556 type(value), # not joined to avoid calling `repr`
557 type(value), # not joined to avoid calling `repr`
557 f" not allowed in {context.evaluation} mode",
558 f" not allowed in {context.evaluation} mode",
558 )
559 )
559 if isinstance(node, ast.Name):
560 if isinstance(node, ast.Name):
560 if policy.allow_locals_access and node.id in context.locals:
561 if policy.allow_locals_access and node.id in context.locals:
561 return context.locals[node.id]
562 return context.locals[node.id]
562 if policy.allow_globals_access and node.id in context.globals:
563 if policy.allow_globals_access and node.id in context.globals:
563 return context.globals[node.id]
564 return context.globals[node.id]
564 if policy.allow_builtins_access and hasattr(builtins, node.id):
565 if policy.allow_builtins_access and hasattr(builtins, node.id):
565 # note: do not use __builtins__, it is implementation detail of cPython
566 # note: do not use __builtins__, it is implementation detail of cPython
566 return getattr(builtins, node.id)
567 return getattr(builtins, node.id)
567 if not policy.allow_globals_access and not policy.allow_locals_access:
568 if not policy.allow_globals_access and not policy.allow_locals_access:
568 raise GuardRejection(
569 raise GuardRejection(
569 f"Namespace access not allowed in {context.evaluation} mode"
570 f"Namespace access not allowed in {context.evaluation} mode"
570 )
571 )
571 else:
572 else:
572 raise NameError(f"{node.id} not found in locals, globals, nor builtins")
573 raise NameError(f"{node.id} not found in locals, globals, nor builtins")
573 if isinstance(node, ast.Attribute):
574 if isinstance(node, ast.Attribute):
574 value = eval_node(node.value, context)
575 value = eval_node(node.value, context)
575 if policy.can_get_attr(value, node.attr):
576 if policy.can_get_attr(value, node.attr):
576 return getattr(value, node.attr)
577 return getattr(value, node.attr)
577 raise GuardRejection(
578 raise GuardRejection(
578 "Attribute access (`__getattr__`) for",
579 "Attribute access (`__getattr__`) for",
579 type(value), # not joined to avoid calling `repr`
580 type(value), # not joined to avoid calling `repr`
580 f"not allowed in {context.evaluation} mode",
581 f"not allowed in {context.evaluation} mode",
581 )
582 )
582 if isinstance(node, ast.IfExp):
583 if isinstance(node, ast.IfExp):
583 test = eval_node(node.test, context)
584 test = eval_node(node.test, context)
584 if test:
585 if test:
585 return eval_node(node.body, context)
586 return eval_node(node.body, context)
586 else:
587 else:
587 return eval_node(node.orelse, context)
588 return eval_node(node.orelse, context)
588 if isinstance(node, ast.Call):
589 if isinstance(node, ast.Call):
589 func = eval_node(node.func, context)
590 func = eval_node(node.func, context)
590 if policy.can_call(func) and not node.keywords:
591 if policy.can_call(func) and not node.keywords:
591 args = [eval_node(arg, context) for arg in node.args]
592 args = [eval_node(arg, context) for arg in node.args]
592 return func(*args)
593 return func(*args)
593 try:
594 if isclass(func):
594 sig = signature(func)
595 # this code path gets entered when calling class e.g. `MyClass()`
595 except ValueError:
596 # or `my_instance.__class__()` - in both cases `func` is `MyClass`.
596 sig = UNKNOWN_SIGNATURE
597 # Should return `MyClass` if `__new__` is not overridden,
597 # if annotation was not stringized, or it was stringized
598 # otherwise whatever `__new__` return type is.
598 # but resolved by signature call we know the return type
599 overridden_return_type = _eval_return_type(
599 not_empty = sig.return_annotation is not Signature.empty
600 func.__new__, policy, node, context
600 not_stringized = not isinstance(sig.return_annotation, str)
601 )
601 if not_empty and not_stringized:
602 if overridden_return_type is not NOT_EVALUATED:
602 duck = Duck()
603 return overridden_return_type
603 # if allow-listed builtin is on type annotation, instantiate it
604 return _create_duck_for_type(func)
604 if policy.can_call(sig.return_annotation) and not node.keywords:
605 else:
605 args = [eval_node(arg, context) for arg in node.args]
606 return_type = _eval_return_type(func, policy, node, context)
606 return sig.return_annotation(*args)
607 if return_type is not NOT_EVALUATED:
607 try:
608 return return_type
608 # if custom class is in type annotation, mock it;
609 # this only works for heap types, not builtins
610 duck.__class__ = sig.return_annotation
611 return duck
612 except TypeError:
613 pass
614 raise GuardRejection(
609 raise GuardRejection(
615 "Call for",
610 "Call for",
616 func, # not joined to avoid calling `repr`
611 func, # not joined to avoid calling `repr`
617 f"not allowed in {context.evaluation} mode",
612 f"not allowed in {context.evaluation} mode",
618 )
613 )
619 raise ValueError("Unhandled node", ast.dump(node))
614 raise ValueError("Unhandled node", ast.dump(node))
620
615
621
616
617 def _eval_return_type(func, policy, node, context):
618 """Evaluate return type of a given callable function.
619
620 Returns the built-in type, a duck or NOT_EVALUATED sentinel.
621 """
622 try:
623 sig = signature(func)
624 except ValueError:
625 sig = UNKNOWN_SIGNATURE
626 # if annotation was not stringized, or it was stringized
627 # but resolved by signature call we know the return type
628 not_empty = sig.return_annotation is not Signature.empty
629 not_stringized = not isinstance(sig.return_annotation, str)
630 if not_empty and not_stringized:
631 # if allow-listed builtin is on type annotation, instantiate it
632 if policy.can_call(sig.return_annotation) and not node.keywords:
633 args = [eval_node(arg, context) for arg in node.args]
634 # if custom class is in type annotation, mock it;
635 return sig.return_annotation(*args)
636 return _create_duck_for_type(sig.return_annotation)
637 return NOT_EVALUATED
638
639
640 def _create_duck_for_type(duck_type):
641 """Create an imitation of an object of a given type (a duck).
642
643 Returns the duck or NOT_EVALUATED sentinel if duck could not be created.
644 """
645 duck = Duck()
646 try:
647 # this only works for heap types, not builtins
648 duck.__class__ = duck_type
649 return duck
650 except TypeError:
651 pass
652 return NOT_EVALUATED
653
654
622 SUPPORTED_EXTERNAL_GETITEM = {
655 SUPPORTED_EXTERNAL_GETITEM = {
623 ("pandas", "core", "indexing", "_iLocIndexer"),
656 ("pandas", "core", "indexing", "_iLocIndexer"),
624 ("pandas", "core", "indexing", "_LocIndexer"),
657 ("pandas", "core", "indexing", "_LocIndexer"),
625 ("pandas", "DataFrame"),
658 ("pandas", "DataFrame"),
626 ("pandas", "Series"),
659 ("pandas", "Series"),
627 ("numpy", "ndarray"),
660 ("numpy", "ndarray"),
628 ("numpy", "void"),
661 ("numpy", "void"),
629 }
662 }
630
663
631
664
632 BUILTIN_GETITEM: Set[InstancesHaveGetItem] = {
665 BUILTIN_GETITEM: Set[InstancesHaveGetItem] = {
633 dict,
666 dict,
634 str, # type: ignore[arg-type]
667 str, # type: ignore[arg-type]
635 bytes, # type: ignore[arg-type]
668 bytes, # type: ignore[arg-type]
636 list,
669 list,
637 tuple,
670 tuple,
638 collections.defaultdict,
671 collections.defaultdict,
639 collections.deque,
672 collections.deque,
640 collections.OrderedDict,
673 collections.OrderedDict,
641 collections.ChainMap,
674 collections.ChainMap,
642 collections.UserDict,
675 collections.UserDict,
643 collections.UserList,
676 collections.UserList,
644 collections.UserString, # type: ignore[arg-type]
677 collections.UserString, # type: ignore[arg-type]
645 _DummyNamedTuple,
678 _DummyNamedTuple,
646 _IdentitySubscript,
679 _IdentitySubscript,
647 }
680 }
648
681
649
682
650 def _list_methods(cls, source=None):
683 def _list_methods(cls, source=None):
651 """For use on immutable objects or with methods returning a copy"""
684 """For use on immutable objects or with methods returning a copy"""
652 return [getattr(cls, k) for k in (source if source else dir(cls))]
685 return [getattr(cls, k) for k in (source if source else dir(cls))]
653
686
654
687
655 dict_non_mutating_methods = ("copy", "keys", "values", "items")
688 dict_non_mutating_methods = ("copy", "keys", "values", "items")
656 list_non_mutating_methods = ("copy", "index", "count")
689 list_non_mutating_methods = ("copy", "index", "count")
657 set_non_mutating_methods = set(dir(set)) & set(dir(frozenset))
690 set_non_mutating_methods = set(dir(set)) & set(dir(frozenset))
658
691
659
692
660 dict_keys: Type[collections.abc.KeysView] = type({}.keys())
693 dict_keys: Type[collections.abc.KeysView] = type({}.keys())
661
694
662 NUMERICS = {int, float, complex}
695 NUMERICS = {int, float, complex}
663
696
664 ALLOWED_CALLS = {
697 ALLOWED_CALLS = {
665 bytes,
698 bytes,
666 *_list_methods(bytes),
699 *_list_methods(bytes),
667 dict,
700 dict,
668 *_list_methods(dict, dict_non_mutating_methods),
701 *_list_methods(dict, dict_non_mutating_methods),
669 dict_keys.isdisjoint,
702 dict_keys.isdisjoint,
670 list,
703 list,
671 *_list_methods(list, list_non_mutating_methods),
704 *_list_methods(list, list_non_mutating_methods),
672 set,
705 set,
673 *_list_methods(set, set_non_mutating_methods),
706 *_list_methods(set, set_non_mutating_methods),
674 frozenset,
707 frozenset,
675 *_list_methods(frozenset),
708 *_list_methods(frozenset),
676 range,
709 range,
677 str,
710 str,
678 *_list_methods(str),
711 *_list_methods(str),
679 tuple,
712 tuple,
680 *_list_methods(tuple),
713 *_list_methods(tuple),
681 *NUMERICS,
714 *NUMERICS,
682 *[method for numeric_cls in NUMERICS for method in _list_methods(numeric_cls)],
715 *[method for numeric_cls in NUMERICS for method in _list_methods(numeric_cls)],
683 collections.deque,
716 collections.deque,
684 *_list_methods(collections.deque, list_non_mutating_methods),
717 *_list_methods(collections.deque, list_non_mutating_methods),
685 collections.defaultdict,
718 collections.defaultdict,
686 *_list_methods(collections.defaultdict, dict_non_mutating_methods),
719 *_list_methods(collections.defaultdict, dict_non_mutating_methods),
687 collections.OrderedDict,
720 collections.OrderedDict,
688 *_list_methods(collections.OrderedDict, dict_non_mutating_methods),
721 *_list_methods(collections.OrderedDict, dict_non_mutating_methods),
689 collections.UserDict,
722 collections.UserDict,
690 *_list_methods(collections.UserDict, dict_non_mutating_methods),
723 *_list_methods(collections.UserDict, dict_non_mutating_methods),
691 collections.UserList,
724 collections.UserList,
692 *_list_methods(collections.UserList, list_non_mutating_methods),
725 *_list_methods(collections.UserList, list_non_mutating_methods),
693 collections.UserString,
726 collections.UserString,
694 *_list_methods(collections.UserString, dir(str)),
727 *_list_methods(collections.UserString, dir(str)),
695 collections.Counter,
728 collections.Counter,
696 *_list_methods(collections.Counter, dict_non_mutating_methods),
729 *_list_methods(collections.Counter, dict_non_mutating_methods),
697 collections.Counter.elements,
730 collections.Counter.elements,
698 collections.Counter.most_common,
731 collections.Counter.most_common,
699 }
732 }
700
733
701 BUILTIN_GETATTR: Set[MayHaveGetattr] = {
734 BUILTIN_GETATTR: Set[MayHaveGetattr] = {
702 *BUILTIN_GETITEM,
735 *BUILTIN_GETITEM,
703 set,
736 set,
704 frozenset,
737 frozenset,
705 object,
738 object,
706 type, # `type` handles a lot of generic cases, e.g. numbers as in `int.real`.
739 type, # `type` handles a lot of generic cases, e.g. numbers as in `int.real`.
707 *NUMERICS,
740 *NUMERICS,
708 dict_keys,
741 dict_keys,
709 MethodDescriptorType,
742 MethodDescriptorType,
710 ModuleType,
743 ModuleType,
711 }
744 }
712
745
713
746
714 BUILTIN_OPERATIONS = {*BUILTIN_GETATTR}
747 BUILTIN_OPERATIONS = {*BUILTIN_GETATTR}
715
748
716 EVALUATION_POLICIES = {
749 EVALUATION_POLICIES = {
717 "minimal": EvaluationPolicy(
750 "minimal": EvaluationPolicy(
718 allow_builtins_access=True,
751 allow_builtins_access=True,
719 allow_locals_access=False,
752 allow_locals_access=False,
720 allow_globals_access=False,
753 allow_globals_access=False,
721 allow_item_access=False,
754 allow_item_access=False,
722 allow_attr_access=False,
755 allow_attr_access=False,
723 allowed_calls=set(),
756 allowed_calls=set(),
724 allow_any_calls=False,
757 allow_any_calls=False,
725 allow_all_operations=False,
758 allow_all_operations=False,
726 ),
759 ),
727 "limited": SelectivePolicy(
760 "limited": SelectivePolicy(
728 allowed_getitem=BUILTIN_GETITEM,
761 allowed_getitem=BUILTIN_GETITEM,
729 allowed_getitem_external=SUPPORTED_EXTERNAL_GETITEM,
762 allowed_getitem_external=SUPPORTED_EXTERNAL_GETITEM,
730 allowed_getattr=BUILTIN_GETATTR,
763 allowed_getattr=BUILTIN_GETATTR,
731 allowed_getattr_external={
764 allowed_getattr_external={
732 # pandas Series/Frame implements custom `__getattr__`
765 # pandas Series/Frame implements custom `__getattr__`
733 ("pandas", "DataFrame"),
766 ("pandas", "DataFrame"),
734 ("pandas", "Series"),
767 ("pandas", "Series"),
735 },
768 },
736 allowed_operations=BUILTIN_OPERATIONS,
769 allowed_operations=BUILTIN_OPERATIONS,
737 allow_builtins_access=True,
770 allow_builtins_access=True,
738 allow_locals_access=True,
771 allow_locals_access=True,
739 allow_globals_access=True,
772 allow_globals_access=True,
740 allowed_calls=ALLOWED_CALLS,
773 allowed_calls=ALLOWED_CALLS,
741 ),
774 ),
742 "unsafe": EvaluationPolicy(
775 "unsafe": EvaluationPolicy(
743 allow_builtins_access=True,
776 allow_builtins_access=True,
744 allow_locals_access=True,
777 allow_locals_access=True,
745 allow_globals_access=True,
778 allow_globals_access=True,
746 allow_attr_access=True,
779 allow_attr_access=True,
747 allow_item_access=True,
780 allow_item_access=True,
748 allow_any_calls=True,
781 allow_any_calls=True,
749 allow_all_operations=True,
782 allow_all_operations=True,
750 ),
783 ),
751 }
784 }
752
785
753
786
754 __all__ = [
787 __all__ = [
755 "guarded_eval",
788 "guarded_eval",
756 "eval_node",
789 "eval_node",
757 "GuardRejection",
790 "GuardRejection",
758 "EvaluationContext",
791 "EvaluationContext",
759 "_unbind_method",
792 "_unbind_method",
760 ]
793 ]
@@ -1,602 +1,631 b''
1 from contextlib import contextmanager
1 from contextlib import contextmanager
2 from typing import NamedTuple
2 from typing import NamedTuple
3 from functools import partial
3 from functools import partial
4 from IPython.core.guarded_eval import (
4 from IPython.core.guarded_eval import (
5 EvaluationContext,
5 EvaluationContext,
6 GuardRejection,
6 GuardRejection,
7 guarded_eval,
7 guarded_eval,
8 _unbind_method,
8 _unbind_method,
9 )
9 )
10 from IPython.testing import decorators as dec
10 from IPython.testing import decorators as dec
11 import pytest
11 import pytest
12
12
13
13
14 def create_context(evaluation: str, **kwargs):
14 def create_context(evaluation: str, **kwargs):
15 return EvaluationContext(locals=kwargs, globals={}, evaluation=evaluation)
15 return EvaluationContext(locals=kwargs, globals={}, evaluation=evaluation)
16
16
17
17
18 forbidden = partial(create_context, "forbidden")
18 forbidden = partial(create_context, "forbidden")
19 minimal = partial(create_context, "minimal")
19 minimal = partial(create_context, "minimal")
20 limited = partial(create_context, "limited")
20 limited = partial(create_context, "limited")
21 unsafe = partial(create_context, "unsafe")
21 unsafe = partial(create_context, "unsafe")
22 dangerous = partial(create_context, "dangerous")
22 dangerous = partial(create_context, "dangerous")
23
23
24 LIMITED_OR_HIGHER = [limited, unsafe, dangerous]
24 LIMITED_OR_HIGHER = [limited, unsafe, dangerous]
25 MINIMAL_OR_HIGHER = [minimal, *LIMITED_OR_HIGHER]
25 MINIMAL_OR_HIGHER = [minimal, *LIMITED_OR_HIGHER]
26
26
27
27
28 @contextmanager
28 @contextmanager
29 def module_not_installed(module: str):
29 def module_not_installed(module: str):
30 import sys
30 import sys
31
31
32 try:
32 try:
33 to_restore = sys.modules[module]
33 to_restore = sys.modules[module]
34 del sys.modules[module]
34 del sys.modules[module]
35 except KeyError:
35 except KeyError:
36 to_restore = None
36 to_restore = None
37 try:
37 try:
38 yield
38 yield
39 finally:
39 finally:
40 sys.modules[module] = to_restore
40 sys.modules[module] = to_restore
41
41
42
42
43 def test_external_not_installed():
43 def test_external_not_installed():
44 """
44 """
45 Because attribute check requires checking if object is not of allowed
45 Because attribute check requires checking if object is not of allowed
46 external type, this tests logic for absence of external module.
46 external type, this tests logic for absence of external module.
47 """
47 """
48
48
49 class Custom:
49 class Custom:
50 def __init__(self):
50 def __init__(self):
51 self.test = 1
51 self.test = 1
52
52
53 def __getattr__(self, key):
53 def __getattr__(self, key):
54 return key
54 return key
55
55
56 with module_not_installed("pandas"):
56 with module_not_installed("pandas"):
57 context = limited(x=Custom())
57 context = limited(x=Custom())
58 with pytest.raises(GuardRejection):
58 with pytest.raises(GuardRejection):
59 guarded_eval("x.test", context)
59 guarded_eval("x.test", context)
60
60
61
61
62 @dec.skip_without("pandas")
62 @dec.skip_without("pandas")
63 def test_external_changed_api(monkeypatch):
63 def test_external_changed_api(monkeypatch):
64 """Check that the execution rejects if external API changed paths"""
64 """Check that the execution rejects if external API changed paths"""
65 import pandas as pd
65 import pandas as pd
66
66
67 series = pd.Series([1], index=["a"])
67 series = pd.Series([1], index=["a"])
68
68
69 with monkeypatch.context() as m:
69 with monkeypatch.context() as m:
70 m.delattr(pd, "Series")
70 m.delattr(pd, "Series")
71 context = limited(data=series)
71 context = limited(data=series)
72 with pytest.raises(GuardRejection):
72 with pytest.raises(GuardRejection):
73 guarded_eval("data.iloc[0]", context)
73 guarded_eval("data.iloc[0]", context)
74
74
75
75
76 @dec.skip_without("pandas")
76 @dec.skip_without("pandas")
77 def test_pandas_series_iloc():
77 def test_pandas_series_iloc():
78 import pandas as pd
78 import pandas as pd
79
79
80 series = pd.Series([1], index=["a"])
80 series = pd.Series([1], index=["a"])
81 context = limited(data=series)
81 context = limited(data=series)
82 assert guarded_eval("data.iloc[0]", context) == 1
82 assert guarded_eval("data.iloc[0]", context) == 1
83
83
84
84
85 def test_rejects_custom_properties():
85 def test_rejects_custom_properties():
86 class BadProperty:
86 class BadProperty:
87 @property
87 @property
88 def iloc(self):
88 def iloc(self):
89 return [None]
89 return [None]
90
90
91 series = BadProperty()
91 series = BadProperty()
92 context = limited(data=series)
92 context = limited(data=series)
93
93
94 with pytest.raises(GuardRejection):
94 with pytest.raises(GuardRejection):
95 guarded_eval("data.iloc[0]", context)
95 guarded_eval("data.iloc[0]", context)
96
96
97
97
98 @dec.skip_without("pandas")
98 @dec.skip_without("pandas")
99 def test_accepts_non_overriden_properties():
99 def test_accepts_non_overriden_properties():
100 import pandas as pd
100 import pandas as pd
101
101
102 class GoodProperty(pd.Series):
102 class GoodProperty(pd.Series):
103 pass
103 pass
104
104
105 series = GoodProperty([1], index=["a"])
105 series = GoodProperty([1], index=["a"])
106 context = limited(data=series)
106 context = limited(data=series)
107
107
108 assert guarded_eval("data.iloc[0]", context) == 1
108 assert guarded_eval("data.iloc[0]", context) == 1
109
109
110
110
111 @dec.skip_without("pandas")
111 @dec.skip_without("pandas")
112 def test_pandas_series():
112 def test_pandas_series():
113 import pandas as pd
113 import pandas as pd
114
114
115 context = limited(data=pd.Series([1], index=["a"]))
115 context = limited(data=pd.Series([1], index=["a"]))
116 assert guarded_eval('data["a"]', context) == 1
116 assert guarded_eval('data["a"]', context) == 1
117 with pytest.raises(KeyError):
117 with pytest.raises(KeyError):
118 guarded_eval('data["c"]', context)
118 guarded_eval('data["c"]', context)
119
119
120
120
121 @dec.skip_without("pandas")
121 @dec.skip_without("pandas")
122 def test_pandas_bad_series():
122 def test_pandas_bad_series():
123 import pandas as pd
123 import pandas as pd
124
124
125 class BadItemSeries(pd.Series):
125 class BadItemSeries(pd.Series):
126 def __getitem__(self, key):
126 def __getitem__(self, key):
127 return "CUSTOM_ITEM"
127 return "CUSTOM_ITEM"
128
128
129 class BadAttrSeries(pd.Series):
129 class BadAttrSeries(pd.Series):
130 def __getattr__(self, key):
130 def __getattr__(self, key):
131 return "CUSTOM_ATTR"
131 return "CUSTOM_ATTR"
132
132
133 bad_series = BadItemSeries([1], index=["a"])
133 bad_series = BadItemSeries([1], index=["a"])
134 context = limited(data=bad_series)
134 context = limited(data=bad_series)
135
135
136 with pytest.raises(GuardRejection):
136 with pytest.raises(GuardRejection):
137 guarded_eval('data["a"]', context)
137 guarded_eval('data["a"]', context)
138 with pytest.raises(GuardRejection):
138 with pytest.raises(GuardRejection):
139 guarded_eval('data["c"]', context)
139 guarded_eval('data["c"]', context)
140
140
141 # note: here result is a bit unexpected because
141 # note: here result is a bit unexpected because
142 # pandas `__getattr__` calls `__getitem__`;
142 # pandas `__getattr__` calls `__getitem__`;
143 # FIXME - special case to handle it?
143 # FIXME - special case to handle it?
144 assert guarded_eval("data.a", context) == "CUSTOM_ITEM"
144 assert guarded_eval("data.a", context) == "CUSTOM_ITEM"
145
145
146 context = unsafe(data=bad_series)
146 context = unsafe(data=bad_series)
147 assert guarded_eval('data["a"]', context) == "CUSTOM_ITEM"
147 assert guarded_eval('data["a"]', context) == "CUSTOM_ITEM"
148
148
149 bad_attr_series = BadAttrSeries([1], index=["a"])
149 bad_attr_series = BadAttrSeries([1], index=["a"])
150 context = limited(data=bad_attr_series)
150 context = limited(data=bad_attr_series)
151 assert guarded_eval('data["a"]', context) == 1
151 assert guarded_eval('data["a"]', context) == 1
152 with pytest.raises(GuardRejection):
152 with pytest.raises(GuardRejection):
153 guarded_eval("data.a", context)
153 guarded_eval("data.a", context)
154
154
155
155
156 @dec.skip_without("pandas")
156 @dec.skip_without("pandas")
157 def test_pandas_dataframe_loc():
157 def test_pandas_dataframe_loc():
158 import pandas as pd
158 import pandas as pd
159 from pandas.testing import assert_series_equal
159 from pandas.testing import assert_series_equal
160
160
161 data = pd.DataFrame([{"a": 1}])
161 data = pd.DataFrame([{"a": 1}])
162 context = limited(data=data)
162 context = limited(data=data)
163 assert_series_equal(guarded_eval('data.loc[:, "a"]', context), data["a"])
163 assert_series_equal(guarded_eval('data.loc[:, "a"]', context), data["a"])
164
164
165
165
166 def test_named_tuple():
166 def test_named_tuple():
167 class GoodNamedTuple(NamedTuple):
167 class GoodNamedTuple(NamedTuple):
168 a: str
168 a: str
169 pass
169 pass
170
170
171 class BadNamedTuple(NamedTuple):
171 class BadNamedTuple(NamedTuple):
172 a: str
172 a: str
173
173
174 def __getitem__(self, key):
174 def __getitem__(self, key):
175 return None
175 return None
176
176
177 good = GoodNamedTuple(a="x")
177 good = GoodNamedTuple(a="x")
178 bad = BadNamedTuple(a="x")
178 bad = BadNamedTuple(a="x")
179
179
180 context = limited(data=good)
180 context = limited(data=good)
181 assert guarded_eval("data[0]", context) == "x"
181 assert guarded_eval("data[0]", context) == "x"
182
182
183 context = limited(data=bad)
183 context = limited(data=bad)
184 with pytest.raises(GuardRejection):
184 with pytest.raises(GuardRejection):
185 guarded_eval("data[0]", context)
185 guarded_eval("data[0]", context)
186
186
187
187
188 def test_dict():
188 def test_dict():
189 context = limited(data={"a": 1, "b": {"x": 2}, ("x", "y"): 3})
189 context = limited(data={"a": 1, "b": {"x": 2}, ("x", "y"): 3})
190 assert guarded_eval('data["a"]', context) == 1
190 assert guarded_eval('data["a"]', context) == 1
191 assert guarded_eval('data["b"]', context) == {"x": 2}
191 assert guarded_eval('data["b"]', context) == {"x": 2}
192 assert guarded_eval('data["b"]["x"]', context) == 2
192 assert guarded_eval('data["b"]["x"]', context) == 2
193 assert guarded_eval('data["x", "y"]', context) == 3
193 assert guarded_eval('data["x", "y"]', context) == 3
194
194
195 assert guarded_eval("data.keys", context)
195 assert guarded_eval("data.keys", context)
196
196
197
197
198 def test_set():
198 def test_set():
199 context = limited(data={"a", "b"})
199 context = limited(data={"a", "b"})
200 assert guarded_eval("data.difference", context)
200 assert guarded_eval("data.difference", context)
201
201
202
202
203 def test_list():
203 def test_list():
204 context = limited(data=[1, 2, 3])
204 context = limited(data=[1, 2, 3])
205 assert guarded_eval("data[1]", context) == 2
205 assert guarded_eval("data[1]", context) == 2
206 assert guarded_eval("data.copy", context)
206 assert guarded_eval("data.copy", context)
207
207
208
208
209 def test_dict_literal():
209 def test_dict_literal():
210 context = limited()
210 context = limited()
211 assert guarded_eval("{}", context) == {}
211 assert guarded_eval("{}", context) == {}
212 assert guarded_eval('{"a": 1}', context) == {"a": 1}
212 assert guarded_eval('{"a": 1}', context) == {"a": 1}
213
213
214
214
215 def test_list_literal():
215 def test_list_literal():
216 context = limited()
216 context = limited()
217 assert guarded_eval("[]", context) == []
217 assert guarded_eval("[]", context) == []
218 assert guarded_eval('[1, "a"]', context) == [1, "a"]
218 assert guarded_eval('[1, "a"]', context) == [1, "a"]
219
219
220
220
221 def test_set_literal():
221 def test_set_literal():
222 context = limited()
222 context = limited()
223 assert guarded_eval("set()", context) == set()
223 assert guarded_eval("set()", context) == set()
224 assert guarded_eval('{"a"}', context) == {"a"}
224 assert guarded_eval('{"a"}', context) == {"a"}
225
225
226
226
227 def test_evaluates_if_expression():
227 def test_evaluates_if_expression():
228 context = limited()
228 context = limited()
229 assert guarded_eval("2 if True else 3", context) == 2
229 assert guarded_eval("2 if True else 3", context) == 2
230 assert guarded_eval("4 if False else 5", context) == 5
230 assert guarded_eval("4 if False else 5", context) == 5
231
231
232
232
233 def test_object():
233 def test_object():
234 obj = object()
234 obj = object()
235 context = limited(obj=obj)
235 context = limited(obj=obj)
236 assert guarded_eval("obj.__dir__", context) == obj.__dir__
236 assert guarded_eval("obj.__dir__", context) == obj.__dir__
237
237
238
238
239 @pytest.mark.parametrize(
239 @pytest.mark.parametrize(
240 "code,expected",
240 "code,expected",
241 [
241 [
242 ["int.numerator", int.numerator],
242 ["int.numerator", int.numerator],
243 ["float.is_integer", float.is_integer],
243 ["float.is_integer", float.is_integer],
244 ["complex.real", complex.real],
244 ["complex.real", complex.real],
245 ],
245 ],
246 )
246 )
247 def test_number_attributes(code, expected):
247 def test_number_attributes(code, expected):
248 assert guarded_eval(code, limited()) == expected
248 assert guarded_eval(code, limited()) == expected
249
249
250
250
251 def test_method_descriptor():
251 def test_method_descriptor():
252 context = limited()
252 context = limited()
253 assert guarded_eval("list.copy.__name__", context) == "copy"
253 assert guarded_eval("list.copy.__name__", context) == "copy"
254
254
255
255
256 class HeapType:
256 class HeapType:
257 pass
257 pass
258
258
259
259
260 class CallCreatesHeapType:
260 class CallCreatesHeapType:
261 def __call__(self) -> HeapType:
261 def __call__(self) -> HeapType:
262 return HeapType()
262 return HeapType()
263
263
264
264
265 class CallCreatesBuiltin:
265 class CallCreatesBuiltin:
266 def __call__(self) -> frozenset:
266 def __call__(self) -> frozenset:
267 return frozenset()
267 return frozenset()
268
268
269
269
270 class HasStaticMethod:
271 @staticmethod
272 def static_method() -> HeapType:
273 return HeapType()
274
275
276 class InitReturnsFrozenset:
277 def __new__(self) -> frozenset: # type:ignore[misc]
278 return frozenset()
279
280
270 @pytest.mark.parametrize(
281 @pytest.mark.parametrize(
271 "data,good,bad,expected, equality",
282 "data,good,expected,equality",
272 [
283 [
273 [[1, 2, 3], "data.index(2)", "data.append(4)", 1, True],
284 [[1, 2, 3], "data.index(2)", 1, True],
274 [{"a": 1}, "data.keys().isdisjoint({})", "data.update()", True, True],
285 [{"a": 1}, "data.keys().isdisjoint({})", True, True],
275 [CallCreatesHeapType(), "data()", "data.__class__()", HeapType, False],
286 # test cases for `__call__`
276 [CallCreatesBuiltin(), "data()", "data.__class__()", frozenset, False],
287 [CallCreatesHeapType(), "data()", HeapType, False],
288 [CallCreatesBuiltin(), "data()", frozenset, False],
289 # Test cases for `__init__`
290 [HeapType, "data()", HeapType, False],
291 [InitReturnsFrozenset, "data()", frozenset, False],
292 [HeapType(), "data.__class__()", HeapType, False],
293 # test cases for static and class methods
294 [HasStaticMethod, "data.static_method()", HeapType, False],
277 ],
295 ],
278 )
296 )
279 def test_evaluates_calls(data, good, bad, expected, equality):
297 def test_evaluates_calls(data, good, expected, equality):
280 context = limited(data=data)
298 context = limited(data=data)
281 value = guarded_eval(good, context)
299 value = guarded_eval(good, context)
282 if equality:
300 if equality:
283 assert value == expected
301 assert value == expected
284 else:
302 else:
285 assert isinstance(value, expected)
303 assert isinstance(value, expected)
286
304
305
306 @pytest.mark.parametrize(
307 "data,bad",
308 [
309 [[1, 2, 3], "data.append(4)"],
310 [{"a": 1}, "data.update()"],
311 ],
312 )
313 def test_rejects_calls_with_side_effects(data, bad):
314 context = limited(data=data)
315
287 with pytest.raises(GuardRejection):
316 with pytest.raises(GuardRejection):
288 guarded_eval(bad, context)
317 guarded_eval(bad, context)
289
318
290
319
291 @pytest.mark.parametrize(
320 @pytest.mark.parametrize(
292 "code,expected",
321 "code,expected",
293 [
322 [
294 ["(1\n+\n1)", 2],
323 ["(1\n+\n1)", 2],
295 ["list(range(10))[-1:]", [9]],
324 ["list(range(10))[-1:]", [9]],
296 ["list(range(20))[3:-2:3]", [3, 6, 9, 12, 15]],
325 ["list(range(20))[3:-2:3]", [3, 6, 9, 12, 15]],
297 ],
326 ],
298 )
327 )
299 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
328 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
300 def test_evaluates_complex_cases(code, expected, context):
329 def test_evaluates_complex_cases(code, expected, context):
301 assert guarded_eval(code, context()) == expected
330 assert guarded_eval(code, context()) == expected
302
331
303
332
304 @pytest.mark.parametrize(
333 @pytest.mark.parametrize(
305 "code,expected",
334 "code,expected",
306 [
335 [
307 ["1", 1],
336 ["1", 1],
308 ["1.0", 1.0],
337 ["1.0", 1.0],
309 ["0xdeedbeef", 0xDEEDBEEF],
338 ["0xdeedbeef", 0xDEEDBEEF],
310 ["True", True],
339 ["True", True],
311 ["None", None],
340 ["None", None],
312 ["{}", {}],
341 ["{}", {}],
313 ["[]", []],
342 ["[]", []],
314 ],
343 ],
315 )
344 )
316 @pytest.mark.parametrize("context", MINIMAL_OR_HIGHER)
345 @pytest.mark.parametrize("context", MINIMAL_OR_HIGHER)
317 def test_evaluates_literals(code, expected, context):
346 def test_evaluates_literals(code, expected, context):
318 assert guarded_eval(code, context()) == expected
347 assert guarded_eval(code, context()) == expected
319
348
320
349
321 @pytest.mark.parametrize(
350 @pytest.mark.parametrize(
322 "code,expected",
351 "code,expected",
323 [
352 [
324 ["-5", -5],
353 ["-5", -5],
325 ["+5", +5],
354 ["+5", +5],
326 ["~5", -6],
355 ["~5", -6],
327 ],
356 ],
328 )
357 )
329 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
358 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
330 def test_evaluates_unary_operations(code, expected, context):
359 def test_evaluates_unary_operations(code, expected, context):
331 assert guarded_eval(code, context()) == expected
360 assert guarded_eval(code, context()) == expected
332
361
333
362
334 @pytest.mark.parametrize(
363 @pytest.mark.parametrize(
335 "code,expected",
364 "code,expected",
336 [
365 [
337 ["1 + 1", 2],
366 ["1 + 1", 2],
338 ["3 - 1", 2],
367 ["3 - 1", 2],
339 ["2 * 3", 6],
368 ["2 * 3", 6],
340 ["5 // 2", 2],
369 ["5 // 2", 2],
341 ["5 / 2", 2.5],
370 ["5 / 2", 2.5],
342 ["5**2", 25],
371 ["5**2", 25],
343 ["2 >> 1", 1],
372 ["2 >> 1", 1],
344 ["2 << 1", 4],
373 ["2 << 1", 4],
345 ["1 | 2", 3],
374 ["1 | 2", 3],
346 ["1 & 1", 1],
375 ["1 & 1", 1],
347 ["1 & 2", 0],
376 ["1 & 2", 0],
348 ],
377 ],
349 )
378 )
350 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
379 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
351 def test_evaluates_binary_operations(code, expected, context):
380 def test_evaluates_binary_operations(code, expected, context):
352 assert guarded_eval(code, context()) == expected
381 assert guarded_eval(code, context()) == expected
353
382
354
383
355 @pytest.mark.parametrize(
384 @pytest.mark.parametrize(
356 "code,expected",
385 "code,expected",
357 [
386 [
358 ["2 > 1", True],
387 ["2 > 1", True],
359 ["2 < 1", False],
388 ["2 < 1", False],
360 ["2 <= 1", False],
389 ["2 <= 1", False],
361 ["2 <= 2", True],
390 ["2 <= 2", True],
362 ["1 >= 2", False],
391 ["1 >= 2", False],
363 ["2 >= 2", True],
392 ["2 >= 2", True],
364 ["2 == 2", True],
393 ["2 == 2", True],
365 ["1 == 2", False],
394 ["1 == 2", False],
366 ["1 != 2", True],
395 ["1 != 2", True],
367 ["1 != 1", False],
396 ["1 != 1", False],
368 ["1 < 4 < 3", False],
397 ["1 < 4 < 3", False],
369 ["(1 < 4) < 3", True],
398 ["(1 < 4) < 3", True],
370 ["4 > 3 > 2 > 1", True],
399 ["4 > 3 > 2 > 1", True],
371 ["4 > 3 > 2 > 9", False],
400 ["4 > 3 > 2 > 9", False],
372 ["1 < 2 < 3 < 4", True],
401 ["1 < 2 < 3 < 4", True],
373 ["9 < 2 < 3 < 4", False],
402 ["9 < 2 < 3 < 4", False],
374 ["1 < 2 > 1 > 0 > -1 < 1", True],
403 ["1 < 2 > 1 > 0 > -1 < 1", True],
375 ["1 in [1] in [[1]]", True],
404 ["1 in [1] in [[1]]", True],
376 ["1 in [1] in [[2]]", False],
405 ["1 in [1] in [[2]]", False],
377 ["1 in [1]", True],
406 ["1 in [1]", True],
378 ["0 in [1]", False],
407 ["0 in [1]", False],
379 ["1 not in [1]", False],
408 ["1 not in [1]", False],
380 ["0 not in [1]", True],
409 ["0 not in [1]", True],
381 ["True is True", True],
410 ["True is True", True],
382 ["False is False", True],
411 ["False is False", True],
383 ["True is False", False],
412 ["True is False", False],
384 ["True is not True", False],
413 ["True is not True", False],
385 ["False is not True", True],
414 ["False is not True", True],
386 ],
415 ],
387 )
416 )
388 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
417 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
389 def test_evaluates_comparisons(code, expected, context):
418 def test_evaluates_comparisons(code, expected, context):
390 assert guarded_eval(code, context()) == expected
419 assert guarded_eval(code, context()) == expected
391
420
392
421
393 def test_guards_comparisons():
422 def test_guards_comparisons():
394 class GoodEq(int):
423 class GoodEq(int):
395 pass
424 pass
396
425
397 class BadEq(int):
426 class BadEq(int):
398 def __eq__(self, other):
427 def __eq__(self, other):
399 assert False
428 assert False
400
429
401 context = limited(bad=BadEq(1), good=GoodEq(1))
430 context = limited(bad=BadEq(1), good=GoodEq(1))
402
431
403 with pytest.raises(GuardRejection):
432 with pytest.raises(GuardRejection):
404 guarded_eval("bad == 1", context)
433 guarded_eval("bad == 1", context)
405
434
406 with pytest.raises(GuardRejection):
435 with pytest.raises(GuardRejection):
407 guarded_eval("bad != 1", context)
436 guarded_eval("bad != 1", context)
408
437
409 with pytest.raises(GuardRejection):
438 with pytest.raises(GuardRejection):
410 guarded_eval("1 == bad", context)
439 guarded_eval("1 == bad", context)
411
440
412 with pytest.raises(GuardRejection):
441 with pytest.raises(GuardRejection):
413 guarded_eval("1 != bad", context)
442 guarded_eval("1 != bad", context)
414
443
415 assert guarded_eval("good == 1", context) is True
444 assert guarded_eval("good == 1", context) is True
416 assert guarded_eval("good != 1", context) is False
445 assert guarded_eval("good != 1", context) is False
417 assert guarded_eval("1 == good", context) is True
446 assert guarded_eval("1 == good", context) is True
418 assert guarded_eval("1 != good", context) is False
447 assert guarded_eval("1 != good", context) is False
419
448
420
449
421 def test_guards_unary_operations():
450 def test_guards_unary_operations():
422 class GoodOp(int):
451 class GoodOp(int):
423 pass
452 pass
424
453
425 class BadOpInv(int):
454 class BadOpInv(int):
426 def __inv__(self, other):
455 def __inv__(self, other):
427 assert False
456 assert False
428
457
429 class BadOpInverse(int):
458 class BadOpInverse(int):
430 def __inv__(self, other):
459 def __inv__(self, other):
431 assert False
460 assert False
432
461
433 context = limited(good=GoodOp(1), bad1=BadOpInv(1), bad2=BadOpInverse(1))
462 context = limited(good=GoodOp(1), bad1=BadOpInv(1), bad2=BadOpInverse(1))
434
463
435 with pytest.raises(GuardRejection):
464 with pytest.raises(GuardRejection):
436 guarded_eval("~bad1", context)
465 guarded_eval("~bad1", context)
437
466
438 with pytest.raises(GuardRejection):
467 with pytest.raises(GuardRejection):
439 guarded_eval("~bad2", context)
468 guarded_eval("~bad2", context)
440
469
441
470
442 def test_guards_binary_operations():
471 def test_guards_binary_operations():
443 class GoodOp(int):
472 class GoodOp(int):
444 pass
473 pass
445
474
446 class BadOp(int):
475 class BadOp(int):
447 def __add__(self, other):
476 def __add__(self, other):
448 assert False
477 assert False
449
478
450 context = limited(good=GoodOp(1), bad=BadOp(1))
479 context = limited(good=GoodOp(1), bad=BadOp(1))
451
480
452 with pytest.raises(GuardRejection):
481 with pytest.raises(GuardRejection):
453 guarded_eval("1 + bad", context)
482 guarded_eval("1 + bad", context)
454
483
455 with pytest.raises(GuardRejection):
484 with pytest.raises(GuardRejection):
456 guarded_eval("bad + 1", context)
485 guarded_eval("bad + 1", context)
457
486
458 assert guarded_eval("good + 1", context) == 2
487 assert guarded_eval("good + 1", context) == 2
459 assert guarded_eval("1 + good", context) == 2
488 assert guarded_eval("1 + good", context) == 2
460
489
461
490
462 def test_guards_attributes():
491 def test_guards_attributes():
463 class GoodAttr(float):
492 class GoodAttr(float):
464 pass
493 pass
465
494
466 class BadAttr1(float):
495 class BadAttr1(float):
467 def __getattr__(self, key):
496 def __getattr__(self, key):
468 assert False
497 assert False
469
498
470 class BadAttr2(float):
499 class BadAttr2(float):
471 def __getattribute__(self, key):
500 def __getattribute__(self, key):
472 assert False
501 assert False
473
502
474 context = limited(good=GoodAttr(0.5), bad1=BadAttr1(0.5), bad2=BadAttr2(0.5))
503 context = limited(good=GoodAttr(0.5), bad1=BadAttr1(0.5), bad2=BadAttr2(0.5))
475
504
476 with pytest.raises(GuardRejection):
505 with pytest.raises(GuardRejection):
477 guarded_eval("bad1.as_integer_ratio", context)
506 guarded_eval("bad1.as_integer_ratio", context)
478
507
479 with pytest.raises(GuardRejection):
508 with pytest.raises(GuardRejection):
480 guarded_eval("bad2.as_integer_ratio", context)
509 guarded_eval("bad2.as_integer_ratio", context)
481
510
482 assert guarded_eval("good.as_integer_ratio()", context) == (1, 2)
511 assert guarded_eval("good.as_integer_ratio()", context) == (1, 2)
483
512
484
513
485 @pytest.mark.parametrize("context", MINIMAL_OR_HIGHER)
514 @pytest.mark.parametrize("context", MINIMAL_OR_HIGHER)
486 def test_access_builtins(context):
515 def test_access_builtins(context):
487 assert guarded_eval("round", context()) == round
516 assert guarded_eval("round", context()) == round
488
517
489
518
490 def test_access_builtins_fails():
519 def test_access_builtins_fails():
491 context = limited()
520 context = limited()
492 with pytest.raises(NameError):
521 with pytest.raises(NameError):
493 guarded_eval("this_is_not_builtin", context)
522 guarded_eval("this_is_not_builtin", context)
494
523
495
524
496 def test_rejects_forbidden():
525 def test_rejects_forbidden():
497 context = forbidden()
526 context = forbidden()
498 with pytest.raises(GuardRejection):
527 with pytest.raises(GuardRejection):
499 guarded_eval("1", context)
528 guarded_eval("1", context)
500
529
501
530
502 def test_guards_locals_and_globals():
531 def test_guards_locals_and_globals():
503 context = EvaluationContext(
532 context = EvaluationContext(
504 locals={"local_a": "a"}, globals={"global_b": "b"}, evaluation="minimal"
533 locals={"local_a": "a"}, globals={"global_b": "b"}, evaluation="minimal"
505 )
534 )
506
535
507 with pytest.raises(GuardRejection):
536 with pytest.raises(GuardRejection):
508 guarded_eval("local_a", context)
537 guarded_eval("local_a", context)
509
538
510 with pytest.raises(GuardRejection):
539 with pytest.raises(GuardRejection):
511 guarded_eval("global_b", context)
540 guarded_eval("global_b", context)
512
541
513
542
514 def test_access_locals_and_globals():
543 def test_access_locals_and_globals():
515 context = EvaluationContext(
544 context = EvaluationContext(
516 locals={"local_a": "a"}, globals={"global_b": "b"}, evaluation="limited"
545 locals={"local_a": "a"}, globals={"global_b": "b"}, evaluation="limited"
517 )
546 )
518 assert guarded_eval("local_a", context) == "a"
547 assert guarded_eval("local_a", context) == "a"
519 assert guarded_eval("global_b", context) == "b"
548 assert guarded_eval("global_b", context) == "b"
520
549
521
550
522 @pytest.mark.parametrize(
551 @pytest.mark.parametrize(
523 "code",
552 "code",
524 ["def func(): pass", "class C: pass", "x = 1", "x += 1", "del x", "import ast"],
553 ["def func(): pass", "class C: pass", "x = 1", "x += 1", "del x", "import ast"],
525 )
554 )
526 @pytest.mark.parametrize("context", [minimal(), limited(), unsafe()])
555 @pytest.mark.parametrize("context", [minimal(), limited(), unsafe()])
527 def test_rejects_side_effect_syntax(code, context):
556 def test_rejects_side_effect_syntax(code, context):
528 with pytest.raises(SyntaxError):
557 with pytest.raises(SyntaxError):
529 guarded_eval(code, context)
558 guarded_eval(code, context)
530
559
531
560
532 def test_subscript():
561 def test_subscript():
533 context = EvaluationContext(
562 context = EvaluationContext(
534 locals={}, globals={}, evaluation="limited", in_subscript=True
563 locals={}, globals={}, evaluation="limited", in_subscript=True
535 )
564 )
536 empty_slice = slice(None, None, None)
565 empty_slice = slice(None, None, None)
537 assert guarded_eval("", context) == tuple()
566 assert guarded_eval("", context) == tuple()
538 assert guarded_eval(":", context) == empty_slice
567 assert guarded_eval(":", context) == empty_slice
539 assert guarded_eval("1:2:3", context) == slice(1, 2, 3)
568 assert guarded_eval("1:2:3", context) == slice(1, 2, 3)
540 assert guarded_eval(':, "a"', context) == (empty_slice, "a")
569 assert guarded_eval(':, "a"', context) == (empty_slice, "a")
541
570
542
571
543 def test_unbind_method():
572 def test_unbind_method():
544 class X(list):
573 class X(list):
545 def index(self, k):
574 def index(self, k):
546 return "CUSTOM"
575 return "CUSTOM"
547
576
548 x = X()
577 x = X()
549 assert _unbind_method(x.index) is X.index
578 assert _unbind_method(x.index) is X.index
550 assert _unbind_method([].index) is list.index
579 assert _unbind_method([].index) is list.index
551 assert _unbind_method(list.index) is None
580 assert _unbind_method(list.index) is None
552
581
553
582
554 def test_assumption_instance_attr_do_not_matter():
583 def test_assumption_instance_attr_do_not_matter():
555 """This is semi-specified in Python documentation.
584 """This is semi-specified in Python documentation.
556
585
557 However, since the specification says 'not guaranteed
586 However, since the specification says 'not guaranteed
558 to work' rather than 'is forbidden to work', future
587 to work' rather than 'is forbidden to work', future
559 versions could invalidate this assumptions. This test
588 versions could invalidate this assumptions. This test
560 is meant to catch such a change if it ever comes true.
589 is meant to catch such a change if it ever comes true.
561 """
590 """
562
591
563 class T:
592 class T:
564 def __getitem__(self, k):
593 def __getitem__(self, k):
565 return "a"
594 return "a"
566
595
567 def __getattr__(self, k):
596 def __getattr__(self, k):
568 return "a"
597 return "a"
569
598
570 def f(self):
599 def f(self):
571 return "b"
600 return "b"
572
601
573 t = T()
602 t = T()
574 t.__getitem__ = f
603 t.__getitem__ = f
575 t.__getattr__ = f
604 t.__getattr__ = f
576 assert t[1] == "a"
605 assert t[1] == "a"
577 assert t[1] == "a"
606 assert t[1] == "a"
578
607
579
608
580 def test_assumption_named_tuples_share_getitem():
609 def test_assumption_named_tuples_share_getitem():
581 """Check assumption on named tuples sharing __getitem__"""
610 """Check assumption on named tuples sharing __getitem__"""
582 from typing import NamedTuple
611 from typing import NamedTuple
583
612
584 class A(NamedTuple):
613 class A(NamedTuple):
585 pass
614 pass
586
615
587 class B(NamedTuple):
616 class B(NamedTuple):
588 pass
617 pass
589
618
590 assert A.__getitem__ == B.__getitem__
619 assert A.__getitem__ == B.__getitem__
591
620
592
621
593 @dec.skip_without("numpy")
622 @dec.skip_without("numpy")
594 def test_module_access():
623 def test_module_access():
595 import numpy
624 import numpy
596
625
597 context = limited(numpy=numpy)
626 context = limited(numpy=numpy)
598 assert guarded_eval("numpy.linalg.norm", context) == numpy.linalg.norm
627 assert guarded_eval("numpy.linalg.norm", context) == numpy.linalg.norm
599
628
600 context = minimal(numpy=numpy)
629 context = minimal(numpy=numpy)
601 with pytest.raises(GuardRejection):
630 with pytest.raises(GuardRejection):
602 guarded_eval("np.linalg.norm", context)
631 guarded_eval("np.linalg.norm", context)
General Comments 0
You need to be logged in to leave comments. Login now