##// END OF EJS Templates
Implement remaining special `typing` wrappers
krassowski -
Show More
@@ -1,822 +1,898 b''
1 from inspect import isclass, signature, Signature
1 from inspect import isclass, signature, Signature
2 from typing import (
2 from typing import (
3 Annotated,
4 AnyStr,
3 Callable,
5 Callable,
4 Dict,
6 Dict,
5 Literal,
7 Literal,
6 NamedTuple,
8 NamedTuple,
7 NewType,
9 NewType,
10 Optional,
11 Protocol,
8 Set,
12 Set,
9 Sequence,
13 Sequence,
10 Tuple,
14 Tuple,
11 Type,
15 Type,
12 Protocol,
16 TypeGuard,
13 Union,
17 Union,
14 get_args,
18 get_args,
15 get_origin,
19 get_origin,
20 is_typeddict,
16 )
21 )
17 import ast
22 import ast
18 import builtins
23 import builtins
19 import collections
24 import collections
20 import operator
25 import operator
21 import sys
26 import sys
22 from functools import cached_property
27 from functools import cached_property
23 from dataclasses import dataclass, field
28 from dataclasses import dataclass, field
24 from types import MethodDescriptorType, ModuleType
29 from types import MethodDescriptorType, ModuleType
25
30
26 from IPython.utils.decorators import undoc
31 from IPython.utils.decorators import undoc
27
32
28
33
29 if sys.version_info < (3, 11):
34 if sys.version_info < (3, 11):
30 from typing_extensions import Self
35 from typing_extensions import Self, LiteralString
31 else:
36 else:
32 from typing import Self
37 from typing import Self, LiteralString
33
38
34 if sys.version_info < (3, 12):
39 if sys.version_info < (3, 12):
35 from typing_extensions import TypeAliasType
40 from typing_extensions import TypeAliasType
36 else:
41 else:
37 from typing import TypeAliasType
42 from typing import TypeAliasType
38
43
39
44
40 @undoc
45 @undoc
41 class HasGetItem(Protocol):
46 class HasGetItem(Protocol):
42 def __getitem__(self, key) -> None:
47 def __getitem__(self, key) -> None:
43 ...
48 ...
44
49
45
50
46 @undoc
51 @undoc
47 class InstancesHaveGetItem(Protocol):
52 class InstancesHaveGetItem(Protocol):
48 def __call__(self, *args, **kwargs) -> HasGetItem:
53 def __call__(self, *args, **kwargs) -> HasGetItem:
49 ...
54 ...
50
55
51
56
52 @undoc
57 @undoc
53 class HasGetAttr(Protocol):
58 class HasGetAttr(Protocol):
54 def __getattr__(self, key) -> None:
59 def __getattr__(self, key) -> None:
55 ...
60 ...
56
61
57
62
58 @undoc
63 @undoc
59 class DoesNotHaveGetAttr(Protocol):
64 class DoesNotHaveGetAttr(Protocol):
60 pass
65 pass
61
66
62
67
63 # By default `__getattr__` is not explicitly implemented on most objects
68 # By default `__getattr__` is not explicitly implemented on most objects
64 MayHaveGetattr = Union[HasGetAttr, DoesNotHaveGetAttr]
69 MayHaveGetattr = Union[HasGetAttr, DoesNotHaveGetAttr]
65
70
66
71
67 def _unbind_method(func: Callable) -> Union[Callable, None]:
72 def _unbind_method(func: Callable) -> Union[Callable, None]:
68 """Get unbound method for given bound method.
73 """Get unbound method for given bound method.
69
74
70 Returns None if cannot get unbound method, or method is already unbound.
75 Returns None if cannot get unbound method, or method is already unbound.
71 """
76 """
72 owner = getattr(func, "__self__", None)
77 owner = getattr(func, "__self__", None)
73 owner_class = type(owner)
78 owner_class = type(owner)
74 name = getattr(func, "__name__", None)
79 name = getattr(func, "__name__", None)
75 instance_dict_overrides = getattr(owner, "__dict__", None)
80 instance_dict_overrides = getattr(owner, "__dict__", None)
76 if (
81 if (
77 owner is not None
82 owner is not None
78 and name
83 and name
79 and (
84 and (
80 not instance_dict_overrides
85 not instance_dict_overrides
81 or (instance_dict_overrides and name not in instance_dict_overrides)
86 or (instance_dict_overrides and name not in instance_dict_overrides)
82 )
87 )
83 ):
88 ):
84 return getattr(owner_class, name)
89 return getattr(owner_class, name)
85 return None
90 return None
86
91
87
92
88 @undoc
93 @undoc
89 @dataclass
94 @dataclass
90 class EvaluationPolicy:
95 class EvaluationPolicy:
91 """Definition of evaluation policy."""
96 """Definition of evaluation policy."""
92
97
93 allow_locals_access: bool = False
98 allow_locals_access: bool = False
94 allow_globals_access: bool = False
99 allow_globals_access: bool = False
95 allow_item_access: bool = False
100 allow_item_access: bool = False
96 allow_attr_access: bool = False
101 allow_attr_access: bool = False
97 allow_builtins_access: bool = False
102 allow_builtins_access: bool = False
98 allow_all_operations: bool = False
103 allow_all_operations: bool = False
99 allow_any_calls: bool = False
104 allow_any_calls: bool = False
100 allowed_calls: Set[Callable] = field(default_factory=set)
105 allowed_calls: Set[Callable] = field(default_factory=set)
101
106
102 def can_get_item(self, value, item):
107 def can_get_item(self, value, item):
103 return self.allow_item_access
108 return self.allow_item_access
104
109
105 def can_get_attr(self, value, attr):
110 def can_get_attr(self, value, attr):
106 return self.allow_attr_access
111 return self.allow_attr_access
107
112
108 def can_operate(self, dunders: Tuple[str, ...], a, b=None):
113 def can_operate(self, dunders: Tuple[str, ...], a, b=None):
109 if self.allow_all_operations:
114 if self.allow_all_operations:
110 return True
115 return True
111
116
112 def can_call(self, func):
117 def can_call(self, func):
113 if self.allow_any_calls:
118 if self.allow_any_calls:
114 return True
119 return True
115
120
116 if func in self.allowed_calls:
121 if func in self.allowed_calls:
117 return True
122 return True
118
123
119 owner_method = _unbind_method(func)
124 owner_method = _unbind_method(func)
120
125
121 if owner_method and owner_method in self.allowed_calls:
126 if owner_method and owner_method in self.allowed_calls:
122 return True
127 return True
123
128
124
129
125 def _get_external(module_name: str, access_path: Sequence[str]):
130 def _get_external(module_name: str, access_path: Sequence[str]):
126 """Get value from external module given a dotted access path.
131 """Get value from external module given a dotted access path.
127
132
128 Raises:
133 Raises:
129 * `KeyError` if module is removed not found, and
134 * `KeyError` if module is removed not found, and
130 * `AttributeError` if acess path does not match an exported object
135 * `AttributeError` if acess path does not match an exported object
131 """
136 """
132 member_type = sys.modules[module_name]
137 member_type = sys.modules[module_name]
133 for attr in access_path:
138 for attr in access_path:
134 member_type = getattr(member_type, attr)
139 member_type = getattr(member_type, attr)
135 return member_type
140 return member_type
136
141
137
142
138 def _has_original_dunder_external(
143 def _has_original_dunder_external(
139 value,
144 value,
140 module_name: str,
145 module_name: str,
141 access_path: Sequence[str],
146 access_path: Sequence[str],
142 method_name: str,
147 method_name: str,
143 ):
148 ):
144 if module_name not in sys.modules:
149 if module_name not in sys.modules:
145 # LBYLB as it is faster
150 # LBYLB as it is faster
146 return False
151 return False
147 try:
152 try:
148 member_type = _get_external(module_name, access_path)
153 member_type = _get_external(module_name, access_path)
149 value_type = type(value)
154 value_type = type(value)
150 if type(value) == member_type:
155 if type(value) == member_type:
151 return True
156 return True
152 if method_name == "__getattribute__":
157 if method_name == "__getattribute__":
153 # we have to short-circuit here due to an unresolved issue in
158 # we have to short-circuit here due to an unresolved issue in
154 # `isinstance` implementation: https://bugs.python.org/issue32683
159 # `isinstance` implementation: https://bugs.python.org/issue32683
155 return False
160 return False
156 if isinstance(value, member_type):
161 if isinstance(value, member_type):
157 method = getattr(value_type, method_name, None)
162 method = getattr(value_type, method_name, None)
158 member_method = getattr(member_type, method_name, None)
163 member_method = getattr(member_type, method_name, None)
159 if member_method == method:
164 if member_method == method:
160 return True
165 return True
161 except (AttributeError, KeyError):
166 except (AttributeError, KeyError):
162 return False
167 return False
163
168
164
169
165 def _has_original_dunder(
170 def _has_original_dunder(
166 value, allowed_types, allowed_methods, allowed_external, method_name
171 value, allowed_types, allowed_methods, allowed_external, method_name
167 ):
172 ):
168 # note: Python ignores `__getattr__`/`__getitem__` on instances,
173 # note: Python ignores `__getattr__`/`__getitem__` on instances,
169 # we only need to check at class level
174 # we only need to check at class level
170 value_type = type(value)
175 value_type = type(value)
171
176
172 # strict type check passes β†’ no need to check method
177 # strict type check passes β†’ no need to check method
173 if value_type in allowed_types:
178 if value_type in allowed_types:
174 return True
179 return True
175
180
176 method = getattr(value_type, method_name, None)
181 method = getattr(value_type, method_name, None)
177
182
178 if method is None:
183 if method is None:
179 return None
184 return None
180
185
181 if method in allowed_methods:
186 if method in allowed_methods:
182 return True
187 return True
183
188
184 for module_name, *access_path in allowed_external:
189 for module_name, *access_path in allowed_external:
185 if _has_original_dunder_external(value, module_name, access_path, method_name):
190 if _has_original_dunder_external(value, module_name, access_path, method_name):
186 return True
191 return True
187
192
188 return False
193 return False
189
194
190
195
191 @undoc
196 @undoc
192 @dataclass
197 @dataclass
193 class SelectivePolicy(EvaluationPolicy):
198 class SelectivePolicy(EvaluationPolicy):
194 allowed_getitem: Set[InstancesHaveGetItem] = field(default_factory=set)
199 allowed_getitem: Set[InstancesHaveGetItem] = field(default_factory=set)
195 allowed_getitem_external: Set[Tuple[str, ...]] = field(default_factory=set)
200 allowed_getitem_external: Set[Tuple[str, ...]] = field(default_factory=set)
196
201
197 allowed_getattr: Set[MayHaveGetattr] = field(default_factory=set)
202 allowed_getattr: Set[MayHaveGetattr] = field(default_factory=set)
198 allowed_getattr_external: Set[Tuple[str, ...]] = field(default_factory=set)
203 allowed_getattr_external: Set[Tuple[str, ...]] = field(default_factory=set)
199
204
200 allowed_operations: Set = field(default_factory=set)
205 allowed_operations: Set = field(default_factory=set)
201 allowed_operations_external: Set[Tuple[str, ...]] = field(default_factory=set)
206 allowed_operations_external: Set[Tuple[str, ...]] = field(default_factory=set)
202
207
203 _operation_methods_cache: Dict[str, Set[Callable]] = field(
208 _operation_methods_cache: Dict[str, Set[Callable]] = field(
204 default_factory=dict, init=False
209 default_factory=dict, init=False
205 )
210 )
206
211
207 def can_get_attr(self, value, attr):
212 def can_get_attr(self, value, attr):
208 has_original_attribute = _has_original_dunder(
213 has_original_attribute = _has_original_dunder(
209 value,
214 value,
210 allowed_types=self.allowed_getattr,
215 allowed_types=self.allowed_getattr,
211 allowed_methods=self._getattribute_methods,
216 allowed_methods=self._getattribute_methods,
212 allowed_external=self.allowed_getattr_external,
217 allowed_external=self.allowed_getattr_external,
213 method_name="__getattribute__",
218 method_name="__getattribute__",
214 )
219 )
215 has_original_attr = _has_original_dunder(
220 has_original_attr = _has_original_dunder(
216 value,
221 value,
217 allowed_types=self.allowed_getattr,
222 allowed_types=self.allowed_getattr,
218 allowed_methods=self._getattr_methods,
223 allowed_methods=self._getattr_methods,
219 allowed_external=self.allowed_getattr_external,
224 allowed_external=self.allowed_getattr_external,
220 method_name="__getattr__",
225 method_name="__getattr__",
221 )
226 )
222
227
223 accept = False
228 accept = False
224
229
225 # Many objects do not have `__getattr__`, this is fine.
230 # Many objects do not have `__getattr__`, this is fine.
226 if has_original_attr is None and has_original_attribute:
231 if has_original_attr is None and has_original_attribute:
227 accept = True
232 accept = True
228 else:
233 else:
229 # Accept objects without modifications to `__getattr__` and `__getattribute__`
234 # Accept objects without modifications to `__getattr__` and `__getattribute__`
230 accept = has_original_attr and has_original_attribute
235 accept = has_original_attr and has_original_attribute
231
236
232 if accept:
237 if accept:
233 # We still need to check for overriden properties.
238 # We still need to check for overriden properties.
234
239
235 value_class = type(value)
240 value_class = type(value)
236 if not hasattr(value_class, attr):
241 if not hasattr(value_class, attr):
237 return True
242 return True
238
243
239 class_attr_val = getattr(value_class, attr)
244 class_attr_val = getattr(value_class, attr)
240 is_property = isinstance(class_attr_val, property)
245 is_property = isinstance(class_attr_val, property)
241
246
242 if not is_property:
247 if not is_property:
243 return True
248 return True
244
249
245 # Properties in allowed types are ok (although we do not include any
250 # Properties in allowed types are ok (although we do not include any
246 # properties in our default allow list currently).
251 # properties in our default allow list currently).
247 if type(value) in self.allowed_getattr:
252 if type(value) in self.allowed_getattr:
248 return True # pragma: no cover
253 return True # pragma: no cover
249
254
250 # Properties in subclasses of allowed types may be ok if not changed
255 # Properties in subclasses of allowed types may be ok if not changed
251 for module_name, *access_path in self.allowed_getattr_external:
256 for module_name, *access_path in self.allowed_getattr_external:
252 try:
257 try:
253 external_class = _get_external(module_name, access_path)
258 external_class = _get_external(module_name, access_path)
254 external_class_attr_val = getattr(external_class, attr)
259 external_class_attr_val = getattr(external_class, attr)
255 except (KeyError, AttributeError):
260 except (KeyError, AttributeError):
256 return False # pragma: no cover
261 return False # pragma: no cover
257 return class_attr_val == external_class_attr_val
262 return class_attr_val == external_class_attr_val
258
263
259 return False
264 return False
260
265
261 def can_get_item(self, value, item):
266 def can_get_item(self, value, item):
262 """Allow accessing `__getiitem__` of allow-listed instances unless it was not modified."""
267 """Allow accessing `__getiitem__` of allow-listed instances unless it was not modified."""
263 return _has_original_dunder(
268 return _has_original_dunder(
264 value,
269 value,
265 allowed_types=self.allowed_getitem,
270 allowed_types=self.allowed_getitem,
266 allowed_methods=self._getitem_methods,
271 allowed_methods=self._getitem_methods,
267 allowed_external=self.allowed_getitem_external,
272 allowed_external=self.allowed_getitem_external,
268 method_name="__getitem__",
273 method_name="__getitem__",
269 )
274 )
270
275
271 def can_operate(self, dunders: Tuple[str, ...], a, b=None):
276 def can_operate(self, dunders: Tuple[str, ...], a, b=None):
272 objects = [a]
277 objects = [a]
273 if b is not None:
278 if b is not None:
274 objects.append(b)
279 objects.append(b)
275 return all(
280 return all(
276 [
281 [
277 _has_original_dunder(
282 _has_original_dunder(
278 obj,
283 obj,
279 allowed_types=self.allowed_operations,
284 allowed_types=self.allowed_operations,
280 allowed_methods=self._operator_dunder_methods(dunder),
285 allowed_methods=self._operator_dunder_methods(dunder),
281 allowed_external=self.allowed_operations_external,
286 allowed_external=self.allowed_operations_external,
282 method_name=dunder,
287 method_name=dunder,
283 )
288 )
284 for dunder in dunders
289 for dunder in dunders
285 for obj in objects
290 for obj in objects
286 ]
291 ]
287 )
292 )
288
293
289 def _operator_dunder_methods(self, dunder: str) -> Set[Callable]:
294 def _operator_dunder_methods(self, dunder: str) -> Set[Callable]:
290 if dunder not in self._operation_methods_cache:
295 if dunder not in self._operation_methods_cache:
291 self._operation_methods_cache[dunder] = self._safe_get_methods(
296 self._operation_methods_cache[dunder] = self._safe_get_methods(
292 self.allowed_operations, dunder
297 self.allowed_operations, dunder
293 )
298 )
294 return self._operation_methods_cache[dunder]
299 return self._operation_methods_cache[dunder]
295
300
296 @cached_property
301 @cached_property
297 def _getitem_methods(self) -> Set[Callable]:
302 def _getitem_methods(self) -> Set[Callable]:
298 return self._safe_get_methods(self.allowed_getitem, "__getitem__")
303 return self._safe_get_methods(self.allowed_getitem, "__getitem__")
299
304
300 @cached_property
305 @cached_property
301 def _getattr_methods(self) -> Set[Callable]:
306 def _getattr_methods(self) -> Set[Callable]:
302 return self._safe_get_methods(self.allowed_getattr, "__getattr__")
307 return self._safe_get_methods(self.allowed_getattr, "__getattr__")
303
308
304 @cached_property
309 @cached_property
305 def _getattribute_methods(self) -> Set[Callable]:
310 def _getattribute_methods(self) -> Set[Callable]:
306 return self._safe_get_methods(self.allowed_getattr, "__getattribute__")
311 return self._safe_get_methods(self.allowed_getattr, "__getattribute__")
307
312
308 def _safe_get_methods(self, classes, name) -> Set[Callable]:
313 def _safe_get_methods(self, classes, name) -> Set[Callable]:
309 return {
314 return {
310 method
315 method
311 for class_ in classes
316 for class_ in classes
312 for method in [getattr(class_, name, None)]
317 for method in [getattr(class_, name, None)]
313 if method
318 if method
314 }
319 }
315
320
316
321
317 class _DummyNamedTuple(NamedTuple):
322 class _DummyNamedTuple(NamedTuple):
318 """Used internally to retrieve methods of named tuple instance."""
323 """Used internally to retrieve methods of named tuple instance."""
319
324
320
325
321 class EvaluationContext(NamedTuple):
326 class EvaluationContext(NamedTuple):
322 #: Local namespace
327 #: Local namespace
323 locals: dict
328 locals: dict
324 #: Global namespace
329 #: Global namespace
325 globals: dict
330 globals: dict
326 #: Evaluation policy identifier
331 #: Evaluation policy identifier
327 evaluation: Literal[
332 evaluation: Literal[
328 "forbidden", "minimal", "limited", "unsafe", "dangerous"
333 "forbidden", "minimal", "limited", "unsafe", "dangerous"
329 ] = "forbidden"
334 ] = "forbidden"
330 #: Whether the evalution of code takes place inside of a subscript.
335 #: Whether the evalution of code takes place inside of a subscript.
331 #: Useful for evaluating ``:-1, 'col'`` in ``df[:-1, 'col']``.
336 #: Useful for evaluating ``:-1, 'col'`` in ``df[:-1, 'col']``.
332 in_subscript: bool = False
337 in_subscript: bool = False
333
338
334
339
335 class _IdentitySubscript:
340 class _IdentitySubscript:
336 """Returns the key itself when item is requested via subscript."""
341 """Returns the key itself when item is requested via subscript."""
337
342
338 def __getitem__(self, key):
343 def __getitem__(self, key):
339 return key
344 return key
340
345
341
346
342 IDENTITY_SUBSCRIPT = _IdentitySubscript()
347 IDENTITY_SUBSCRIPT = _IdentitySubscript()
343 SUBSCRIPT_MARKER = "__SUBSCRIPT_SENTINEL__"
348 SUBSCRIPT_MARKER = "__SUBSCRIPT_SENTINEL__"
344 UNKNOWN_SIGNATURE = Signature()
349 UNKNOWN_SIGNATURE = Signature()
345 NOT_EVALUATED = object()
350 NOT_EVALUATED = object()
346
351
347
352
348 class GuardRejection(Exception):
353 class GuardRejection(Exception):
349 """Exception raised when guard rejects evaluation attempt."""
354 """Exception raised when guard rejects evaluation attempt."""
350
355
351 pass
356 pass
352
357
353
358
354 def guarded_eval(code: str, context: EvaluationContext):
359 def guarded_eval(code: str, context: EvaluationContext):
355 """Evaluate provided code in the evaluation context.
360 """Evaluate provided code in the evaluation context.
356
361
357 If evaluation policy given by context is set to ``forbidden``
362 If evaluation policy given by context is set to ``forbidden``
358 no evaluation will be performed; if it is set to ``dangerous``
363 no evaluation will be performed; if it is set to ``dangerous``
359 standard :func:`eval` will be used; finally, for any other,
364 standard :func:`eval` will be used; finally, for any other,
360 policy :func:`eval_node` will be called on parsed AST.
365 policy :func:`eval_node` will be called on parsed AST.
361 """
366 """
362 locals_ = context.locals
367 locals_ = context.locals
363
368
364 if context.evaluation == "forbidden":
369 if context.evaluation == "forbidden":
365 raise GuardRejection("Forbidden mode")
370 raise GuardRejection("Forbidden mode")
366
371
367 # note: not using `ast.literal_eval` as it does not implement
372 # note: not using `ast.literal_eval` as it does not implement
368 # getitem at all, for example it fails on simple `[0][1]`
373 # getitem at all, for example it fails on simple `[0][1]`
369
374
370 if context.in_subscript:
375 if context.in_subscript:
371 # syntatic sugar for ellipsis (:) is only available in susbcripts
376 # syntatic sugar for ellipsis (:) is only available in susbcripts
372 # so we need to trick the ast parser into thinking that we have
377 # so we need to trick the ast parser into thinking that we have
373 # a subscript, but we need to be able to later recognise that we did
378 # a subscript, but we need to be able to later recognise that we did
374 # it so we can ignore the actual __getitem__ operation
379 # it so we can ignore the actual __getitem__ operation
375 if not code:
380 if not code:
376 return tuple()
381 return tuple()
377 locals_ = locals_.copy()
382 locals_ = locals_.copy()
378 locals_[SUBSCRIPT_MARKER] = IDENTITY_SUBSCRIPT
383 locals_[SUBSCRIPT_MARKER] = IDENTITY_SUBSCRIPT
379 code = SUBSCRIPT_MARKER + "[" + code + "]"
384 code = SUBSCRIPT_MARKER + "[" + code + "]"
380 context = EvaluationContext(**{**context._asdict(), **{"locals": locals_}})
385 context = EvaluationContext(**{**context._asdict(), **{"locals": locals_}})
381
386
382 if context.evaluation == "dangerous":
387 if context.evaluation == "dangerous":
383 return eval(code, context.globals, context.locals)
388 return eval(code, context.globals, context.locals)
384
389
385 expression = ast.parse(code, mode="eval")
390 expression = ast.parse(code, mode="eval")
386
391
387 return eval_node(expression, context)
392 return eval_node(expression, context)
388
393
389
394
390 BINARY_OP_DUNDERS: Dict[Type[ast.operator], Tuple[str]] = {
395 BINARY_OP_DUNDERS: Dict[Type[ast.operator], Tuple[str]] = {
391 ast.Add: ("__add__",),
396 ast.Add: ("__add__",),
392 ast.Sub: ("__sub__",),
397 ast.Sub: ("__sub__",),
393 ast.Mult: ("__mul__",),
398 ast.Mult: ("__mul__",),
394 ast.Div: ("__truediv__",),
399 ast.Div: ("__truediv__",),
395 ast.FloorDiv: ("__floordiv__",),
400 ast.FloorDiv: ("__floordiv__",),
396 ast.Mod: ("__mod__",),
401 ast.Mod: ("__mod__",),
397 ast.Pow: ("__pow__",),
402 ast.Pow: ("__pow__",),
398 ast.LShift: ("__lshift__",),
403 ast.LShift: ("__lshift__",),
399 ast.RShift: ("__rshift__",),
404 ast.RShift: ("__rshift__",),
400 ast.BitOr: ("__or__",),
405 ast.BitOr: ("__or__",),
401 ast.BitXor: ("__xor__",),
406 ast.BitXor: ("__xor__",),
402 ast.BitAnd: ("__and__",),
407 ast.BitAnd: ("__and__",),
403 ast.MatMult: ("__matmul__",),
408 ast.MatMult: ("__matmul__",),
404 }
409 }
405
410
406 COMP_OP_DUNDERS: Dict[Type[ast.cmpop], Tuple[str, ...]] = {
411 COMP_OP_DUNDERS: Dict[Type[ast.cmpop], Tuple[str, ...]] = {
407 ast.Eq: ("__eq__",),
412 ast.Eq: ("__eq__",),
408 ast.NotEq: ("__ne__", "__eq__"),
413 ast.NotEq: ("__ne__", "__eq__"),
409 ast.Lt: ("__lt__", "__gt__"),
414 ast.Lt: ("__lt__", "__gt__"),
410 ast.LtE: ("__le__", "__ge__"),
415 ast.LtE: ("__le__", "__ge__"),
411 ast.Gt: ("__gt__", "__lt__"),
416 ast.Gt: ("__gt__", "__lt__"),
412 ast.GtE: ("__ge__", "__le__"),
417 ast.GtE: ("__ge__", "__le__"),
413 ast.In: ("__contains__",),
418 ast.In: ("__contains__",),
414 # Note: ast.Is, ast.IsNot, ast.NotIn are handled specially
419 # Note: ast.Is, ast.IsNot, ast.NotIn are handled specially
415 }
420 }
416
421
417 UNARY_OP_DUNDERS: Dict[Type[ast.unaryop], Tuple[str, ...]] = {
422 UNARY_OP_DUNDERS: Dict[Type[ast.unaryop], Tuple[str, ...]] = {
418 ast.USub: ("__neg__",),
423 ast.USub: ("__neg__",),
419 ast.UAdd: ("__pos__",),
424 ast.UAdd: ("__pos__",),
420 # we have to check both __inv__ and __invert__!
425 # we have to check both __inv__ and __invert__!
421 ast.Invert: ("__invert__", "__inv__"),
426 ast.Invert: ("__invert__", "__inv__"),
422 ast.Not: ("__not__",),
427 ast.Not: ("__not__",),
423 }
428 }
424
429
425
430
426 class Duck:
431 class ImpersonatingDuck:
427 """A dummy class used to create objects of other classes without calling their ``__init__``"""
432 """A dummy class used to create objects of other classes without calling their ``__init__``"""
428
433
434 # no-op: override __class__ to impersonate
435
436
437 class _Duck:
438 """A dummy class used to create objects pretending to have given attributes"""
439
440 def __init__(self, attributes: Optional[dict] = None, items: Optional[dict] = None):
441 self.attributes = attributes or {}
442 self.items = items or {}
443
444 def __getattr__(self, attr: str):
445 return self.attributes[attr]
446
447 def __hasattr__(self, attr: str):
448 return attr in self.attributes
449
450 def __dir__(self):
451 return [*dir(super), *self.attributes]
452
453 def __getitem__(self, key: str):
454 return self.items[key]
455
456 def __hasitem__(self, key: str):
457 return self.items[key]
458
459 def _ipython_key_completions_(self):
460 return self.items.keys()
461
429
462
430 def _find_dunder(node_op, dunders) -> Union[Tuple[str, ...], None]:
463 def _find_dunder(node_op, dunders) -> Union[Tuple[str, ...], None]:
431 dunder = None
464 dunder = None
432 for op, candidate_dunder in dunders.items():
465 for op, candidate_dunder in dunders.items():
433 if isinstance(node_op, op):
466 if isinstance(node_op, op):
434 dunder = candidate_dunder
467 dunder = candidate_dunder
435 return dunder
468 return dunder
436
469
437
470
438 def eval_node(node: Union[ast.AST, None], context: EvaluationContext):
471 def eval_node(node: Union[ast.AST, None], context: EvaluationContext):
439 """Evaluate AST node in provided context.
472 """Evaluate AST node in provided context.
440
473
441 Applies evaluation restrictions defined in the context. Currently does not support evaluation of functions with keyword arguments.
474 Applies evaluation restrictions defined in the context. Currently does not support evaluation of functions with keyword arguments.
442
475
443 Does not evaluate actions that always have side effects:
476 Does not evaluate actions that always have side effects:
444
477
445 - class definitions (``class sth: ...``)
478 - class definitions (``class sth: ...``)
446 - function definitions (``def sth: ...``)
479 - function definitions (``def sth: ...``)
447 - variable assignments (``x = 1``)
480 - variable assignments (``x = 1``)
448 - augmented assignments (``x += 1``)
481 - augmented assignments (``x += 1``)
449 - deletions (``del x``)
482 - deletions (``del x``)
450
483
451 Does not evaluate operations which do not return values:
484 Does not evaluate operations which do not return values:
452
485
453 - assertions (``assert x``)
486 - assertions (``assert x``)
454 - pass (``pass``)
487 - pass (``pass``)
455 - imports (``import x``)
488 - imports (``import x``)
456 - control flow:
489 - control flow:
457
490
458 - conditionals (``if x:``) except for ternary IfExp (``a if x else b``)
491 - conditionals (``if x:``) except for ternary IfExp (``a if x else b``)
459 - loops (``for`` and ``while``)
492 - loops (``for`` and ``while``)
460 - exception handling
493 - exception handling
461
494
462 The purpose of this function is to guard against unwanted side-effects;
495 The purpose of this function is to guard against unwanted side-effects;
463 it does not give guarantees on protection from malicious code execution.
496 it does not give guarantees on protection from malicious code execution.
464 """
497 """
465 policy = EVALUATION_POLICIES[context.evaluation]
498 policy = EVALUATION_POLICIES[context.evaluation]
466 if node is None:
499 if node is None:
467 return None
500 return None
468 if isinstance(node, ast.Expression):
501 if isinstance(node, ast.Expression):
469 return eval_node(node.body, context)
502 return eval_node(node.body, context)
470 if isinstance(node, ast.BinOp):
503 if isinstance(node, ast.BinOp):
471 left = eval_node(node.left, context)
504 left = eval_node(node.left, context)
472 right = eval_node(node.right, context)
505 right = eval_node(node.right, context)
473 dunders = _find_dunder(node.op, BINARY_OP_DUNDERS)
506 dunders = _find_dunder(node.op, BINARY_OP_DUNDERS)
474 if dunders:
507 if dunders:
475 if policy.can_operate(dunders, left, right):
508 if policy.can_operate(dunders, left, right):
476 return getattr(left, dunders[0])(right)
509 return getattr(left, dunders[0])(right)
477 else:
510 else:
478 raise GuardRejection(
511 raise GuardRejection(
479 f"Operation (`{dunders}`) for",
512 f"Operation (`{dunders}`) for",
480 type(left),
513 type(left),
481 f"not allowed in {context.evaluation} mode",
514 f"not allowed in {context.evaluation} mode",
482 )
515 )
483 if isinstance(node, ast.Compare):
516 if isinstance(node, ast.Compare):
484 left = eval_node(node.left, context)
517 left = eval_node(node.left, context)
485 all_true = True
518 all_true = True
486 negate = False
519 negate = False
487 for op, right in zip(node.ops, node.comparators):
520 for op, right in zip(node.ops, node.comparators):
488 right = eval_node(right, context)
521 right = eval_node(right, context)
489 dunder = None
522 dunder = None
490 dunders = _find_dunder(op, COMP_OP_DUNDERS)
523 dunders = _find_dunder(op, COMP_OP_DUNDERS)
491 if not dunders:
524 if not dunders:
492 if isinstance(op, ast.NotIn):
525 if isinstance(op, ast.NotIn):
493 dunders = COMP_OP_DUNDERS[ast.In]
526 dunders = COMP_OP_DUNDERS[ast.In]
494 negate = True
527 negate = True
495 if isinstance(op, ast.Is):
528 if isinstance(op, ast.Is):
496 dunder = "is_"
529 dunder = "is_"
497 if isinstance(op, ast.IsNot):
530 if isinstance(op, ast.IsNot):
498 dunder = "is_"
531 dunder = "is_"
499 negate = True
532 negate = True
500 if not dunder and dunders:
533 if not dunder and dunders:
501 dunder = dunders[0]
534 dunder = dunders[0]
502 if dunder:
535 if dunder:
503 a, b = (right, left) if dunder == "__contains__" else (left, right)
536 a, b = (right, left) if dunder == "__contains__" else (left, right)
504 if dunder == "is_" or dunders and policy.can_operate(dunders, a, b):
537 if dunder == "is_" or dunders and policy.can_operate(dunders, a, b):
505 result = getattr(operator, dunder)(a, b)
538 result = getattr(operator, dunder)(a, b)
506 if negate:
539 if negate:
507 result = not result
540 result = not result
508 if not result:
541 if not result:
509 all_true = False
542 all_true = False
510 left = right
543 left = right
511 else:
544 else:
512 raise GuardRejection(
545 raise GuardRejection(
513 f"Comparison (`{dunder}`) for",
546 f"Comparison (`{dunder}`) for",
514 type(left),
547 type(left),
515 f"not allowed in {context.evaluation} mode",
548 f"not allowed in {context.evaluation} mode",
516 )
549 )
517 else:
550 else:
518 raise ValueError(
551 raise ValueError(
519 f"Comparison `{dunder}` not supported"
552 f"Comparison `{dunder}` not supported"
520 ) # pragma: no cover
553 ) # pragma: no cover
521 return all_true
554 return all_true
522 if isinstance(node, ast.Constant):
555 if isinstance(node, ast.Constant):
523 return node.value
556 return node.value
524 if isinstance(node, ast.Tuple):
557 if isinstance(node, ast.Tuple):
525 return tuple(eval_node(e, context) for e in node.elts)
558 return tuple(eval_node(e, context) for e in node.elts)
526 if isinstance(node, ast.List):
559 if isinstance(node, ast.List):
527 return [eval_node(e, context) for e in node.elts]
560 return [eval_node(e, context) for e in node.elts]
528 if isinstance(node, ast.Set):
561 if isinstance(node, ast.Set):
529 return {eval_node(e, context) for e in node.elts}
562 return {eval_node(e, context) for e in node.elts}
530 if isinstance(node, ast.Dict):
563 if isinstance(node, ast.Dict):
531 return dict(
564 return dict(
532 zip(
565 zip(
533 [eval_node(k, context) for k in node.keys],
566 [eval_node(k, context) for k in node.keys],
534 [eval_node(v, context) for v in node.values],
567 [eval_node(v, context) for v in node.values],
535 )
568 )
536 )
569 )
537 if isinstance(node, ast.Slice):
570 if isinstance(node, ast.Slice):
538 return slice(
571 return slice(
539 eval_node(node.lower, context),
572 eval_node(node.lower, context),
540 eval_node(node.upper, context),
573 eval_node(node.upper, context),
541 eval_node(node.step, context),
574 eval_node(node.step, context),
542 )
575 )
543 if isinstance(node, ast.UnaryOp):
576 if isinstance(node, ast.UnaryOp):
544 value = eval_node(node.operand, context)
577 value = eval_node(node.operand, context)
545 dunders = _find_dunder(node.op, UNARY_OP_DUNDERS)
578 dunders = _find_dunder(node.op, UNARY_OP_DUNDERS)
546 if dunders:
579 if dunders:
547 if policy.can_operate(dunders, value):
580 if policy.can_operate(dunders, value):
548 return getattr(value, dunders[0])()
581 return getattr(value, dunders[0])()
549 else:
582 else:
550 raise GuardRejection(
583 raise GuardRejection(
551 f"Operation (`{dunders}`) for",
584 f"Operation (`{dunders}`) for",
552 type(value),
585 type(value),
553 f"not allowed in {context.evaluation} mode",
586 f"not allowed in {context.evaluation} mode",
554 )
587 )
555 if isinstance(node, ast.Subscript):
588 if isinstance(node, ast.Subscript):
556 value = eval_node(node.value, context)
589 value = eval_node(node.value, context)
557 slice_ = eval_node(node.slice, context)
590 slice_ = eval_node(node.slice, context)
558 if policy.can_get_item(value, slice_):
591 if policy.can_get_item(value, slice_):
559 return value[slice_]
592 return value[slice_]
560 raise GuardRejection(
593 raise GuardRejection(
561 "Subscript access (`__getitem__`) for",
594 "Subscript access (`__getitem__`) for",
562 type(value), # not joined to avoid calling `repr`
595 type(value), # not joined to avoid calling `repr`
563 f" not allowed in {context.evaluation} mode",
596 f" not allowed in {context.evaluation} mode",
564 )
597 )
565 if isinstance(node, ast.Name):
598 if isinstance(node, ast.Name):
566 return _eval_node_name(node.id, context)
599 return _eval_node_name(node.id, context)
567 if isinstance(node, ast.Attribute):
600 if isinstance(node, ast.Attribute):
568 value = eval_node(node.value, context)
601 value = eval_node(node.value, context)
569 if policy.can_get_attr(value, node.attr):
602 if policy.can_get_attr(value, node.attr):
570 return getattr(value, node.attr)
603 return getattr(value, node.attr)
571 raise GuardRejection(
604 raise GuardRejection(
572 "Attribute access (`__getattr__`) for",
605 "Attribute access (`__getattr__`) for",
573 type(value), # not joined to avoid calling `repr`
606 type(value), # not joined to avoid calling `repr`
574 f"not allowed in {context.evaluation} mode",
607 f"not allowed in {context.evaluation} mode",
575 )
608 )
576 if isinstance(node, ast.IfExp):
609 if isinstance(node, ast.IfExp):
577 test = eval_node(node.test, context)
610 test = eval_node(node.test, context)
578 if test:
611 if test:
579 return eval_node(node.body, context)
612 return eval_node(node.body, context)
580 else:
613 else:
581 return eval_node(node.orelse, context)
614 return eval_node(node.orelse, context)
582 if isinstance(node, ast.Call):
615 if isinstance(node, ast.Call):
583 func = eval_node(node.func, context)
616 func = eval_node(node.func, context)
584 if policy.can_call(func) and not node.keywords:
617 if policy.can_call(func) and not node.keywords:
585 args = [eval_node(arg, context) for arg in node.args]
618 args = [eval_node(arg, context) for arg in node.args]
586 return func(*args)
619 return func(*args)
587 if isclass(func):
620 if isclass(func):
588 # this code path gets entered when calling class e.g. `MyClass()`
621 # this code path gets entered when calling class e.g. `MyClass()`
589 # or `my_instance.__class__()` - in both cases `func` is `MyClass`.
622 # or `my_instance.__class__()` - in both cases `func` is `MyClass`.
590 # Should return `MyClass` if `__new__` is not overridden,
623 # Should return `MyClass` if `__new__` is not overridden,
591 # otherwise whatever `__new__` return type is.
624 # otherwise whatever `__new__` return type is.
592 overridden_return_type = _eval_return_type(func.__new__, node, context)
625 overridden_return_type = _eval_return_type(func.__new__, node, context)
593 if overridden_return_type is not NOT_EVALUATED:
626 if overridden_return_type is not NOT_EVALUATED:
594 return overridden_return_type
627 return overridden_return_type
595 return _create_duck_for_heap_type(func)
628 return _create_duck_for_heap_type(func)
596 else:
629 else:
597 return_type = _eval_return_type(func, node, context)
630 return_type = _eval_return_type(func, node, context)
598 if return_type is not NOT_EVALUATED:
631 if return_type is not NOT_EVALUATED:
599 return return_type
632 return return_type
600 raise GuardRejection(
633 raise GuardRejection(
601 "Call for",
634 "Call for",
602 func, # not joined to avoid calling `repr`
635 func, # not joined to avoid calling `repr`
603 f"not allowed in {context.evaluation} mode",
636 f"not allowed in {context.evaluation} mode",
604 )
637 )
605 raise ValueError("Unhandled node", ast.dump(node))
638 raise ValueError("Unhandled node", ast.dump(node))
606
639
607
640
608 def _eval_return_type(func: Callable, node: ast.Call, context: EvaluationContext):
641 def _eval_return_type(func: Callable, node: ast.Call, context: EvaluationContext):
609 """Evaluate return type of a given callable function.
642 """Evaluate return type of a given callable function.
610
643
611 Returns the built-in type, a duck or NOT_EVALUATED sentinel.
644 Returns the built-in type, a duck or NOT_EVALUATED sentinel.
612 """
645 """
613 try:
646 try:
614 sig = signature(func)
647 sig = signature(func)
615 except ValueError:
648 except ValueError:
616 sig = UNKNOWN_SIGNATURE
649 sig = UNKNOWN_SIGNATURE
617 # if annotation was not stringized, or it was stringized
650 # if annotation was not stringized, or it was stringized
618 # but resolved by signature call we know the return type
651 # but resolved by signature call we know the return type
619 not_empty = sig.return_annotation is not Signature.empty
652 not_empty = sig.return_annotation is not Signature.empty
620 stringized = isinstance(sig.return_annotation, str)
621 if not_empty:
653 if not_empty:
622 return_type = (
654 return _resolve_annotation(sig.return_annotation, sig, func, node, context)
623 _eval_node_name(sig.return_annotation, context)
624 if stringized
625 else sig.return_annotation
626 )
627 if return_type is Self and hasattr(func, "__self__"):
628 return func.__self__
629 elif get_origin(return_type) is Literal:
630 type_args = get_args(return_type)
631 if len(type_args) == 1:
632 return type_args[0]
633 elif isinstance(return_type, NewType):
634 return _eval_or_create_duck(return_type.__supertype__, node, context)
635 elif isinstance(return_type, TypeAliasType):
636 return _eval_or_create_duck(return_type.__value__, node, context)
637 else:
638 return _eval_or_create_duck(return_type, node, context)
639 return NOT_EVALUATED
655 return NOT_EVALUATED
640
656
641
657
658 def _resolve_annotation(
659 annotation,
660 sig: Signature,
661 func: Callable,
662 node: ast.Call,
663 context: EvaluationContext,
664 ):
665 """Resolve annotation created by user with `typing` module and custom objects."""
666 annotation = (
667 _eval_node_name(annotation, context)
668 if isinstance(annotation, str)
669 else annotation
670 )
671 origin = get_origin(annotation)
672 if annotation is Self and hasattr(func, "__self__"):
673 return func.__self__
674 elif origin is Literal:
675 type_args = get_args(annotation)
676 if len(type_args) == 1:
677 return type_args[0]
678 elif annotation is LiteralString:
679 return ""
680 elif annotation is AnyStr:
681 index = None
682 for i, (key, value) in enumerate(sig.parameters.items()):
683 if value.annotation is AnyStr:
684 index = i
685 break
686 if index is not None and index < len(node.args):
687 return eval_node(node.args[index], context)
688 elif origin is TypeGuard:
689 return bool()
690 elif origin is Union:
691 attributes = [
692 attr
693 for type_arg in get_args(annotation)
694 for attr in dir(_resolve_annotation(type_arg, sig, func, node, context))
695 ]
696 return _Duck(attributes=dict.fromkeys(attributes))
697 elif is_typeddict(annotation):
698 return _Duck(
699 attributes=dict.fromkeys(dir(dict())),
700 items={
701 k: _resolve_annotation(v, sig, func, node, context)
702 for k, v in annotation.__annotations__.items()
703 },
704 )
705 elif hasattr(annotation, "_is_protocol"):
706 return _Duck(attributes=dict.fromkeys(dir(annotation)))
707 elif origin is Annotated:
708 type_arg = get_args(annotation)[0]
709 return _resolve_annotation(type_arg, sig, func, node, context)
710 elif isinstance(annotation, NewType):
711 return _eval_or_create_duck(annotation.__supertype__, node, context)
712 elif isinstance(annotation, TypeAliasType):
713 return _eval_or_create_duck(annotation.__value__, node, context)
714 else:
715 return _eval_or_create_duck(annotation, node, context)
716
717
642 def _eval_node_name(node_id: str, context: EvaluationContext):
718 def _eval_node_name(node_id: str, context: EvaluationContext):
643 policy = EVALUATION_POLICIES[context.evaluation]
719 policy = EVALUATION_POLICIES[context.evaluation]
644 if policy.allow_locals_access and node_id in context.locals:
720 if policy.allow_locals_access and node_id in context.locals:
645 return context.locals[node_id]
721 return context.locals[node_id]
646 if policy.allow_globals_access and node_id in context.globals:
722 if policy.allow_globals_access and node_id in context.globals:
647 return context.globals[node_id]
723 return context.globals[node_id]
648 if policy.allow_builtins_access and hasattr(builtins, node_id):
724 if policy.allow_builtins_access and hasattr(builtins, node_id):
649 # note: do not use __builtins__, it is implementation detail of cPython
725 # note: do not use __builtins__, it is implementation detail of cPython
650 return getattr(builtins, node_id)
726 return getattr(builtins, node_id)
651 if not policy.allow_globals_access and not policy.allow_locals_access:
727 if not policy.allow_globals_access and not policy.allow_locals_access:
652 raise GuardRejection(
728 raise GuardRejection(
653 f"Namespace access not allowed in {context.evaluation} mode"
729 f"Namespace access not allowed in {context.evaluation} mode"
654 )
730 )
655 else:
731 else:
656 raise NameError(f"{node_id} not found in locals, globals, nor builtins")
732 raise NameError(f"{node_id} not found in locals, globals, nor builtins")
657
733
658
734
659 def _eval_or_create_duck(duck_type, node: ast.Call, context: EvaluationContext):
735 def _eval_or_create_duck(duck_type, node: ast.Call, context: EvaluationContext):
660 policy = EVALUATION_POLICIES[context.evaluation]
736 policy = EVALUATION_POLICIES[context.evaluation]
661 # if allow-listed builtin is on type annotation, instantiate it
737 # if allow-listed builtin is on type annotation, instantiate it
662 if policy.can_call(duck_type) and not node.keywords:
738 if policy.can_call(duck_type) and not node.keywords:
663 args = [eval_node(arg, context) for arg in node.args]
739 args = [eval_node(arg, context) for arg in node.args]
664 return duck_type(*args)
740 return duck_type(*args)
665 # if custom class is in type annotation, mock it
741 # if custom class is in type annotation, mock it
666 return _create_duck_for_heap_type(duck_type)
742 return _create_duck_for_heap_type(duck_type)
667
743
668
744
669 def _create_duck_for_heap_type(duck_type):
745 def _create_duck_for_heap_type(duck_type):
670 """Create an imitation of an object of a given type (a duck).
746 """Create an imitation of an object of a given type (a duck).
671
747
672 Returns the duck or NOT_EVALUATED sentinel if duck could not be created.
748 Returns the duck or NOT_EVALUATED sentinel if duck could not be created.
673 """
749 """
674 duck = Duck()
750 duck = ImpersonatingDuck()
675 try:
751 try:
676 # this only works for heap types, not builtins
752 # this only works for heap types, not builtins
677 duck.__class__ = duck_type
753 duck.__class__ = duck_type
678 return duck
754 return duck
679 except TypeError:
755 except TypeError:
680 pass
756 pass
681 return NOT_EVALUATED
757 return NOT_EVALUATED
682
758
683
759
684 SUPPORTED_EXTERNAL_GETITEM = {
760 SUPPORTED_EXTERNAL_GETITEM = {
685 ("pandas", "core", "indexing", "_iLocIndexer"),
761 ("pandas", "core", "indexing", "_iLocIndexer"),
686 ("pandas", "core", "indexing", "_LocIndexer"),
762 ("pandas", "core", "indexing", "_LocIndexer"),
687 ("pandas", "DataFrame"),
763 ("pandas", "DataFrame"),
688 ("pandas", "Series"),
764 ("pandas", "Series"),
689 ("numpy", "ndarray"),
765 ("numpy", "ndarray"),
690 ("numpy", "void"),
766 ("numpy", "void"),
691 }
767 }
692
768
693
769
694 BUILTIN_GETITEM: Set[InstancesHaveGetItem] = {
770 BUILTIN_GETITEM: Set[InstancesHaveGetItem] = {
695 dict,
771 dict,
696 str, # type: ignore[arg-type]
772 str, # type: ignore[arg-type]
697 bytes, # type: ignore[arg-type]
773 bytes, # type: ignore[arg-type]
698 list,
774 list,
699 tuple,
775 tuple,
700 collections.defaultdict,
776 collections.defaultdict,
701 collections.deque,
777 collections.deque,
702 collections.OrderedDict,
778 collections.OrderedDict,
703 collections.ChainMap,
779 collections.ChainMap,
704 collections.UserDict,
780 collections.UserDict,
705 collections.UserList,
781 collections.UserList,
706 collections.UserString, # type: ignore[arg-type]
782 collections.UserString, # type: ignore[arg-type]
707 _DummyNamedTuple,
783 _DummyNamedTuple,
708 _IdentitySubscript,
784 _IdentitySubscript,
709 }
785 }
710
786
711
787
712 def _list_methods(cls, source=None):
788 def _list_methods(cls, source=None):
713 """For use on immutable objects or with methods returning a copy"""
789 """For use on immutable objects or with methods returning a copy"""
714 return [getattr(cls, k) for k in (source if source else dir(cls))]
790 return [getattr(cls, k) for k in (source if source else dir(cls))]
715
791
716
792
717 dict_non_mutating_methods = ("copy", "keys", "values", "items")
793 dict_non_mutating_methods = ("copy", "keys", "values", "items")
718 list_non_mutating_methods = ("copy", "index", "count")
794 list_non_mutating_methods = ("copy", "index", "count")
719 set_non_mutating_methods = set(dir(set)) & set(dir(frozenset))
795 set_non_mutating_methods = set(dir(set)) & set(dir(frozenset))
720
796
721
797
722 dict_keys: Type[collections.abc.KeysView] = type({}.keys())
798 dict_keys: Type[collections.abc.KeysView] = type({}.keys())
723
799
724 NUMERICS = {int, float, complex}
800 NUMERICS = {int, float, complex}
725
801
726 ALLOWED_CALLS = {
802 ALLOWED_CALLS = {
727 bytes,
803 bytes,
728 *_list_methods(bytes),
804 *_list_methods(bytes),
729 dict,
805 dict,
730 *_list_methods(dict, dict_non_mutating_methods),
806 *_list_methods(dict, dict_non_mutating_methods),
731 dict_keys.isdisjoint,
807 dict_keys.isdisjoint,
732 list,
808 list,
733 *_list_methods(list, list_non_mutating_methods),
809 *_list_methods(list, list_non_mutating_methods),
734 set,
810 set,
735 *_list_methods(set, set_non_mutating_methods),
811 *_list_methods(set, set_non_mutating_methods),
736 frozenset,
812 frozenset,
737 *_list_methods(frozenset),
813 *_list_methods(frozenset),
738 range,
814 range,
739 str,
815 str,
740 *_list_methods(str),
816 *_list_methods(str),
741 tuple,
817 tuple,
742 *_list_methods(tuple),
818 *_list_methods(tuple),
743 *NUMERICS,
819 *NUMERICS,
744 *[method for numeric_cls in NUMERICS for method in _list_methods(numeric_cls)],
820 *[method for numeric_cls in NUMERICS for method in _list_methods(numeric_cls)],
745 collections.deque,
821 collections.deque,
746 *_list_methods(collections.deque, list_non_mutating_methods),
822 *_list_methods(collections.deque, list_non_mutating_methods),
747 collections.defaultdict,
823 collections.defaultdict,
748 *_list_methods(collections.defaultdict, dict_non_mutating_methods),
824 *_list_methods(collections.defaultdict, dict_non_mutating_methods),
749 collections.OrderedDict,
825 collections.OrderedDict,
750 *_list_methods(collections.OrderedDict, dict_non_mutating_methods),
826 *_list_methods(collections.OrderedDict, dict_non_mutating_methods),
751 collections.UserDict,
827 collections.UserDict,
752 *_list_methods(collections.UserDict, dict_non_mutating_methods),
828 *_list_methods(collections.UserDict, dict_non_mutating_methods),
753 collections.UserList,
829 collections.UserList,
754 *_list_methods(collections.UserList, list_non_mutating_methods),
830 *_list_methods(collections.UserList, list_non_mutating_methods),
755 collections.UserString,
831 collections.UserString,
756 *_list_methods(collections.UserString, dir(str)),
832 *_list_methods(collections.UserString, dir(str)),
757 collections.Counter,
833 collections.Counter,
758 *_list_methods(collections.Counter, dict_non_mutating_methods),
834 *_list_methods(collections.Counter, dict_non_mutating_methods),
759 collections.Counter.elements,
835 collections.Counter.elements,
760 collections.Counter.most_common,
836 collections.Counter.most_common,
761 }
837 }
762
838
763 BUILTIN_GETATTR: Set[MayHaveGetattr] = {
839 BUILTIN_GETATTR: Set[MayHaveGetattr] = {
764 *BUILTIN_GETITEM,
840 *BUILTIN_GETITEM,
765 set,
841 set,
766 frozenset,
842 frozenset,
767 object,
843 object,
768 type, # `type` handles a lot of generic cases, e.g. numbers as in `int.real`.
844 type, # `type` handles a lot of generic cases, e.g. numbers as in `int.real`.
769 *NUMERICS,
845 *NUMERICS,
770 dict_keys,
846 dict_keys,
771 MethodDescriptorType,
847 MethodDescriptorType,
772 ModuleType,
848 ModuleType,
773 }
849 }
774
850
775
851
776 BUILTIN_OPERATIONS = {*BUILTIN_GETATTR}
852 BUILTIN_OPERATIONS = {*BUILTIN_GETATTR}
777
853
778 EVALUATION_POLICIES = {
854 EVALUATION_POLICIES = {
779 "minimal": EvaluationPolicy(
855 "minimal": EvaluationPolicy(
780 allow_builtins_access=True,
856 allow_builtins_access=True,
781 allow_locals_access=False,
857 allow_locals_access=False,
782 allow_globals_access=False,
858 allow_globals_access=False,
783 allow_item_access=False,
859 allow_item_access=False,
784 allow_attr_access=False,
860 allow_attr_access=False,
785 allowed_calls=set(),
861 allowed_calls=set(),
786 allow_any_calls=False,
862 allow_any_calls=False,
787 allow_all_operations=False,
863 allow_all_operations=False,
788 ),
864 ),
789 "limited": SelectivePolicy(
865 "limited": SelectivePolicy(
790 allowed_getitem=BUILTIN_GETITEM,
866 allowed_getitem=BUILTIN_GETITEM,
791 allowed_getitem_external=SUPPORTED_EXTERNAL_GETITEM,
867 allowed_getitem_external=SUPPORTED_EXTERNAL_GETITEM,
792 allowed_getattr=BUILTIN_GETATTR,
868 allowed_getattr=BUILTIN_GETATTR,
793 allowed_getattr_external={
869 allowed_getattr_external={
794 # pandas Series/Frame implements custom `__getattr__`
870 # pandas Series/Frame implements custom `__getattr__`
795 ("pandas", "DataFrame"),
871 ("pandas", "DataFrame"),
796 ("pandas", "Series"),
872 ("pandas", "Series"),
797 },
873 },
798 allowed_operations=BUILTIN_OPERATIONS,
874 allowed_operations=BUILTIN_OPERATIONS,
799 allow_builtins_access=True,
875 allow_builtins_access=True,
800 allow_locals_access=True,
876 allow_locals_access=True,
801 allow_globals_access=True,
877 allow_globals_access=True,
802 allowed_calls=ALLOWED_CALLS,
878 allowed_calls=ALLOWED_CALLS,
803 ),
879 ),
804 "unsafe": EvaluationPolicy(
880 "unsafe": EvaluationPolicy(
805 allow_builtins_access=True,
881 allow_builtins_access=True,
806 allow_locals_access=True,
882 allow_locals_access=True,
807 allow_globals_access=True,
883 allow_globals_access=True,
808 allow_attr_access=True,
884 allow_attr_access=True,
809 allow_item_access=True,
885 allow_item_access=True,
810 allow_any_calls=True,
886 allow_any_calls=True,
811 allow_all_operations=True,
887 allow_all_operations=True,
812 ),
888 ),
813 }
889 }
814
890
815
891
816 __all__ = [
892 __all__ = [
817 "guarded_eval",
893 "guarded_eval",
818 "eval_node",
894 "eval_node",
819 "GuardRejection",
895 "GuardRejection",
820 "EvaluationContext",
896 "EvaluationContext",
821 "_unbind_method",
897 "_unbind_method",
822 ]
898 ]
@@ -1,689 +1,785 b''
1 import sys
1 import sys
2 from contextlib import contextmanager
2 from contextlib import contextmanager
3 from typing import NamedTuple, Literal, NewType
3 from typing import (
4 Annotated,
5 AnyStr,
6 NamedTuple,
7 Literal,
8 NewType,
9 Optional,
10 Protocol,
11 TypeGuard,
12 Union,
13 TypedDict,
14 )
4 from functools import partial
15 from functools import partial
5 from IPython.core.guarded_eval import (
16 from IPython.core.guarded_eval import (
6 EvaluationContext,
17 EvaluationContext,
7 GuardRejection,
18 GuardRejection,
8 guarded_eval,
19 guarded_eval,
9 _unbind_method,
20 _unbind_method,
10 )
21 )
11 from IPython.testing import decorators as dec
22 from IPython.testing import decorators as dec
12 import pytest
23 import pytest
13
24
14
25
15 if sys.version_info < (3, 11):
26 if sys.version_info < (3, 11):
16 from typing_extensions import Self
27 from typing_extensions import Self, LiteralString
17 else:
28 else:
18 from typing import Self
29 from typing import Self, LiteralString
19
30
20 if sys.version_info < (3, 12):
31 if sys.version_info < (3, 12):
21 from typing_extensions import TypeAliasType
32 from typing_extensions import TypeAliasType
22 else:
33 else:
23 from typing import TypeAliasType
34 from typing import TypeAliasType
24
35
25
36
26 def create_context(evaluation: str, **kwargs):
37 def create_context(evaluation: str, **kwargs):
27 return EvaluationContext(locals=kwargs, globals={}, evaluation=evaluation)
38 return EvaluationContext(locals=kwargs, globals={}, evaluation=evaluation)
28
39
29
40
30 forbidden = partial(create_context, "forbidden")
41 forbidden = partial(create_context, "forbidden")
31 minimal = partial(create_context, "minimal")
42 minimal = partial(create_context, "minimal")
32 limited = partial(create_context, "limited")
43 limited = partial(create_context, "limited")
33 unsafe = partial(create_context, "unsafe")
44 unsafe = partial(create_context, "unsafe")
34 dangerous = partial(create_context, "dangerous")
45 dangerous = partial(create_context, "dangerous")
35
46
36 LIMITED_OR_HIGHER = [limited, unsafe, dangerous]
47 LIMITED_OR_HIGHER = [limited, unsafe, dangerous]
37 MINIMAL_OR_HIGHER = [minimal, *LIMITED_OR_HIGHER]
48 MINIMAL_OR_HIGHER = [minimal, *LIMITED_OR_HIGHER]
38
49
39
50
40 @contextmanager
51 @contextmanager
41 def module_not_installed(module: str):
52 def module_not_installed(module: str):
42 import sys
53 import sys
43
54
44 try:
55 try:
45 to_restore = sys.modules[module]
56 to_restore = sys.modules[module]
46 del sys.modules[module]
57 del sys.modules[module]
47 except KeyError:
58 except KeyError:
48 to_restore = None
59 to_restore = None
49 try:
60 try:
50 yield
61 yield
51 finally:
62 finally:
52 sys.modules[module] = to_restore
63 sys.modules[module] = to_restore
53
64
54
65
55 def test_external_not_installed():
66 def test_external_not_installed():
56 """
67 """
57 Because attribute check requires checking if object is not of allowed
68 Because attribute check requires checking if object is not of allowed
58 external type, this tests logic for absence of external module.
69 external type, this tests logic for absence of external module.
59 """
70 """
60
71
61 class Custom:
72 class Custom:
62 def __init__(self):
73 def __init__(self):
63 self.test = 1
74 self.test = 1
64
75
65 def __getattr__(self, key):
76 def __getattr__(self, key):
66 return key
77 return key
67
78
68 with module_not_installed("pandas"):
79 with module_not_installed("pandas"):
69 context = limited(x=Custom())
80 context = limited(x=Custom())
70 with pytest.raises(GuardRejection):
81 with pytest.raises(GuardRejection):
71 guarded_eval("x.test", context)
82 guarded_eval("x.test", context)
72
83
73
84
74 @dec.skip_without("pandas")
85 @dec.skip_without("pandas")
75 def test_external_changed_api(monkeypatch):
86 def test_external_changed_api(monkeypatch):
76 """Check that the execution rejects if external API changed paths"""
87 """Check that the execution rejects if external API changed paths"""
77 import pandas as pd
88 import pandas as pd
78
89
79 series = pd.Series([1], index=["a"])
90 series = pd.Series([1], index=["a"])
80
91
81 with monkeypatch.context() as m:
92 with monkeypatch.context() as m:
82 m.delattr(pd, "Series")
93 m.delattr(pd, "Series")
83 context = limited(data=series)
94 context = limited(data=series)
84 with pytest.raises(GuardRejection):
95 with pytest.raises(GuardRejection):
85 guarded_eval("data.iloc[0]", context)
96 guarded_eval("data.iloc[0]", context)
86
97
87
98
88 @dec.skip_without("pandas")
99 @dec.skip_without("pandas")
89 def test_pandas_series_iloc():
100 def test_pandas_series_iloc():
90 import pandas as pd
101 import pandas as pd
91
102
92 series = pd.Series([1], index=["a"])
103 series = pd.Series([1], index=["a"])
93 context = limited(data=series)
104 context = limited(data=series)
94 assert guarded_eval("data.iloc[0]", context) == 1
105 assert guarded_eval("data.iloc[0]", context) == 1
95
106
96
107
97 def test_rejects_custom_properties():
108 def test_rejects_custom_properties():
98 class BadProperty:
109 class BadProperty:
99 @property
110 @property
100 def iloc(self):
111 def iloc(self):
101 return [None]
112 return [None]
102
113
103 series = BadProperty()
114 series = BadProperty()
104 context = limited(data=series)
115 context = limited(data=series)
105
116
106 with pytest.raises(GuardRejection):
117 with pytest.raises(GuardRejection):
107 guarded_eval("data.iloc[0]", context)
118 guarded_eval("data.iloc[0]", context)
108
119
109
120
110 @dec.skip_without("pandas")
121 @dec.skip_without("pandas")
111 def test_accepts_non_overriden_properties():
122 def test_accepts_non_overriden_properties():
112 import pandas as pd
123 import pandas as pd
113
124
114 class GoodProperty(pd.Series):
125 class GoodProperty(pd.Series):
115 pass
126 pass
116
127
117 series = GoodProperty([1], index=["a"])
128 series = GoodProperty([1], index=["a"])
118 context = limited(data=series)
129 context = limited(data=series)
119
130
120 assert guarded_eval("data.iloc[0]", context) == 1
131 assert guarded_eval("data.iloc[0]", context) == 1
121
132
122
133
123 @dec.skip_without("pandas")
134 @dec.skip_without("pandas")
124 def test_pandas_series():
135 def test_pandas_series():
125 import pandas as pd
136 import pandas as pd
126
137
127 context = limited(data=pd.Series([1], index=["a"]))
138 context = limited(data=pd.Series([1], index=["a"]))
128 assert guarded_eval('data["a"]', context) == 1
139 assert guarded_eval('data["a"]', context) == 1
129 with pytest.raises(KeyError):
140 with pytest.raises(KeyError):
130 guarded_eval('data["c"]', context)
141 guarded_eval('data["c"]', context)
131
142
132
143
133 @dec.skip_without("pandas")
144 @dec.skip_without("pandas")
134 def test_pandas_bad_series():
145 def test_pandas_bad_series():
135 import pandas as pd
146 import pandas as pd
136
147
137 class BadItemSeries(pd.Series):
148 class BadItemSeries(pd.Series):
138 def __getitem__(self, key):
149 def __getitem__(self, key):
139 return "CUSTOM_ITEM"
150 return "CUSTOM_ITEM"
140
151
141 class BadAttrSeries(pd.Series):
152 class BadAttrSeries(pd.Series):
142 def __getattr__(self, key):
153 def __getattr__(self, key):
143 return "CUSTOM_ATTR"
154 return "CUSTOM_ATTR"
144
155
145 bad_series = BadItemSeries([1], index=["a"])
156 bad_series = BadItemSeries([1], index=["a"])
146 context = limited(data=bad_series)
157 context = limited(data=bad_series)
147
158
148 with pytest.raises(GuardRejection):
159 with pytest.raises(GuardRejection):
149 guarded_eval('data["a"]', context)
160 guarded_eval('data["a"]', context)
150 with pytest.raises(GuardRejection):
161 with pytest.raises(GuardRejection):
151 guarded_eval('data["c"]', context)
162 guarded_eval('data["c"]', context)
152
163
153 # note: here result is a bit unexpected because
164 # note: here result is a bit unexpected because
154 # pandas `__getattr__` calls `__getitem__`;
165 # pandas `__getattr__` calls `__getitem__`;
155 # FIXME - special case to handle it?
166 # FIXME - special case to handle it?
156 assert guarded_eval("data.a", context) == "CUSTOM_ITEM"
167 assert guarded_eval("data.a", context) == "CUSTOM_ITEM"
157
168
158 context = unsafe(data=bad_series)
169 context = unsafe(data=bad_series)
159 assert guarded_eval('data["a"]', context) == "CUSTOM_ITEM"
170 assert guarded_eval('data["a"]', context) == "CUSTOM_ITEM"
160
171
161 bad_attr_series = BadAttrSeries([1], index=["a"])
172 bad_attr_series = BadAttrSeries([1], index=["a"])
162 context = limited(data=bad_attr_series)
173 context = limited(data=bad_attr_series)
163 assert guarded_eval('data["a"]', context) == 1
174 assert guarded_eval('data["a"]', context) == 1
164 with pytest.raises(GuardRejection):
175 with pytest.raises(GuardRejection):
165 guarded_eval("data.a", context)
176 guarded_eval("data.a", context)
166
177
167
178
168 @dec.skip_without("pandas")
179 @dec.skip_without("pandas")
169 def test_pandas_dataframe_loc():
180 def test_pandas_dataframe_loc():
170 import pandas as pd
181 import pandas as pd
171 from pandas.testing import assert_series_equal
182 from pandas.testing import assert_series_equal
172
183
173 data = pd.DataFrame([{"a": 1}])
184 data = pd.DataFrame([{"a": 1}])
174 context = limited(data=data)
185 context = limited(data=data)
175 assert_series_equal(guarded_eval('data.loc[:, "a"]', context), data["a"])
186 assert_series_equal(guarded_eval('data.loc[:, "a"]', context), data["a"])
176
187
177
188
178 def test_named_tuple():
189 def test_named_tuple():
179 class GoodNamedTuple(NamedTuple):
190 class GoodNamedTuple(NamedTuple):
180 a: str
191 a: str
181 pass
192 pass
182
193
183 class BadNamedTuple(NamedTuple):
194 class BadNamedTuple(NamedTuple):
184 a: str
195 a: str
185
196
186 def __getitem__(self, key):
197 def __getitem__(self, key):
187 return None
198 return None
188
199
189 good = GoodNamedTuple(a="x")
200 good = GoodNamedTuple(a="x")
190 bad = BadNamedTuple(a="x")
201 bad = BadNamedTuple(a="x")
191
202
192 context = limited(data=good)
203 context = limited(data=good)
193 assert guarded_eval("data[0]", context) == "x"
204 assert guarded_eval("data[0]", context) == "x"
194
205
195 context = limited(data=bad)
206 context = limited(data=bad)
196 with pytest.raises(GuardRejection):
207 with pytest.raises(GuardRejection):
197 guarded_eval("data[0]", context)
208 guarded_eval("data[0]", context)
198
209
199
210
200 def test_dict():
211 def test_dict():
201 context = limited(data={"a": 1, "b": {"x": 2}, ("x", "y"): 3})
212 context = limited(data={"a": 1, "b": {"x": 2}, ("x", "y"): 3})
202 assert guarded_eval('data["a"]', context) == 1
213 assert guarded_eval('data["a"]', context) == 1
203 assert guarded_eval('data["b"]', context) == {"x": 2}
214 assert guarded_eval('data["b"]', context) == {"x": 2}
204 assert guarded_eval('data["b"]["x"]', context) == 2
215 assert guarded_eval('data["b"]["x"]', context) == 2
205 assert guarded_eval('data["x", "y"]', context) == 3
216 assert guarded_eval('data["x", "y"]', context) == 3
206
217
207 assert guarded_eval("data.keys", context)
218 assert guarded_eval("data.keys", context)
208
219
209
220
210 def test_set():
221 def test_set():
211 context = limited(data={"a", "b"})
222 context = limited(data={"a", "b"})
212 assert guarded_eval("data.difference", context)
223 assert guarded_eval("data.difference", context)
213
224
214
225
215 def test_list():
226 def test_list():
216 context = limited(data=[1, 2, 3])
227 context = limited(data=[1, 2, 3])
217 assert guarded_eval("data[1]", context) == 2
228 assert guarded_eval("data[1]", context) == 2
218 assert guarded_eval("data.copy", context)
229 assert guarded_eval("data.copy", context)
219
230
220
231
221 def test_dict_literal():
232 def test_dict_literal():
222 context = limited()
233 context = limited()
223 assert guarded_eval("{}", context) == {}
234 assert guarded_eval("{}", context) == {}
224 assert guarded_eval('{"a": 1}', context) == {"a": 1}
235 assert guarded_eval('{"a": 1}', context) == {"a": 1}
225
236
226
237
227 def test_list_literal():
238 def test_list_literal():
228 context = limited()
239 context = limited()
229 assert guarded_eval("[]", context) == []
240 assert guarded_eval("[]", context) == []
230 assert guarded_eval('[1, "a"]', context) == [1, "a"]
241 assert guarded_eval('[1, "a"]', context) == [1, "a"]
231
242
232
243
233 def test_set_literal():
244 def test_set_literal():
234 context = limited()
245 context = limited()
235 assert guarded_eval("set()", context) == set()
246 assert guarded_eval("set()", context) == set()
236 assert guarded_eval('{"a"}', context) == {"a"}
247 assert guarded_eval('{"a"}', context) == {"a"}
237
248
238
249
239 def test_evaluates_if_expression():
250 def test_evaluates_if_expression():
240 context = limited()
251 context = limited()
241 assert guarded_eval("2 if True else 3", context) == 2
252 assert guarded_eval("2 if True else 3", context) == 2
242 assert guarded_eval("4 if False else 5", context) == 5
253 assert guarded_eval("4 if False else 5", context) == 5
243
254
244
255
245 def test_object():
256 def test_object():
246 obj = object()
257 obj = object()
247 context = limited(obj=obj)
258 context = limited(obj=obj)
248 assert guarded_eval("obj.__dir__", context) == obj.__dir__
259 assert guarded_eval("obj.__dir__", context) == obj.__dir__
249
260
250
261
251 @pytest.mark.parametrize(
262 @pytest.mark.parametrize(
252 "code,expected",
263 "code,expected",
253 [
264 [
254 ["int.numerator", int.numerator],
265 ["int.numerator", int.numerator],
255 ["float.is_integer", float.is_integer],
266 ["float.is_integer", float.is_integer],
256 ["complex.real", complex.real],
267 ["complex.real", complex.real],
257 ],
268 ],
258 )
269 )
259 def test_number_attributes(code, expected):
270 def test_number_attributes(code, expected):
260 assert guarded_eval(code, limited()) == expected
271 assert guarded_eval(code, limited()) == expected
261
272
262
273
263 def test_method_descriptor():
274 def test_method_descriptor():
264 context = limited()
275 context = limited()
265 assert guarded_eval("list.copy.__name__", context) == "copy"
276 assert guarded_eval("list.copy.__name__", context) == "copy"
266
277
267
278
268 class HeapType:
279 class HeapType:
269 pass
280 pass
270
281
271
282
272 class CallCreatesHeapType:
283 class CallCreatesHeapType:
273 def __call__(self) -> HeapType:
284 def __call__(self) -> HeapType:
274 return HeapType()
285 return HeapType()
275
286
276
287
277 class CallCreatesBuiltin:
288 class CallCreatesBuiltin:
278 def __call__(self) -> frozenset:
289 def __call__(self) -> frozenset:
279 return frozenset()
290 return frozenset()
280
291
281
292
282 class HasStaticMethod:
293 class HasStaticMethod:
283 @staticmethod
294 @staticmethod
284 def static_method() -> HeapType:
295 def static_method() -> HeapType:
285 return HeapType()
296 return HeapType()
286
297
287
298
288 class InitReturnsFrozenset:
299 class InitReturnsFrozenset:
289 def __new__(self) -> frozenset: # type:ignore[misc]
300 def __new__(self) -> frozenset: # type:ignore[misc]
290 return frozenset()
301 return frozenset()
291
302
292
303
293 class StringAnnotation:
304 class StringAnnotation:
294 def heap(self) -> "HeapType":
305 def heap(self) -> "HeapType":
295 return HeapType()
306 return HeapType()
296
307
297 def copy(self) -> "StringAnnotation":
308 def copy(self) -> "StringAnnotation":
298 return StringAnnotation()
309 return StringAnnotation()
299
310
300
311
301 CustomIntType = NewType("CustomIntType", int)
312 CustomIntType = NewType("CustomIntType", int)
302 CustomHeapType = NewType("CustomHeapType", HeapType)
313 CustomHeapType = NewType("CustomHeapType", HeapType)
303 IntTypeAlias = TypeAliasType("IntTypeAlias", int)
314 IntTypeAlias = TypeAliasType("IntTypeAlias", int)
304 HeapTypeAlias = TypeAliasType("HeapTypeAlias", HeapType)
315 HeapTypeAlias = TypeAliasType("HeapTypeAlias", HeapType)
305
316
306
317
318 class TestProtocol(Protocol):
319 def test_method(self) -> bool:
320 pass
321
322
323 class TestProtocolImplementer(TestProtocol):
324 def test_method(self) -> bool:
325 return True
326
327
328 class Movie(TypedDict):
329 name: str
330 year: int
331
332
307 class SpecialTyping:
333 class SpecialTyping:
308 def custom_int_type(self) -> CustomIntType:
334 def custom_int_type(self) -> CustomIntType:
309 return CustomIntType(1)
335 return CustomIntType(1)
310
336
311 def custom_heap_type(self) -> CustomHeapType:
337 def custom_heap_type(self) -> CustomHeapType:
312 return CustomHeapType(HeapType())
338 return CustomHeapType(HeapType())
313
339
314 # TODO: remove type:ignore comment once mypy
340 # TODO: remove type:ignore comment once mypy
315 # supports explicit calls to `TypeAliasType`, see:
341 # supports explicit calls to `TypeAliasType`, see:
316 # https://github.com/python/mypy/issues/16614
342 # https://github.com/python/mypy/issues/16614
317 def int_type_alias(self) -> IntTypeAlias: # type:ignore[valid-type]
343 def int_type_alias(self) -> IntTypeAlias: # type:ignore[valid-type]
318 return 1
344 return 1
319
345
320 def heap_type_alias(self) -> HeapTypeAlias: # type:ignore[valid-type]
346 def heap_type_alias(self) -> HeapTypeAlias: # type:ignore[valid-type]
321 return 1
347 return 1
322
348
323 def literal(self) -> Literal[False]:
349 def literal(self) -> Literal[False]:
324 return False
350 return False
325
351
352 def literal_string(self) -> LiteralString:
353 return "test"
354
326 def self(self) -> Self:
355 def self(self) -> Self:
327 return self
356 return self
328
357
358 def any_str(self, x: AnyStr) -> AnyStr:
359 return x
360
361 def annotated(self) -> Annotated[float, "positive number"]:
362 return 1
363
364 def annotated_self(self) -> Annotated[Self, "self with metadata"]:
365 self._metadata = "test"
366 return self
367
368 def int_type_guard(self, x) -> TypeGuard[int]:
369 return isinstance(x, int)
370
371 def optional_float(self) -> Optional[float]:
372 return 1.0
373
374 def union_str_and_int(self) -> Union[str, int]:
375 return ""
376
377 def protocol(self) -> TestProtocol:
378 return TestProtocolImplementer()
379
380 def typed_dict(self) -> Movie:
381 return {"name": "The Matrix", "year": 1999}
382
329
383
330 @pytest.mark.parametrize(
384 @pytest.mark.parametrize(
331 "data,good,expected,equality",
385 "data,code,expected,equality",
332 [
386 [
333 [[1, 2, 3], "data.index(2)", 1, True],
387 [[1, 2, 3], "data.index(2)", 1, True],
334 [{"a": 1}, "data.keys().isdisjoint({})", True, True],
388 [{"a": 1}, "data.keys().isdisjoint({})", True, True],
335 [StringAnnotation(), "data.heap()", HeapType, False],
389 [StringAnnotation(), "data.heap()", HeapType, False],
336 [StringAnnotation(), "data.copy()", StringAnnotation, False],
390 [StringAnnotation(), "data.copy()", StringAnnotation, False],
337 # test cases for `__call__`
391 # test cases for `__call__`
338 [CallCreatesHeapType(), "data()", HeapType, False],
392 [CallCreatesHeapType(), "data()", HeapType, False],
339 [CallCreatesBuiltin(), "data()", frozenset, False],
393 [CallCreatesBuiltin(), "data()", frozenset, False],
340 # Test cases for `__init__`
394 # Test cases for `__init__`
341 [HeapType, "data()", HeapType, False],
395 [HeapType, "data()", HeapType, False],
342 [InitReturnsFrozenset, "data()", frozenset, False],
396 [InitReturnsFrozenset, "data()", frozenset, False],
343 [HeapType(), "data.__class__()", HeapType, False],
397 [HeapType(), "data.__class__()", HeapType, False],
344 # supported special cases for typing
398 # supported special cases for typing
345 [SpecialTyping(), "data.custom_int_type()", int, False],
399 [SpecialTyping(), "data.custom_int_type()", int, False],
346 [SpecialTyping(), "data.custom_heap_type()", HeapType, False],
400 [SpecialTyping(), "data.custom_heap_type()", HeapType, False],
347 [SpecialTyping(), "data.int_type_alias()", int, False],
401 [SpecialTyping(), "data.int_type_alias()", int, False],
348 [SpecialTyping(), "data.heap_type_alias()", HeapType, False],
402 [SpecialTyping(), "data.heap_type_alias()", HeapType, False],
349 [SpecialTyping(), "data.self()", SpecialTyping, False],
403 [SpecialTyping(), "data.self()", SpecialTyping, False],
350 [SpecialTyping(), "data.literal()", False, True],
404 [SpecialTyping(), "data.literal()", False, True],
405 [SpecialTyping(), "data.literal_string()", str, False],
406 [SpecialTyping(), "data.any_str('a')", str, False],
407 [SpecialTyping(), "data.any_str(b'a')", bytes, False],
408 [SpecialTyping(), "data.annotated()", float, False],
409 [SpecialTyping(), "data.annotated_self()", SpecialTyping, False],
410 [SpecialTyping(), "data.int_type_guard()", int, False],
351 # test cases for static methods
411 # test cases for static methods
352 [HasStaticMethod, "data.static_method()", HeapType, False],
412 [HasStaticMethod, "data.static_method()", HeapType, False],
353 ],
413 ],
354 )
414 )
355 def test_evaluates_calls(data, good, expected, equality):
415 def test_evaluates_calls(data, code, expected, equality):
356 context = limited(data=data, HeapType=HeapType, StringAnnotation=StringAnnotation)
416 context = limited(data=data, HeapType=HeapType, StringAnnotation=StringAnnotation)
357 value = guarded_eval(good, context)
417 value = guarded_eval(code, context)
358 if equality:
418 if equality:
359 assert value == expected
419 assert value == expected
360 else:
420 else:
361 assert isinstance(value, expected)
421 assert isinstance(value, expected)
362
422
363
423
364 @pytest.mark.parametrize(
424 @pytest.mark.parametrize(
425 "data,code,expected_attributes",
426 [
427 [SpecialTyping(), "data.optional_float()", ["is_integer"]],
428 [
429 SpecialTyping(),
430 "data.union_str_and_int()",
431 ["capitalize", "as_integer_ratio"],
432 ],
433 [SpecialTyping(), "data.protocol()", ["test_method"]],
434 [SpecialTyping(), "data.typed_dict()", ["keys", "values", "items"]],
435 ],
436 )
437 def test_mocks_attributes_of_call_results(data, code, expected_attributes):
438 context = limited(data=data, HeapType=HeapType, StringAnnotation=StringAnnotation)
439 result = guarded_eval(code, context)
440 for attr in expected_attributes:
441 assert hasattr(result, attr)
442 assert attr in dir(result)
443
444
445 @pytest.mark.parametrize(
446 "data,code,expected_items",
447 [
448 [SpecialTyping(), "data.typed_dict()", {"year": int, "name": str}],
449 ],
450 )
451 def test_mocks_items_of_call_results(data, code, expected_items):
452 context = limited(data=data, HeapType=HeapType, StringAnnotation=StringAnnotation)
453 result = guarded_eval(code, context)
454 ipython_keys = result._ipython_key_completions_()
455 for key, value in expected_items.items():
456 assert isinstance(result[key], value)
457 assert key in ipython_keys
458
459
460 @pytest.mark.parametrize(
365 "data,bad",
461 "data,bad",
366 [
462 [
367 [[1, 2, 3], "data.append(4)"],
463 [[1, 2, 3], "data.append(4)"],
368 [{"a": 1}, "data.update()"],
464 [{"a": 1}, "data.update()"],
369 ],
465 ],
370 )
466 )
371 def test_rejects_calls_with_side_effects(data, bad):
467 def test_rejects_calls_with_side_effects(data, bad):
372 context = limited(data=data)
468 context = limited(data=data)
373
469
374 with pytest.raises(GuardRejection):
470 with pytest.raises(GuardRejection):
375 guarded_eval(bad, context)
471 guarded_eval(bad, context)
376
472
377
473
378 @pytest.mark.parametrize(
474 @pytest.mark.parametrize(
379 "code,expected",
475 "code,expected",
380 [
476 [
381 ["(1\n+\n1)", 2],
477 ["(1\n+\n1)", 2],
382 ["list(range(10))[-1:]", [9]],
478 ["list(range(10))[-1:]", [9]],
383 ["list(range(20))[3:-2:3]", [3, 6, 9, 12, 15]],
479 ["list(range(20))[3:-2:3]", [3, 6, 9, 12, 15]],
384 ],
480 ],
385 )
481 )
386 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
482 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
387 def test_evaluates_complex_cases(code, expected, context):
483 def test_evaluates_complex_cases(code, expected, context):
388 assert guarded_eval(code, context()) == expected
484 assert guarded_eval(code, context()) == expected
389
485
390
486
391 @pytest.mark.parametrize(
487 @pytest.mark.parametrize(
392 "code,expected",
488 "code,expected",
393 [
489 [
394 ["1", 1],
490 ["1", 1],
395 ["1.0", 1.0],
491 ["1.0", 1.0],
396 ["0xdeedbeef", 0xDEEDBEEF],
492 ["0xdeedbeef", 0xDEEDBEEF],
397 ["True", True],
493 ["True", True],
398 ["None", None],
494 ["None", None],
399 ["{}", {}],
495 ["{}", {}],
400 ["[]", []],
496 ["[]", []],
401 ],
497 ],
402 )
498 )
403 @pytest.mark.parametrize("context", MINIMAL_OR_HIGHER)
499 @pytest.mark.parametrize("context", MINIMAL_OR_HIGHER)
404 def test_evaluates_literals(code, expected, context):
500 def test_evaluates_literals(code, expected, context):
405 assert guarded_eval(code, context()) == expected
501 assert guarded_eval(code, context()) == expected
406
502
407
503
408 @pytest.mark.parametrize(
504 @pytest.mark.parametrize(
409 "code,expected",
505 "code,expected",
410 [
506 [
411 ["-5", -5],
507 ["-5", -5],
412 ["+5", +5],
508 ["+5", +5],
413 ["~5", -6],
509 ["~5", -6],
414 ],
510 ],
415 )
511 )
416 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
512 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
417 def test_evaluates_unary_operations(code, expected, context):
513 def test_evaluates_unary_operations(code, expected, context):
418 assert guarded_eval(code, context()) == expected
514 assert guarded_eval(code, context()) == expected
419
515
420
516
421 @pytest.mark.parametrize(
517 @pytest.mark.parametrize(
422 "code,expected",
518 "code,expected",
423 [
519 [
424 ["1 + 1", 2],
520 ["1 + 1", 2],
425 ["3 - 1", 2],
521 ["3 - 1", 2],
426 ["2 * 3", 6],
522 ["2 * 3", 6],
427 ["5 // 2", 2],
523 ["5 // 2", 2],
428 ["5 / 2", 2.5],
524 ["5 / 2", 2.5],
429 ["5**2", 25],
525 ["5**2", 25],
430 ["2 >> 1", 1],
526 ["2 >> 1", 1],
431 ["2 << 1", 4],
527 ["2 << 1", 4],
432 ["1 | 2", 3],
528 ["1 | 2", 3],
433 ["1 & 1", 1],
529 ["1 & 1", 1],
434 ["1 & 2", 0],
530 ["1 & 2", 0],
435 ],
531 ],
436 )
532 )
437 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
533 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
438 def test_evaluates_binary_operations(code, expected, context):
534 def test_evaluates_binary_operations(code, expected, context):
439 assert guarded_eval(code, context()) == expected
535 assert guarded_eval(code, context()) == expected
440
536
441
537
442 @pytest.mark.parametrize(
538 @pytest.mark.parametrize(
443 "code,expected",
539 "code,expected",
444 [
540 [
445 ["2 > 1", True],
541 ["2 > 1", True],
446 ["2 < 1", False],
542 ["2 < 1", False],
447 ["2 <= 1", False],
543 ["2 <= 1", False],
448 ["2 <= 2", True],
544 ["2 <= 2", True],
449 ["1 >= 2", False],
545 ["1 >= 2", False],
450 ["2 >= 2", True],
546 ["2 >= 2", True],
451 ["2 == 2", True],
547 ["2 == 2", True],
452 ["1 == 2", False],
548 ["1 == 2", False],
453 ["1 != 2", True],
549 ["1 != 2", True],
454 ["1 != 1", False],
550 ["1 != 1", False],
455 ["1 < 4 < 3", False],
551 ["1 < 4 < 3", False],
456 ["(1 < 4) < 3", True],
552 ["(1 < 4) < 3", True],
457 ["4 > 3 > 2 > 1", True],
553 ["4 > 3 > 2 > 1", True],
458 ["4 > 3 > 2 > 9", False],
554 ["4 > 3 > 2 > 9", False],
459 ["1 < 2 < 3 < 4", True],
555 ["1 < 2 < 3 < 4", True],
460 ["9 < 2 < 3 < 4", False],
556 ["9 < 2 < 3 < 4", False],
461 ["1 < 2 > 1 > 0 > -1 < 1", True],
557 ["1 < 2 > 1 > 0 > -1 < 1", True],
462 ["1 in [1] in [[1]]", True],
558 ["1 in [1] in [[1]]", True],
463 ["1 in [1] in [[2]]", False],
559 ["1 in [1] in [[2]]", False],
464 ["1 in [1]", True],
560 ["1 in [1]", True],
465 ["0 in [1]", False],
561 ["0 in [1]", False],
466 ["1 not in [1]", False],
562 ["1 not in [1]", False],
467 ["0 not in [1]", True],
563 ["0 not in [1]", True],
468 ["True is True", True],
564 ["True is True", True],
469 ["False is False", True],
565 ["False is False", True],
470 ["True is False", False],
566 ["True is False", False],
471 ["True is not True", False],
567 ["True is not True", False],
472 ["False is not True", True],
568 ["False is not True", True],
473 ],
569 ],
474 )
570 )
475 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
571 @pytest.mark.parametrize("context", LIMITED_OR_HIGHER)
476 def test_evaluates_comparisons(code, expected, context):
572 def test_evaluates_comparisons(code, expected, context):
477 assert guarded_eval(code, context()) == expected
573 assert guarded_eval(code, context()) == expected
478
574
479
575
480 def test_guards_comparisons():
576 def test_guards_comparisons():
481 class GoodEq(int):
577 class GoodEq(int):
482 pass
578 pass
483
579
484 class BadEq(int):
580 class BadEq(int):
485 def __eq__(self, other):
581 def __eq__(self, other):
486 assert False
582 assert False
487
583
488 context = limited(bad=BadEq(1), good=GoodEq(1))
584 context = limited(bad=BadEq(1), good=GoodEq(1))
489
585
490 with pytest.raises(GuardRejection):
586 with pytest.raises(GuardRejection):
491 guarded_eval("bad == 1", context)
587 guarded_eval("bad == 1", context)
492
588
493 with pytest.raises(GuardRejection):
589 with pytest.raises(GuardRejection):
494 guarded_eval("bad != 1", context)
590 guarded_eval("bad != 1", context)
495
591
496 with pytest.raises(GuardRejection):
592 with pytest.raises(GuardRejection):
497 guarded_eval("1 == bad", context)
593 guarded_eval("1 == bad", context)
498
594
499 with pytest.raises(GuardRejection):
595 with pytest.raises(GuardRejection):
500 guarded_eval("1 != bad", context)
596 guarded_eval("1 != bad", context)
501
597
502 assert guarded_eval("good == 1", context) is True
598 assert guarded_eval("good == 1", context) is True
503 assert guarded_eval("good != 1", context) is False
599 assert guarded_eval("good != 1", context) is False
504 assert guarded_eval("1 == good", context) is True
600 assert guarded_eval("1 == good", context) is True
505 assert guarded_eval("1 != good", context) is False
601 assert guarded_eval("1 != good", context) is False
506
602
507
603
508 def test_guards_unary_operations():
604 def test_guards_unary_operations():
509 class GoodOp(int):
605 class GoodOp(int):
510 pass
606 pass
511
607
512 class BadOpInv(int):
608 class BadOpInv(int):
513 def __inv__(self, other):
609 def __inv__(self, other):
514 assert False
610 assert False
515
611
516 class BadOpInverse(int):
612 class BadOpInverse(int):
517 def __inv__(self, other):
613 def __inv__(self, other):
518 assert False
614 assert False
519
615
520 context = limited(good=GoodOp(1), bad1=BadOpInv(1), bad2=BadOpInverse(1))
616 context = limited(good=GoodOp(1), bad1=BadOpInv(1), bad2=BadOpInverse(1))
521
617
522 with pytest.raises(GuardRejection):
618 with pytest.raises(GuardRejection):
523 guarded_eval("~bad1", context)
619 guarded_eval("~bad1", context)
524
620
525 with pytest.raises(GuardRejection):
621 with pytest.raises(GuardRejection):
526 guarded_eval("~bad2", context)
622 guarded_eval("~bad2", context)
527
623
528
624
529 def test_guards_binary_operations():
625 def test_guards_binary_operations():
530 class GoodOp(int):
626 class GoodOp(int):
531 pass
627 pass
532
628
533 class BadOp(int):
629 class BadOp(int):
534 def __add__(self, other):
630 def __add__(self, other):
535 assert False
631 assert False
536
632
537 context = limited(good=GoodOp(1), bad=BadOp(1))
633 context = limited(good=GoodOp(1), bad=BadOp(1))
538
634
539 with pytest.raises(GuardRejection):
635 with pytest.raises(GuardRejection):
540 guarded_eval("1 + bad", context)
636 guarded_eval("1 + bad", context)
541
637
542 with pytest.raises(GuardRejection):
638 with pytest.raises(GuardRejection):
543 guarded_eval("bad + 1", context)
639 guarded_eval("bad + 1", context)
544
640
545 assert guarded_eval("good + 1", context) == 2
641 assert guarded_eval("good + 1", context) == 2
546 assert guarded_eval("1 + good", context) == 2
642 assert guarded_eval("1 + good", context) == 2
547
643
548
644
549 def test_guards_attributes():
645 def test_guards_attributes():
550 class GoodAttr(float):
646 class GoodAttr(float):
551 pass
647 pass
552
648
553 class BadAttr1(float):
649 class BadAttr1(float):
554 def __getattr__(self, key):
650 def __getattr__(self, key):
555 assert False
651 assert False
556
652
557 class BadAttr2(float):
653 class BadAttr2(float):
558 def __getattribute__(self, key):
654 def __getattribute__(self, key):
559 assert False
655 assert False
560
656
561 context = limited(good=GoodAttr(0.5), bad1=BadAttr1(0.5), bad2=BadAttr2(0.5))
657 context = limited(good=GoodAttr(0.5), bad1=BadAttr1(0.5), bad2=BadAttr2(0.5))
562
658
563 with pytest.raises(GuardRejection):
659 with pytest.raises(GuardRejection):
564 guarded_eval("bad1.as_integer_ratio", context)
660 guarded_eval("bad1.as_integer_ratio", context)
565
661
566 with pytest.raises(GuardRejection):
662 with pytest.raises(GuardRejection):
567 guarded_eval("bad2.as_integer_ratio", context)
663 guarded_eval("bad2.as_integer_ratio", context)
568
664
569 assert guarded_eval("good.as_integer_ratio()", context) == (1, 2)
665 assert guarded_eval("good.as_integer_ratio()", context) == (1, 2)
570
666
571
667
572 @pytest.mark.parametrize("context", MINIMAL_OR_HIGHER)
668 @pytest.mark.parametrize("context", MINIMAL_OR_HIGHER)
573 def test_access_builtins(context):
669 def test_access_builtins(context):
574 assert guarded_eval("round", context()) == round
670 assert guarded_eval("round", context()) == round
575
671
576
672
577 def test_access_builtins_fails():
673 def test_access_builtins_fails():
578 context = limited()
674 context = limited()
579 with pytest.raises(NameError):
675 with pytest.raises(NameError):
580 guarded_eval("this_is_not_builtin", context)
676 guarded_eval("this_is_not_builtin", context)
581
677
582
678
583 def test_rejects_forbidden():
679 def test_rejects_forbidden():
584 context = forbidden()
680 context = forbidden()
585 with pytest.raises(GuardRejection):
681 with pytest.raises(GuardRejection):
586 guarded_eval("1", context)
682 guarded_eval("1", context)
587
683
588
684
589 def test_guards_locals_and_globals():
685 def test_guards_locals_and_globals():
590 context = EvaluationContext(
686 context = EvaluationContext(
591 locals={"local_a": "a"}, globals={"global_b": "b"}, evaluation="minimal"
687 locals={"local_a": "a"}, globals={"global_b": "b"}, evaluation="minimal"
592 )
688 )
593
689
594 with pytest.raises(GuardRejection):
690 with pytest.raises(GuardRejection):
595 guarded_eval("local_a", context)
691 guarded_eval("local_a", context)
596
692
597 with pytest.raises(GuardRejection):
693 with pytest.raises(GuardRejection):
598 guarded_eval("global_b", context)
694 guarded_eval("global_b", context)
599
695
600
696
601 def test_access_locals_and_globals():
697 def test_access_locals_and_globals():
602 context = EvaluationContext(
698 context = EvaluationContext(
603 locals={"local_a": "a"}, globals={"global_b": "b"}, evaluation="limited"
699 locals={"local_a": "a"}, globals={"global_b": "b"}, evaluation="limited"
604 )
700 )
605 assert guarded_eval("local_a", context) == "a"
701 assert guarded_eval("local_a", context) == "a"
606 assert guarded_eval("global_b", context) == "b"
702 assert guarded_eval("global_b", context) == "b"
607
703
608
704
609 @pytest.mark.parametrize(
705 @pytest.mark.parametrize(
610 "code",
706 "code",
611 ["def func(): pass", "class C: pass", "x = 1", "x += 1", "del x", "import ast"],
707 ["def func(): pass", "class C: pass", "x = 1", "x += 1", "del x", "import ast"],
612 )
708 )
613 @pytest.mark.parametrize("context", [minimal(), limited(), unsafe()])
709 @pytest.mark.parametrize("context", [minimal(), limited(), unsafe()])
614 def test_rejects_side_effect_syntax(code, context):
710 def test_rejects_side_effect_syntax(code, context):
615 with pytest.raises(SyntaxError):
711 with pytest.raises(SyntaxError):
616 guarded_eval(code, context)
712 guarded_eval(code, context)
617
713
618
714
619 def test_subscript():
715 def test_subscript():
620 context = EvaluationContext(
716 context = EvaluationContext(
621 locals={}, globals={}, evaluation="limited", in_subscript=True
717 locals={}, globals={}, evaluation="limited", in_subscript=True
622 )
718 )
623 empty_slice = slice(None, None, None)
719 empty_slice = slice(None, None, None)
624 assert guarded_eval("", context) == tuple()
720 assert guarded_eval("", context) == tuple()
625 assert guarded_eval(":", context) == empty_slice
721 assert guarded_eval(":", context) == empty_slice
626 assert guarded_eval("1:2:3", context) == slice(1, 2, 3)
722 assert guarded_eval("1:2:3", context) == slice(1, 2, 3)
627 assert guarded_eval(':, "a"', context) == (empty_slice, "a")
723 assert guarded_eval(':, "a"', context) == (empty_slice, "a")
628
724
629
725
630 def test_unbind_method():
726 def test_unbind_method():
631 class X(list):
727 class X(list):
632 def index(self, k):
728 def index(self, k):
633 return "CUSTOM"
729 return "CUSTOM"
634
730
635 x = X()
731 x = X()
636 assert _unbind_method(x.index) is X.index
732 assert _unbind_method(x.index) is X.index
637 assert _unbind_method([].index) is list.index
733 assert _unbind_method([].index) is list.index
638 assert _unbind_method(list.index) is None
734 assert _unbind_method(list.index) is None
639
735
640
736
641 def test_assumption_instance_attr_do_not_matter():
737 def test_assumption_instance_attr_do_not_matter():
642 """This is semi-specified in Python documentation.
738 """This is semi-specified in Python documentation.
643
739
644 However, since the specification says 'not guaranteed
740 However, since the specification says 'not guaranteed
645 to work' rather than 'is forbidden to work', future
741 to work' rather than 'is forbidden to work', future
646 versions could invalidate this assumptions. This test
742 versions could invalidate this assumptions. This test
647 is meant to catch such a change if it ever comes true.
743 is meant to catch such a change if it ever comes true.
648 """
744 """
649
745
650 class T:
746 class T:
651 def __getitem__(self, k):
747 def __getitem__(self, k):
652 return "a"
748 return "a"
653
749
654 def __getattr__(self, k):
750 def __getattr__(self, k):
655 return "a"
751 return "a"
656
752
657 def f(self):
753 def f(self):
658 return "b"
754 return "b"
659
755
660 t = T()
756 t = T()
661 t.__getitem__ = f
757 t.__getitem__ = f
662 t.__getattr__ = f
758 t.__getattr__ = f
663 assert t[1] == "a"
759 assert t[1] == "a"
664 assert t[1] == "a"
760 assert t[1] == "a"
665
761
666
762
667 def test_assumption_named_tuples_share_getitem():
763 def test_assumption_named_tuples_share_getitem():
668 """Check assumption on named tuples sharing __getitem__"""
764 """Check assumption on named tuples sharing __getitem__"""
669 from typing import NamedTuple
765 from typing import NamedTuple
670
766
671 class A(NamedTuple):
767 class A(NamedTuple):
672 pass
768 pass
673
769
674 class B(NamedTuple):
770 class B(NamedTuple):
675 pass
771 pass
676
772
677 assert A.__getitem__ == B.__getitem__
773 assert A.__getitem__ == B.__getitem__
678
774
679
775
680 @dec.skip_without("numpy")
776 @dec.skip_without("numpy")
681 def test_module_access():
777 def test_module_access():
682 import numpy
778 import numpy
683
779
684 context = limited(numpy=numpy)
780 context = limited(numpy=numpy)
685 assert guarded_eval("numpy.linalg.norm", context) == numpy.linalg.norm
781 assert guarded_eval("numpy.linalg.norm", context) == numpy.linalg.norm
686
782
687 context = minimal(numpy=numpy)
783 context = minimal(numpy=numpy)
688 with pytest.raises(GuardRejection):
784 with pytest.raises(GuardRejection):
689 guarded_eval("np.linalg.norm", context)
785 guarded_eval("np.linalg.norm", context)
General Comments 0
You need to be logged in to leave comments. Login now