Show More
This diff has been collapsed as it changes many lines, (541 lines changed) Show them Hide them | |||||
@@ -0,0 +1,541 b'' | |||||
|
1 | from typing import Callable, Protocol, Set, Tuple, NamedTuple, Literal, Union | |||
|
2 | import collections | |||
|
3 | import sys | |||
|
4 | import ast | |||
|
5 | import types | |||
|
6 | from functools import cached_property | |||
|
7 | from dataclasses import dataclass, field | |||
|
8 | ||||
|
9 | ||||
|
10 | class HasGetItem(Protocol): | |||
|
11 | def __getitem__(self, key) -> None: ... | |||
|
12 | ||||
|
13 | ||||
|
14 | class InstancesHaveGetItem(Protocol): | |||
|
15 | def __call__(self) -> HasGetItem: ... | |||
|
16 | ||||
|
17 | ||||
|
18 | class HasGetAttr(Protocol): | |||
|
19 | def __getattr__(self, key) -> None: ... | |||
|
20 | ||||
|
21 | ||||
|
22 | class DoesNotHaveGetAttr(Protocol): | |||
|
23 | pass | |||
|
24 | ||||
|
25 | # By default `__getattr__` is not explicitly implemented on most objects | |||
|
26 | MayHaveGetattr = Union[HasGetAttr, DoesNotHaveGetAttr] | |||
|
27 | ||||
|
28 | ||||
|
29 | def unbind_method(func: Callable) -> Union[Callable, None]: | |||
|
30 | """Get unbound method for given bound method. | |||
|
31 | ||||
|
32 | Returns None if cannot get unbound method.""" | |||
|
33 | owner = getattr(func, '__self__', None) | |||
|
34 | owner_class = type(owner) | |||
|
35 | name = getattr(func, '__name__', None) | |||
|
36 | instance_dict_overrides = getattr(owner, '__dict__', None) | |||
|
37 | if ( | |||
|
38 | owner is not None | |||
|
39 | and | |||
|
40 | name | |||
|
41 | and | |||
|
42 | ( | |||
|
43 | not instance_dict_overrides | |||
|
44 | or | |||
|
45 | ( | |||
|
46 | instance_dict_overrides | |||
|
47 | and name not in instance_dict_overrides | |||
|
48 | ) | |||
|
49 | ) | |||
|
50 | ): | |||
|
51 | return getattr(owner_class, name) | |||
|
52 | ||||
|
53 | ||||
|
54 | @dataclass | |||
|
55 | class EvaluationPolicy: | |||
|
56 | allow_locals_access: bool = False | |||
|
57 | allow_globals_access: bool = False | |||
|
58 | allow_item_access: bool = False | |||
|
59 | allow_attr_access: bool = False | |||
|
60 | allow_builtins_access: bool = False | |||
|
61 | allow_any_calls: bool = False | |||
|
62 | allowed_calls: Set[Callable] = field(default_factory=set) | |||
|
63 | ||||
|
64 | def can_get_item(self, value, item): | |||
|
65 | return self.allow_item_access | |||
|
66 | ||||
|
67 | def can_get_attr(self, value, attr): | |||
|
68 | return self.allow_attr_access | |||
|
69 | ||||
|
70 | def can_call(self, func): | |||
|
71 | if self.allow_any_calls: | |||
|
72 | return True | |||
|
73 | ||||
|
74 | if func in self.allowed_calls: | |||
|
75 | return True | |||
|
76 | ||||
|
77 | owner_method = unbind_method(func) | |||
|
78 | if owner_method and owner_method in self.allowed_calls: | |||
|
79 | return True | |||
|
80 | ||||
|
81 | def has_original_dunder_external(value, module_name, access_path, method_name,): | |||
|
82 | try: | |||
|
83 | if module_name not in sys.modules: | |||
|
84 | return False | |||
|
85 | member_type = sys.modules[module_name] | |||
|
86 | for attr in access_path: | |||
|
87 | member_type = getattr(member_type, attr) | |||
|
88 | value_type = type(value) | |||
|
89 | if type(value) == member_type: | |||
|
90 | return True | |||
|
91 | if isinstance(value, member_type): | |||
|
92 | method = getattr(value_type, method_name, None) | |||
|
93 | member_method = getattr(member_type, method_name, None) | |||
|
94 | if member_method == method: | |||
|
95 | return True | |||
|
96 | except (AttributeError, KeyError): | |||
|
97 | return False | |||
|
98 | ||||
|
99 | ||||
|
100 | def has_original_dunder( | |||
|
101 | value, | |||
|
102 | allowed_types, | |||
|
103 | allowed_methods, | |||
|
104 | allowed_external, | |||
|
105 | method_name | |||
|
106 | ): | |||
|
107 | # note: Python ignores `__getattr__`/`__getitem__` on instances, | |||
|
108 | # we only need to check at class level | |||
|
109 | value_type = type(value) | |||
|
110 | ||||
|
111 | # strict type check passes β no need to check method | |||
|
112 | if value_type in allowed_types: | |||
|
113 | return True | |||
|
114 | ||||
|
115 | method = getattr(value_type, method_name, None) | |||
|
116 | ||||
|
117 | if not method: | |||
|
118 | return None | |||
|
119 | ||||
|
120 | if method in allowed_methods: | |||
|
121 | return True | |||
|
122 | ||||
|
123 | for module_name, *access_path in allowed_external: | |||
|
124 | if has_original_dunder_external(value, module_name, access_path, method_name): | |||
|
125 | return True | |||
|
126 | ||||
|
127 | return False | |||
|
128 | ||||
|
129 | ||||
|
130 | @dataclass | |||
|
131 | class SelectivePolicy(EvaluationPolicy): | |||
|
132 | allowed_getitem: Set[HasGetItem] = field(default_factory=set) | |||
|
133 | allowed_getitem_external: Set[Tuple[str, ...]] = field(default_factory=set) | |||
|
134 | allowed_getattr: Set[MayHaveGetattr] = field(default_factory=set) | |||
|
135 | allowed_getattr_external: Set[Tuple[str, ...]] = field(default_factory=set) | |||
|
136 | ||||
|
137 | def can_get_attr(self, value, attr): | |||
|
138 | has_original_attribute = has_original_dunder( | |||
|
139 | value, | |||
|
140 | allowed_types=self.allowed_getattr, | |||
|
141 | allowed_methods=self._getattribute_methods, | |||
|
142 | allowed_external=self.allowed_getattr_external, | |||
|
143 | method_name='__getattribute__' | |||
|
144 | ) | |||
|
145 | has_original_attr = has_original_dunder( | |||
|
146 | value, | |||
|
147 | allowed_types=self.allowed_getattr, | |||
|
148 | allowed_methods=self._getattr_methods, | |||
|
149 | allowed_external=self.allowed_getattr_external, | |||
|
150 | method_name='__getattr__' | |||
|
151 | ) | |||
|
152 | # Many objects do not have `__getattr__`, this is fine | |||
|
153 | if has_original_attr is None and has_original_attribute: | |||
|
154 | return True | |||
|
155 | ||||
|
156 | # Accept objects without modifications to `__getattr__` and `__getattribute__` | |||
|
157 | return has_original_attr and has_original_attribute | |||
|
158 | ||||
|
159 | def get_attr(self, value, attr): | |||
|
160 | if self.can_get_attr(value, attr): | |||
|
161 | return getattr(value, attr) | |||
|
162 | ||||
|
163 | ||||
|
164 | def can_get_item(self, value, item): | |||
|
165 | """Allow accessing `__getiitem__` of allow-listed instances unless it was not modified.""" | |||
|
166 | return has_original_dunder( | |||
|
167 | value, | |||
|
168 | allowed_types=self.allowed_getitem, | |||
|
169 | allowed_methods=self._getitem_methods, | |||
|
170 | allowed_external=self.allowed_getitem_external, | |||
|
171 | method_name='__getitem__' | |||
|
172 | ) | |||
|
173 | ||||
|
174 | @cached_property | |||
|
175 | def _getitem_methods(self) -> Set[Callable]: | |||
|
176 | return self._safe_get_methods( | |||
|
177 | self.allowed_getitem, | |||
|
178 | '__getitem__' | |||
|
179 | ) | |||
|
180 | ||||
|
181 | @cached_property | |||
|
182 | def _getattr_methods(self) -> Set[Callable]: | |||
|
183 | return self._safe_get_methods( | |||
|
184 | self.allowed_getattr, | |||
|
185 | '__getattr__' | |||
|
186 | ) | |||
|
187 | ||||
|
188 | @cached_property | |||
|
189 | def _getattribute_methods(self) -> Set[Callable]: | |||
|
190 | return self._safe_get_methods( | |||
|
191 | self.allowed_getattr, | |||
|
192 | '__getattribute__' | |||
|
193 | ) | |||
|
194 | ||||
|
195 | def _safe_get_methods(self, classes, name) -> Set[Callable]: | |||
|
196 | return { | |||
|
197 | method | |||
|
198 | for class_ in classes | |||
|
199 | for method in [getattr(class_, name, None)] | |||
|
200 | if method | |||
|
201 | } | |||
|
202 | ||||
|
203 | ||||
|
204 | class DummyNamedTuple(NamedTuple): | |||
|
205 | pass | |||
|
206 | ||||
|
207 | ||||
|
208 | class EvaluationContext(NamedTuple): | |||
|
209 | locals_: dict | |||
|
210 | globals_: dict | |||
|
211 | evaluation: Literal['forbidden', 'minimal', 'limitted', 'unsafe', 'dangerous'] = 'forbidden' | |||
|
212 | in_subscript: bool = False | |||
|
213 | ||||
|
214 | ||||
|
215 | class IdentitySubscript: | |||
|
216 | def __getitem__(self, key): | |||
|
217 | return key | |||
|
218 | ||||
|
219 | IDENTITY_SUBSCRIPT = IdentitySubscript() | |||
|
220 | SUBSCRIPT_MARKER = '__SUBSCRIPT_SENTINEL__' | |||
|
221 | ||||
|
222 | class GuardRejection(ValueError): | |||
|
223 | pass | |||
|
224 | ||||
|
225 | ||||
|
226 | def guarded_eval( | |||
|
227 | code: str, | |||
|
228 | context: EvaluationContext | |||
|
229 | ): | |||
|
230 | locals_ = context.locals_ | |||
|
231 | ||||
|
232 | if context.evaluation == 'forbidden': | |||
|
233 | raise GuardRejection('Forbidden mode') | |||
|
234 | ||||
|
235 | # note: not using `ast.literal_eval` as it does not implement | |||
|
236 | # getitem at all, for example it fails on simple `[0][1]` | |||
|
237 | ||||
|
238 | if context.in_subscript: | |||
|
239 | # syntatic sugar for ellipsis (:) is only available in susbcripts | |||
|
240 | # so we need to trick the ast parser into thinking that we have | |||
|
241 | # a subscript, but we need to be able to later recognise that we did | |||
|
242 | # it so we can ignore the actual __getitem__ operation | |||
|
243 | if not code: | |||
|
244 | return tuple() | |||
|
245 | locals_ = locals_.copy() | |||
|
246 | locals_[SUBSCRIPT_MARKER] = IDENTITY_SUBSCRIPT | |||
|
247 | code = SUBSCRIPT_MARKER + '[' + code + ']' | |||
|
248 | context = EvaluationContext(**{ | |||
|
249 | **context._asdict(), | |||
|
250 | **{'locals_': locals_} | |||
|
251 | }) | |||
|
252 | ||||
|
253 | if context.evaluation == 'dangerous': | |||
|
254 | return eval(code, context.globals_, context.locals_) | |||
|
255 | ||||
|
256 | expression = ast.parse(code, mode='eval') | |||
|
257 | ||||
|
258 | return eval_node(expression, context) | |||
|
259 | ||||
|
260 | def eval_node(node: Union[ast.AST, None], context: EvaluationContext): | |||
|
261 | """ | |||
|
262 | Evaluate AST node in provided context. | |||
|
263 | ||||
|
264 | Applies evaluation restrictions defined in the context. | |||
|
265 | ||||
|
266 | Currently does not support evaluation of functions with arguments. | |||
|
267 | ||||
|
268 | Does not evaluate actions which always have side effects: | |||
|
269 | - class definitions (`class sth: ...`) | |||
|
270 | - function definitions (`def sth: ...`) | |||
|
271 | - variable assignments (`x = 1`) | |||
|
272 | - augumented assignments (`x += 1`) | |||
|
273 | - deletions (`del x`) | |||
|
274 | ||||
|
275 | Does not evaluate operations which do not return values: | |||
|
276 | - assertions (`assert x`) | |||
|
277 | - pass (`pass`) | |||
|
278 | - imports (`import x`) | |||
|
279 | - control flow | |||
|
280 | - conditionals (`if x:`) except for terenary IfExp (`a if x else b`) | |||
|
281 | - loops (`for` and `while`) | |||
|
282 | - exception handling | |||
|
283 | """ | |||
|
284 | policy = EVALUATION_POLICIES[context.evaluation] | |||
|
285 | if node is None: | |||
|
286 | return None | |||
|
287 | if isinstance(node, ast.Expression): | |||
|
288 | return eval_node(node.body, context) | |||
|
289 | if isinstance(node, ast.BinOp): | |||
|
290 | # TODO: add guards | |||
|
291 | left = eval_node(node.left, context) | |||
|
292 | right = eval_node(node.right, context) | |||
|
293 | if isinstance(node.op, ast.Add): | |||
|
294 | return left + right | |||
|
295 | if isinstance(node.op, ast.Sub): | |||
|
296 | return left - right | |||
|
297 | if isinstance(node.op, ast.Mult): | |||
|
298 | return left * right | |||
|
299 | if isinstance(node.op, ast.Div): | |||
|
300 | return left / right | |||
|
301 | if isinstance(node.op, ast.FloorDiv): | |||
|
302 | return left // right | |||
|
303 | if isinstance(node.op, ast.Mod): | |||
|
304 | return left % right | |||
|
305 | if isinstance(node.op, ast.Pow): | |||
|
306 | return left ** right | |||
|
307 | if isinstance(node.op, ast.LShift): | |||
|
308 | return left << right | |||
|
309 | if isinstance(node.op, ast.RShift): | |||
|
310 | return left >> right | |||
|
311 | if isinstance(node.op, ast.BitOr): | |||
|
312 | return left | right | |||
|
313 | if isinstance(node.op, ast.BitXor): | |||
|
314 | return left ^ right | |||
|
315 | if isinstance(node.op, ast.BitAnd): | |||
|
316 | return left & right | |||
|
317 | if isinstance(node.op, ast.MatMult): | |||
|
318 | return left @ right | |||
|
319 | if isinstance(node, ast.Constant): | |||
|
320 | return node.value | |||
|
321 | if isinstance(node, ast.Index): | |||
|
322 | return eval_node(node.value, context) | |||
|
323 | if isinstance(node, ast.Tuple): | |||
|
324 | return tuple( | |||
|
325 | eval_node(e, context) | |||
|
326 | for e in node.elts | |||
|
327 | ) | |||
|
328 | if isinstance(node, ast.List): | |||
|
329 | return [ | |||
|
330 | eval_node(e, context) | |||
|
331 | for e in node.elts | |||
|
332 | ] | |||
|
333 | if isinstance(node, ast.Set): | |||
|
334 | return { | |||
|
335 | eval_node(e, context) | |||
|
336 | for e in node.elts | |||
|
337 | } | |||
|
338 | if isinstance(node, ast.Dict): | |||
|
339 | return dict(zip( | |||
|
340 | [eval_node(k, context) for k in node.keys], | |||
|
341 | [eval_node(v, context) for v in node.values] | |||
|
342 | )) | |||
|
343 | if isinstance(node, ast.Slice): | |||
|
344 | return slice( | |||
|
345 | eval_node(node.lower, context), | |||
|
346 | eval_node(node.upper, context), | |||
|
347 | eval_node(node.step, context) | |||
|
348 | ) | |||
|
349 | if isinstance(node, ast.ExtSlice): | |||
|
350 | return tuple([ | |||
|
351 | eval_node(dim, context) | |||
|
352 | for dim in node.dims | |||
|
353 | ]) | |||
|
354 | if isinstance(node, ast.UnaryOp): | |||
|
355 | # TODO: add guards | |||
|
356 | value = eval_node(node.operand, context) | |||
|
357 | if isinstance(node.op, ast.USub): | |||
|
358 | return -value | |||
|
359 | if isinstance(node.op, ast.UAdd): | |||
|
360 | return +value | |||
|
361 | if isinstance(node.op, ast.Invert): | |||
|
362 | return ~value | |||
|
363 | if isinstance(node.op, ast.Not): | |||
|
364 | return not value | |||
|
365 | raise ValueError('Unhandled unary operation:', node.op) | |||
|
366 | if isinstance(node, ast.Subscript): | |||
|
367 | value = eval_node(node.value, context) | |||
|
368 | slice_ = eval_node(node.slice, context) | |||
|
369 | if policy.can_get_item(value, slice_): | |||
|
370 | return value[slice_] | |||
|
371 | raise GuardRejection( | |||
|
372 | 'Subscript access (`__getitem__`) for', | |||
|
373 | type(value), # not joined to avoid calling `repr` | |||
|
374 | f' not allowed in {context.evaluation} mode' | |||
|
375 | ) | |||
|
376 | if isinstance(node, ast.Name): | |||
|
377 | if policy.allow_locals_access and node.id in context.locals_: | |||
|
378 | return context.locals_[node.id] | |||
|
379 | if policy.allow_globals_access and node.id in context.globals_: | |||
|
380 | return context.globals_[node.id] | |||
|
381 | if policy.allow_builtins_access and node.id in __builtins__: | |||
|
382 | return __builtins__[node.id] | |||
|
383 | if not policy.allow_globals_access and not policy.allow_locals_access: | |||
|
384 | raise GuardRejection( | |||
|
385 | f'Namespace access not allowed in {context.evaluation} mode' | |||
|
386 | ) | |||
|
387 | else: | |||
|
388 | raise NameError(f'{node.id} not found in locals nor globals') | |||
|
389 | if isinstance(node, ast.Attribute): | |||
|
390 | value = eval_node(node.value, context) | |||
|
391 | if policy.can_get_attr(value, node.attr): | |||
|
392 | return getattr(value, node.attr) | |||
|
393 | raise GuardRejection( | |||
|
394 | 'Attribute access (`__getattr__`) for', | |||
|
395 | type(value), # not joined to avoid calling `repr` | |||
|
396 | f'not allowed in {context.evaluation} mode' | |||
|
397 | ) | |||
|
398 | if isinstance(node, ast.IfExp): | |||
|
399 | test = eval_node(node.test, context) | |||
|
400 | if test: | |||
|
401 | return eval_node(node.body, context) | |||
|
402 | else: | |||
|
403 | return eval_node(node.orelse, context) | |||
|
404 | if isinstance(node, ast.Call): | |||
|
405 | func = eval_node(node.func, context) | |||
|
406 | print(node.keywords) | |||
|
407 | if policy.can_call(func) and not node.keywords: | |||
|
408 | args = [ | |||
|
409 | eval_node(arg, context) | |||
|
410 | for arg in node.args | |||
|
411 | ] | |||
|
412 | return func(*args) | |||
|
413 | raise GuardRejection( | |||
|
414 | 'Call for', | |||
|
415 | func, # not joined to avoid calling `repr` | |||
|
416 | f'not allowed in {context.evaluation} mode' | |||
|
417 | ) | |||
|
418 | raise ValueError('Unhandled node', node) | |||
|
419 | ||||
|
420 | ||||
|
421 | SUPPORTED_EXTERNAL_GETITEM = { | |||
|
422 | ('pandas', 'core', 'indexing', '_iLocIndexer'), | |||
|
423 | ('pandas', 'core', 'indexing', '_LocIndexer'), | |||
|
424 | ('pandas', 'DataFrame'), | |||
|
425 | ('pandas', 'Series'), | |||
|
426 | ('numpy', 'ndarray'), | |||
|
427 | ('numpy', 'void') | |||
|
428 | } | |||
|
429 | ||||
|
430 | BUILTIN_GETITEM = { | |||
|
431 | dict, | |||
|
432 | str, | |||
|
433 | bytes, | |||
|
434 | list, | |||
|
435 | tuple, | |||
|
436 | collections.defaultdict, | |||
|
437 | collections.deque, | |||
|
438 | collections.OrderedDict, | |||
|
439 | collections.ChainMap, | |||
|
440 | collections.UserDict, | |||
|
441 | collections.UserList, | |||
|
442 | collections.UserString, | |||
|
443 | DummyNamedTuple, | |||
|
444 | IdentitySubscript | |||
|
445 | } | |||
|
446 | ||||
|
447 | ||||
|
448 | def _list_methods(cls, source=None): | |||
|
449 | """For use on immutable objects or with methods returning a copy""" | |||
|
450 | return [ | |||
|
451 | getattr(cls, k) | |||
|
452 | for k in (source if source else dir(cls)) | |||
|
453 | ] | |||
|
454 | ||||
|
455 | ||||
|
456 | dict_non_mutating_methods = ('copy', 'keys', 'values', 'items') | |||
|
457 | list_non_mutating_methods = ('copy', 'index', 'count') | |||
|
458 | set_non_mutating_methods = set(dir(set)) & set(dir(frozenset)) | |||
|
459 | ||||
|
460 | ||||
|
461 | dict_keys = type({}.keys()) | |||
|
462 | method_descriptor = type(list.copy) | |||
|
463 | ||||
|
464 | ALLOWED_CALLS = { | |||
|
465 | bytes, | |||
|
466 | *_list_methods(bytes), | |||
|
467 | dict, | |||
|
468 | *_list_methods(dict, dict_non_mutating_methods), | |||
|
469 | dict_keys.isdisjoint, | |||
|
470 | list, | |||
|
471 | *_list_methods(list, list_non_mutating_methods), | |||
|
472 | set, | |||
|
473 | *_list_methods(set, set_non_mutating_methods), | |||
|
474 | frozenset, | |||
|
475 | *_list_methods(frozenset), | |||
|
476 | range, | |||
|
477 | str, | |||
|
478 | *_list_methods(str), | |||
|
479 | tuple, | |||
|
480 | *_list_methods(tuple), | |||
|
481 | collections.deque, | |||
|
482 | *_list_methods(collections.deque, list_non_mutating_methods), | |||
|
483 | collections.defaultdict, | |||
|
484 | *_list_methods(collections.defaultdict, dict_non_mutating_methods), | |||
|
485 | collections.OrderedDict, | |||
|
486 | *_list_methods(collections.OrderedDict, dict_non_mutating_methods), | |||
|
487 | collections.UserDict, | |||
|
488 | *_list_methods(collections.UserDict, dict_non_mutating_methods), | |||
|
489 | collections.UserList, | |||
|
490 | *_list_methods(collections.UserList, list_non_mutating_methods), | |||
|
491 | collections.UserString, | |||
|
492 | *_list_methods(collections.UserString, dir(str)), | |||
|
493 | collections.Counter, | |||
|
494 | *_list_methods(collections.Counter, dict_non_mutating_methods), | |||
|
495 | collections.Counter.elements, | |||
|
496 | collections.Counter.most_common | |||
|
497 | } | |||
|
498 | ||||
|
499 | EVALUATION_POLICIES = { | |||
|
500 | 'minimal': EvaluationPolicy( | |||
|
501 | allow_builtins_access=True, | |||
|
502 | allow_locals_access=False, | |||
|
503 | allow_globals_access=False, | |||
|
504 | allow_item_access=False, | |||
|
505 | allow_attr_access=False, | |||
|
506 | allowed_calls=set(), | |||
|
507 | allow_any_calls=False | |||
|
508 | ), | |||
|
509 | 'limitted': SelectivePolicy( | |||
|
510 | # TODO: | |||
|
511 | # - should reject binary and unary operations if custom methods would be dispatched | |||
|
512 | allowed_getitem=BUILTIN_GETITEM, | |||
|
513 | allowed_getitem_external=SUPPORTED_EXTERNAL_GETITEM, | |||
|
514 | allowed_getattr={ | |||
|
515 | *BUILTIN_GETITEM, | |||
|
516 | set, | |||
|
517 | frozenset, | |||
|
518 | object, | |||
|
519 | type, # `type` handles a lot of generic cases, e.g. numbers as in `int.real`. | |||
|
520 | dict_keys, | |||
|
521 | method_descriptor | |||
|
522 | }, | |||
|
523 | allowed_getattr_external={ | |||
|
524 | # pandas Series/Frame implements custom `__getattr__` | |||
|
525 | ('pandas', 'DataFrame'), | |||
|
526 | ('pandas', 'Series') | |||
|
527 | }, | |||
|
528 | allow_builtins_access=True, | |||
|
529 | allow_locals_access=True, | |||
|
530 | allow_globals_access=True, | |||
|
531 | allowed_calls=ALLOWED_CALLS | |||
|
532 | ), | |||
|
533 | 'unsafe': EvaluationPolicy( | |||
|
534 | allow_builtins_access=True, | |||
|
535 | allow_locals_access=True, | |||
|
536 | allow_globals_access=True, | |||
|
537 | allow_attr_access=True, | |||
|
538 | allow_item_access=True, | |||
|
539 | allow_any_calls=True | |||
|
540 | ) | |||
|
541 | } No newline at end of file |
@@ -0,0 +1,286 b'' | |||||
|
1 | from typing import NamedTuple | |||
|
2 | from IPython.core.guarded_eval import EvaluationContext, GuardRejection, guarded_eval, unbind_method | |||
|
3 | from IPython.testing import decorators as dec | |||
|
4 | import pytest | |||
|
5 | ||||
|
6 | ||||
|
7 | def limitted(**kwargs): | |||
|
8 | return EvaluationContext( | |||
|
9 | locals_=kwargs, | |||
|
10 | globals_={}, | |||
|
11 | evaluation='limitted' | |||
|
12 | ) | |||
|
13 | ||||
|
14 | ||||
|
15 | def unsafe(**kwargs): | |||
|
16 | return EvaluationContext( | |||
|
17 | locals_=kwargs, | |||
|
18 | globals_={}, | |||
|
19 | evaluation='unsafe' | |||
|
20 | ) | |||
|
21 | ||||
|
22 | @dec.skip_without('pandas') | |||
|
23 | def test_pandas_series_iloc(): | |||
|
24 | import pandas as pd | |||
|
25 | series = pd.Series([1], index=['a']) | |||
|
26 | context = limitted(data=series) | |||
|
27 | assert guarded_eval('data.iloc[0]', context) == 1 | |||
|
28 | ||||
|
29 | ||||
|
30 | @dec.skip_without('pandas') | |||
|
31 | def test_pandas_series(): | |||
|
32 | import pandas as pd | |||
|
33 | context = limitted(data=pd.Series([1], index=['a'])) | |||
|
34 | assert guarded_eval('data["a"]', context) == 1 | |||
|
35 | with pytest.raises(KeyError): | |||
|
36 | guarded_eval('data["c"]', context) | |||
|
37 | ||||
|
38 | ||||
|
39 | @dec.skip_without('pandas') | |||
|
40 | def test_pandas_bad_series(): | |||
|
41 | import pandas as pd | |||
|
42 | class BadItemSeries(pd.Series): | |||
|
43 | def __getitem__(self, key): | |||
|
44 | return 'CUSTOM_ITEM' | |||
|
45 | ||||
|
46 | class BadAttrSeries(pd.Series): | |||
|
47 | def __getattr__(self, key): | |||
|
48 | return 'CUSTOM_ATTR' | |||
|
49 | ||||
|
50 | bad_series = BadItemSeries([1], index=['a']) | |||
|
51 | context = limitted(data=bad_series) | |||
|
52 | ||||
|
53 | with pytest.raises(GuardRejection): | |||
|
54 | guarded_eval('data["a"]', context) | |||
|
55 | with pytest.raises(GuardRejection): | |||
|
56 | guarded_eval('data["c"]', context) | |||
|
57 | ||||
|
58 | # note: here result is a bit unexpected because | |||
|
59 | # pandas `__getattr__` calls `__getitem__`; | |||
|
60 | # FIXME - special case to handle it? | |||
|
61 | assert guarded_eval('data.a', context) == 'CUSTOM_ITEM' | |||
|
62 | ||||
|
63 | context = unsafe(data=bad_series) | |||
|
64 | assert guarded_eval('data["a"]', context) == 'CUSTOM_ITEM' | |||
|
65 | ||||
|
66 | bad_attr_series = BadAttrSeries([1], index=['a']) | |||
|
67 | context = limitted(data=bad_attr_series) | |||
|
68 | assert guarded_eval('data["a"]', context) == 1 | |||
|
69 | with pytest.raises(GuardRejection): | |||
|
70 | guarded_eval('data.a', context) | |||
|
71 | ||||
|
72 | ||||
|
73 | @dec.skip_without('pandas') | |||
|
74 | def test_pandas_dataframe_loc(): | |||
|
75 | import pandas as pd | |||
|
76 | from pandas.testing import assert_series_equal | |||
|
77 | data = pd.DataFrame([{'a': 1}]) | |||
|
78 | context = limitted(data=data) | |||
|
79 | assert_series_equal( | |||
|
80 | guarded_eval('data.loc[:, "a"]', context), | |||
|
81 | data['a'] | |||
|
82 | ) | |||
|
83 | ||||
|
84 | ||||
|
85 | def test_named_tuple(): | |||
|
86 | ||||
|
87 | class GoodNamedTuple(NamedTuple): | |||
|
88 | a: str | |||
|
89 | pass | |||
|
90 | ||||
|
91 | class BadNamedTuple(NamedTuple): | |||
|
92 | a: str | |||
|
93 | def __getitem__(self, key): | |||
|
94 | return None | |||
|
95 | ||||
|
96 | good = GoodNamedTuple(a='x') | |||
|
97 | bad = BadNamedTuple(a='x') | |||
|
98 | ||||
|
99 | context = limitted(data=good) | |||
|
100 | assert guarded_eval('data[0]', context) == 'x' | |||
|
101 | ||||
|
102 | context = limitted(data=bad) | |||
|
103 | with pytest.raises(GuardRejection): | |||
|
104 | guarded_eval('data[0]', context) | |||
|
105 | ||||
|
106 | ||||
|
107 | def test_dict(): | |||
|
108 | context = limitted( | |||
|
109 | data={'a': 1, 'b': {'x': 2}, ('x', 'y'): 3} | |||
|
110 | ) | |||
|
111 | assert guarded_eval('data["a"]', context) == 1 | |||
|
112 | assert guarded_eval('data["b"]', context) == {'x': 2} | |||
|
113 | assert guarded_eval('data["b"]["x"]', context) == 2 | |||
|
114 | assert guarded_eval('data["x", "y"]', context) == 3 | |||
|
115 | ||||
|
116 | assert guarded_eval('data.keys', context) | |||
|
117 | ||||
|
118 | ||||
|
119 | def test_set(): | |||
|
120 | context = limitted(data={'a', 'b'}) | |||
|
121 | assert guarded_eval('data.difference', context) | |||
|
122 | ||||
|
123 | ||||
|
124 | def test_list(): | |||
|
125 | context = limitted(data=[1, 2, 3]) | |||
|
126 | assert guarded_eval('data[1]', context) == 2 | |||
|
127 | assert guarded_eval('data.copy', context) | |||
|
128 | ||||
|
129 | ||||
|
130 | def test_dict_literal(): | |||
|
131 | context = limitted() | |||
|
132 | assert guarded_eval('{}', context) == {} | |||
|
133 | assert guarded_eval('{"a": 1}', context) == {"a": 1} | |||
|
134 | ||||
|
135 | ||||
|
136 | def test_list_literal(): | |||
|
137 | context = limitted() | |||
|
138 | assert guarded_eval('[]', context) == [] | |||
|
139 | assert guarded_eval('[1, "a"]', context) == [1, "a"] | |||
|
140 | ||||
|
141 | ||||
|
142 | def test_set_literal(): | |||
|
143 | context = limitted() | |||
|
144 | assert guarded_eval('set()', context) == set() | |||
|
145 | assert guarded_eval('{"a"}', context) == {"a"} | |||
|
146 | ||||
|
147 | ||||
|
148 | def test_if_expression(): | |||
|
149 | context = limitted() | |||
|
150 | assert guarded_eval('2 if True else 3', context) == 2 | |||
|
151 | assert guarded_eval('4 if False else 5', context) == 5 | |||
|
152 | ||||
|
153 | ||||
|
154 | def test_object(): | |||
|
155 | obj = object() | |||
|
156 | context = limitted(obj=obj) | |||
|
157 | assert guarded_eval('obj.__dir__', context) == obj.__dir__ | |||
|
158 | ||||
|
159 | ||||
|
160 | @pytest.mark.parametrize( | |||
|
161 | "code,expected", | |||
|
162 | [ | |||
|
163 | [ | |||
|
164 | 'int.numerator', | |||
|
165 | int.numerator | |||
|
166 | ], | |||
|
167 | [ | |||
|
168 | 'float.is_integer', | |||
|
169 | float.is_integer | |||
|
170 | ], | |||
|
171 | [ | |||
|
172 | 'complex.real', | |||
|
173 | complex.real | |||
|
174 | ] | |||
|
175 | ] | |||
|
176 | ) | |||
|
177 | def test_number_attributes(code, expected): | |||
|
178 | assert guarded_eval(code, limitted()) == expected | |||
|
179 | ||||
|
180 | ||||
|
181 | def test_method_descriptor(): | |||
|
182 | context = limitted() | |||
|
183 | assert guarded_eval('list.copy.__name__', context) == 'copy' | |||
|
184 | ||||
|
185 | ||||
|
186 | @pytest.mark.parametrize( | |||
|
187 | "data,good,bad,expected", | |||
|
188 | [ | |||
|
189 | [ | |||
|
190 | [1, 2, 3], | |||
|
191 | 'data.index(2)', | |||
|
192 | 'data.append(4)', | |||
|
193 | 1 | |||
|
194 | ], | |||
|
195 | [ | |||
|
196 | {'a': 1}, | |||
|
197 | 'data.keys().isdisjoint({})', | |||
|
198 | 'data.update()', | |||
|
199 | True | |||
|
200 | ] | |||
|
201 | ] | |||
|
202 | ) | |||
|
203 | def test_calls(data, good, bad, expected): | |||
|
204 | context = limitted(data=data) | |||
|
205 | assert guarded_eval(good, context) == expected | |||
|
206 | ||||
|
207 | with pytest.raises(GuardRejection): | |||
|
208 | guarded_eval(bad, context) | |||
|
209 | ||||
|
210 | ||||
|
211 | @pytest.mark.parametrize( | |||
|
212 | "code,expected", | |||
|
213 | [ | |||
|
214 | [ | |||
|
215 | '(1\n+\n1)', | |||
|
216 | 2 | |||
|
217 | ], | |||
|
218 | [ | |||
|
219 | 'list(range(10))[-1:]', | |||
|
220 | [9] | |||
|
221 | ], | |||
|
222 | [ | |||
|
223 | 'list(range(20))[3:-2:3]', | |||
|
224 | [3, 6, 9, 12, 15] | |||
|
225 | ] | |||
|
226 | ] | |||
|
227 | ) | |||
|
228 | def test_literals(code, expected): | |||
|
229 | context = limitted() | |||
|
230 | assert guarded_eval(code, context) == expected | |||
|
231 | ||||
|
232 | ||||
|
233 | def test_subscript(): | |||
|
234 | context = EvaluationContext( | |||
|
235 | locals_={}, | |||
|
236 | globals_={}, | |||
|
237 | evaluation='limitted', | |||
|
238 | in_subscript=True | |||
|
239 | ) | |||
|
240 | empty_slice = slice(None, None, None) | |||
|
241 | assert guarded_eval('', context) == tuple() | |||
|
242 | assert guarded_eval(':', context) == empty_slice | |||
|
243 | assert guarded_eval('1:2:3', context) == slice(1, 2, 3) | |||
|
244 | assert guarded_eval(':, "a"', context) == (empty_slice, "a") | |||
|
245 | ||||
|
246 | ||||
|
247 | def test_unbind_method(): | |||
|
248 | class X(list): | |||
|
249 | def index(self, k): | |||
|
250 | return 'CUSTOM' | |||
|
251 | x = X() | |||
|
252 | assert unbind_method(x.index) is X.index | |||
|
253 | assert unbind_method([].index) is list.index | |||
|
254 | ||||
|
255 | ||||
|
256 | def test_assumption_instance_attr_do_not_matter(): | |||
|
257 | """This is semi-specified in Python documentation. | |||
|
258 | ||||
|
259 | However, since the specification says 'not guaranted | |||
|
260 | to work' rather than 'is forbidden to work', future | |||
|
261 | versions could invalidate this assumptions. This test | |||
|
262 | is meant to catch such a change if it ever comes true. | |||
|
263 | """ | |||
|
264 | class T: | |||
|
265 | def __getitem__(self, k): | |||
|
266 | return 'a' | |||
|
267 | def __getattr__(self, k): | |||
|
268 | return 'a' | |||
|
269 | t = T() | |||
|
270 | t.__getitem__ = lambda f: 'b' | |||
|
271 | t.__getattr__ = lambda f: 'b' | |||
|
272 | assert t[1] == 'a' | |||
|
273 | assert t[1] == 'a' | |||
|
274 | ||||
|
275 | ||||
|
276 | def test_assumption_named_tuples_share_getitem(): | |||
|
277 | """Check assumption on named tuples sharing __getitem__""" | |||
|
278 | from typing import NamedTuple | |||
|
279 | ||||
|
280 | class A(NamedTuple): | |||
|
281 | pass | |||
|
282 | ||||
|
283 | class B(NamedTuple): | |||
|
284 | pass | |||
|
285 | ||||
|
286 | assert A.__getitem__ == B.__getitem__ |
@@ -190,6 +190,7 b' import time' | |||||
190 | import unicodedata |
|
190 | import unicodedata | |
191 | import uuid |
|
191 | import uuid | |
192 | import warnings |
|
192 | import warnings | |
|
193 | from ast import literal_eval | |||
193 | from contextlib import contextmanager |
|
194 | from contextlib import contextmanager | |
194 | from dataclasses import dataclass |
|
195 | from dataclasses import dataclass | |
195 | from functools import cached_property, partial |
|
196 | from functools import cached_property, partial | |
@@ -212,6 +213,7 b' from typing import (' | |||||
212 | Literal, |
|
213 | Literal, | |
213 | ) |
|
214 | ) | |
214 |
|
215 | |||
|
216 | from IPython.core.guarded_eval import guarded_eval, EvaluationContext | |||
215 | from IPython.core.error import TryNext |
|
217 | from IPython.core.error import TryNext | |
216 | from IPython.core.inputtransformer2 import ESC_MAGIC |
|
218 | from IPython.core.inputtransformer2 import ESC_MAGIC | |
217 | from IPython.core.latex_symbols import latex_symbols, reverse_latex_symbol |
|
219 | from IPython.core.latex_symbols import latex_symbols, reverse_latex_symbol | |
@@ -296,6 +298,9 b' MATCHES_LIMIT = 500' | |||||
296 | # Completion type reported when no type can be inferred. |
|
298 | # Completion type reported when no type can be inferred. | |
297 | _UNKNOWN_TYPE = "<unknown>" |
|
299 | _UNKNOWN_TYPE = "<unknown>" | |
298 |
|
300 | |||
|
301 | # sentinel value to signal lack of a match | |||
|
302 | not_found = object() | |||
|
303 | ||||
299 | class ProvisionalCompleterWarning(FutureWarning): |
|
304 | class ProvisionalCompleterWarning(FutureWarning): | |
300 | """ |
|
305 | """ | |
301 | Exception raise by an experimental feature in this module. |
|
306 | Exception raise by an experimental feature in this module. | |
@@ -902,12 +907,33 b' class CompletionSplitter(object):' | |||||
902 |
|
907 | |||
903 | class Completer(Configurable): |
|
908 | class Completer(Configurable): | |
904 |
|
909 | |||
905 |
greedy = Bool( |
|
910 | greedy = Bool( | |
906 | help="""Activate greedy completion |
|
911 | False, | |
907 | PENDING DEPRECATION. this is now mostly taken care of with Jedi. |
|
912 | help="""Activate greedy completion. | |
|
913 | ||||
|
914 | .. deprecated:: 8.8 | |||
|
915 | Use :any:`evaluation` instead. | |||
|
916 | ||||
|
917 | As of IPython 8.8 proxy for ``evaluation = 'unsafe'`` when set to ``True``, | |||
|
918 | and for ``'forbidden'`` when set to ``False``. | |||
|
919 | """, | |||
|
920 | ).tag(config=True) | |||
908 |
|
921 | |||
909 | This will enable completion on elements of lists, results of function calls, etc., |
|
922 | evaluation = Enum( | |
910 | but can be unsafe because the code is actually evaluated on TAB. |
|
923 | ('forbidden', 'minimal', 'limitted', 'unsafe', 'dangerous'), | |
|
924 | default_value='limitted', | |||
|
925 | help="""Code evaluation under completion. | |||
|
926 | ||||
|
927 | Successive options allow to enable more eager evaluation for more accurate completion suggestions, | |||
|
928 | including for nested dictionaries, nested lists, or even results of function calls. Setting `unsafe` | |||
|
929 | or higher can lead to evaluation of arbitrary user code on TAB with potentially dangerous side effects. | |||
|
930 | ||||
|
931 | Allowed values are: | |||
|
932 | - `forbidden`: no evaluation at all | |||
|
933 | - `minimal`: evaluation of literals and access to built-in namespaces; no item/attribute evaluation nor access to locals/globals | |||
|
934 | - `limitted` (default): access to all namespaces, evaluation of hard-coded methods (``keys()``, ``__getattr__``, ``__getitems__``, etc) on allow-listed objects (e.g. ``dict``, ``list``, ``tuple``, ``pandas.Series``) | |||
|
935 | - `unsafe`: evaluation of all methods and function calls but not of syntax with side-effects like `del x`, | |||
|
936 | - `dangerous`: completely arbitrary evaluation | |||
911 | """, |
|
937 | """, | |
912 | ).tag(config=True) |
|
938 | ).tag(config=True) | |
913 |
|
939 | |||
@@ -1029,26 +1055,14 b' class Completer(Configurable):' | |||||
1029 | with a __getattr__ hook is evaluated. |
|
1055 | with a __getattr__ hook is evaluated. | |
1030 |
|
1056 | |||
1031 | """ |
|
1057 | """ | |
1032 |
|
||||
1033 | # Another option, seems to work great. Catches things like ''.<tab> |
|
|||
1034 | m = re.match(r"(\S+(\.\w+)*)\.(\w*)$", text) |
|
|||
1035 |
|
||||
1036 | if m: |
|
|||
1037 | expr, attr = m.group(1, 3) |
|
|||
1038 | elif self.greedy: |
|
|||
1039 |
|
|
1058 | m2 = re.match(r"(.+)\.(\w*)$", self.line_buffer) | |
1040 |
|
|
1059 | if not m2: | |
1041 |
|
|
1060 | return [] | |
1042 |
|
|
1061 | expr, attr = m2.group(1,2) | |
1043 | else: |
|
|||
1044 | return [] |
|
|||
1045 |
|
1062 | |||
1046 | try: |
|
1063 | obj = self._evaluate_expr(expr) | |
1047 | obj = eval(expr, self.namespace) |
|
1064 | ||
1048 | except: |
|
1065 | if obj is not_found: | |
1049 | try: |
|
|||
1050 | obj = eval(expr, self.global_namespace) |
|
|||
1051 | except: |
|
|||
1052 |
|
|
1066 | return [] | |
1053 |
|
1067 | |||
1054 | if self.limit_to__all__ and hasattr(obj, '__all__'): |
|
1068 | if self.limit_to__all__ and hasattr(obj, '__all__'): | |
@@ -1068,8 +1082,32 b' class Completer(Configurable):' | |||||
1068 | pass |
|
1082 | pass | |
1069 | # Build match list to return |
|
1083 | # Build match list to return | |
1070 | n = len(attr) |
|
1084 | n = len(attr) | |
1071 |
return [ |
|
1085 | return ["%s.%s" % (expr, w) for w in words if w[:n] == attr ] | |
|
1086 | ||||
1072 |
|
1087 | |||
|
1088 | def _evaluate_expr(self, expr): | |||
|
1089 | obj = not_found | |||
|
1090 | done = False | |||
|
1091 | while not done and expr: | |||
|
1092 | try: | |||
|
1093 | obj = guarded_eval( | |||
|
1094 | expr, | |||
|
1095 | EvaluationContext( | |||
|
1096 | globals_=self.global_namespace, | |||
|
1097 | locals_=self.namespace, | |||
|
1098 | evaluation=self.evaluation | |||
|
1099 | ) | |||
|
1100 | ) | |||
|
1101 | done = True | |||
|
1102 | except Exception as e: | |||
|
1103 | if self.debug: | |||
|
1104 | print('Evaluation exception', e) | |||
|
1105 | # trim the expression to remove any invalid prefix | |||
|
1106 | # e.g. user starts `(d[`, so we get `expr = '(d'`, | |||
|
1107 | # where parenthesis is not closed. | |||
|
1108 | # TODO: make this faster by reusing parts of the computation? | |||
|
1109 | expr = expr[1:] | |||
|
1110 | return obj | |||
1073 |
|
1111 | |||
1074 | def get__all__entries(obj): |
|
1112 | def get__all__entries(obj): | |
1075 | """returns the strings in the __all__ attribute""" |
|
1113 | """returns the strings in the __all__ attribute""" | |
@@ -1081,8 +1119,8 b' def get__all__entries(obj):' | |||||
1081 | return [w for w in words if isinstance(w, str)] |
|
1119 | return [w for w in words if isinstance(w, str)] | |
1082 |
|
1120 | |||
1083 |
|
1121 | |||
1084 | def match_dict_keys(keys: List[Union[str, bytes, Tuple[Union[str, bytes]]]], prefix: str, delims: str, |
|
1122 | def match_dict_keys(keys: List[Union[str, bytes, Tuple[Union[str, bytes], ...]]], prefix: str, delims: str, | |
1085 | extra_prefix: Optional[Tuple[str, bytes]]=None) -> Tuple[str, int, List[str]]: |
|
1123 | extra_prefix: Optional[Tuple[Union[str, bytes], ...]]=None) -> Tuple[str, int, List[str]]: | |
1086 | """Used by dict_key_matches, matching the prefix to a list of keys |
|
1124 | """Used by dict_key_matches, matching the prefix to a list of keys | |
1087 |
|
1125 | |||
1088 | Parameters |
|
1126 | Parameters | |
@@ -1106,25 +1144,28 b' def match_dict_keys(keys: List[Union[str, bytes, Tuple[Union[str, bytes]]]], pre' | |||||
1106 |
|
1144 | |||
1107 | """ |
|
1145 | """ | |
1108 | prefix_tuple = extra_prefix if extra_prefix else () |
|
1146 | prefix_tuple = extra_prefix if extra_prefix else () | |
|
1147 | ||||
1109 | Nprefix = len(prefix_tuple) |
|
1148 | Nprefix = len(prefix_tuple) | |
|
1149 | text_serializable_types = (str, bytes, int, float, slice) | |||
1110 | def filter_prefix_tuple(key): |
|
1150 | def filter_prefix_tuple(key): | |
1111 | # Reject too short keys |
|
1151 | # Reject too short keys | |
1112 | if len(key) <= Nprefix: |
|
1152 | if len(key) <= Nprefix: | |
1113 | return False |
|
1153 | return False | |
1114 |
# Reject keys w |
|
1154 | # Reject keys which cannot be serialised to text | |
1115 | for k in key: |
|
1155 | for k in key: | |
1116 |
if not isinstance(k, |
|
1156 | if not isinstance(k, text_serializable_types): | |
1117 | return False |
|
1157 | return False | |
1118 | # Reject keys that do not match the prefix |
|
1158 | # Reject keys that do not match the prefix | |
1119 | for k, pt in zip(key, prefix_tuple): |
|
1159 | for k, pt in zip(key, prefix_tuple): | |
1120 | if k != pt: |
|
1160 | if k != pt and not isinstance(pt, slice): | |
1121 | return False |
|
1161 | return False | |
1122 | # All checks passed! |
|
1162 | # All checks passed! | |
1123 | return True |
|
1163 | return True | |
1124 |
|
1164 | |||
1125 | filtered_keys:List[Union[str,bytes]] = [] |
|
1165 | filtered_keys: List[Union[str, bytes, int, float, slice]] = [] | |
|
1166 | ||||
1126 | def _add_to_filtered_keys(key): |
|
1167 | def _add_to_filtered_keys(key): | |
1127 |
if isinstance(key, |
|
1168 | if isinstance(key, text_serializable_types): | |
1128 | filtered_keys.append(key) |
|
1169 | filtered_keys.append(key) | |
1129 |
|
1170 | |||
1130 | for k in keys: |
|
1171 | for k in keys: | |
@@ -1140,7 +1181,7 b' def match_dict_keys(keys: List[Union[str, bytes, Tuple[Union[str, bytes]]]], pre' | |||||
1140 | assert quote_match is not None # silence mypy |
|
1181 | assert quote_match is not None # silence mypy | |
1141 | quote = quote_match.group() |
|
1182 | quote = quote_match.group() | |
1142 | try: |
|
1183 | try: | |
1143 |
prefix_str = eval(prefix + quote |
|
1184 | prefix_str = literal_eval(prefix + quote) | |
1144 | except Exception: |
|
1185 | except Exception: | |
1145 | return '', 0, [] |
|
1186 | return '', 0, [] | |
1146 |
|
1187 | |||
@@ -1152,15 +1193,16 b' def match_dict_keys(keys: List[Union[str, bytes, Tuple[Union[str, bytes]]]], pre' | |||||
1152 |
|
1193 | |||
1153 | matched:List[str] = [] |
|
1194 | matched: List[str] = [] | |
1154 | for key in filtered_keys: |
|
1195 | for key in filtered_keys: | |
|
1196 | str_key = key if isinstance(key, (str, bytes)) else str(key) | |||
1155 | try: |
|
1197 | try: | |
1156 | if not key.startswith(prefix_str): |
|
1198 | if not str_key.startswith(prefix_str): | |
1157 | continue |
|
1199 | continue | |
1158 | except (AttributeError, TypeError, UnicodeError): |
|
1200 | except (AttributeError, TypeError, UnicodeError): | |
1159 | # Python 3+ TypeError on b'a'.startswith('a') or vice-versa |
|
1201 | # Python 3+ TypeError on b'a'.startswith('a') or vice-versa | |
1160 | continue |
|
1202 | continue | |
1161 |
|
1203 | |||
1162 | # reformat remainder of key to begin with prefix |
|
1204 | # reformat remainder of key to begin with prefix | |
1163 | rem = key[len(prefix_str):] |
|
1205 | rem = str_key[len(prefix_str):] | |
1164 | # force repr wrapped in ' |
|
1206 | # force repr wrapped in ' | |
1165 | rem_repr = repr(rem + '"') if isinstance(rem, str) else repr(rem + b'"') |
|
1207 | rem_repr = repr(rem + '"') if isinstance(rem, str) else repr(rem + b'"') | |
1166 | rem_repr = rem_repr[1 + rem_repr.index("'"):-2] |
|
1208 | rem_repr = rem_repr[1 + rem_repr.index("'"):-2] | |
@@ -1237,11 +1279,14 b' def position_to_cursor(text:str, offset:int)->Tuple[int, int]:' | |||||
1237 | return line, col |
|
1279 | return line, col | |
1238 |
|
1280 | |||
1239 |
|
1281 | |||
1240 | def _safe_isinstance(obj, module, class_name): |
|
1282 | def _safe_isinstance(obj, module, class_name, *attrs): | |
1241 | """Checks if obj is an instance of module.class_name if loaded |
|
1283 | """Checks if obj is an instance of module.class_name if loaded | |
1242 | """ |
|
1284 | """ | |
1243 |
|
|
1285 | if module in sys.modules: | |
1244 | isinstance(obj, getattr(import_module(module), class_name))) |
|
1286 | m = sys.modules[module] | |
|
1287 | for attr in [class_name, *attrs]: | |||
|
1288 | m = getattr(m, attr) | |||
|
1289 | return isinstance(obj, m) | |||
1245 |
|
1290 | |||
1246 |
|
1291 | |||
1247 | @context_matcher() |
|
1292 | @context_matcher() | |
@@ -1394,6 +1439,37 b' def _make_signature(completion)-> str:' | |||||
1394 | _CompleteResult = Dict[str, MatcherResult] |
|
1439 | _CompleteResult = Dict[str, MatcherResult] | |
1395 |
|
1440 | |||
1396 |
|
1441 | |||
|
1442 | DICT_MATCHER_REGEX = re.compile(r"""(?x) | |||
|
1443 | ( # match dict-referring - or any get item object - expression | |||
|
1444 | .+ | |||
|
1445 | ) | |||
|
1446 | \[ # open bracket | |||
|
1447 | \s* # and optional whitespace | |||
|
1448 | # Capture any number of serializable objects (e.g. "a", "b", 'c') | |||
|
1449 | # and slices | |||
|
1450 | ((?:[uUbB]? # string prefix (r not handled) | |||
|
1451 | (?: | |||
|
1452 | '(?:[^']|(?<!\\)\\')*' | |||
|
1453 | | | |||
|
1454 | "(?:[^"]|(?<!\\)\\")*" | |||
|
1455 | | | |||
|
1456 | # capture integers and slices | |||
|
1457 | (?:[-+]?\d+)?(?::(?:[-+]?\d+)?){0,2} | |||
|
1458 | ) | |||
|
1459 | \s*,\s* | |||
|
1460 | )*) | |||
|
1461 | ([uUbB]? # string prefix (r not handled) | |||
|
1462 | (?: # unclosed string | |||
|
1463 | '(?:[^']|(?<!\\)\\')* | |||
|
1464 | | | |||
|
1465 | "(?:[^"]|(?<!\\)\\")* | |||
|
1466 | | | |||
|
1467 | (?:[-+]?\d+) | |||
|
1468 | ) | |||
|
1469 | )? | |||
|
1470 | $ | |||
|
1471 | """) | |||
|
1472 | ||||
1397 | def _convert_matcher_v1_result_to_v2( |
|
1473 | def _convert_matcher_v1_result_to_v2( | |
1398 | matches: Sequence[str], |
|
1474 | matches: Sequence[str], | |
1399 | type: str, |
|
1475 | type: str, | |
@@ -1413,14 +1489,14 b' def _convert_matcher_v1_result_to_v2(' | |||||
1413 | class IPCompleter(Completer): |
|
1489 | class IPCompleter(Completer): | |
1414 | """Extension of the completer class with IPython-specific features""" |
|
1490 | """Extension of the completer class with IPython-specific features""" | |
1415 |
|
1491 | |||
1416 | __dict_key_regexps: Optional[Dict[bool,Pattern]] = None |
|
|||
1417 |
|
||||
1418 | @observe('greedy') |
|
1492 | @observe('greedy') | |
1419 | def _greedy_changed(self, change): |
|
1493 | def _greedy_changed(self, change): | |
1420 | """update the splitter and readline delims when greedy is changed""" |
|
1494 | """update the splitter and readline delims when greedy is changed""" | |
1421 | if change['new']: |
|
1495 | if change['new']: | |
|
1496 | self.evaluation = 'unsafe' | |||
1422 | self.splitter.delims = GREEDY_DELIMS |
|
1497 | self.splitter.delims = GREEDY_DELIMS | |
1423 | else: |
|
1498 | else: | |
|
1499 | self.evaluation = 'limitted' | |||
1424 | self.splitter.delims = DELIMS |
|
1500 | self.splitter.delims = DELIMS | |
1425 |
|
1501 | |||
1426 | dict_keys_only = Bool( |
|
1502 | dict_keys_only = Bool( | |
@@ -2149,12 +2225,17 b' class IPCompleter(Completer):' | |||||
2149 | return method() |
|
2225 | return method() | |
2150 |
|
2226 | |||
2151 | # Special case some common in-memory dict-like types |
|
2227 | # Special case some common in-memory dict-like types | |
2152 |
if isinstance(obj, dict) or |
|
2228 | if (isinstance(obj, dict) or | |
2153 | _safe_isinstance(obj, 'pandas', 'DataFrame'): |
|
2229 | _safe_isinstance(obj, 'pandas', 'DataFrame')): | |
2154 | try: |
|
2230 | try: | |
2155 | return list(obj.keys()) |
|
2231 | return list(obj.keys()) | |
2156 | except Exception: |
|
2232 | except Exception: | |
2157 | return [] |
|
2233 | return [] | |
|
2234 | elif _safe_isinstance(obj, 'pandas', 'core', 'indexing', '_LocIndexer'): | |||
|
2235 | try: | |||
|
2236 | return list(obj.obj.keys()) | |||
|
2237 | except Exception: | |||
|
2238 | return [] | |||
2158 | elif _safe_isinstance(obj, 'numpy', 'ndarray') or\ |
|
2239 | elif _safe_isinstance(obj, 'numpy', 'ndarray') or\ | |
2159 | _safe_isinstance(obj, 'numpy', 'void'): |
|
2240 | _safe_isinstance(obj, 'numpy', 'void'): | |
2160 | return obj.dtype.names or [] |
|
2241 | return obj.dtype.names or [] | |
@@ -2175,65 +2256,43 b' class IPCompleter(Completer):' | |||||
2175 | You can use :meth:`dict_key_matcher` instead. |
|
2256 | You can use :meth:`dict_key_matcher` instead. | |
2176 | """ |
|
2257 | """ | |
2177 |
|
2258 | |||
2178 | if self.__dict_key_regexps is not None: |
|
2259 | # Short-circuit on closed dictionary (regular expression would | |
2179 | regexps = self.__dict_key_regexps |
|
2260 | # not match anyway, but would take quite a while). | |
2180 | else: |
|
2261 | if self.text_until_cursor.strip().endswith(']'): | |
2181 | dict_key_re_fmt = r'''(?x) |
|
2262 | return [] | |
2182 | ( # match dict-referring expression wrt greedy setting |
|
|||
2183 | %s |
|
|||
2184 | ) |
|
|||
2185 | \[ # open bracket |
|
|||
2186 | \s* # and optional whitespace |
|
|||
2187 | # Capture any number of str-like objects (e.g. "a", "b", 'c') |
|
|||
2188 | ((?:[uUbB]? # string prefix (r not handled) |
|
|||
2189 | (?: |
|
|||
2190 | '(?:[^']|(?<!\\)\\')*' |
|
|||
2191 | | |
|
|||
2192 | "(?:[^"]|(?<!\\)\\")*" |
|
|||
2193 | ) |
|
|||
2194 | \s*,\s* |
|
|||
2195 | )*) |
|
|||
2196 | ([uUbB]? # string prefix (r not handled) |
|
|||
2197 | (?: # unclosed string |
|
|||
2198 | '(?:[^']|(?<!\\)\\')* |
|
|||
2199 | | |
|
|||
2200 | "(?:[^"]|(?<!\\)\\")* |
|
|||
2201 | ) |
|
|||
2202 | )? |
|
|||
2203 | $ |
|
|||
2204 | ''' |
|
|||
2205 | regexps = self.__dict_key_regexps = { |
|
|||
2206 | False: re.compile(dict_key_re_fmt % r''' |
|
|||
2207 | # identifiers separated by . |
|
|||
2208 | (?!\d)\w+ |
|
|||
2209 | (?:\.(?!\d)\w+)* |
|
|||
2210 | '''), |
|
|||
2211 | True: re.compile(dict_key_re_fmt % ''' |
|
|||
2212 | .+ |
|
|||
2213 | ''') |
|
|||
2214 | } |
|
|||
2215 |
|
2263 | |||
2216 |
match = |
|
2264 | match = DICT_MATCHER_REGEX.search(self.text_until_cursor) | |
2217 |
|
2265 | |||
2218 | if match is None: |
|
2266 | if match is None: | |
2219 | return [] |
|
2267 | return [] | |
2220 |
|
2268 | |||
2221 |
expr, pr |
|
2269 | expr, prior_tuple_keys, key_prefix = match.groups() | |
2222 | try: |
|
2270 | ||
2223 | obj = eval(expr, self.namespace) |
|
2271 | obj = self._evaluate_expr(expr) | |
2224 | except Exception: |
|
2272 | ||
2225 | try: |
|
2273 | if obj is not_found: | |
2226 | obj = eval(expr, self.global_namespace) |
|
|||
2227 | except Exception: |
|
|||
2228 |
|
|
2274 | return [] | |
2229 |
|
2275 | |||
2230 | keys = self._get_keys(obj) |
|
2276 | keys = self._get_keys(obj) | |
2231 | if not keys: |
|
2277 | if not keys: | |
2232 | return keys |
|
2278 | return keys | |
2233 |
|
2279 | |||
2234 | extra_prefix = eval(prefix0) if prefix0 != '' else None |
|
2280 | tuple_prefix = guarded_eval( | |
|
2281 | prior_tuple_keys, | |||
|
2282 | EvaluationContext( | |||
|
2283 | globals_=self.global_namespace, | |||
|
2284 | locals_=self.namespace, | |||
|
2285 | evaluation=self.evaluation, | |||
|
2286 | in_subscript=True | |||
|
2287 | ) | |||
|
2288 | ) | |||
2235 |
|
2289 | |||
2236 |
closing_quote, token_offset, matches = match_dict_keys( |
|
2290 | closing_quote, token_offset, matches = match_dict_keys( | |
|
2291 | keys, | |||
|
2292 | key_prefix, | |||
|
2293 | self.splitter.delims, | |||
|
2294 | extra_prefix=tuple_prefix | |||
|
2295 | ) | |||
2237 | if not matches: |
|
2296 | if not matches: | |
2238 | return matches |
|
2297 | return matches | |
2239 |
|
2298 | |||
@@ -2242,7 +2301,7 b' class IPCompleter(Completer):' | |||||
2242 | # - the start of the key text |
|
2301 | # - the start of the key text | |
2243 | # - the start of the completion |
|
2302 | # - the start of the completion | |
2244 | text_start = len(self.text_until_cursor) - len(text) |
|
2303 | text_start = len(self.text_until_cursor) - len(text) | |
2245 | if prefix: |
|
2304 | if key_prefix: | |
2246 | key_start = match.start(3) |
|
2305 | key_start = match.start(3) | |
2247 | completion_start = key_start + token_offset |
|
2306 | completion_start = key_start + token_offset | |
2248 | else: |
|
2307 | else: |
@@ -113,6 +113,17 b' def greedy_completion():' | |||||
113 |
|
113 | |||
114 |
|
114 | |||
115 | @contextmanager |
|
115 | @contextmanager | |
|
116 | def evaluation_level(evaluation: str): | |||
|
117 | ip = get_ipython() | |||
|
118 | evaluation_original = ip.Completer.evaluation | |||
|
119 | try: | |||
|
120 | ip.Completer.evaluation = evaluation | |||
|
121 | yield | |||
|
122 | finally: | |||
|
123 | ip.Completer.evaluation = evaluation_original | |||
|
124 | ||||
|
125 | ||||
|
126 | @contextmanager | |||
116 | def custom_matchers(matchers): |
|
127 | def custom_matchers(matchers): | |
117 | ip = get_ipython() |
|
128 | ip = get_ipython() | |
118 | try: |
|
129 | try: | |
@@ -852,8 +863,6 b' class TestCompleter(unittest.TestCase):' | |||||
852 | assert match_dict_keys(keys, '"', delims=delims) == ('"', 1, ["foo"]) |
|
863 | assert match_dict_keys(keys, '"', delims=delims) == ('"', 1, ["foo"]) | |
853 | assert match_dict_keys(keys, '"f', delims=delims) == ('"', 1, ["foo"]) |
|
864 | assert match_dict_keys(keys, '"f', delims=delims) == ('"', 1, ["foo"]) | |
854 |
|
865 | |||
855 | match_dict_keys |
|
|||
856 |
|
||||
857 | def test_match_dict_keys_tuple(self): |
|
866 | def test_match_dict_keys_tuple(self): | |
858 | """ |
|
867 | """ | |
859 | Test that match_dict_keys called with extra prefix works on a couple of use case, |
|
868 | Test that match_dict_keys called with extra prefix works on a couple of use case, | |
@@ -883,6 +892,11 b' class TestCompleter(unittest.TestCase):' | |||||
883 | assert match_dict_keys(keys, "'foo", delims=delims, extra_prefix=('foo1', 'foo2', 'foo3')) == ("'", 1, ["foo4"]) |
|
892 | assert match_dict_keys(keys, "'foo", delims=delims, extra_prefix=('foo1', 'foo2', 'foo3')) == ("'", 1, ["foo4"]) | |
884 | assert match_dict_keys(keys, "'foo", delims=delims, extra_prefix=('foo1', 'foo2', 'foo3', 'foo4')) == ("'", 1, []) |
|
893 | assert match_dict_keys(keys, "'foo", delims=delims, extra_prefix=('foo1', 'foo2', 'foo3', 'foo4')) == ("'", 1, []) | |
885 |
|
894 | |||
|
895 | keys = [("foo", 1111), ("foo", 2222), (3333, "bar"), (3333, 'test')] | |||
|
896 | assert match_dict_keys(keys, "'", delims=delims, extra_prefix=("foo",)) == ("'", 1, ["1111", "2222"]) | |||
|
897 | assert match_dict_keys(keys, "'", delims=delims, extra_prefix=(3333,)) == ("'", 1, ["bar", "test"]) | |||
|
898 | assert match_dict_keys(keys, "'", delims=delims, extra_prefix=("3333",)) == ("'", 1, []) | |||
|
899 | ||||
886 | def test_dict_key_completion_string(self): |
|
900 | def test_dict_key_completion_string(self): | |
887 | """Test dictionary key completion for string keys""" |
|
901 | """Test dictionary key completion for string keys""" | |
888 | ip = get_ipython() |
|
902 | ip = get_ipython() | |
@@ -1050,6 +1064,7 b' class TestCompleter(unittest.TestCase):' | |||||
1050 |
|
1064 | |||
1051 | ip.user_ns["C"] = C |
|
1065 | ip.user_ns["C"] = C | |
1052 | ip.user_ns["get"] = lambda: d |
|
1066 | ip.user_ns["get"] = lambda: d | |
|
1067 | ip.user_ns["nested"] = {'x': d} | |||
1053 |
|
1068 | |||
1054 | def assert_no_completion(**kwargs): |
|
1069 | def assert_no_completion(**kwargs): | |
1055 | _, matches = complete(**kwargs) |
|
1070 | _, matches = complete(**kwargs) | |
@@ -1075,6 +1090,13 b' class TestCompleter(unittest.TestCase):' | |||||
1075 | assert_completion(line_buffer="(d[") |
|
1090 | assert_completion(line_buffer="(d[") | |
1076 | assert_completion(line_buffer="C.data[") |
|
1091 | assert_completion(line_buffer="C.data[") | |
1077 |
|
1092 | |||
|
1093 | # nested dict completion | |||
|
1094 | assert_completion(line_buffer="nested['x'][") | |||
|
1095 | ||||
|
1096 | with evaluation_level('minimal'): | |||
|
1097 | with pytest.raises(AssertionError): | |||
|
1098 | assert_completion(line_buffer="nested['x'][") | |||
|
1099 | ||||
1078 | # greedy flag |
|
1100 | # greedy flag | |
1079 | def assert_completion(**kwargs): |
|
1101 | def assert_completion(**kwargs): | |
1080 | _, matches = complete(**kwargs) |
|
1102 | _, matches = complete(**kwargs) | |
@@ -1162,12 +1184,21 b' class TestCompleter(unittest.TestCase):' | |||||
1162 | _, matches = complete(line_buffer="d['") |
|
1184 | _, matches = complete(line_buffer="d['") | |
1163 | self.assertIn("my_head", matches) |
|
1185 | self.assertIn("my_head", matches) | |
1164 | self.assertIn("my_data", matches) |
|
1186 | self.assertIn("my_data", matches) | |
1165 |
|
|
1187 | def completes_on_nested(): | |
1166 | with greedy_completion(): |
|
|||
1167 | ip.user_ns["d"] = numpy.zeros(2, dtype=dt) |
|
1188 | ip.user_ns["d"] = numpy.zeros(2, dtype=dt) | |
1168 | _, matches = complete(line_buffer="d[1]['my_head']['") |
|
1189 | _, matches = complete(line_buffer="d[1]['my_head']['") | |
1169 | self.assertTrue(any(["my_dt" in m for m in matches])) |
|
1190 | self.assertTrue(any(["my_dt" in m for m in matches])) | |
1170 | self.assertTrue(any(["my_df" in m for m in matches])) |
|
1191 | self.assertTrue(any(["my_df" in m for m in matches])) | |
|
1192 | # complete on a nested level | |||
|
1193 | with greedy_completion(): | |||
|
1194 | completes_on_nested() | |||
|
1195 | ||||
|
1196 | with evaluation_level('limitted'): | |||
|
1197 | completes_on_nested() | |||
|
1198 | ||||
|
1199 | with evaluation_level('minimal'): | |||
|
1200 | with pytest.raises(AssertionError): | |||
|
1201 | completes_on_nested() | |||
1171 |
|
1202 | |||
1172 | @dec.skip_without("pandas") |
|
1203 | @dec.skip_without("pandas") | |
1173 | def test_dataframe_key_completion(self): |
|
1204 | def test_dataframe_key_completion(self): | |
@@ -1180,6 +1211,17 b' class TestCompleter(unittest.TestCase):' | |||||
1180 | _, matches = complete(line_buffer="d['") |
|
1211 | _, matches = complete(line_buffer="d['") | |
1181 | self.assertIn("hello", matches) |
|
1212 | self.assertIn("hello", matches) | |
1182 | self.assertIn("world", matches) |
|
1213 | self.assertIn("world", matches) | |
|
1214 | _, matches = complete(line_buffer="d.loc[:, '") | |||
|
1215 | self.assertIn("hello", matches) | |||
|
1216 | self.assertIn("world", matches) | |||
|
1217 | _, matches = complete(line_buffer="d.loc[1:, '") | |||
|
1218 | self.assertIn("hello", matches) | |||
|
1219 | _, matches = complete(line_buffer="d.loc[1:1, '") | |||
|
1220 | self.assertIn("hello", matches) | |||
|
1221 | _, matches = complete(line_buffer="d.loc[1:1:-1, '") | |||
|
1222 | self.assertIn("hello", matches) | |||
|
1223 | _, matches = complete(line_buffer="d.loc[::, '") | |||
|
1224 | self.assertIn("hello", matches) | |||
1183 |
|
1225 | |||
1184 | def test_dict_key_completion_invalids(self): |
|
1226 | def test_dict_key_completion_invalids(self): | |
1185 | """Smoke test cases dict key completion can't handle""" |
|
1227 | """Smoke test cases dict key completion can't handle""" |
General Comments 0
You need to be logged in to leave comments.
Login now