##// END OF EJS Templates
Support stringized return type annotations
krassowski -
Show More
@@ -1,793 +1,803 b''
1 from inspect import isclass, signature, Signature
1 from inspect import isclass, signature, Signature
2 from typing import (
2 from typing import (
3 Any,
4 Callable,
3 Callable,
5 Dict,
4 Dict,
6 Set,
5 Set,
7 Sequence,
6 Sequence,
8 Tuple,
7 Tuple,
9 NamedTuple,
8 NamedTuple,
10 Type,
9 Type,
11 Literal,
10 Literal,
12 Union,
11 Union,
13 TYPE_CHECKING,
12 TYPE_CHECKING,
14 )
13 )
15 import ast
14 import ast
16 import builtins
15 import builtins
17 import collections
16 import collections
18 import operator
17 import operator
19 import sys
18 import sys
20 from functools import cached_property
19 from functools import cached_property
21 from dataclasses import dataclass, field
20 from dataclasses import dataclass, field
22 from types import MethodDescriptorType, ModuleType
21 from types import MethodDescriptorType, ModuleType
23
22
24 from IPython.utils.docs import GENERATING_DOCUMENTATION
23 from IPython.utils.docs import GENERATING_DOCUMENTATION
25 from IPython.utils.decorators import undoc
24 from IPython.utils.decorators import undoc
26
25
27
26
28 if TYPE_CHECKING or GENERATING_DOCUMENTATION:
27 if TYPE_CHECKING or GENERATING_DOCUMENTATION:
29 from typing_extensions import Protocol
28 from typing_extensions import Protocol
30 else:
29 else:
31 # do not require on runtime
30 # do not require on runtime
32 Protocol = object # requires Python >=3.8
31 Protocol = object # requires Python >=3.8
33
32
34
33
35 @undoc
34 @undoc
36 class HasGetItem(Protocol):
35 class HasGetItem(Protocol):
37 def __getitem__(self, key) -> None:
36 def __getitem__(self, key) -> None:
38 ...
37 ...
39
38
40
39
41 @undoc
40 @undoc
42 class InstancesHaveGetItem(Protocol):
41 class InstancesHaveGetItem(Protocol):
43 def __call__(self, *args, **kwargs) -> HasGetItem:
42 def __call__(self, *args, **kwargs) -> HasGetItem:
44 ...
43 ...
45
44
46
45
47 @undoc
46 @undoc
48 class HasGetAttr(Protocol):
47 class HasGetAttr(Protocol):
49 def __getattr__(self, key) -> None:
48 def __getattr__(self, key) -> None:
50 ...
49 ...
51
50
52
51
53 @undoc
52 @undoc
54 class DoesNotHaveGetAttr(Protocol):
53 class DoesNotHaveGetAttr(Protocol):
55 pass
54 pass
56
55
57
56
58 # By default `__getattr__` is not explicitly implemented on most objects
57 # By default `__getattr__` is not explicitly implemented on most objects
59 MayHaveGetattr = Union[HasGetAttr, DoesNotHaveGetAttr]
58 MayHaveGetattr = Union[HasGetAttr, DoesNotHaveGetAttr]
60
59
61
60
62 def _unbind_method(func: Callable) -> Union[Callable, None]:
61 def _unbind_method(func: Callable) -> Union[Callable, None]:
63 """Get unbound method for given bound method.
62 """Get unbound method for given bound method.
64
63
65 Returns None if cannot get unbound method, or method is already unbound.
64 Returns None if cannot get unbound method, or method is already unbound.
66 """
65 """
67 owner = getattr(func, "__self__", None)
66 owner = getattr(func, "__self__", None)
68 owner_class = type(owner)
67 owner_class = type(owner)
69 name = getattr(func, "__name__", None)
68 name = getattr(func, "__name__", None)
70 instance_dict_overrides = getattr(owner, "__dict__", None)
69 instance_dict_overrides = getattr(owner, "__dict__", None)
71 if (
70 if (
72 owner is not None
71 owner is not None
73 and name
72 and name
74 and (
73 and (
75 not instance_dict_overrides
74 not instance_dict_overrides
76 or (instance_dict_overrides and name not in instance_dict_overrides)
75 or (instance_dict_overrides and name not in instance_dict_overrides)
77 )
76 )
78 ):
77 ):
79 return getattr(owner_class, name)
78 return getattr(owner_class, name)
80 return None
79 return None
81
80
82
81
83 @undoc
82 @undoc
84 @dataclass
83 @dataclass
85 class EvaluationPolicy:
84 class EvaluationPolicy:
86 """Definition of evaluation policy."""
85 """Definition of evaluation policy."""
87
86
88 allow_locals_access: bool = False
87 allow_locals_access: bool = False
89 allow_globals_access: bool = False
88 allow_globals_access: bool = False
90 allow_item_access: bool = False
89 allow_item_access: bool = False
91 allow_attr_access: bool = False
90 allow_attr_access: bool = False
92 allow_builtins_access: bool = False
91 allow_builtins_access: bool = False
93 allow_all_operations: bool = False
92 allow_all_operations: bool = False
94 allow_any_calls: bool = False
93 allow_any_calls: bool = False
95 allowed_calls: Set[Callable] = field(default_factory=set)
94 allowed_calls: Set[Callable] = field(default_factory=set)
96
95
97 def can_get_item(self, value, item):
96 def can_get_item(self, value, item):
98 return self.allow_item_access
97 return self.allow_item_access
99
98
100 def can_get_attr(self, value, attr):
99 def can_get_attr(self, value, attr):
101 return self.allow_attr_access
100 return self.allow_attr_access
102
101
103 def can_operate(self, dunders: Tuple[str, ...], a, b=None):
102 def can_operate(self, dunders: Tuple[str, ...], a, b=None):
104 if self.allow_all_operations:
103 if self.allow_all_operations:
105 return True
104 return True
106
105
107 def can_call(self, func):
106 def can_call(self, func):
108 if self.allow_any_calls:
107 if self.allow_any_calls:
109 return True
108 return True
110
109
111 if func in self.allowed_calls:
110 if func in self.allowed_calls:
112 return True
111 return True
113
112
114 owner_method = _unbind_method(func)
113 owner_method = _unbind_method(func)
115
114
116 if owner_method and owner_method in self.allowed_calls:
115 if owner_method and owner_method in self.allowed_calls:
117 return True
116 return True
118
117
119
118
120 def _get_external(module_name: str, access_path: Sequence[str]):
119 def _get_external(module_name: str, access_path: Sequence[str]):
121 """Get value from external module given a dotted access path.
120 """Get value from external module given a dotted access path.
122
121
123 Raises:
122 Raises:
124 * `KeyError` if module is removed not found, and
123 * `KeyError` if module is removed not found, and
125 * `AttributeError` if acess path does not match an exported object
124 * `AttributeError` if acess path does not match an exported object
126 """
125 """
127 member_type = sys.modules[module_name]
126 member_type = sys.modules[module_name]
128 for attr in access_path:
127 for attr in access_path:
129 member_type = getattr(member_type, attr)
128 member_type = getattr(member_type, attr)
130 return member_type
129 return member_type
131
130
132
131
133 def _has_original_dunder_external(
132 def _has_original_dunder_external(
134 value,
133 value,
135 module_name: str,
134 module_name: str,
136 access_path: Sequence[str],
135 access_path: Sequence[str],
137 method_name: str,
136 method_name: str,
138 ):
137 ):
139 if module_name not in sys.modules:
138 if module_name not in sys.modules:
140 # LBYLB as it is faster
139 # LBYLB as it is faster
141 return False
140 return False
142 try:
141 try:
143 member_type = _get_external(module_name, access_path)
142 member_type = _get_external(module_name, access_path)
144 value_type = type(value)
143 value_type = type(value)
145 if type(value) == member_type:
144 if type(value) == member_type:
146 return True
145 return True
147 if method_name == "__getattribute__":
146 if method_name == "__getattribute__":
148 # we have to short-circuit here due to an unresolved issue in
147 # we have to short-circuit here due to an unresolved issue in
149 # `isinstance` implementation: https://bugs.python.org/issue32683
148 # `isinstance` implementation: https://bugs.python.org/issue32683
150 return False
149 return False
151 if isinstance(value, member_type):
150 if isinstance(value, member_type):
152 method = getattr(value_type, method_name, None)
151 method = getattr(value_type, method_name, None)
153 member_method = getattr(member_type, method_name, None)
152 member_method = getattr(member_type, method_name, None)
154 if member_method == method:
153 if member_method == method:
155 return True
154 return True
156 except (AttributeError, KeyError):
155 except (AttributeError, KeyError):
157 return False
156 return False
158
157
159
158
160 def _has_original_dunder(
159 def _has_original_dunder(
161 value, allowed_types, allowed_methods, allowed_external, method_name
160 value, allowed_types, allowed_methods, allowed_external, method_name
162 ):
161 ):
163 # note: Python ignores `__getattr__`/`__getitem__` on instances,
162 # note: Python ignores `__getattr__`/`__getitem__` on instances,
164 # we only need to check at class level
163 # we only need to check at class level
165 value_type = type(value)
164 value_type = type(value)
166
165
167 # strict type check passes β†’ no need to check method
166 # strict type check passes β†’ no need to check method
168 if value_type in allowed_types:
167 if value_type in allowed_types:
169 return True
168 return True
170
169
171 method = getattr(value_type, method_name, None)
170 method = getattr(value_type, method_name, None)
172
171
173 if method is None:
172 if method is None:
174 return None
173 return None
175
174
176 if method in allowed_methods:
175 if method in allowed_methods:
177 return True
176 return True
178
177
179 for module_name, *access_path in allowed_external:
178 for module_name, *access_path in allowed_external:
180 if _has_original_dunder_external(value, module_name, access_path, method_name):
179 if _has_original_dunder_external(value, module_name, access_path, method_name):
181 return True
180 return True
182
181
183 return False
182 return False
184
183
185
184
186 @undoc
185 @undoc
187 @dataclass
186 @dataclass
188 class SelectivePolicy(EvaluationPolicy):
187 class SelectivePolicy(EvaluationPolicy):
189 allowed_getitem: Set[InstancesHaveGetItem] = field(default_factory=set)
188 allowed_getitem: Set[InstancesHaveGetItem] = field(default_factory=set)
190 allowed_getitem_external: Set[Tuple[str, ...]] = field(default_factory=set)
189 allowed_getitem_external: Set[Tuple[str, ...]] = field(default_factory=set)
191
190
192 allowed_getattr: Set[MayHaveGetattr] = field(default_factory=set)
191 allowed_getattr: Set[MayHaveGetattr] = field(default_factory=set)
193 allowed_getattr_external: Set[Tuple[str, ...]] = field(default_factory=set)
192 allowed_getattr_external: Set[Tuple[str, ...]] = field(default_factory=set)
194
193
195 allowed_operations: Set = field(default_factory=set)
194 allowed_operations: Set = field(default_factory=set)
196 allowed_operations_external: Set[Tuple[str, ...]] = field(default_factory=set)
195 allowed_operations_external: Set[Tuple[str, ...]] = field(default_factory=set)
197
196
198 _operation_methods_cache: Dict[str, Set[Callable]] = field(
197 _operation_methods_cache: Dict[str, Set[Callable]] = field(
199 default_factory=dict, init=False
198 default_factory=dict, init=False
200 )
199 )
201
200
202 def can_get_attr(self, value, attr):
201 def can_get_attr(self, value, attr):
203 has_original_attribute = _has_original_dunder(
202 has_original_attribute = _has_original_dunder(
204 value,
203 value,
205 allowed_types=self.allowed_getattr,
204 allowed_types=self.allowed_getattr,
206 allowed_methods=self._getattribute_methods,
205 allowed_methods=self._getattribute_methods,
207 allowed_external=self.allowed_getattr_external,
206 allowed_external=self.allowed_getattr_external,
208 method_name="__getattribute__",
207 method_name="__getattribute__",
209 )
208 )
210 has_original_attr = _has_original_dunder(
209 has_original_attr = _has_original_dunder(
211 value,
210 value,
212 allowed_types=self.allowed_getattr,
211 allowed_types=self.allowed_getattr,
213 allowed_methods=self._getattr_methods,
212 allowed_methods=self._getattr_methods,
214 allowed_external=self.allowed_getattr_external,
213 allowed_external=self.allowed_getattr_external,
215 method_name="__getattr__",
214 method_name="__getattr__",
216 )
215 )
217
216
218 accept = False
217 accept = False
219
218
220 # Many objects do not have `__getattr__`, this is fine.
219 # Many objects do not have `__getattr__`, this is fine.
221 if has_original_attr is None and has_original_attribute:
220 if has_original_attr is None and has_original_attribute:
222 accept = True
221 accept = True
223 else:
222 else:
224 # Accept objects without modifications to `__getattr__` and `__getattribute__`
223 # Accept objects without modifications to `__getattr__` and `__getattribute__`
225 accept = has_original_attr and has_original_attribute
224 accept = has_original_attr and has_original_attribute
226
225
227 if accept:
226 if accept:
228 # We still need to check for overriden properties.
227 # We still need to check for overriden properties.
229
228
230 value_class = type(value)
229 value_class = type(value)
231 if not hasattr(value_class, attr):
230 if not hasattr(value_class, attr):
232 return True
231 return True
233
232
234 class_attr_val = getattr(value_class, attr)
233 class_attr_val = getattr(value_class, attr)
235 is_property = isinstance(class_attr_val, property)
234 is_property = isinstance(class_attr_val, property)
236
235
237 if not is_property:
236 if not is_property:
238 return True
237 return True
239
238
240 # Properties in allowed types are ok (although we do not include any
239 # Properties in allowed types are ok (although we do not include any
241 # properties in our default allow list currently).
240 # properties in our default allow list currently).
242 if type(value) in self.allowed_getattr:
241 if type(value) in self.allowed_getattr:
243 return True # pragma: no cover
242 return True # pragma: no cover
244
243
245 # Properties in subclasses of allowed types may be ok if not changed
244 # Properties in subclasses of allowed types may be ok if not changed
246 for module_name, *access_path in self.allowed_getattr_external:
245 for module_name, *access_path in self.allowed_getattr_external:
247 try:
246 try:
248 external_class = _get_external(module_name, access_path)
247 external_class = _get_external(module_name, access_path)
249 external_class_attr_val = getattr(external_class, attr)
248 external_class_attr_val = getattr(external_class, attr)
250 except (KeyError, AttributeError):
249 except (KeyError, AttributeError):
251 return False # pragma: no cover
250 return False # pragma: no cover
252 return class_attr_val == external_class_attr_val
251 return class_attr_val == external_class_attr_val
253
252
254 return False
253 return False
255
254
256 def can_get_item(self, value, item):
255 def can_get_item(self, value, item):
257 """Allow accessing `__getiitem__` of allow-listed instances unless it was not modified."""
256 """Allow accessing `__getiitem__` of allow-listed instances unless it was not modified."""
258 return _has_original_dunder(
257 return _has_original_dunder(
259 value,
258 value,
260 allowed_types=self.allowed_getitem,
259 allowed_types=self.allowed_getitem,
261 allowed_methods=self._getitem_methods,
260 allowed_methods=self._getitem_methods,
262 allowed_external=self.allowed_getitem_external,
261 allowed_external=self.allowed_getitem_external,
263 method_name="__getitem__",
262 method_name="__getitem__",
264 )
263 )
265
264
266 def can_operate(self, dunders: Tuple[str, ...], a, b=None):
265 def can_operate(self, dunders: Tuple[str, ...], a, b=None):
267 objects = [a]
266 objects = [a]
268 if b is not None:
267 if b is not None:
269 objects.append(b)
268 objects.append(b)
270 return all(
269 return all(
271 [
270 [
272 _has_original_dunder(
271 _has_original_dunder(
273 obj,
272 obj,
274 allowed_types=self.allowed_operations,
273 allowed_types=self.allowed_operations,
275 allowed_methods=self._operator_dunder_methods(dunder),
274 allowed_methods=self._operator_dunder_methods(dunder),
276 allowed_external=self.allowed_operations_external,
275 allowed_external=self.allowed_operations_external,
277 method_name=dunder,
276 method_name=dunder,
278 )
277 )
279 for dunder in dunders
278 for dunder in dunders
280 for obj in objects
279 for obj in objects
281 ]
280 ]
282 )
281 )
283
282
284 def _operator_dunder_methods(self, dunder: str) -> Set[Callable]:
283 def _operator_dunder_methods(self, dunder: str) -> Set[Callable]:
285 if dunder not in self._operation_methods_cache:
284 if dunder not in self._operation_methods_cache:
286 self._operation_methods_cache[dunder] = self._safe_get_methods(
285 self._operation_methods_cache[dunder] = self._safe_get_methods(
287 self.allowed_operations, dunder
286 self.allowed_operations, dunder
288 )
287 )
289 return self._operation_methods_cache[dunder]
288 return self._operation_methods_cache[dunder]
290
289
291 @cached_property
290 @cached_property
292 def _getitem_methods(self) -> Set[Callable]:
291 def _getitem_methods(self) -> Set[Callable]:
293 return self._safe_get_methods(self.allowed_getitem, "__getitem__")
292 return self._safe_get_methods(self.allowed_getitem, "__getitem__")
294
293
295 @cached_property
294 @cached_property
296 def _getattr_methods(self) -> Set[Callable]:
295 def _getattr_methods(self) -> Set[Callable]:
297 return self._safe_get_methods(self.allowed_getattr, "__getattr__")
296 return self._safe_get_methods(self.allowed_getattr, "__getattr__")
298
297
299 @cached_property
298 @cached_property
300 def _getattribute_methods(self) -> Set[Callable]:
299 def _getattribute_methods(self) -> Set[Callable]:
301 return self._safe_get_methods(self.allowed_getattr, "__getattribute__")
300 return self._safe_get_methods(self.allowed_getattr, "__getattribute__")
302
301
303 def _safe_get_methods(self, classes, name) -> Set[Callable]:
302 def _safe_get_methods(self, classes, name) -> Set[Callable]:
304 return {
303 return {
305 method
304 method
306 for class_ in classes
305 for class_ in classes
307 for method in [getattr(class_, name, None)]
306 for method in [getattr(class_, name, None)]
308 if method
307 if method
309 }
308 }
310
309
311
310
312 class _DummyNamedTuple(NamedTuple):
311 class _DummyNamedTuple(NamedTuple):
313 """Used internally to retrieve methods of named tuple instance."""
312 """Used internally to retrieve methods of named tuple instance."""
314
313
315
314
316 class EvaluationContext(NamedTuple):
315 class EvaluationContext(NamedTuple):
317 #: Local namespace
316 #: Local namespace
318 locals: dict
317 locals: dict
319 #: Global namespace
318 #: Global namespace
320 globals: dict
319 globals: dict
321 #: Evaluation policy identifier
320 #: Evaluation policy identifier
322 evaluation: Literal[
321 evaluation: Literal[
323 "forbidden", "minimal", "limited", "unsafe", "dangerous"
322 "forbidden", "minimal", "limited", "unsafe", "dangerous"
324 ] = "forbidden"
323 ] = "forbidden"
325 #: Whether the evalution of code takes place inside of a subscript.
324 #: Whether the evalution of code takes place inside of a subscript.
326 #: Useful for evaluating ``:-1, 'col'`` in ``df[:-1, 'col']``.
325 #: Useful for evaluating ``:-1, 'col'`` in ``df[:-1, 'col']``.
327 in_subscript: bool = False
326 in_subscript: bool = False
328
327
329
328
330 class _IdentitySubscript:
329 class _IdentitySubscript:
331 """Returns the key itself when item is requested via subscript."""
330 """Returns the key itself when item is requested via subscript."""
332
331
333 def __getitem__(self, key):
332 def __getitem__(self, key):
334 return key
333 return key
335
334
336
335
337 IDENTITY_SUBSCRIPT = _IdentitySubscript()
336 IDENTITY_SUBSCRIPT = _IdentitySubscript()
338 SUBSCRIPT_MARKER = "__SUBSCRIPT_SENTINEL__"
337 SUBSCRIPT_MARKER = "__SUBSCRIPT_SENTINEL__"
339 UNKNOWN_SIGNATURE = Signature()
338 UNKNOWN_SIGNATURE = Signature()
340 NOT_EVALUATED = object()
339 NOT_EVALUATED = object()
341
340
342
341
343 class GuardRejection(Exception):
342 class GuardRejection(Exception):
344 """Exception raised when guard rejects evaluation attempt."""
343 """Exception raised when guard rejects evaluation attempt."""
345
344
346 pass
345 pass
347
346
348
347
349 def guarded_eval(code: str, context: EvaluationContext):
348 def guarded_eval(code: str, context: EvaluationContext):
350 """Evaluate provided code in the evaluation context.
349 """Evaluate provided code in the evaluation context.
351
350
352 If evaluation policy given by context is set to ``forbidden``
351 If evaluation policy given by context is set to ``forbidden``
353 no evaluation will be performed; if it is set to ``dangerous``
352 no evaluation will be performed; if it is set to ``dangerous``
354 standard :func:`eval` will be used; finally, for any other,
353 standard :func:`eval` will be used; finally, for any other,
355 policy :func:`eval_node` will be called on parsed AST.
354 policy :func:`eval_node` will be called on parsed AST.
356 """
355 """
357 locals_ = context.locals
356 locals_ = context.locals
358
357
359 if context.evaluation == "forbidden":
358 if context.evaluation == "forbidden":
360 raise GuardRejection("Forbidden mode")
359 raise GuardRejection("Forbidden mode")
361
360
362 # note: not using `ast.literal_eval` as it does not implement
361 # note: not using `ast.literal_eval` as it does not implement
363 # getitem at all, for example it fails on simple `[0][1]`
362 # getitem at all, for example it fails on simple `[0][1]`
364
363
365 if context.in_subscript:
364 if context.in_subscript:
366 # syntatic sugar for ellipsis (:) is only available in susbcripts
365 # syntatic sugar for ellipsis (:) is only available in susbcripts
367 # so we need to trick the ast parser into thinking that we have
366 # so we need to trick the ast parser into thinking that we have
368 # a subscript, but we need to be able to later recognise that we did
367 # a subscript, but we need to be able to later recognise that we did
369 # it so we can ignore the actual __getitem__ operation
368 # it so we can ignore the actual __getitem__ operation
370 if not code:
369 if not code:
371 return tuple()
370 return tuple()
372 locals_ = locals_.copy()
371 locals_ = locals_.copy()
373 locals_[SUBSCRIPT_MARKER] = IDENTITY_SUBSCRIPT
372 locals_[SUBSCRIPT_MARKER] = IDENTITY_SUBSCRIPT
374 code = SUBSCRIPT_MARKER + "[" + code + "]"
373 code = SUBSCRIPT_MARKER + "[" + code + "]"
375 context = EvaluationContext(**{**context._asdict(), **{"locals": locals_}})
374 context = EvaluationContext(**{**context._asdict(), **{"locals": locals_}})
376
375
377 if context.evaluation == "dangerous":
376 if context.evaluation == "dangerous":
378 return eval(code, context.globals, context.locals)
377 return eval(code, context.globals, context.locals)
379
378
380 expression = ast.parse(code, mode="eval")
379 expression = ast.parse(code, mode="eval")
381
380
382 return eval_node(expression, context)
381 return eval_node(expression, context)
383
382
384
383
385 BINARY_OP_DUNDERS: Dict[Type[ast.operator], Tuple[str]] = {
384 BINARY_OP_DUNDERS: Dict[Type[ast.operator], Tuple[str]] = {
386 ast.Add: ("__add__",),
385 ast.Add: ("__add__",),
387 ast.Sub: ("__sub__",),
386 ast.Sub: ("__sub__",),
388 ast.Mult: ("__mul__",),
387 ast.Mult: ("__mul__",),
389 ast.Div: ("__truediv__",),
388 ast.Div: ("__truediv__",),
390 ast.FloorDiv: ("__floordiv__",),
389 ast.FloorDiv: ("__floordiv__",),
391 ast.Mod: ("__mod__",),
390 ast.Mod: ("__mod__",),
392 ast.Pow: ("__pow__",),
391 ast.Pow: ("__pow__",),
393 ast.LShift: ("__lshift__",),
392 ast.LShift: ("__lshift__",),
394 ast.RShift: ("__rshift__",),
393 ast.RShift: ("__rshift__",),
395 ast.BitOr: ("__or__",),
394 ast.BitOr: ("__or__",),
396 ast.BitXor: ("__xor__",),
395 ast.BitXor: ("__xor__",),
397 ast.BitAnd: ("__and__",),
396 ast.BitAnd: ("__and__",),
398 ast.MatMult: ("__matmul__",),
397 ast.MatMult: ("__matmul__",),
399 }
398 }
400
399
401 COMP_OP_DUNDERS: Dict[Type[ast.cmpop], Tuple[str, ...]] = {
400 COMP_OP_DUNDERS: Dict[Type[ast.cmpop], Tuple[str, ...]] = {
402 ast.Eq: ("__eq__",),
401 ast.Eq: ("__eq__",),
403 ast.NotEq: ("__ne__", "__eq__"),
402 ast.NotEq: ("__ne__", "__eq__"),
404 ast.Lt: ("__lt__", "__gt__"),
403 ast.Lt: ("__lt__", "__gt__"),
405 ast.LtE: ("__le__", "__ge__"),
404 ast.LtE: ("__le__", "__ge__"),
406 ast.Gt: ("__gt__", "__lt__"),
405 ast.Gt: ("__gt__", "__lt__"),
407 ast.GtE: ("__ge__", "__le__"),
406 ast.GtE: ("__ge__", "__le__"),
408 ast.In: ("__contains__",),
407 ast.In: ("__contains__",),
409 # Note: ast.Is, ast.IsNot, ast.NotIn are handled specially
408 # Note: ast.Is, ast.IsNot, ast.NotIn are handled specially
410 }
409 }
411
410
412 UNARY_OP_DUNDERS: Dict[Type[ast.unaryop], Tuple[str, ...]] = {
411 UNARY_OP_DUNDERS: Dict[Type[ast.unaryop], Tuple[str, ...]] = {
413 ast.USub: ("__neg__",),
412 ast.USub: ("__neg__",),
414 ast.UAdd: ("__pos__",),
413 ast.UAdd: ("__pos__",),
415 # we have to check both __inv__ and __invert__!
414 # we have to check both __inv__ and __invert__!
416 ast.Invert: ("__invert__", "__inv__"),
415 ast.Invert: ("__invert__", "__inv__"),
417 ast.Not: ("__not__",),
416 ast.Not: ("__not__",),
418 }
417 }
419
418
420
419
421 class Duck:
420 class Duck:
422 """A dummy class used to create objects of other classes without calling their ``__init__``"""
421 """A dummy class used to create objects of other classes without calling their ``__init__``"""
423
422
424
423
425 def _find_dunder(node_op, dunders) -> Union[Tuple[str, ...], None]:
424 def _find_dunder(node_op, dunders) -> Union[Tuple[str, ...], None]:
426 dunder = None
425 dunder = None
427 for op, candidate_dunder in dunders.items():
426 for op, candidate_dunder in dunders.items():
428 if isinstance(node_op, op):
427 if isinstance(node_op, op):
429 dunder = candidate_dunder
428 dunder = candidate_dunder
430 return dunder
429 return dunder
431
430
432
431
433 def eval_node(node: Union[ast.AST, None], context: EvaluationContext):
432 def eval_node(node: Union[ast.AST, None], context: EvaluationContext):
434 """Evaluate AST node in provided context.
433 """Evaluate AST node in provided context.
435
434
436 Applies evaluation restrictions defined in the context. Currently does not support evaluation of functions with keyword arguments.
435 Applies evaluation restrictions defined in the context. Currently does not support evaluation of functions with keyword arguments.
437
436
438 Does not evaluate actions that always have side effects:
437 Does not evaluate actions that always have side effects:
439
438
440 - class definitions (``class sth: ...``)
439 - class definitions (``class sth: ...``)
441 - function definitions (``def sth: ...``)
440 - function definitions (``def sth: ...``)
442 - variable assignments (``x = 1``)
441 - variable assignments (``x = 1``)
443 - augmented assignments (``x += 1``)
442 - augmented assignments (``x += 1``)
444 - deletions (``del x``)
443 - deletions (``del x``)
445
444
446 Does not evaluate operations which do not return values:
445 Does not evaluate operations which do not return values:
447
446
448 - assertions (``assert x``)
447 - assertions (``assert x``)
449 - pass (``pass``)
448 - pass (``pass``)
450 - imports (``import x``)
449 - imports (``import x``)
451 - control flow:
450 - control flow:
452
451
453 - conditionals (``if x:``) except for ternary IfExp (``a if x else b``)
452 - conditionals (``if x:``) except for ternary IfExp (``a if x else b``)
454 - loops (``for`` and ``while``)
453 - loops (``for`` and ``while``)
455 - exception handling
454 - exception handling
456
455
457 The purpose of this function is to guard against unwanted side-effects;
456 The purpose of this function is to guard against unwanted side-effects;
458 it does not give guarantees on protection from malicious code execution.
457 it does not give guarantees on protection from malicious code execution.
459 """
458 """
460 policy = EVALUATION_POLICIES[context.evaluation]
459 policy = EVALUATION_POLICIES[context.evaluation]
461 if node is None:
460 if node is None:
462 return None
461 return None
463 if isinstance(node, ast.Expression):
462 if isinstance(node, ast.Expression):
464 return eval_node(node.body, context)
463 return eval_node(node.body, context)
465 if isinstance(node, ast.BinOp):
464 if isinstance(node, ast.BinOp):
466 left = eval_node(node.left, context)
465 left = eval_node(node.left, context)
467 right = eval_node(node.right, context)
466 right = eval_node(node.right, context)
468 dunders = _find_dunder(node.op, BINARY_OP_DUNDERS)
467 dunders = _find_dunder(node.op, BINARY_OP_DUNDERS)
469 if dunders:
468 if dunders:
470 if policy.can_operate(dunders, left, right):
469 if policy.can_operate(dunders, left, right):
471 return getattr(left, dunders[0])(right)
470 return getattr(left, dunders[0])(right)
472 else:
471 else:
473 raise GuardRejection(
472 raise GuardRejection(
474 f"Operation (`{dunders}`) for",
473 f"Operation (`{dunders}`) for",
475 type(left),
474 type(left),
476 f"not allowed in {context.evaluation} mode",
475 f"not allowed in {context.evaluation} mode",
477 )
476 )
478 if isinstance(node, ast.Compare):
477 if isinstance(node, ast.Compare):
479 left = eval_node(node.left, context)
478 left = eval_node(node.left, context)
480 all_true = True
479 all_true = True
481 negate = False
480 negate = False
482 for op, right in zip(node.ops, node.comparators):
481 for op, right in zip(node.ops, node.comparators):
483 right = eval_node(right, context)
482 right = eval_node(right, context)
484 dunder = None
483 dunder = None
485 dunders = _find_dunder(op, COMP_OP_DUNDERS)
484 dunders = _find_dunder(op, COMP_OP_DUNDERS)
486 if not dunders:
485 if not dunders:
487 if isinstance(op, ast.NotIn):
486 if isinstance(op, ast.NotIn):
488 dunders = COMP_OP_DUNDERS[ast.In]
487 dunders = COMP_OP_DUNDERS[ast.In]
489 negate = True
488 negate = True
490 if isinstance(op, ast.Is):
489 if isinstance(op, ast.Is):
491 dunder = "is_"
490 dunder = "is_"
492 if isinstance(op, ast.IsNot):
491 if isinstance(op, ast.IsNot):
493 dunder = "is_"
492 dunder = "is_"
494 negate = True
493 negate = True
495 if not dunder and dunders:
494 if not dunder and dunders:
496 dunder = dunders[0]
495 dunder = dunders[0]
497 if dunder:
496 if dunder:
498 a, b = (right, left) if dunder == "__contains__" else (left, right)
497 a, b = (right, left) if dunder == "__contains__" else (left, right)
499 if dunder == "is_" or dunders and policy.can_operate(dunders, a, b):
498 if dunder == "is_" or dunders and policy.can_operate(dunders, a, b):
500 result = getattr(operator, dunder)(a, b)
499 result = getattr(operator, dunder)(a, b)
501 if negate:
500 if negate:
502 result = not result
501 result = not result
503 if not result:
502 if not result:
504 all_true = False
503 all_true = False
505 left = right
504 left = right
506 else:
505 else:
507 raise GuardRejection(
506 raise GuardRejection(
508 f"Comparison (`{dunder}`) for",
507 f"Comparison (`{dunder}`) for",
509 type(left),
508 type(left),
510 f"not allowed in {context.evaluation} mode",
509 f"not allowed in {context.evaluation} mode",
511 )
510 )
512 else:
511 else:
513 raise ValueError(
512 raise ValueError(
514 f"Comparison `{dunder}` not supported"
513 f"Comparison `{dunder}` not supported"
515 ) # pragma: no cover
514 ) # pragma: no cover
516 return all_true
515 return all_true
517 if isinstance(node, ast.Constant):
516 if isinstance(node, ast.Constant):
518 return node.value
517 return node.value
519 if isinstance(node, ast.Tuple):
518 if isinstance(node, ast.Tuple):
520 return tuple(eval_node(e, context) for e in node.elts)
519 return tuple(eval_node(e, context) for e in node.elts)
521 if isinstance(node, ast.List):
520 if isinstance(node, ast.List):
522 return [eval_node(e, context) for e in node.elts]
521 return [eval_node(e, context) for e in node.elts]
523 if isinstance(node, ast.Set):
522 if isinstance(node, ast.Set):
524 return {eval_node(e, context) for e in node.elts}
523 return {eval_node(e, context) for e in node.elts}
525 if isinstance(node, ast.Dict):
524 if isinstance(node, ast.Dict):
526 return dict(
525 return dict(
527 zip(
526 zip(
528 [eval_node(k, context) for k in node.keys],
527 [eval_node(k, context) for k in node.keys],
529 [eval_node(v, context) for v in node.values],
528 [eval_node(v, context) for v in node.values],
530 )
529 )
531 )
530 )
532 if isinstance(node, ast.Slice):
531 if isinstance(node, ast.Slice):
533 return slice(
532 return slice(
534 eval_node(node.lower, context),
533 eval_node(node.lower, context),
535 eval_node(node.upper, context),
534 eval_node(node.upper, context),
536 eval_node(node.step, context),
535 eval_node(node.step, context),
537 )
536 )
538 if isinstance(node, ast.UnaryOp):
537 if isinstance(node, ast.UnaryOp):
539 value = eval_node(node.operand, context)
538 value = eval_node(node.operand, context)
540 dunders = _find_dunder(node.op, UNARY_OP_DUNDERS)
539 dunders = _find_dunder(node.op, UNARY_OP_DUNDERS)
541 if dunders:
540 if dunders:
542 if policy.can_operate(dunders, value):
541 if policy.can_operate(dunders, value):
543 return getattr(value, dunders[0])()
542 return getattr(value, dunders[0])()
544 else:
543 else:
545 raise GuardRejection(
544 raise GuardRejection(
546 f"Operation (`{dunders}`) for",
545 f"Operation (`{dunders}`) for",
547 type(value),
546 type(value),
548 f"not allowed in {context.evaluation} mode",
547 f"not allowed in {context.evaluation} mode",
549 )
548 )
550 if isinstance(node, ast.Subscript):
549 if isinstance(node, ast.Subscript):
551 value = eval_node(node.value, context)
550 value = eval_node(node.value, context)
552 slice_ = eval_node(node.slice, context)
551 slice_ = eval_node(node.slice, context)
553 if policy.can_get_item(value, slice_):
552 if policy.can_get_item(value, slice_):
554 return value[slice_]
553 return value[slice_]
555 raise GuardRejection(
554 raise GuardRejection(
556 "Subscript access (`__getitem__`) for",
555 "Subscript access (`__getitem__`) for",
557 type(value), # not joined to avoid calling `repr`
556 type(value), # not joined to avoid calling `repr`
558 f" not allowed in {context.evaluation} mode",
557 f" not allowed in {context.evaluation} mode",
559 )
558 )
560 if isinstance(node, ast.Name):
559 if isinstance(node, ast.Name):
561 if policy.allow_locals_access and node.id in context.locals:
560 return _eval_node_name(node.id, policy, context)
562 return context.locals[node.id]
563 if policy.allow_globals_access and node.id in context.globals:
564 return context.globals[node.id]
565 if policy.allow_builtins_access and hasattr(builtins, node.id):
566 # note: do not use __builtins__, it is implementation detail of cPython
567 return getattr(builtins, node.id)
568 if not policy.allow_globals_access and not policy.allow_locals_access:
569 raise GuardRejection(
570 f"Namespace access not allowed in {context.evaluation} mode"
571 )
572 else:
573 raise NameError(f"{node.id} not found in locals, globals, nor builtins")
574 if isinstance(node, ast.Attribute):
561 if isinstance(node, ast.Attribute):
575 value = eval_node(node.value, context)
562 value = eval_node(node.value, context)
576 if policy.can_get_attr(value, node.attr):
563 if policy.can_get_attr(value, node.attr):
577 return getattr(value, node.attr)
564 return getattr(value, node.attr)
578 raise GuardRejection(
565 raise GuardRejection(
579 "Attribute access (`__getattr__`) for",
566 "Attribute access (`__getattr__`) for",
580 type(value), # not joined to avoid calling `repr`
567 type(value), # not joined to avoid calling `repr`
581 f"not allowed in {context.evaluation} mode",
568 f"not allowed in {context.evaluation} mode",
582 )
569 )
583 if isinstance(node, ast.IfExp):
570 if isinstance(node, ast.IfExp):
584 test = eval_node(node.test, context)
571 test = eval_node(node.test, context)
585 if test:
572 if test:
586 return eval_node(node.body, context)
573 return eval_node(node.body, context)
587 else:
574 else:
588 return eval_node(node.orelse, context)
575 return eval_node(node.orelse, context)
589 if isinstance(node, ast.Call):
576 if isinstance(node, ast.Call):
590 func = eval_node(node.func, context)
577 func = eval_node(node.func, context)
591 if policy.can_call(func) and not node.keywords:
578 if policy.can_call(func) and not node.keywords:
592 args = [eval_node(arg, context) for arg in node.args]
579 args = [eval_node(arg, context) for arg in node.args]
593 return func(*args)
580 return func(*args)
594 if isclass(func):
581 if isclass(func):
595 # this code path gets entered when calling class e.g. `MyClass()`
582 # this code path gets entered when calling class e.g. `MyClass()`
596 # or `my_instance.__class__()` - in both cases `func` is `MyClass`.
583 # or `my_instance.__class__()` - in both cases `func` is `MyClass`.
597 # Should return `MyClass` if `__new__` is not overridden,
584 # Should return `MyClass` if `__new__` is not overridden,
598 # otherwise whatever `__new__` return type is.
585 # otherwise whatever `__new__` return type is.
599 overridden_return_type = _eval_return_type(
586 overridden_return_type = _eval_return_type(
600 func.__new__, policy, node, context
587 func.__new__, policy, node, context
601 )
588 )
602 if overridden_return_type is not NOT_EVALUATED:
589 if overridden_return_type is not NOT_EVALUATED:
603 return overridden_return_type
590 return overridden_return_type
604 return _create_duck_for_type(func)
591 return _create_duck_for_type(func)
605 else:
592 else:
606 return_type = _eval_return_type(func, policy, node, context)
593 return_type = _eval_return_type(func, policy, node, context)
607 if return_type is not NOT_EVALUATED:
594 if return_type is not NOT_EVALUATED:
608 return return_type
595 return return_type
609 raise GuardRejection(
596 raise GuardRejection(
610 "Call for",
597 "Call for",
611 func, # not joined to avoid calling `repr`
598 func, # not joined to avoid calling `repr`
612 f"not allowed in {context.evaluation} mode",
599 f"not allowed in {context.evaluation} mode",
613 )
600 )
614 raise ValueError("Unhandled node", ast.dump(node))
601 raise ValueError("Unhandled node", ast.dump(node))
615
602
616
603
617 def _eval_return_type(func, policy, node, context):
604 def _eval_return_type(
605 func: Callable, policy: EvaluationPolicy, node: ast.Call, context: EvaluationContext
606 ):
618 """Evaluate return type of a given callable function.
607 """Evaluate return type of a given callable function.
619
608
620 Returns the built-in type, a duck or NOT_EVALUATED sentinel.
609 Returns the built-in type, a duck or NOT_EVALUATED sentinel.
621 """
610 """
622 try:
611 try:
623 sig = signature(func)
612 sig = signature(func)
624 except ValueError:
613 except ValueError:
625 sig = UNKNOWN_SIGNATURE
614 sig = UNKNOWN_SIGNATURE
626 # if annotation was not stringized, or it was stringized
615 # if annotation was not stringized, or it was stringized
627 # but resolved by signature call we know the return type
616 # but resolved by signature call we know the return type
628 not_empty = sig.return_annotation is not Signature.empty
617 not_empty = sig.return_annotation is not Signature.empty
629 not_stringized = not isinstance(sig.return_annotation, str)
618 stringized = isinstance(sig.return_annotation, str)
630 if not_empty and not_stringized:
619 if not_empty:
620 return_type = (
621 _eval_node_name(sig.return_annotation, policy, context)
622 if stringized
623 else sig.return_annotation
624 )
631 # if allow-listed builtin is on type annotation, instantiate it
625 # if allow-listed builtin is on type annotation, instantiate it
632 if policy.can_call(sig.return_annotation) and not node.keywords:
626 if policy.can_call(return_type) and not node.keywords:
633 args = [eval_node(arg, context) for arg in node.args]
627 args = [eval_node(arg, context) for arg in node.args]
634 # if custom class is in type annotation, mock it;
628 # if custom class is in type annotation, mock it;
635 return sig.return_annotation(*args)
629 return return_type(*args)
636 return _create_duck_for_type(sig.return_annotation)
630 return _create_duck_for_type(return_type)
637 return NOT_EVALUATED
631 return NOT_EVALUATED
638
632
639
633
634 def _eval_node_name(node_id: str, policy: EvaluationPolicy, context: EvaluationContext):
635 if policy.allow_locals_access and node_id in context.locals:
636 return context.locals[node_id]
637 if policy.allow_globals_access and node_id in context.globals:
638 return context.globals[node_id]
639 if policy.allow_builtins_access and hasattr(builtins, node_id):
640 # note: do not use __builtins__, it is implementation detail of cPython
641 return getattr(builtins, node_id)
642 if not policy.allow_globals_access and not policy.allow_locals_access:
643 raise GuardRejection(
644 f"Namespace access not allowed in {context.evaluation} mode"
645 )
646 else:
647 raise NameError(f"{node_id} not found in locals, globals, nor builtins")
648
649
640 def _create_duck_for_type(duck_type):
650 def _create_duck_for_type(duck_type):
641 """Create an imitation of an object of a given type (a duck).
651 """Create an imitation of an object of a given type (a duck).
642
652
643 Returns the duck or NOT_EVALUATED sentinel if duck could not be created.
653 Returns the duck or NOT_EVALUATED sentinel if duck could not be created.
644 """
654 """
645 duck = Duck()
655 duck = Duck()
646 try:
656 try:
647 # this only works for heap types, not builtins
657 # this only works for heap types, not builtins
648 duck.__class__ = duck_type
658 duck.__class__ = duck_type
649 return duck
659 return duck
650 except TypeError:
660 except TypeError:
651 pass
661 pass
652 return NOT_EVALUATED
662 return NOT_EVALUATED
653
663
654
664
655 SUPPORTED_EXTERNAL_GETITEM = {
665 SUPPORTED_EXTERNAL_GETITEM = {
656 ("pandas", "core", "indexing", "_iLocIndexer"),
666 ("pandas", "core", "indexing", "_iLocIndexer"),
657 ("pandas", "core", "indexing", "_LocIndexer"),
667 ("pandas", "core", "indexing", "_LocIndexer"),
658 ("pandas", "DataFrame"),
668 ("pandas", "DataFrame"),
659 ("pandas", "Series"),
669 ("pandas", "Series"),
660 ("numpy", "ndarray"),
670 ("numpy", "ndarray"),
661 ("numpy", "void"),
671 ("numpy", "void"),
662 }
672 }
663
673
664
674
665 BUILTIN_GETITEM: Set[InstancesHaveGetItem] = {
675 BUILTIN_GETITEM: Set[InstancesHaveGetItem] = {
666 dict,
676 dict,
667 str, # type: ignore[arg-type]
677 str, # type: ignore[arg-type]
668 bytes, # type: ignore[arg-type]
678 bytes, # type: ignore[arg-type]
669 list,
679 list,
670 tuple,
680 tuple,
671 collections.defaultdict,
681 collections.defaultdict,
672 collections.deque,
682 collections.deque,
673 collections.OrderedDict,
683 collections.OrderedDict,
674 collections.ChainMap,
684 collections.ChainMap,
675 collections.UserDict,
685 collections.UserDict,
676 collections.UserList,
686 collections.UserList,
677 collections.UserString, # type: ignore[arg-type]
687 collections.UserString, # type: ignore[arg-type]
678 _DummyNamedTuple,
688 _DummyNamedTuple,
679 _IdentitySubscript,
689 _IdentitySubscript,
680 }
690 }
681
691
682
692
683 def _list_methods(cls, source=None):
693 def _list_methods(cls, source=None):
684 """For use on immutable objects or with methods returning a copy"""
694 """For use on immutable objects or with methods returning a copy"""
685 return [getattr(cls, k) for k in (source if source else dir(cls))]
695 return [getattr(cls, k) for k in (source if source else dir(cls))]
686
696
687
697
688 dict_non_mutating_methods = ("copy", "keys", "values", "items")
698 dict_non_mutating_methods = ("copy", "keys", "values", "items")
689 list_non_mutating_methods = ("copy", "index", "count")
699 list_non_mutating_methods = ("copy", "index", "count")
690 set_non_mutating_methods = set(dir(set)) & set(dir(frozenset))
700 set_non_mutating_methods = set(dir(set)) & set(dir(frozenset))
691
701
692
702
693 dict_keys: Type[collections.abc.KeysView] = type({}.keys())
703 dict_keys: Type[collections.abc.KeysView] = type({}.keys())
694
704
695 NUMERICS = {int, float, complex}
705 NUMERICS = {int, float, complex}
696
706
697 ALLOWED_CALLS = {
707 ALLOWED_CALLS = {
698 bytes,
708 bytes,
699 *_list_methods(bytes),
709 *_list_methods(bytes),
700 dict,
710 dict,
701 *_list_methods(dict, dict_non_mutating_methods),
711 *_list_methods(dict, dict_non_mutating_methods),
702 dict_keys.isdisjoint,
712 dict_keys.isdisjoint,
703 list,
713 list,
704 *_list_methods(list, list_non_mutating_methods),
714 *_list_methods(list, list_non_mutating_methods),
705 set,
715 set,
706 *_list_methods(set, set_non_mutating_methods),
716 *_list_methods(set, set_non_mutating_methods),
707 frozenset,
717 frozenset,
708 *_list_methods(frozenset),
718 *_list_methods(frozenset),
709 range,
719 range,
710 str,
720 str,
711 *_list_methods(str),
721 *_list_methods(str),
712 tuple,
722 tuple,
713 *_list_methods(tuple),
723 *_list_methods(tuple),
714 *NUMERICS,
724 *NUMERICS,
715 *[method for numeric_cls in NUMERICS for method in _list_methods(numeric_cls)],
725 *[method for numeric_cls in NUMERICS for method in _list_methods(numeric_cls)],
716 collections.deque,
726 collections.deque,
717 *_list_methods(collections.deque, list_non_mutating_methods),
727 *_list_methods(collections.deque, list_non_mutating_methods),
718 collections.defaultdict,
728 collections.defaultdict,
719 *_list_methods(collections.defaultdict, dict_non_mutating_methods),
729 *_list_methods(collections.defaultdict, dict_non_mutating_methods),
720 collections.OrderedDict,
730 collections.OrderedDict,
721 *_list_methods(collections.OrderedDict, dict_non_mutating_methods),
731 *_list_methods(collections.OrderedDict, dict_non_mutating_methods),
722 collections.UserDict,
732 collections.UserDict,
723 *_list_methods(collections.UserDict, dict_non_mutating_methods),
733 *_list_methods(collections.UserDict, dict_non_mutating_methods),
724 collections.UserList,
734 collections.UserList,
725 *_list_methods(collections.UserList, list_non_mutating_methods),
735 *_list_methods(collections.UserList, list_non_mutating_methods),
726 collections.UserString,
736 collections.UserString,
727 *_list_methods(collections.UserString, dir(str)),
737 *_list_methods(collections.UserString, dir(str)),
728 collections.Counter,
738 collections.Counter,
729 *_list_methods(collections.Counter, dict_non_mutating_methods),
739 *_list_methods(collections.Counter, dict_non_mutating_methods),
730 collections.Counter.elements,
740 collections.Counter.elements,
731 collections.Counter.most_common,
741 collections.Counter.most_common,
732 }
742 }
733
743
734 BUILTIN_GETATTR: Set[MayHaveGetattr] = {
744 BUILTIN_GETATTR: Set[MayHaveGetattr] = {
735 *BUILTIN_GETITEM,
745 *BUILTIN_GETITEM,
736 set,
746 set,
737 frozenset,
747 frozenset,
738 object,
748 object,
739 type, # `type` handles a lot of generic cases, e.g. numbers as in `int.real`.
749 type, # `type` handles a lot of generic cases, e.g. numbers as in `int.real`.
740 *NUMERICS,
750 *NUMERICS,
741 dict_keys,
751 dict_keys,
742 MethodDescriptorType,
752 MethodDescriptorType,
743 ModuleType,
753 ModuleType,
744 }
754 }
745
755
746
756
747 BUILTIN_OPERATIONS = {*BUILTIN_GETATTR}
757 BUILTIN_OPERATIONS = {*BUILTIN_GETATTR}
748
758
749 EVALUATION_POLICIES = {
759 EVALUATION_POLICIES = {
750 "minimal": EvaluationPolicy(
760 "minimal": EvaluationPolicy(
751 allow_builtins_access=True,
761 allow_builtins_access=True,
752 allow_locals_access=False,
762 allow_locals_access=False,
753 allow_globals_access=False,
763 allow_globals_access=False,
754 allow_item_access=False,
764 allow_item_access=False,
755 allow_attr_access=False,
765 allow_attr_access=False,
756 allowed_calls=set(),
766 allowed_calls=set(),
757 allow_any_calls=False,
767 allow_any_calls=False,
758 allow_all_operations=False,
768 allow_all_operations=False,
759 ),
769 ),
760 "limited": SelectivePolicy(
770 "limited": SelectivePolicy(
761 allowed_getitem=BUILTIN_GETITEM,
771 allowed_getitem=BUILTIN_GETITEM,
762 allowed_getitem_external=SUPPORTED_EXTERNAL_GETITEM,
772 allowed_getitem_external=SUPPORTED_EXTERNAL_GETITEM,
763 allowed_getattr=BUILTIN_GETATTR,
773 allowed_getattr=BUILTIN_GETATTR,
764 allowed_getattr_external={
774 allowed_getattr_external={
765 # pandas Series/Frame implements custom `__getattr__`
775 # pandas Series/Frame implements custom `__getattr__`
766 ("pandas", "DataFrame"),
776 ("pandas", "DataFrame"),
767 ("pandas", "Series"),
777 ("pandas", "Series"),
768 },
778 },
769 allowed_operations=BUILTIN_OPERATIONS,
779 allowed_operations=BUILTIN_OPERATIONS,
770 allow_builtins_access=True,
780 allow_builtins_access=True,
771 allow_locals_access=True,
781 allow_locals_access=True,
772 allow_globals_access=True,
782 allow_globals_access=True,
773 allowed_calls=ALLOWED_CALLS,
783 allowed_calls=ALLOWED_CALLS,
774 ),
784 ),
775 "unsafe": EvaluationPolicy(
785 "unsafe": EvaluationPolicy(
776 allow_builtins_access=True,
786 allow_builtins_access=True,
777 allow_locals_access=True,
787 allow_locals_access=True,
778 allow_globals_access=True,
788 allow_globals_access=True,
779 allow_attr_access=True,
789 allow_attr_access=True,
780 allow_item_access=True,
790 allow_item_access=True,
781 allow_any_calls=True,
791 allow_any_calls=True,
782 allow_all_operations=True,
792 allow_all_operations=True,
783 ),
793 ),
784 }
794 }
785
795
786
796
787 __all__ = [
797 __all__ = [
788 "guarded_eval",
798 "guarded_eval",
789 "eval_node",
799 "eval_node",
790 "GuardRejection",
800 "GuardRejection",
791 "EvaluationContext",
801 "EvaluationContext",
792 "_unbind_method",
802 "_unbind_method",
793 ]
803 ]
@@ -1,631 +1,641 b''
1 from contextlib import contextmanager
1 from contextlib import contextmanager
2 from typing import NamedTuple
2 from typing import NamedTuple
3 from functools import partial
3 from functools import partial
4 from IPython.core.guarded_eval import (
4 from IPython.core.guarded_eval import (
5 EvaluationContext,
5 EvaluationContext,
6 GuardRejection,
6 GuardRejection,
7 guarded_eval,
7 guarded_eval,
8 _unbind_method,
8 _unbind_method,
9 )
9 )
10 from IPython.testing import decorators as dec
10 from IPython.testing import decorators as dec
11 import pytest
11 import pytest
12
12
13
13
14 def create_context(evaluation: str, **kwargs):
14 def create_context(evaluation: str, **kwargs):
15 return EvaluationContext(locals=kwargs, globals={}, evaluation=evaluation)
15 return EvaluationContext(locals=kwargs, globals={}, evaluation=evaluation)
16
16
17
17
18 forbidden = partial(create_context, "forbidden")
18 forbidden = partial(create_context, "forbidden")
19 minimal = partial(create_context, "minimal")
19 minimal = partial(create_context, "minimal")
20 limited = partial(create_context, "limited")
20 limited = partial(create_context, "limited")
21 unsafe = partial(create_context, "unsafe")
21 unsafe = partial(create_context, "unsafe")
22 dangerous = partial(create_context, "dangerous")
22 dangerous = partial(create_context, "dangerous")
23
23
24 LIMITED_OR_HIGHER = [limited, unsafe, dangerous]
24 LIMITED_OR_HIGHER = [limited, unsafe, dangerous]
25 MINIMAL_OR_HIGHER = [minimal, *LIMITED_OR_HIGHER]
25 MINIMAL_OR_HIGHER = [minimal, *LIMITED_OR_HIGHER]
26
26
27
27
28 @contextmanager
28 @contextmanager
29 def module_not_installed(module: str):
29 def module_not_installed(module: str):
30 import sys
30 import sys
31
31
32 try:
32 try:
33 to_restore = sys.modules[module]
33 to_restore = sys.modules[module]
34 del sys.modules[module]
34 del sys.modules[module]
35 except KeyError:
35 except KeyError:
36 to_restore = None
36 to_restore = None
37 try:
37 try:
38 yield
38 yield
39 finally:
39 finally:
40 sys.modules[module] = to_restore
40 sys.modules[module] = to_restore
41
41
42
42
43 def test_external_not_installed():
43 def test_external_not_installed():
44 """
44 """
45 Because attribute check requires checking if object is not of allowed
45 Because attribute check requires checking if object is not of allowed
46 external type, this tests logic for absence of external module.
46 external type, this tests logic for absence of external module.
47 """
47 """
48
48
49 class Custom:
49 class Custom:
50 def __init__(self):
50 def __init__(self):
51 self.test = 1
51 self.test = 1
52
52
53 def __getattr__(self, key):
53 def __getattr__(self, key):
54 return key
54 return key
55
55
56 with module_not_installed("pandas"):
56 with module_not_installed("pandas"):
57 context = limited(x=Custom())
57 context = limited(x=Custom())
58 with pytest.raises(GuardRejection):
58 with pytest.raises(GuardRejection):
59 guarded_eval("x.test", context)
59 guarded_eval("x.test", context)
60
60
61
61
62 @dec.skip_without("pandas")
62 @dec.skip_without("pandas")
63 def test_external_changed_api(monkeypatch):
63 def test_external_changed_api(monkeypatch):
64 """Check that the execution rejects if external API changed paths"""
64 """Check that the execution rejects if external API changed paths"""
65 import pandas as pd
65 import pandas as pd
66
66
67 series = pd.Series([1], index=["a"])
67 series = pd.Series([1], index=["a"])
68
68
69 with monkeypatch.context() as m:
69 with monkeypatch.context() as m:
70 m.delattr(pd, "Series")
70 m.delattr(pd, "Series")
71 context = limited(data=series)
71 context = limited(data=series)
72 with pytest.raises(GuardRejection):
72 with pytest.raises(GuardRejection):
73 guarded_eval("data.iloc[0]", context)
73 guarded_eval("data.iloc[0]", context)
74
74
75
75
76 @dec.skip_without("pandas")
76 @dec.skip_without("pandas")
77 def test_pandas_series_iloc():
77 def test_pandas_series_iloc():
78 import pandas as pd
78 import pandas as pd
79
79
80 series = pd.Series([1], index=["a"])
80 series = pd.Series([1], index=["a"])
81 context = limited(data=series)
81 context = limited(data=series)
82 assert guarded_eval("data.iloc[0]", context) == 1
82 assert guarded_eval("data.iloc[0]", context) == 1
83
83
84
84
85 def test_rejects_custom_properties():
85 def test_rejects_custom_properties():
86 class BadProperty:
86 class BadProperty:
87 @property
87 @property
88 def iloc(self):
88 def iloc(self):
89 return [None]
89 return [None]
90
90
91 series = BadProperty()
91 series = BadProperty()
92 context = limited(data=series)
92 context = limited(data=series)
93
93
94 with pytest.raises(GuardRejection):
94 with pytest.raises(GuardRejection):
95 guarded_eval("data.iloc[0]", context)
95 guarded_eval("data.iloc[0]", context)
96
96
97
97
98 @dec.skip_without("pandas")
98 @dec.skip_without("pandas")
99 def test_accepts_non_overriden_properties():
99 def test_accepts_non_overriden_properties():
100 import pandas as pd
100 import pandas as pd
101
101
102 class GoodProperty(pd.Series):
102 class GoodProperty(pd.Series):
103 pass
103 pass
104
104
105 series = GoodProperty([1], index=["a"])
105 series = GoodProperty([1], index=["a"])
106 context = limited(data=series)
106 context = limited(data=series)
107
107
108 assert guarded_eval("data.iloc[0]", context) == 1
108 assert guarded_eval("data.iloc[0]", context) == 1
109
109
110
110
111 @dec.skip_without("pandas")
111 @dec.skip_without("pandas")
112 def test_pandas_series():
112 def test_pandas_series():
113 import pandas as pd
113 import pandas as pd
114
114
115 context = limited(data=pd.Series([1], index=["a"]))
115 context = limited(data=pd.Series([1], index=["a"]))
116 assert guarded_eval('data["a"]', context) == 1
116 assert guarded_eval('data["a"]', context) == 1
117 with pytest.raises(KeyError):
117 with pytest.raises(KeyError):
118 guarded_eval('data["c"]', context)
118 guarded_eval('data["c"]', context)
119
119
120
120
121 @dec.skip_without("pandas")
121 @dec.skip_without("pandas")
122 def test_pandas_bad_series():
122 def test_pandas_bad_series():
123 import pandas as pd
123 import pandas as pd
124
124
125 class BadItemSeries(pd.Series):
125 class BadItemSeries(pd.Series):
126 def __getitem__(self, key):
126 def __getitem__(self, key):
127 return "CUSTOM_ITEM"
127 return "CUSTOM_ITEM"
128
128
129 class BadAttrSeries(pd.Series):
129 class BadAttrSeries(pd.Series):
130 def __getattr__(self, key):
130 def __getattr__(self, key):
131 return "CUSTOM_ATTR"
131 return "CUSTOM_ATTR"
132
132
133 bad_series = BadItemSeries([1], index=["a"])
133 bad_series = BadItemSeries([1], index=["a"])
134 context = limited(data=bad_series)
134 context = limited(data=bad_series)
135
135
136 with pytest.raises(GuardRejection):
136 with pytest.raises(GuardRejection):
137 guarded_eval('data["a"]', context)
137 guarded_eval('data["a"]', context)
138 with pytest.raises(GuardRejection):
138 with pytest.raises(GuardRejection):
139 guarded_eval('data["c"]', context)
139 guarded_eval('data["c"]', context)
140
140
141 # note: here result is a bit unexpected because
141 # note: here result is a bit unexpected because
142 # pandas `__getattr__` calls `__getitem__`;
142 # pandas `__getattr__` calls `__getitem__`;
143 # FIXME - special case to handle it?
143 # FIXME - special case to handle it?
144 assert guarded_eval("data.a", context) == "CUSTOM_ITEM"
144 assert guarded_eval("data.a", context) == "CUSTOM_ITEM"
145
145
146 context = unsafe(data=bad_series)
146 context = unsafe(data=bad_series)
147 assert guarded_eval('data["a"]', context) == "CUSTOM_ITEM"
147 assert guarded_eval('data["a"]', context) == "CUSTOM_ITEM"
148
148
149 bad_attr_series = BadAttrSeries([1], index=["a"])
149 bad_attr_series = BadAttrSeries([1], index=["a"])
150 context = limited(data=bad_attr_series)
150 context = limited(data=bad_attr_series)
151 assert guarded_eval('data["a"]', context) == 1
151 assert guarded_eval('data["a"]', context) == 1
152 with pytest.raises(GuardRejection):
152 with pytest.raises(GuardRejection):
153 guarded_eval("data.a", context)
153 guarded_eval("data.a", context)
154
154
155
155
156 @dec.skip_without("pandas")
156 @dec.skip_without("pandas")
157 def test_pandas_dataframe_loc():
157 def test_pandas_dataframe_loc():
158 import pandas as pd
158 import pandas as pd
159 from pandas.testing import assert_series_equal
159 from pandas.testing import assert_series_equal
160
160
161 data = pd.DataFrame([{"a": 1}])
161 data = pd.DataFrame([{"a": 1}])
162 context = limited(data=data)
162 context = limited(data=data)
163 assert_series_equal(guarded_eval('data.loc[:, "a"]', context), data["a"])
163 assert_series_equal(guarded_eval('data.loc[:, "a"]', context), data["a"])
164
164
165
165
166 def test_named_tuple():
166 def test_named_tuple():
167 class GoodNamedTuple(NamedTuple):
167 class GoodNamedTuple(NamedTuple):
168 a: str
168 a: str
169 pass
169 pass
170
170
171 class BadNamedTuple(NamedTuple):
171 class BadNamedTuple(NamedTuple):
172 a: str
172 a: str
173
173
174 def __getitem__(self, key):
174 def __getitem__(self, key):
175 return None
175 return None
176
176
177 good = GoodNamedTuple(a="x")
177 good = GoodNamedTuple(a="x")
178 bad = BadNamedTuple(a="x")
178 bad = BadNamedTuple(a="x")
179
179
180 context = limited(data=good)
180 context = limited(data=good)
181 assert guarded_eval("data[0]", context) == "x"
181 assert guarded_eval("data[0]", context) == "x"
182
182
183 context = limited(data=bad)
183 context = limited(data=bad)
184 with pytest.raises(GuardRejection):
184 with pytest.raises(GuardRejection):
185 guarded_eval("data[0]", context)
185 guarded_eval("data[0]", context)
186
186
187
187
188 def test_dict():
188 def test_dict():
189 context = limited(data={"a": 1, "b": {"x": 2}, ("x", "y"): 3})
189 context = limited(data={"a": 1, "b": {"x": 2}, ("x", "y"): 3})
190 assert guarded_eval('data["a"]', context) == 1
190 assert guarded_eval('data["a"]', context) == 1
191 assert guarded_eval('data["b"]', context) == {"x": 2}
191 assert guarded_eval('data["b"]', context) == {"x": 2}
192 assert guarded_eval('data["b"]["x"]', context) == 2
192 assert guarded_eval('data["b"]["x"]', context) == 2
193 assert guarded_eval('data["x", "y"]', context) == 3
193 assert guarded_eval('data["x", "y"]', context) == 3
194
194
195 assert guarded_eval("data.keys", context)
195 assert guarded_eval("data.keys", context)
196
196
197
197
198 def test_set():
198 def test_set():
199 context = limited(data={"a", "b"})
199 context = limited(data={"a", "b"})
200 assert guarded_eval("data.difference", context)
200 assert guarded_eval("data.difference", context)
201
201
202
202
203 def test_list():
203 def test_list():
204 context = limited(data=[1, 2, 3])
204 context = limited(data=[1, 2, 3])
205 assert guarded_eval("data[1]", context) == 2
205 assert guarded_eval("data[1]", context) == 2
206 assert guarded_eval("data.copy", context)
206 assert guarded_eval("data.copy", context)
207
207
208
208
209 def test_dict_literal():
209 def test_dict_literal():
210 context = limited()
210 context = limited()
211 assert guarded_eval("{}", context) == {}
211 assert guarded_eval("{}", context) == {}
212 assert guarded_eval('{"a": 1}', context) == {"a": 1}
212 assert guarded_eval('{"a": 1}', context) == {"a": 1}
213
213
214
214
215 def test_list_literal():
215 def test_list_literal():
216 context = limited()
216 context = limited()
217 assert guarded_eval("[]", context) == []
217 assert guarded_eval("[]", context) == []
218 assert guarded_eval('[1, "a"]', context) == [1, "a"]
218 assert guarded_eval('[1, "a"]', context) == [1, "a"]
219
219
220
220
221 def test_set_literal():
221 def test_set_literal():
222 context = limited()
222 context = limited()
223 assert guarded_eval("set()", context) == set()
223 assert guarded_eval("set()", context) == set()
224 assert guarded_eval('{"a"}', context) == {"a"}
224 assert guarded_eval('{"a"}', context) == {"a"}
225
225
226
226
227 def test_evaluates_if_expression():
227 def test_evaluates_if_expression():
228 context = limited()
228 context = limited()
229 assert guarded_eval("2 if True else 3", context) == 2
229 assert guarded_eval("2 if True else 3", context) == 2
230 assert guarded_eval("4 if False else 5", context) == 5
230 assert guarded_eval("4 if False else 5", context) == 5
231
231
232
232
233 def test_object():
233 def test_object():
234 obj = object()
234 obj = object()
235 context = limited(obj=obj)
235 context = limited(obj=obj)
236 assert guarded_eval("obj.__dir__", context) == obj.__dir__
236 assert guarded_eval("obj.__dir__", context) == obj.__dir__
237
237
238
238
239 @pytest.mark.parametrize(
239 @pytest.mark.parametrize(
240 "code,expected",
240 "code,expected",
241 [
241 [
242 ["int.numerator", int.numerator],
242 ["int.numerator", int.numerator],
243 ["float.is_integer", float.is_integer],
243 ["float.is_integer", float.is_integer],
244 ["complex.real", complex.real],
244 ["complex.real", complex.real],
245 ],
245 ],
246 )
246 )
247 def test_number_attributes(code, expected):
247 def test_number_attributes(code, expected):
248 assert guarded_eval(code, limited()) == expected
248 assert guarded_eval(code, limited()) == expected
249
249
250
250
251 def test_method_descriptor():
251 def test_method_descriptor():
252 context = limited()
252 context = limited()
253 assert guarded_eval("list.copy.__name__", context) == "copy"
253 assert guarded_eval("list.copy.__name__", context) == "copy"
254
254
255
255
256 class HeapType:
256 class HeapType:
257 pass
257 pass
258
258
259
259
260 class CallCreatesHeapType:
260 class CallCreatesHeapType:
261 def __call__(self) -> HeapType:
261 def __call__(self) -> HeapType:
262 return HeapType()
262 return HeapType()
263
263
264
264
265 class CallCreatesBuiltin:
265 class CallCreatesBuiltin:
266 def __call__(self) -> frozenset:
266 def __call__(self) -> frozenset:
267 return frozenset()
267 return frozenset()
268
268
269
269
270 class HasStaticMethod:
270 class HasStaticMethod:
271 @staticmethod
271 @staticmethod
272 def static_method() -> HeapType:
272 def static_method() -> HeapType:
273 return HeapType()
273 return HeapType()
274
274
275
275
276 class InitReturnsFrozenset:
276 class InitReturnsFrozenset:
277 def __new__(self) -> frozenset: # type:ignore[misc]
277 def __new__(self) -> frozenset: # type:ignore[misc]
278 return frozenset()
278 return frozenset()
279
279
280
280
281 class StringAnnotation:
282 def heap(self) -> "HeapType":
283 return HeapType()
284
285 def copy(self) -> "StringAnnotation":
286 return StringAnnotation()
287
288
281 @pytest.mark.parametrize(
289 @pytest.mark.parametrize(
282 "data,good,expected,equality",
290 "data,good,expected,equality",
283 [
291 [
284 [[1, 2, 3], "data.index(2)", 1, True],
292 [[1, 2, 3], "data.index(2)", 1, True],
285 [{"a": 1}, "data.keys().isdisjoint({})", True, True],
293 [{"a": 1}, "data.keys().isdisjoint({})", True, True],
294 [StringAnnotation(), "data.heap()", HeapType, False],
295 [StringAnnotation(), "data.copy()", StringAnnotation, False],
286 # test cases for `__call__`
296 # test cases for `__call__`
287 [CallCreatesHeapType(), "data()", HeapType, False],
297 [CallCreatesHeapType(), "data()", HeapType, False],
288 [CallCreatesBuiltin(), "data()", frozenset, False],
298 [CallCreatesBuiltin(), "data()", frozenset, False],
289 # Test cases for `__init__`
299 # Test cases for `__init__`
290 [HeapType, "data()", HeapType, False],
300 [HeapType, "data()", HeapType, False],
291 [InitReturnsFrozenset, "data()", frozenset, False],
301 [InitReturnsFrozenset, "data()", frozenset, False],
292 [HeapType(), "data.__class__()", HeapType, False],
302 [HeapType(), "data.__class__()", HeapType, False],
293 # test cases for static and class methods
303 # test cases for static methods
294 [HasStaticMethod, "data.static_method()", HeapType, False],
304 [HasStaticMethod, "data.static_method()", HeapType, False],
295 ],
305 ],
296 )
306 )
297 def test_evaluates_calls(data, good, expected, equality):
307 def test_evaluates_calls(data, good, expected, equality):
298 context = limited(data=data)
308 context = limited(data=data, HeapType=HeapType, StringAnnotation=StringAnnotation)
299 value = guarded_eval(good, context)
309 value = guarded_eval(good, context)
300 if equality:
310 if equality:
301 assert value == expected
311 assert value == expected
302 else:
312 else:
303 assert isinstance(value, expected)
313 assert isinstance(value, expected)
304
314
305
315
306 @pytest.mark.parametrize(
316 @pytest.mark.parametrize(
307 "data,bad",
317 "data,bad",
308 [
318 [
309 [[1, 2, 3], "data.append(4)"],
319 [[1, 2, 3], "data.append(4)"],
310 [{"a": 1}, "data.update()"],
320 [{"a": 1}, "data.update()"],
311 ],
321 ],
312 )
322 )
313 def test_rejects_calls_with_side_effects(data, bad):
323 def test_rejects_calls_with_side_effects(data, bad):
314 context = limited(data=data)
324 context = limited(data=data)
315
325
316 with pytest.raises(GuardRejection):
326 with pytest.raises(GuardRejection):
317 guarded_eval(bad, context)
327 guarded_eval(bad, context)
318
328
319
329
320 @pytest.mark.parametrize(
330 @pytest.mark.parametrize(
321 "code,expected",
331 "code,expected",
322 [
332 [
323 ["(1\n+\n1)", 2],
333 ["(1\n+\n1)", 2],
324 ["list(range(10))[-1:]", [9]],
334 ["list(range(10))[-1:]", [9]],
325 ["list(range(20))[3:-2:3]", [3, 6, 9, 12, 15]],
335 ["list(range(20))[3:-2:3]", [3, 6, 9, 12, 15]],
326 ],
336 ],
327 )
337 )
328 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
338 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
329 def test_evaluates_complex_cases(code, expected, context):
339 def test_evaluates_complex_cases(code, expected, context):
330 assert guarded_eval(code, context()) == expected
340 assert guarded_eval(code, context()) == expected
331
341
332
342
333 @pytest.mark.parametrize(
343 @pytest.mark.parametrize(
334 "code,expected",
344 "code,expected",
335 [
345 [
336 ["1", 1],
346 ["1", 1],
337 ["1.0", 1.0],
347 ["1.0", 1.0],
338 ["0xdeedbeef", 0xDEEDBEEF],
348 ["0xdeedbeef", 0xDEEDBEEF],
339 ["True", True],
349 ["True", True],
340 ["None", None],
350 ["None", None],
341 ["{}", {}],
351 ["{}", {}],
342 ["[]", []],
352 ["[]", []],
343 ],
353 ],
344 )
354 )
345 @pytest.mark.parametrize("context", MINIMAL_OR_HIGHER)
355 @pytest.mark.parametrize("context", MINIMAL_OR_HIGHER)
346 def test_evaluates_literals(code, expected, context):
356 def test_evaluates_literals(code, expected, context):
347 assert guarded_eval(code, context()) == expected
357 assert guarded_eval(code, context()) == expected
348
358
349
359
350 @pytest.mark.parametrize(
360 @pytest.mark.parametrize(
351 "code,expected",
361 "code,expected",
352 [
362 [
353 ["-5", -5],
363 ["-5", -5],
354 ["+5", +5],
364 ["+5", +5],
355 ["~5", -6],
365 ["~5", -6],
356 ],
366 ],
357 )
367 )
358 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
368 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
359 def test_evaluates_unary_operations(code, expected, context):
369 def test_evaluates_unary_operations(code, expected, context):
360 assert guarded_eval(code, context()) == expected
370 assert guarded_eval(code, context()) == expected
361
371
362
372
363 @pytest.mark.parametrize(
373 @pytest.mark.parametrize(
364 "code,expected",
374 "code,expected",
365 [
375 [
366 ["1 + 1", 2],
376 ["1 + 1", 2],
367 ["3 - 1", 2],
377 ["3 - 1", 2],
368 ["2 * 3", 6],
378 ["2 * 3", 6],
369 ["5 // 2", 2],
379 ["5 // 2", 2],
370 ["5 / 2", 2.5],
380 ["5 / 2", 2.5],
371 ["5**2", 25],
381 ["5**2", 25],
372 ["2 >> 1", 1],
382 ["2 >> 1", 1],
373 ["2 << 1", 4],
383 ["2 << 1", 4],
374 ["1 | 2", 3],
384 ["1 | 2", 3],
375 ["1 & 1", 1],
385 ["1 & 1", 1],
376 ["1 & 2", 0],
386 ["1 & 2", 0],
377 ],
387 ],
378 )
388 )
379 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
389 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
380 def test_evaluates_binary_operations(code, expected, context):
390 def test_evaluates_binary_operations(code, expected, context):
381 assert guarded_eval(code, context()) == expected
391 assert guarded_eval(code, context()) == expected
382
392
383
393
384 @pytest.mark.parametrize(
394 @pytest.mark.parametrize(
385 "code,expected",
395 "code,expected",
386 [
396 [
387 ["2 > 1", True],
397 ["2 > 1", True],
388 ["2 < 1", False],
398 ["2 < 1", False],
389 ["2 <= 1", False],
399 ["2 <= 1", False],
390 ["2 <= 2", True],
400 ["2 <= 2", True],
391 ["1 >= 2", False],
401 ["1 >= 2", False],
392 ["2 >= 2", True],
402 ["2 >= 2", True],
393 ["2 == 2", True],
403 ["2 == 2", True],
394 ["1 == 2", False],
404 ["1 == 2", False],
395 ["1 != 2", True],
405 ["1 != 2", True],
396 ["1 != 1", False],
406 ["1 != 1", False],
397 ["1 < 4 < 3", False],
407 ["1 < 4 < 3", False],
398 ["(1 < 4) < 3", True],
408 ["(1 < 4) < 3", True],
399 ["4 > 3 > 2 > 1", True],
409 ["4 > 3 > 2 > 1", True],
400 ["4 > 3 > 2 > 9", False],
410 ["4 > 3 > 2 > 9", False],
401 ["1 < 2 < 3 < 4", True],
411 ["1 < 2 < 3 < 4", True],
402 ["9 < 2 < 3 < 4", False],
412 ["9 < 2 < 3 < 4", False],
403 ["1 < 2 > 1 > 0 > -1 < 1", True],
413 ["1 < 2 > 1 > 0 > -1 < 1", True],
404 ["1 in [1] in [[1]]", True],
414 ["1 in [1] in [[1]]", True],
405 ["1 in [1] in [[2]]", False],
415 ["1 in [1] in [[2]]", False],
406 ["1 in [1]", True],
416 ["1 in [1]", True],
407 ["0 in [1]", False],
417 ["0 in [1]", False],
408 ["1 not in [1]", False],
418 ["1 not in [1]", False],
409 ["0 not in [1]", True],
419 ["0 not in [1]", True],
410 ["True is True", True],
420 ["True is True", True],
411 ["False is False", True],
421 ["False is False", True],
412 ["True is False", False],
422 ["True is False", False],
413 ["True is not True", False],
423 ["True is not True", False],
414 ["False is not True", True],
424 ["False is not True", True],
415 ],
425 ],
416 )
426 )
417 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
427 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
418 def test_evaluates_comparisons(code, expected, context):
428 def test_evaluates_comparisons(code, expected, context):
419 assert guarded_eval(code, context()) == expected
429 assert guarded_eval(code, context()) == expected
420
430
421
431
422 def test_guards_comparisons():
432 def test_guards_comparisons():
423 class GoodEq(int):
433 class GoodEq(int):
424 pass
434 pass
425
435
426 class BadEq(int):
436 class BadEq(int):
427 def __eq__(self, other):
437 def __eq__(self, other):
428 assert False
438 assert False
429
439
430 context = limited(bad=BadEq(1), good=GoodEq(1))
440 context = limited(bad=BadEq(1), good=GoodEq(1))
431
441
432 with pytest.raises(GuardRejection):
442 with pytest.raises(GuardRejection):
433 guarded_eval("bad == 1", context)
443 guarded_eval("bad == 1", context)
434
444
435 with pytest.raises(GuardRejection):
445 with pytest.raises(GuardRejection):
436 guarded_eval("bad != 1", context)
446 guarded_eval("bad != 1", context)
437
447
438 with pytest.raises(GuardRejection):
448 with pytest.raises(GuardRejection):
439 guarded_eval("1 == bad", context)
449 guarded_eval("1 == bad", context)
440
450
441 with pytest.raises(GuardRejection):
451 with pytest.raises(GuardRejection):
442 guarded_eval("1 != bad", context)
452 guarded_eval("1 != bad", context)
443
453
444 assert guarded_eval("good == 1", context) is True
454 assert guarded_eval("good == 1", context) is True
445 assert guarded_eval("good != 1", context) is False
455 assert guarded_eval("good != 1", context) is False
446 assert guarded_eval("1 == good", context) is True
456 assert guarded_eval("1 == good", context) is True
447 assert guarded_eval("1 != good", context) is False
457 assert guarded_eval("1 != good", context) is False
448
458
449
459
450 def test_guards_unary_operations():
460 def test_guards_unary_operations():
451 class GoodOp(int):
461 class GoodOp(int):
452 pass
462 pass
453
463
454 class BadOpInv(int):
464 class BadOpInv(int):
455 def __inv__(self, other):
465 def __inv__(self, other):
456 assert False
466 assert False
457
467
458 class BadOpInverse(int):
468 class BadOpInverse(int):
459 def __inv__(self, other):
469 def __inv__(self, other):
460 assert False
470 assert False
461
471
462 context = limited(good=GoodOp(1), bad1=BadOpInv(1), bad2=BadOpInverse(1))
472 context = limited(good=GoodOp(1), bad1=BadOpInv(1), bad2=BadOpInverse(1))
463
473
464 with pytest.raises(GuardRejection):
474 with pytest.raises(GuardRejection):
465 guarded_eval("~bad1", context)
475 guarded_eval("~bad1", context)
466
476
467 with pytest.raises(GuardRejection):
477 with pytest.raises(GuardRejection):
468 guarded_eval("~bad2", context)
478 guarded_eval("~bad2", context)
469
479
470
480
471 def test_guards_binary_operations():
481 def test_guards_binary_operations():
472 class GoodOp(int):
482 class GoodOp(int):
473 pass
483 pass
474
484
475 class BadOp(int):
485 class BadOp(int):
476 def __add__(self, other):
486 def __add__(self, other):
477 assert False
487 assert False
478
488
479 context = limited(good=GoodOp(1), bad=BadOp(1))
489 context = limited(good=GoodOp(1), bad=BadOp(1))
480
490
481 with pytest.raises(GuardRejection):
491 with pytest.raises(GuardRejection):
482 guarded_eval("1 + bad", context)
492 guarded_eval("1 + bad", context)
483
493
484 with pytest.raises(GuardRejection):
494 with pytest.raises(GuardRejection):
485 guarded_eval("bad + 1", context)
495 guarded_eval("bad + 1", context)
486
496
487 assert guarded_eval("good + 1", context) == 2
497 assert guarded_eval("good + 1", context) == 2
488 assert guarded_eval("1 + good", context) == 2
498 assert guarded_eval("1 + good", context) == 2
489
499
490
500
491 def test_guards_attributes():
501 def test_guards_attributes():
492 class GoodAttr(float):
502 class GoodAttr(float):
493 pass
503 pass
494
504
495 class BadAttr1(float):
505 class BadAttr1(float):
496 def __getattr__(self, key):
506 def __getattr__(self, key):
497 assert False
507 assert False
498
508
499 class BadAttr2(float):
509 class BadAttr2(float):
500 def __getattribute__(self, key):
510 def __getattribute__(self, key):
501 assert False
511 assert False
502
512
503 context = limited(good=GoodAttr(0.5), bad1=BadAttr1(0.5), bad2=BadAttr2(0.5))
513 context = limited(good=GoodAttr(0.5), bad1=BadAttr1(0.5), bad2=BadAttr2(0.5))
504
514
505 with pytest.raises(GuardRejection):
515 with pytest.raises(GuardRejection):
506 guarded_eval("bad1.as_integer_ratio", context)
516 guarded_eval("bad1.as_integer_ratio", context)
507
517
508 with pytest.raises(GuardRejection):
518 with pytest.raises(GuardRejection):
509 guarded_eval("bad2.as_integer_ratio", context)
519 guarded_eval("bad2.as_integer_ratio", context)
510
520
511 assert guarded_eval("good.as_integer_ratio()", context) == (1, 2)
521 assert guarded_eval("good.as_integer_ratio()", context) == (1, 2)
512
522
513
523
514 @pytest.mark.parametrize("context", MINIMAL_OR_HIGHER)
524 @pytest.mark.parametrize("context", MINIMAL_OR_HIGHER)
515 def test_access_builtins(context):
525 def test_access_builtins(context):
516 assert guarded_eval("round", context()) == round
526 assert guarded_eval("round", context()) == round
517
527
518
528
519 def test_access_builtins_fails():
529 def test_access_builtins_fails():
520 context = limited()
530 context = limited()
521 with pytest.raises(NameError):
531 with pytest.raises(NameError):
522 guarded_eval("this_is_not_builtin", context)
532 guarded_eval("this_is_not_builtin", context)
523
533
524
534
525 def test_rejects_forbidden():
535 def test_rejects_forbidden():
526 context = forbidden()
536 context = forbidden()
527 with pytest.raises(GuardRejection):
537 with pytest.raises(GuardRejection):
528 guarded_eval("1", context)
538 guarded_eval("1", context)
529
539
530
540
531 def test_guards_locals_and_globals():
541 def test_guards_locals_and_globals():
532 context = EvaluationContext(
542 context = EvaluationContext(
533 locals={"local_a": "a"}, globals={"global_b": "b"}, evaluation="minimal"
543 locals={"local_a": "a"}, globals={"global_b": "b"}, evaluation="minimal"
534 )
544 )
535
545
536 with pytest.raises(GuardRejection):
546 with pytest.raises(GuardRejection):
537 guarded_eval("local_a", context)
547 guarded_eval("local_a", context)
538
548
539 with pytest.raises(GuardRejection):
549 with pytest.raises(GuardRejection):
540 guarded_eval("global_b", context)
550 guarded_eval("global_b", context)
541
551
542
552
543 def test_access_locals_and_globals():
553 def test_access_locals_and_globals():
544 context = EvaluationContext(
554 context = EvaluationContext(
545 locals={"local_a": "a"}, globals={"global_b": "b"}, evaluation="limited"
555 locals={"local_a": "a"}, globals={"global_b": "b"}, evaluation="limited"
546 )
556 )
547 assert guarded_eval("local_a", context) == "a"
557 assert guarded_eval("local_a", context) == "a"
548 assert guarded_eval("global_b", context) == "b"
558 assert guarded_eval("global_b", context) == "b"
549
559
550
560
551 @pytest.mark.parametrize(
561 @pytest.mark.parametrize(
552 "code",
562 "code",
553 ["def func(): pass", "class C: pass", "x = 1", "x += 1", "del x", "import ast"],
563 ["def func(): pass", "class C: pass", "x = 1", "x += 1", "del x", "import ast"],
554 )
564 )
555 @pytest.mark.parametrize("context", [minimal(), limited(), unsafe()])
565 @pytest.mark.parametrize("context", [minimal(), limited(), unsafe()])
556 def test_rejects_side_effect_syntax(code, context):
566 def test_rejects_side_effect_syntax(code, context):
557 with pytest.raises(SyntaxError):
567 with pytest.raises(SyntaxError):
558 guarded_eval(code, context)
568 guarded_eval(code, context)
559
569
560
570
561 def test_subscript():
571 def test_subscript():
562 context = EvaluationContext(
572 context = EvaluationContext(
563 locals={}, globals={}, evaluation="limited", in_subscript=True
573 locals={}, globals={}, evaluation="limited", in_subscript=True
564 )
574 )
565 empty_slice = slice(None, None, None)
575 empty_slice = slice(None, None, None)
566 assert guarded_eval("", context) == tuple()
576 assert guarded_eval("", context) == tuple()
567 assert guarded_eval(":", context) == empty_slice
577 assert guarded_eval(":", context) == empty_slice
568 assert guarded_eval("1:2:3", context) == slice(1, 2, 3)
578 assert guarded_eval("1:2:3", context) == slice(1, 2, 3)
569 assert guarded_eval(':, "a"', context) == (empty_slice, "a")
579 assert guarded_eval(':, "a"', context) == (empty_slice, "a")
570
580
571
581
572 def test_unbind_method():
582 def test_unbind_method():
573 class X(list):
583 class X(list):
574 def index(self, k):
584 def index(self, k):
575 return "CUSTOM"
585 return "CUSTOM"
576
586
577 x = X()
587 x = X()
578 assert _unbind_method(x.index) is X.index
588 assert _unbind_method(x.index) is X.index
579 assert _unbind_method([].index) is list.index
589 assert _unbind_method([].index) is list.index
580 assert _unbind_method(list.index) is None
590 assert _unbind_method(list.index) is None
581
591
582
592
583 def test_assumption_instance_attr_do_not_matter():
593 def test_assumption_instance_attr_do_not_matter():
584 """This is semi-specified in Python documentation.
594 """This is semi-specified in Python documentation.
585
595
586 However, since the specification says 'not guaranteed
596 However, since the specification says 'not guaranteed
587 to work' rather than 'is forbidden to work', future
597 to work' rather than 'is forbidden to work', future
588 versions could invalidate this assumptions. This test
598 versions could invalidate this assumptions. This test
589 is meant to catch such a change if it ever comes true.
599 is meant to catch such a change if it ever comes true.
590 """
600 """
591
601
592 class T:
602 class T:
593 def __getitem__(self, k):
603 def __getitem__(self, k):
594 return "a"
604 return "a"
595
605
596 def __getattr__(self, k):
606 def __getattr__(self, k):
597 return "a"
607 return "a"
598
608
599 def f(self):
609 def f(self):
600 return "b"
610 return "b"
601
611
602 t = T()
612 t = T()
603 t.__getitem__ = f
613 t.__getitem__ = f
604 t.__getattr__ = f
614 t.__getattr__ = f
605 assert t[1] == "a"
615 assert t[1] == "a"
606 assert t[1] == "a"
616 assert t[1] == "a"
607
617
608
618
609 def test_assumption_named_tuples_share_getitem():
619 def test_assumption_named_tuples_share_getitem():
610 """Check assumption on named tuples sharing __getitem__"""
620 """Check assumption on named tuples sharing __getitem__"""
611 from typing import NamedTuple
621 from typing import NamedTuple
612
622
613 class A(NamedTuple):
623 class A(NamedTuple):
614 pass
624 pass
615
625
616 class B(NamedTuple):
626 class B(NamedTuple):
617 pass
627 pass
618
628
619 assert A.__getitem__ == B.__getitem__
629 assert A.__getitem__ == B.__getitem__
620
630
621
631
622 @dec.skip_without("numpy")
632 @dec.skip_without("numpy")
623 def test_module_access():
633 def test_module_access():
624 import numpy
634 import numpy
625
635
626 context = limited(numpy=numpy)
636 context = limited(numpy=numpy)
627 assert guarded_eval("numpy.linalg.norm", context) == numpy.linalg.norm
637 assert guarded_eval("numpy.linalg.norm", context) == numpy.linalg.norm
628
638
629 context = minimal(numpy=numpy)
639 context = minimal(numpy=numpy)
630 with pytest.raises(GuardRejection):
640 with pytest.raises(GuardRejection):
631 guarded_eval("np.linalg.norm", context)
641 guarded_eval("np.linalg.norm", context)
General Comments 0
You need to be logged in to leave comments. Login now