Show More
This diff has been collapsed as it changes many lines, (541 lines changed) Show them Hide them | |||
@@ -0,0 +1,541 b'' | |||
|
1 | from typing import Callable, Protocol, Set, Tuple, NamedTuple, Literal, Union | |
|
2 | import collections | |
|
3 | import sys | |
|
4 | import ast | |
|
5 | import types | |
|
6 | from functools import cached_property | |
|
7 | from dataclasses import dataclass, field | |
|
8 | ||
|
9 | ||
|
10 | class HasGetItem(Protocol): | |
|
11 | def __getitem__(self, key) -> None: ... | |
|
12 | ||
|
13 | ||
|
14 | class InstancesHaveGetItem(Protocol): | |
|
15 | def __call__(self) -> HasGetItem: ... | |
|
16 | ||
|
17 | ||
|
18 | class HasGetAttr(Protocol): | |
|
19 | def __getattr__(self, key) -> None: ... | |
|
20 | ||
|
21 | ||
|
22 | class DoesNotHaveGetAttr(Protocol): | |
|
23 | pass | |
|
24 | ||
|
25 | # By default `__getattr__` is not explicitly implemented on most objects | |
|
26 | MayHaveGetattr = Union[HasGetAttr, DoesNotHaveGetAttr] | |
|
27 | ||
|
28 | ||
|
29 | def unbind_method(func: Callable) -> Union[Callable, None]: | |
|
30 | """Get unbound method for given bound method. | |
|
31 | ||
|
32 | Returns None if cannot get unbound method.""" | |
|
33 | owner = getattr(func, '__self__', None) | |
|
34 | owner_class = type(owner) | |
|
35 | name = getattr(func, '__name__', None) | |
|
36 | instance_dict_overrides = getattr(owner, '__dict__', None) | |
|
37 | if ( | |
|
38 | owner is not None | |
|
39 | and | |
|
40 | name | |
|
41 | and | |
|
42 | ( | |
|
43 | not instance_dict_overrides | |
|
44 | or | |
|
45 | ( | |
|
46 | instance_dict_overrides | |
|
47 | and name not in instance_dict_overrides | |
|
48 | ) | |
|
49 | ) | |
|
50 | ): | |
|
51 | return getattr(owner_class, name) | |
|
52 | ||
|
53 | ||
|
54 | @dataclass | |
|
55 | class EvaluationPolicy: | |
|
56 | allow_locals_access: bool = False | |
|
57 | allow_globals_access: bool = False | |
|
58 | allow_item_access: bool = False | |
|
59 | allow_attr_access: bool = False | |
|
60 | allow_builtins_access: bool = False | |
|
61 | allow_any_calls: bool = False | |
|
62 | allowed_calls: Set[Callable] = field(default_factory=set) | |
|
63 | ||
|
64 | def can_get_item(self, value, item): | |
|
65 | return self.allow_item_access | |
|
66 | ||
|
67 | def can_get_attr(self, value, attr): | |
|
68 | return self.allow_attr_access | |
|
69 | ||
|
70 | def can_call(self, func): | |
|
71 | if self.allow_any_calls: | |
|
72 | return True | |
|
73 | ||
|
74 | if func in self.allowed_calls: | |
|
75 | return True | |
|
76 | ||
|
77 | owner_method = unbind_method(func) | |
|
78 | if owner_method and owner_method in self.allowed_calls: | |
|
79 | return True | |
|
80 | ||
|
81 | def has_original_dunder_external(value, module_name, access_path, method_name,): | |
|
82 | try: | |
|
83 | if module_name not in sys.modules: | |
|
84 | return False | |
|
85 | member_type = sys.modules[module_name] | |
|
86 | for attr in access_path: | |
|
87 | member_type = getattr(member_type, attr) | |
|
88 | value_type = type(value) | |
|
89 | if type(value) == member_type: | |
|
90 | return True | |
|
91 | if isinstance(value, member_type): | |
|
92 | method = getattr(value_type, method_name, None) | |
|
93 | member_method = getattr(member_type, method_name, None) | |
|
94 | if member_method == method: | |
|
95 | return True | |
|
96 | except (AttributeError, KeyError): | |
|
97 | return False | |
|
98 | ||
|
99 | ||
|
100 | def has_original_dunder( | |
|
101 | value, | |
|
102 | allowed_types, | |
|
103 | allowed_methods, | |
|
104 | allowed_external, | |
|
105 | method_name | |
|
106 | ): | |
|
107 | # note: Python ignores `__getattr__`/`__getitem__` on instances, | |
|
108 | # we only need to check at class level | |
|
109 | value_type = type(value) | |
|
110 | ||
|
111 | # strict type check passes β no need to check method | |
|
112 | if value_type in allowed_types: | |
|
113 | return True | |
|
114 | ||
|
115 | method = getattr(value_type, method_name, None) | |
|
116 | ||
|
117 | if not method: | |
|
118 | return None | |
|
119 | ||
|
120 | if method in allowed_methods: | |
|
121 | return True | |
|
122 | ||
|
123 | for module_name, *access_path in allowed_external: | |
|
124 | if has_original_dunder_external(value, module_name, access_path, method_name): | |
|
125 | return True | |
|
126 | ||
|
127 | return False | |
|
128 | ||
|
129 | ||
|
130 | @dataclass | |
|
131 | class SelectivePolicy(EvaluationPolicy): | |
|
132 | allowed_getitem: Set[HasGetItem] = field(default_factory=set) | |
|
133 | allowed_getitem_external: Set[Tuple[str, ...]] = field(default_factory=set) | |
|
134 | allowed_getattr: Set[MayHaveGetattr] = field(default_factory=set) | |
|
135 | allowed_getattr_external: Set[Tuple[str, ...]] = field(default_factory=set) | |
|
136 | ||
|
137 | def can_get_attr(self, value, attr): | |
|
138 | has_original_attribute = has_original_dunder( | |
|
139 | value, | |
|
140 | allowed_types=self.allowed_getattr, | |
|
141 | allowed_methods=self._getattribute_methods, | |
|
142 | allowed_external=self.allowed_getattr_external, | |
|
143 | method_name='__getattribute__' | |
|
144 | ) | |
|
145 | has_original_attr = has_original_dunder( | |
|
146 | value, | |
|
147 | allowed_types=self.allowed_getattr, | |
|
148 | allowed_methods=self._getattr_methods, | |
|
149 | allowed_external=self.allowed_getattr_external, | |
|
150 | method_name='__getattr__' | |
|
151 | ) | |
|
152 | # Many objects do not have `__getattr__`, this is fine | |
|
153 | if has_original_attr is None and has_original_attribute: | |
|
154 | return True | |
|
155 | ||
|
156 | # Accept objects without modifications to `__getattr__` and `__getattribute__` | |
|
157 | return has_original_attr and has_original_attribute | |
|
158 | ||
|
159 | def get_attr(self, value, attr): | |
|
160 | if self.can_get_attr(value, attr): | |
|
161 | return getattr(value, attr) | |
|
162 | ||
|
163 | ||
|
164 | def can_get_item(self, value, item): | |
|
165 | """Allow accessing `__getiitem__` of allow-listed instances unless it was not modified.""" | |
|
166 | return has_original_dunder( | |
|
167 | value, | |
|
168 | allowed_types=self.allowed_getitem, | |
|
169 | allowed_methods=self._getitem_methods, | |
|
170 | allowed_external=self.allowed_getitem_external, | |
|
171 | method_name='__getitem__' | |
|
172 | ) | |
|
173 | ||
|
174 | @cached_property | |
|
175 | def _getitem_methods(self) -> Set[Callable]: | |
|
176 | return self._safe_get_methods( | |
|
177 | self.allowed_getitem, | |
|
178 | '__getitem__' | |
|
179 | ) | |
|
180 | ||
|
181 | @cached_property | |
|
182 | def _getattr_methods(self) -> Set[Callable]: | |
|
183 | return self._safe_get_methods( | |
|
184 | self.allowed_getattr, | |
|
185 | '__getattr__' | |
|
186 | ) | |
|
187 | ||
|
188 | @cached_property | |
|
189 | def _getattribute_methods(self) -> Set[Callable]: | |
|
190 | return self._safe_get_methods( | |
|
191 | self.allowed_getattr, | |
|
192 | '__getattribute__' | |
|
193 | ) | |
|
194 | ||
|
195 | def _safe_get_methods(self, classes, name) -> Set[Callable]: | |
|
196 | return { | |
|
197 | method | |
|
198 | for class_ in classes | |
|
199 | for method in [getattr(class_, name, None)] | |
|
200 | if method | |
|
201 | } | |
|
202 | ||
|
203 | ||
|
204 | class DummyNamedTuple(NamedTuple): | |
|
205 | pass | |
|
206 | ||
|
207 | ||
|
208 | class EvaluationContext(NamedTuple): | |
|
209 | locals_: dict | |
|
210 | globals_: dict | |
|
211 | evaluation: Literal['forbidden', 'minimal', 'limitted', 'unsafe', 'dangerous'] = 'forbidden' | |
|
212 | in_subscript: bool = False | |
|
213 | ||
|
214 | ||
|
215 | class IdentitySubscript: | |
|
216 | def __getitem__(self, key): | |
|
217 | return key | |
|
218 | ||
|
219 | IDENTITY_SUBSCRIPT = IdentitySubscript() | |
|
220 | SUBSCRIPT_MARKER = '__SUBSCRIPT_SENTINEL__' | |
|
221 | ||
|
222 | class GuardRejection(ValueError): | |
|
223 | pass | |
|
224 | ||
|
225 | ||
|
226 | def guarded_eval( | |
|
227 | code: str, | |
|
228 | context: EvaluationContext | |
|
229 | ): | |
|
230 | locals_ = context.locals_ | |
|
231 | ||
|
232 | if context.evaluation == 'forbidden': | |
|
233 | raise GuardRejection('Forbidden mode') | |
|
234 | ||
|
235 | # note: not using `ast.literal_eval` as it does not implement | |
|
236 | # getitem at all, for example it fails on simple `[0][1]` | |
|
237 | ||
|
238 | if context.in_subscript: | |
|
239 | # syntatic sugar for ellipsis (:) is only available in susbcripts | |
|
240 | # so we need to trick the ast parser into thinking that we have | |
|
241 | # a subscript, but we need to be able to later recognise that we did | |
|
242 | # it so we can ignore the actual __getitem__ operation | |
|
243 | if not code: | |
|
244 | return tuple() | |
|
245 | locals_ = locals_.copy() | |
|
246 | locals_[SUBSCRIPT_MARKER] = IDENTITY_SUBSCRIPT | |
|
247 | code = SUBSCRIPT_MARKER + '[' + code + ']' | |
|
248 | context = EvaluationContext(**{ | |
|
249 | **context._asdict(), | |
|
250 | **{'locals_': locals_} | |
|
251 | }) | |
|
252 | ||
|
253 | if context.evaluation == 'dangerous': | |
|
254 | return eval(code, context.globals_, context.locals_) | |
|
255 | ||
|
256 | expression = ast.parse(code, mode='eval') | |
|
257 | ||
|
258 | return eval_node(expression, context) | |
|
259 | ||
|
260 | def eval_node(node: Union[ast.AST, None], context: EvaluationContext): | |
|
261 | """ | |
|
262 | Evaluate AST node in provided context. | |
|
263 | ||
|
264 | Applies evaluation restrictions defined in the context. | |
|
265 | ||
|
266 | Currently does not support evaluation of functions with arguments. | |
|
267 | ||
|
268 | Does not evaluate actions which always have side effects: | |
|
269 | - class definitions (`class sth: ...`) | |
|
270 | - function definitions (`def sth: ...`) | |
|
271 | - variable assignments (`x = 1`) | |
|
272 | - augumented assignments (`x += 1`) | |
|
273 | - deletions (`del x`) | |
|
274 | ||
|
275 | Does not evaluate operations which do not return values: | |
|
276 | - assertions (`assert x`) | |
|
277 | - pass (`pass`) | |
|
278 | - imports (`import x`) | |
|
279 | - control flow | |
|
280 | - conditionals (`if x:`) except for terenary IfExp (`a if x else b`) | |
|
281 | - loops (`for` and `while`) | |
|
282 | - exception handling | |
|
283 | """ | |
|
284 | policy = EVALUATION_POLICIES[context.evaluation] | |
|
285 | if node is None: | |
|
286 | return None | |
|
287 | if isinstance(node, ast.Expression): | |
|
288 | return eval_node(node.body, context) | |
|
289 | if isinstance(node, ast.BinOp): | |
|
290 | # TODO: add guards | |
|
291 | left = eval_node(node.left, context) | |
|
292 | right = eval_node(node.right, context) | |
|
293 | if isinstance(node.op, ast.Add): | |
|
294 | return left + right | |
|
295 | if isinstance(node.op, ast.Sub): | |
|
296 | return left - right | |
|
297 | if isinstance(node.op, ast.Mult): | |
|
298 | return left * right | |
|
299 | if isinstance(node.op, ast.Div): | |
|
300 | return left / right | |
|
301 | if isinstance(node.op, ast.FloorDiv): | |
|
302 | return left // right | |
|
303 | if isinstance(node.op, ast.Mod): | |
|
304 | return left % right | |
|
305 | if isinstance(node.op, ast.Pow): | |
|
306 | return left ** right | |
|
307 | if isinstance(node.op, ast.LShift): | |
|
308 | return left << right | |
|
309 | if isinstance(node.op, ast.RShift): | |
|
310 | return left >> right | |
|
311 | if isinstance(node.op, ast.BitOr): | |
|
312 | return left | right | |
|
313 | if isinstance(node.op, ast.BitXor): | |
|
314 | return left ^ right | |
|
315 | if isinstance(node.op, ast.BitAnd): | |
|
316 | return left & right | |
|
317 | if isinstance(node.op, ast.MatMult): | |
|
318 | return left @ right | |
|
319 | if isinstance(node, ast.Constant): | |
|
320 | return node.value | |
|
321 | if isinstance(node, ast.Index): | |
|
322 | return eval_node(node.value, context) | |
|
323 | if isinstance(node, ast.Tuple): | |
|
324 | return tuple( | |
|
325 | eval_node(e, context) | |
|
326 | for e in node.elts | |
|
327 | ) | |
|
328 | if isinstance(node, ast.List): | |
|
329 | return [ | |
|
330 | eval_node(e, context) | |
|
331 | for e in node.elts | |
|
332 | ] | |
|
333 | if isinstance(node, ast.Set): | |
|
334 | return { | |
|
335 | eval_node(e, context) | |
|
336 | for e in node.elts | |
|
337 | } | |
|
338 | if isinstance(node, ast.Dict): | |
|
339 | return dict(zip( | |
|
340 | [eval_node(k, context) for k in node.keys], | |
|
341 | [eval_node(v, context) for v in node.values] | |
|
342 | )) | |
|
343 | if isinstance(node, ast.Slice): | |
|
344 | return slice( | |
|
345 | eval_node(node.lower, context), | |
|
346 | eval_node(node.upper, context), | |
|
347 | eval_node(node.step, context) | |
|
348 | ) | |
|
349 | if isinstance(node, ast.ExtSlice): | |
|
350 | return tuple([ | |
|
351 | eval_node(dim, context) | |
|
352 | for dim in node.dims | |
|
353 | ]) | |
|
354 | if isinstance(node, ast.UnaryOp): | |
|
355 | # TODO: add guards | |
|
356 | value = eval_node(node.operand, context) | |
|
357 | if isinstance(node.op, ast.USub): | |
|
358 | return -value | |
|
359 | if isinstance(node.op, ast.UAdd): | |
|
360 | return +value | |
|
361 | if isinstance(node.op, ast.Invert): | |
|
362 | return ~value | |
|
363 | if isinstance(node.op, ast.Not): | |
|
364 | return not value | |
|
365 | raise ValueError('Unhandled unary operation:', node.op) | |
|
366 | if isinstance(node, ast.Subscript): | |
|
367 | value = eval_node(node.value, context) | |
|
368 | slice_ = eval_node(node.slice, context) | |
|
369 | if policy.can_get_item(value, slice_): | |
|
370 | return value[slice_] | |
|
371 | raise GuardRejection( | |
|
372 | 'Subscript access (`__getitem__`) for', | |
|
373 | type(value), # not joined to avoid calling `repr` | |
|
374 | f' not allowed in {context.evaluation} mode' | |
|
375 | ) | |
|
376 | if isinstance(node, ast.Name): | |
|
377 | if policy.allow_locals_access and node.id in context.locals_: | |
|
378 | return context.locals_[node.id] | |
|
379 | if policy.allow_globals_access and node.id in context.globals_: | |
|
380 | return context.globals_[node.id] | |
|
381 | if policy.allow_builtins_access and node.id in __builtins__: | |
|
382 | return __builtins__[node.id] | |
|
383 | if not policy.allow_globals_access and not policy.allow_locals_access: | |
|
384 | raise GuardRejection( | |
|
385 | f'Namespace access not allowed in {context.evaluation} mode' | |
|
386 | ) | |
|
387 | else: | |
|
388 | raise NameError(f'{node.id} not found in locals nor globals') | |
|
389 | if isinstance(node, ast.Attribute): | |
|
390 | value = eval_node(node.value, context) | |
|
391 | if policy.can_get_attr(value, node.attr): | |
|
392 | return getattr(value, node.attr) | |
|
393 | raise GuardRejection( | |
|
394 | 'Attribute access (`__getattr__`) for', | |
|
395 | type(value), # not joined to avoid calling `repr` | |
|
396 | f'not allowed in {context.evaluation} mode' | |
|
397 | ) | |
|
398 | if isinstance(node, ast.IfExp): | |
|
399 | test = eval_node(node.test, context) | |
|
400 | if test: | |
|
401 | return eval_node(node.body, context) | |
|
402 | else: | |
|
403 | return eval_node(node.orelse, context) | |
|
404 | if isinstance(node, ast.Call): | |
|
405 | func = eval_node(node.func, context) | |
|
406 | print(node.keywords) | |
|
407 | if policy.can_call(func) and not node.keywords: | |
|
408 | args = [ | |
|
409 | eval_node(arg, context) | |
|
410 | for arg in node.args | |
|
411 | ] | |
|
412 | return func(*args) | |
|
413 | raise GuardRejection( | |
|
414 | 'Call for', | |
|
415 | func, # not joined to avoid calling `repr` | |
|
416 | f'not allowed in {context.evaluation} mode' | |
|
417 | ) | |
|
418 | raise ValueError('Unhandled node', node) | |
|
419 | ||
|
420 | ||
|
421 | SUPPORTED_EXTERNAL_GETITEM = { | |
|
422 | ('pandas', 'core', 'indexing', '_iLocIndexer'), | |
|
423 | ('pandas', 'core', 'indexing', '_LocIndexer'), | |
|
424 | ('pandas', 'DataFrame'), | |
|
425 | ('pandas', 'Series'), | |
|
426 | ('numpy', 'ndarray'), | |
|
427 | ('numpy', 'void') | |
|
428 | } | |
|
429 | ||
|
430 | BUILTIN_GETITEM = { | |
|
431 | dict, | |
|
432 | str, | |
|
433 | bytes, | |
|
434 | list, | |
|
435 | tuple, | |
|
436 | collections.defaultdict, | |
|
437 | collections.deque, | |
|
438 | collections.OrderedDict, | |
|
439 | collections.ChainMap, | |
|
440 | collections.UserDict, | |
|
441 | collections.UserList, | |
|
442 | collections.UserString, | |
|
443 | DummyNamedTuple, | |
|
444 | IdentitySubscript | |
|
445 | } | |
|
446 | ||
|
447 | ||
|
448 | def _list_methods(cls, source=None): | |
|
449 | """For use on immutable objects or with methods returning a copy""" | |
|
450 | return [ | |
|
451 | getattr(cls, k) | |
|
452 | for k in (source if source else dir(cls)) | |
|
453 | ] | |
|
454 | ||
|
455 | ||
|
456 | dict_non_mutating_methods = ('copy', 'keys', 'values', 'items') | |
|
457 | list_non_mutating_methods = ('copy', 'index', 'count') | |
|
458 | set_non_mutating_methods = set(dir(set)) & set(dir(frozenset)) | |
|
459 | ||
|
460 | ||
|
461 | dict_keys = type({}.keys()) | |
|
462 | method_descriptor = type(list.copy) | |
|
463 | ||
|
464 | ALLOWED_CALLS = { | |
|
465 | bytes, | |
|
466 | *_list_methods(bytes), | |
|
467 | dict, | |
|
468 | *_list_methods(dict, dict_non_mutating_methods), | |
|
469 | dict_keys.isdisjoint, | |
|
470 | list, | |
|
471 | *_list_methods(list, list_non_mutating_methods), | |
|
472 | set, | |
|
473 | *_list_methods(set, set_non_mutating_methods), | |
|
474 | frozenset, | |
|
475 | *_list_methods(frozenset), | |
|
476 | range, | |
|
477 | str, | |
|
478 | *_list_methods(str), | |
|
479 | tuple, | |
|
480 | *_list_methods(tuple), | |
|
481 | collections.deque, | |
|
482 | *_list_methods(collections.deque, list_non_mutating_methods), | |
|
483 | collections.defaultdict, | |
|
484 | *_list_methods(collections.defaultdict, dict_non_mutating_methods), | |
|
485 | collections.OrderedDict, | |
|
486 | *_list_methods(collections.OrderedDict, dict_non_mutating_methods), | |
|
487 | collections.UserDict, | |
|
488 | *_list_methods(collections.UserDict, dict_non_mutating_methods), | |
|
489 | collections.UserList, | |
|
490 | *_list_methods(collections.UserList, list_non_mutating_methods), | |
|
491 | collections.UserString, | |
|
492 | *_list_methods(collections.UserString, dir(str)), | |
|
493 | collections.Counter, | |
|
494 | *_list_methods(collections.Counter, dict_non_mutating_methods), | |
|
495 | collections.Counter.elements, | |
|
496 | collections.Counter.most_common | |
|
497 | } | |
|
498 | ||
|
499 | EVALUATION_POLICIES = { | |
|
500 | 'minimal': EvaluationPolicy( | |
|
501 | allow_builtins_access=True, | |
|
502 | allow_locals_access=False, | |
|
503 | allow_globals_access=False, | |
|
504 | allow_item_access=False, | |
|
505 | allow_attr_access=False, | |
|
506 | allowed_calls=set(), | |
|
507 | allow_any_calls=False | |
|
508 | ), | |
|
509 | 'limitted': SelectivePolicy( | |
|
510 | # TODO: | |
|
511 | # - should reject binary and unary operations if custom methods would be dispatched | |
|
512 | allowed_getitem=BUILTIN_GETITEM, | |
|
513 | allowed_getitem_external=SUPPORTED_EXTERNAL_GETITEM, | |
|
514 | allowed_getattr={ | |
|
515 | *BUILTIN_GETITEM, | |
|
516 | set, | |
|
517 | frozenset, | |
|
518 | object, | |
|
519 | type, # `type` handles a lot of generic cases, e.g. numbers as in `int.real`. | |
|
520 | dict_keys, | |
|
521 | method_descriptor | |
|
522 | }, | |
|
523 | allowed_getattr_external={ | |
|
524 | # pandas Series/Frame implements custom `__getattr__` | |
|
525 | ('pandas', 'DataFrame'), | |
|
526 | ('pandas', 'Series') | |
|
527 | }, | |
|
528 | allow_builtins_access=True, | |
|
529 | allow_locals_access=True, | |
|
530 | allow_globals_access=True, | |
|
531 | allowed_calls=ALLOWED_CALLS | |
|
532 | ), | |
|
533 | 'unsafe': EvaluationPolicy( | |
|
534 | allow_builtins_access=True, | |
|
535 | allow_locals_access=True, | |
|
536 | allow_globals_access=True, | |
|
537 | allow_attr_access=True, | |
|
538 | allow_item_access=True, | |
|
539 | allow_any_calls=True | |
|
540 | ) | |
|
541 | } No newline at end of file |
@@ -0,0 +1,286 b'' | |||
|
1 | from typing import NamedTuple | |
|
2 | from IPython.core.guarded_eval import EvaluationContext, GuardRejection, guarded_eval, unbind_method | |
|
3 | from IPython.testing import decorators as dec | |
|
4 | import pytest | |
|
5 | ||
|
6 | ||
|
7 | def limitted(**kwargs): | |
|
8 | return EvaluationContext( | |
|
9 | locals_=kwargs, | |
|
10 | globals_={}, | |
|
11 | evaluation='limitted' | |
|
12 | ) | |
|
13 | ||
|
14 | ||
|
15 | def unsafe(**kwargs): | |
|
16 | return EvaluationContext( | |
|
17 | locals_=kwargs, | |
|
18 | globals_={}, | |
|
19 | evaluation='unsafe' | |
|
20 | ) | |
|
21 | ||
|
22 | @dec.skip_without('pandas') | |
|
23 | def test_pandas_series_iloc(): | |
|
24 | import pandas as pd | |
|
25 | series = pd.Series([1], index=['a']) | |
|
26 | context = limitted(data=series) | |
|
27 | assert guarded_eval('data.iloc[0]', context) == 1 | |
|
28 | ||
|
29 | ||
|
30 | @dec.skip_without('pandas') | |
|
31 | def test_pandas_series(): | |
|
32 | import pandas as pd | |
|
33 | context = limitted(data=pd.Series([1], index=['a'])) | |
|
34 | assert guarded_eval('data["a"]', context) == 1 | |
|
35 | with pytest.raises(KeyError): | |
|
36 | guarded_eval('data["c"]', context) | |
|
37 | ||
|
38 | ||
|
39 | @dec.skip_without('pandas') | |
|
40 | def test_pandas_bad_series(): | |
|
41 | import pandas as pd | |
|
42 | class BadItemSeries(pd.Series): | |
|
43 | def __getitem__(self, key): | |
|
44 | return 'CUSTOM_ITEM' | |
|
45 | ||
|
46 | class BadAttrSeries(pd.Series): | |
|
47 | def __getattr__(self, key): | |
|
48 | return 'CUSTOM_ATTR' | |
|
49 | ||
|
50 | bad_series = BadItemSeries([1], index=['a']) | |
|
51 | context = limitted(data=bad_series) | |
|
52 | ||
|
53 | with pytest.raises(GuardRejection): | |
|
54 | guarded_eval('data["a"]', context) | |
|
55 | with pytest.raises(GuardRejection): | |
|
56 | guarded_eval('data["c"]', context) | |
|
57 | ||
|
58 | # note: here result is a bit unexpected because | |
|
59 | # pandas `__getattr__` calls `__getitem__`; | |
|
60 | # FIXME - special case to handle it? | |
|
61 | assert guarded_eval('data.a', context) == 'CUSTOM_ITEM' | |
|
62 | ||
|
63 | context = unsafe(data=bad_series) | |
|
64 | assert guarded_eval('data["a"]', context) == 'CUSTOM_ITEM' | |
|
65 | ||
|
66 | bad_attr_series = BadAttrSeries([1], index=['a']) | |
|
67 | context = limitted(data=bad_attr_series) | |
|
68 | assert guarded_eval('data["a"]', context) == 1 | |
|
69 | with pytest.raises(GuardRejection): | |
|
70 | guarded_eval('data.a', context) | |
|
71 | ||
|
72 | ||
|
73 | @dec.skip_without('pandas') | |
|
74 | def test_pandas_dataframe_loc(): | |
|
75 | import pandas as pd | |
|
76 | from pandas.testing import assert_series_equal | |
|
77 | data = pd.DataFrame([{'a': 1}]) | |
|
78 | context = limitted(data=data) | |
|
79 | assert_series_equal( | |
|
80 | guarded_eval('data.loc[:, "a"]', context), | |
|
81 | data['a'] | |
|
82 | ) | |
|
83 | ||
|
84 | ||
|
85 | def test_named_tuple(): | |
|
86 | ||
|
87 | class GoodNamedTuple(NamedTuple): | |
|
88 | a: str | |
|
89 | pass | |
|
90 | ||
|
91 | class BadNamedTuple(NamedTuple): | |
|
92 | a: str | |
|
93 | def __getitem__(self, key): | |
|
94 | return None | |
|
95 | ||
|
96 | good = GoodNamedTuple(a='x') | |
|
97 | bad = BadNamedTuple(a='x') | |
|
98 | ||
|
99 | context = limitted(data=good) | |
|
100 | assert guarded_eval('data[0]', context) == 'x' | |
|
101 | ||
|
102 | context = limitted(data=bad) | |
|
103 | with pytest.raises(GuardRejection): | |
|
104 | guarded_eval('data[0]', context) | |
|
105 | ||
|
106 | ||
|
107 | def test_dict(): | |
|
108 | context = limitted( | |
|
109 | data={'a': 1, 'b': {'x': 2}, ('x', 'y'): 3} | |
|
110 | ) | |
|
111 | assert guarded_eval('data["a"]', context) == 1 | |
|
112 | assert guarded_eval('data["b"]', context) == {'x': 2} | |
|
113 | assert guarded_eval('data["b"]["x"]', context) == 2 | |
|
114 | assert guarded_eval('data["x", "y"]', context) == 3 | |
|
115 | ||
|
116 | assert guarded_eval('data.keys', context) | |
|
117 | ||
|
118 | ||
|
119 | def test_set(): | |
|
120 | context = limitted(data={'a', 'b'}) | |
|
121 | assert guarded_eval('data.difference', context) | |
|
122 | ||
|
123 | ||
|
124 | def test_list(): | |
|
125 | context = limitted(data=[1, 2, 3]) | |
|
126 | assert guarded_eval('data[1]', context) == 2 | |
|
127 | assert guarded_eval('data.copy', context) | |
|
128 | ||
|
129 | ||
|
130 | def test_dict_literal(): | |
|
131 | context = limitted() | |
|
132 | assert guarded_eval('{}', context) == {} | |
|
133 | assert guarded_eval('{"a": 1}', context) == {"a": 1} | |
|
134 | ||
|
135 | ||
|
136 | def test_list_literal(): | |
|
137 | context = limitted() | |
|
138 | assert guarded_eval('[]', context) == [] | |
|
139 | assert guarded_eval('[1, "a"]', context) == [1, "a"] | |
|
140 | ||
|
141 | ||
|
142 | def test_set_literal(): | |
|
143 | context = limitted() | |
|
144 | assert guarded_eval('set()', context) == set() | |
|
145 | assert guarded_eval('{"a"}', context) == {"a"} | |
|
146 | ||
|
147 | ||
|
148 | def test_if_expression(): | |
|
149 | context = limitted() | |
|
150 | assert guarded_eval('2 if True else 3', context) == 2 | |
|
151 | assert guarded_eval('4 if False else 5', context) == 5 | |
|
152 | ||
|
153 | ||
|
154 | def test_object(): | |
|
155 | obj = object() | |
|
156 | context = limitted(obj=obj) | |
|
157 | assert guarded_eval('obj.__dir__', context) == obj.__dir__ | |
|
158 | ||
|
159 | ||
|
160 | @pytest.mark.parametrize( | |
|
161 | "code,expected", | |
|
162 | [ | |
|
163 | [ | |
|
164 | 'int.numerator', | |
|
165 | int.numerator | |
|
166 | ], | |
|
167 | [ | |
|
168 | 'float.is_integer', | |
|
169 | float.is_integer | |
|
170 | ], | |
|
171 | [ | |
|
172 | 'complex.real', | |
|
173 | complex.real | |
|
174 | ] | |
|
175 | ] | |
|
176 | ) | |
|
177 | def test_number_attributes(code, expected): | |
|
178 | assert guarded_eval(code, limitted()) == expected | |
|
179 | ||
|
180 | ||
|
181 | def test_method_descriptor(): | |
|
182 | context = limitted() | |
|
183 | assert guarded_eval('list.copy.__name__', context) == 'copy' | |
|
184 | ||
|
185 | ||
|
186 | @pytest.mark.parametrize( | |
|
187 | "data,good,bad,expected", | |
|
188 | [ | |
|
189 | [ | |
|
190 | [1, 2, 3], | |
|
191 | 'data.index(2)', | |
|
192 | 'data.append(4)', | |
|
193 | 1 | |
|
194 | ], | |
|
195 | [ | |
|
196 | {'a': 1}, | |
|
197 | 'data.keys().isdisjoint({})', | |
|
198 | 'data.update()', | |
|
199 | True | |
|
200 | ] | |
|
201 | ] | |
|
202 | ) | |
|
203 | def test_calls(data, good, bad, expected): | |
|
204 | context = limitted(data=data) | |
|
205 | assert guarded_eval(good, context) == expected | |
|
206 | ||
|
207 | with pytest.raises(GuardRejection): | |
|
208 | guarded_eval(bad, context) | |
|
209 | ||
|
210 | ||
|
211 | @pytest.mark.parametrize( | |
|
212 | "code,expected", | |
|
213 | [ | |
|
214 | [ | |
|
215 | '(1\n+\n1)', | |
|
216 | 2 | |
|
217 | ], | |
|
218 | [ | |
|
219 | 'list(range(10))[-1:]', | |
|
220 | [9] | |
|
221 | ], | |
|
222 | [ | |
|
223 | 'list(range(20))[3:-2:3]', | |
|
224 | [3, 6, 9, 12, 15] | |
|
225 | ] | |
|
226 | ] | |
|
227 | ) | |
|
228 | def test_literals(code, expected): | |
|
229 | context = limitted() | |
|
230 | assert guarded_eval(code, context) == expected | |
|
231 | ||
|
232 | ||
|
233 | def test_subscript(): | |
|
234 | context = EvaluationContext( | |
|
235 | locals_={}, | |
|
236 | globals_={}, | |
|
237 | evaluation='limitted', | |
|
238 | in_subscript=True | |
|
239 | ) | |
|
240 | empty_slice = slice(None, None, None) | |
|
241 | assert guarded_eval('', context) == tuple() | |
|
242 | assert guarded_eval(':', context) == empty_slice | |
|
243 | assert guarded_eval('1:2:3', context) == slice(1, 2, 3) | |
|
244 | assert guarded_eval(':, "a"', context) == (empty_slice, "a") | |
|
245 | ||
|
246 | ||
|
247 | def test_unbind_method(): | |
|
248 | class X(list): | |
|
249 | def index(self, k): | |
|
250 | return 'CUSTOM' | |
|
251 | x = X() | |
|
252 | assert unbind_method(x.index) is X.index | |
|
253 | assert unbind_method([].index) is list.index | |
|
254 | ||
|
255 | ||
|
256 | def test_assumption_instance_attr_do_not_matter(): | |
|
257 | """This is semi-specified in Python documentation. | |
|
258 | ||
|
259 | However, since the specification says 'not guaranted | |
|
260 | to work' rather than 'is forbidden to work', future | |
|
261 | versions could invalidate this assumptions. This test | |
|
262 | is meant to catch such a change if it ever comes true. | |
|
263 | """ | |
|
264 | class T: | |
|
265 | def __getitem__(self, k): | |
|
266 | return 'a' | |
|
267 | def __getattr__(self, k): | |
|
268 | return 'a' | |
|
269 | t = T() | |
|
270 | t.__getitem__ = lambda f: 'b' | |
|
271 | t.__getattr__ = lambda f: 'b' | |
|
272 | assert t[1] == 'a' | |
|
273 | assert t[1] == 'a' | |
|
274 | ||
|
275 | ||
|
276 | def test_assumption_named_tuples_share_getitem(): | |
|
277 | """Check assumption on named tuples sharing __getitem__""" | |
|
278 | from typing import NamedTuple | |
|
279 | ||
|
280 | class A(NamedTuple): | |
|
281 | pass | |
|
282 | ||
|
283 | class B(NamedTuple): | |
|
284 | pass | |
|
285 | ||
|
286 | assert A.__getitem__ == B.__getitem__ |
@@ -190,6 +190,7 b' import time' | |||
|
190 | 190 | import unicodedata |
|
191 | 191 | import uuid |
|
192 | 192 | import warnings |
|
193 | from ast import literal_eval | |
|
193 | 194 | from contextlib import contextmanager |
|
194 | 195 | from dataclasses import dataclass |
|
195 | 196 | from functools import cached_property, partial |
@@ -212,6 +213,7 b' from typing import (' | |||
|
212 | 213 | Literal, |
|
213 | 214 | ) |
|
214 | 215 | |
|
216 | from IPython.core.guarded_eval import guarded_eval, EvaluationContext | |
|
215 | 217 | from IPython.core.error import TryNext |
|
216 | 218 | from IPython.core.inputtransformer2 import ESC_MAGIC |
|
217 | 219 | from IPython.core.latex_symbols import latex_symbols, reverse_latex_symbol |
@@ -296,6 +298,9 b' MATCHES_LIMIT = 500' | |||
|
296 | 298 | # Completion type reported when no type can be inferred. |
|
297 | 299 | _UNKNOWN_TYPE = "<unknown>" |
|
298 | 300 | |
|
301 | # sentinel value to signal lack of a match | |
|
302 | not_found = object() | |
|
303 | ||
|
299 | 304 | class ProvisionalCompleterWarning(FutureWarning): |
|
300 | 305 | """ |
|
301 | 306 | Exception raise by an experimental feature in this module. |
@@ -902,12 +907,33 b' class CompletionSplitter(object):' | |||
|
902 | 907 | |
|
903 | 908 | class Completer(Configurable): |
|
904 | 909 | |
|
905 |
greedy = Bool( |
|
|
906 | help="""Activate greedy completion | |
|
907 | PENDING DEPRECATION. this is now mostly taken care of with Jedi. | |
|
910 | greedy = Bool( | |
|
911 | False, | |
|
912 | help="""Activate greedy completion. | |
|
913 | ||
|
914 | .. deprecated:: 8.8 | |
|
915 | Use :any:`evaluation` instead. | |
|
916 | ||
|
917 | As of IPython 8.8 proxy for ``evaluation = 'unsafe'`` when set to ``True``, | |
|
918 | and for ``'forbidden'`` when set to ``False``. | |
|
919 | """, | |
|
920 | ).tag(config=True) | |
|
908 | 921 | |
|
909 | This will enable completion on elements of lists, results of function calls, etc., | |
|
910 | but can be unsafe because the code is actually evaluated on TAB. | |
|
922 | evaluation = Enum( | |
|
923 | ('forbidden', 'minimal', 'limitted', 'unsafe', 'dangerous'), | |
|
924 | default_value='limitted', | |
|
925 | help="""Code evaluation under completion. | |
|
926 | ||
|
927 | Successive options allow to enable more eager evaluation for more accurate completion suggestions, | |
|
928 | including for nested dictionaries, nested lists, or even results of function calls. Setting `unsafe` | |
|
929 | or higher can lead to evaluation of arbitrary user code on TAB with potentially dangerous side effects. | |
|
930 | ||
|
931 | Allowed values are: | |
|
932 | - `forbidden`: no evaluation at all | |
|
933 | - `minimal`: evaluation of literals and access to built-in namespaces; no item/attribute evaluation nor access to locals/globals | |
|
934 | - `limitted` (default): access to all namespaces, evaluation of hard-coded methods (``keys()``, ``__getattr__``, ``__getitems__``, etc) on allow-listed objects (e.g. ``dict``, ``list``, ``tuple``, ``pandas.Series``) | |
|
935 | - `unsafe`: evaluation of all methods and function calls but not of syntax with side-effects like `del x`, | |
|
936 | - `dangerous`: completely arbitrary evaluation | |
|
911 | 937 | """, |
|
912 | 938 | ).tag(config=True) |
|
913 | 939 | |
@@ -1029,26 +1055,14 b' class Completer(Configurable):' | |||
|
1029 | 1055 | with a __getattr__ hook is evaluated. |
|
1030 | 1056 | |
|
1031 | 1057 | """ |
|
1032 | ||
|
1033 | # Another option, seems to work great. Catches things like ''.<tab> | |
|
1034 | m = re.match(r"(\S+(\.\w+)*)\.(\w*)$", text) | |
|
1035 | ||
|
1036 | if m: | |
|
1037 | expr, attr = m.group(1, 3) | |
|
1038 | elif self.greedy: | |
|
1039 | 1058 |
|
|
1040 | 1059 |
|
|
1041 | 1060 |
|
|
1042 | 1061 |
|
|
1043 | else: | |
|
1044 | return [] | |
|
1045 | 1062 | |
|
1046 | try: | |
|
1047 | obj = eval(expr, self.namespace) | |
|
1048 | except: | |
|
1049 | try: | |
|
1050 | obj = eval(expr, self.global_namespace) | |
|
1051 | except: | |
|
1063 | obj = self._evaluate_expr(expr) | |
|
1064 | ||
|
1065 | if obj is not_found: | |
|
1052 | 1066 |
|
|
1053 | 1067 | |
|
1054 | 1068 | if self.limit_to__all__ and hasattr(obj, '__all__'): |
@@ -1068,8 +1082,32 b' class Completer(Configurable):' | |||
|
1068 | 1082 | pass |
|
1069 | 1083 | # Build match list to return |
|
1070 | 1084 | n = len(attr) |
|
1071 |
return [ |
|
|
1085 | return ["%s.%s" % (expr, w) for w in words if w[:n] == attr ] | |
|
1086 | ||
|
1072 | 1087 | |
|
1088 | def _evaluate_expr(self, expr): | |
|
1089 | obj = not_found | |
|
1090 | done = False | |
|
1091 | while not done and expr: | |
|
1092 | try: | |
|
1093 | obj = guarded_eval( | |
|
1094 | expr, | |
|
1095 | EvaluationContext( | |
|
1096 | globals_=self.global_namespace, | |
|
1097 | locals_=self.namespace, | |
|
1098 | evaluation=self.evaluation | |
|
1099 | ) | |
|
1100 | ) | |
|
1101 | done = True | |
|
1102 | except Exception as e: | |
|
1103 | if self.debug: | |
|
1104 | print('Evaluation exception', e) | |
|
1105 | # trim the expression to remove any invalid prefix | |
|
1106 | # e.g. user starts `(d[`, so we get `expr = '(d'`, | |
|
1107 | # where parenthesis is not closed. | |
|
1108 | # TODO: make this faster by reusing parts of the computation? | |
|
1109 | expr = expr[1:] | |
|
1110 | return obj | |
|
1073 | 1111 | |
|
1074 | 1112 | def get__all__entries(obj): |
|
1075 | 1113 | """returns the strings in the __all__ attribute""" |
@@ -1081,8 +1119,8 b' def get__all__entries(obj):' | |||
|
1081 | 1119 | return [w for w in words if isinstance(w, str)] |
|
1082 | 1120 | |
|
1083 | 1121 | |
|
1084 | def match_dict_keys(keys: List[Union[str, bytes, Tuple[Union[str, bytes]]]], prefix: str, delims: str, | |
|
1085 | extra_prefix: Optional[Tuple[str, bytes]]=None) -> Tuple[str, int, List[str]]: | |
|
1122 | def match_dict_keys(keys: List[Union[str, bytes, Tuple[Union[str, bytes], ...]]], prefix: str, delims: str, | |
|
1123 | extra_prefix: Optional[Tuple[Union[str, bytes], ...]]=None) -> Tuple[str, int, List[str]]: | |
|
1086 | 1124 | """Used by dict_key_matches, matching the prefix to a list of keys |
|
1087 | 1125 | |
|
1088 | 1126 | Parameters |
@@ -1106,25 +1144,28 b' def match_dict_keys(keys: List[Union[str, bytes, Tuple[Union[str, bytes]]]], pre' | |||
|
1106 | 1144 | |
|
1107 | 1145 | """ |
|
1108 | 1146 | prefix_tuple = extra_prefix if extra_prefix else () |
|
1147 | ||
|
1109 | 1148 | Nprefix = len(prefix_tuple) |
|
1149 | text_serializable_types = (str, bytes, int, float, slice) | |
|
1110 | 1150 | def filter_prefix_tuple(key): |
|
1111 | 1151 | # Reject too short keys |
|
1112 | 1152 | if len(key) <= Nprefix: |
|
1113 | 1153 | return False |
|
1114 |
# Reject keys w |
|
|
1154 | # Reject keys which cannot be serialised to text | |
|
1115 | 1155 | for k in key: |
|
1116 |
if not isinstance(k, |
|
|
1156 | if not isinstance(k, text_serializable_types): | |
|
1117 | 1157 | return False |
|
1118 | 1158 | # Reject keys that do not match the prefix |
|
1119 | 1159 | for k, pt in zip(key, prefix_tuple): |
|
1120 | if k != pt: | |
|
1160 | if k != pt and not isinstance(pt, slice): | |
|
1121 | 1161 | return False |
|
1122 | 1162 | # All checks passed! |
|
1123 | 1163 | return True |
|
1124 | 1164 | |
|
1125 | filtered_keys:List[Union[str,bytes]] = [] | |
|
1165 | filtered_keys: List[Union[str, bytes, int, float, slice]] = [] | |
|
1166 | ||
|
1126 | 1167 | def _add_to_filtered_keys(key): |
|
1127 |
if isinstance(key, |
|
|
1168 | if isinstance(key, text_serializable_types): | |
|
1128 | 1169 | filtered_keys.append(key) |
|
1129 | 1170 | |
|
1130 | 1171 | for k in keys: |
@@ -1140,7 +1181,7 b' def match_dict_keys(keys: List[Union[str, bytes, Tuple[Union[str, bytes]]]], pre' | |||
|
1140 | 1181 | assert quote_match is not None # silence mypy |
|
1141 | 1182 | quote = quote_match.group() |
|
1142 | 1183 | try: |
|
1143 |
prefix_str = eval(prefix + quote |
|
|
1184 | prefix_str = literal_eval(prefix + quote) | |
|
1144 | 1185 | except Exception: |
|
1145 | 1186 | return '', 0, [] |
|
1146 | 1187 | |
@@ -1152,15 +1193,16 b' def match_dict_keys(keys: List[Union[str, bytes, Tuple[Union[str, bytes]]]], pre' | |||
|
1152 | 1193 | |
|
1153 | 1194 | matched:List[str] = [] |
|
1154 | 1195 | for key in filtered_keys: |
|
1196 | str_key = key if isinstance(key, (str, bytes)) else str(key) | |
|
1155 | 1197 | try: |
|
1156 | if not key.startswith(prefix_str): | |
|
1198 | if not str_key.startswith(prefix_str): | |
|
1157 | 1199 | continue |
|
1158 | 1200 | except (AttributeError, TypeError, UnicodeError): |
|
1159 | 1201 | # Python 3+ TypeError on b'a'.startswith('a') or vice-versa |
|
1160 | 1202 | continue |
|
1161 | 1203 | |
|
1162 | 1204 | # reformat remainder of key to begin with prefix |
|
1163 | rem = key[len(prefix_str):] | |
|
1205 | rem = str_key[len(prefix_str):] | |
|
1164 | 1206 | # force repr wrapped in ' |
|
1165 | 1207 | rem_repr = repr(rem + '"') if isinstance(rem, str) else repr(rem + b'"') |
|
1166 | 1208 | rem_repr = rem_repr[1 + rem_repr.index("'"):-2] |
@@ -1237,11 +1279,14 b' def position_to_cursor(text:str, offset:int)->Tuple[int, int]:' | |||
|
1237 | 1279 | return line, col |
|
1238 | 1280 | |
|
1239 | 1281 | |
|
1240 | def _safe_isinstance(obj, module, class_name): | |
|
1282 | def _safe_isinstance(obj, module, class_name, *attrs): | |
|
1241 | 1283 | """Checks if obj is an instance of module.class_name if loaded |
|
1242 | 1284 | """ |
|
1243 |
|
|
|
1244 | isinstance(obj, getattr(import_module(module), class_name))) | |
|
1285 | if module in sys.modules: | |
|
1286 | m = sys.modules[module] | |
|
1287 | for attr in [class_name, *attrs]: | |
|
1288 | m = getattr(m, attr) | |
|
1289 | return isinstance(obj, m) | |
|
1245 | 1290 | |
|
1246 | 1291 | |
|
1247 | 1292 | @context_matcher() |
@@ -1394,6 +1439,37 b' def _make_signature(completion)-> str:' | |||
|
1394 | 1439 | _CompleteResult = Dict[str, MatcherResult] |
|
1395 | 1440 | |
|
1396 | 1441 | |
|
1442 | DICT_MATCHER_REGEX = re.compile(r"""(?x) | |
|
1443 | ( # match dict-referring - or any get item object - expression | |
|
1444 | .+ | |
|
1445 | ) | |
|
1446 | \[ # open bracket | |
|
1447 | \s* # and optional whitespace | |
|
1448 | # Capture any number of serializable objects (e.g. "a", "b", 'c') | |
|
1449 | # and slices | |
|
1450 | ((?:[uUbB]? # string prefix (r not handled) | |
|
1451 | (?: | |
|
1452 | '(?:[^']|(?<!\\)\\')*' | |
|
1453 | | | |
|
1454 | "(?:[^"]|(?<!\\)\\")*" | |
|
1455 | | | |
|
1456 | # capture integers and slices | |
|
1457 | (?:[-+]?\d+)?(?::(?:[-+]?\d+)?){0,2} | |
|
1458 | ) | |
|
1459 | \s*,\s* | |
|
1460 | )*) | |
|
1461 | ([uUbB]? # string prefix (r not handled) | |
|
1462 | (?: # unclosed string | |
|
1463 | '(?:[^']|(?<!\\)\\')* | |
|
1464 | | | |
|
1465 | "(?:[^"]|(?<!\\)\\")* | |
|
1466 | | | |
|
1467 | (?:[-+]?\d+) | |
|
1468 | ) | |
|
1469 | )? | |
|
1470 | $ | |
|
1471 | """) | |
|
1472 | ||
|
1397 | 1473 | def _convert_matcher_v1_result_to_v2( |
|
1398 | 1474 | matches: Sequence[str], |
|
1399 | 1475 | type: str, |
@@ -1413,14 +1489,14 b' def _convert_matcher_v1_result_to_v2(' | |||
|
1413 | 1489 | class IPCompleter(Completer): |
|
1414 | 1490 | """Extension of the completer class with IPython-specific features""" |
|
1415 | 1491 | |
|
1416 | __dict_key_regexps: Optional[Dict[bool,Pattern]] = None | |
|
1417 | ||
|
1418 | 1492 | @observe('greedy') |
|
1419 | 1493 | def _greedy_changed(self, change): |
|
1420 | 1494 | """update the splitter and readline delims when greedy is changed""" |
|
1421 | 1495 | if change['new']: |
|
1496 | self.evaluation = 'unsafe' | |
|
1422 | 1497 | self.splitter.delims = GREEDY_DELIMS |
|
1423 | 1498 | else: |
|
1499 | self.evaluation = 'limitted' | |
|
1424 | 1500 | self.splitter.delims = DELIMS |
|
1425 | 1501 | |
|
1426 | 1502 | dict_keys_only = Bool( |
@@ -2149,12 +2225,17 b' class IPCompleter(Completer):' | |||
|
2149 | 2225 | return method() |
|
2150 | 2226 | |
|
2151 | 2227 | # Special case some common in-memory dict-like types |
|
2152 |
if isinstance(obj, dict) or |
|
|
2153 | _safe_isinstance(obj, 'pandas', 'DataFrame'): | |
|
2228 | if (isinstance(obj, dict) or | |
|
2229 | _safe_isinstance(obj, 'pandas', 'DataFrame')): | |
|
2154 | 2230 | try: |
|
2155 | 2231 | return list(obj.keys()) |
|
2156 | 2232 | except Exception: |
|
2157 | 2233 | return [] |
|
2234 | elif _safe_isinstance(obj, 'pandas', 'core', 'indexing', '_LocIndexer'): | |
|
2235 | try: | |
|
2236 | return list(obj.obj.keys()) | |
|
2237 | except Exception: | |
|
2238 | return [] | |
|
2158 | 2239 | elif _safe_isinstance(obj, 'numpy', 'ndarray') or\ |
|
2159 | 2240 | _safe_isinstance(obj, 'numpy', 'void'): |
|
2160 | 2241 | return obj.dtype.names or [] |
@@ -2175,65 +2256,43 b' class IPCompleter(Completer):' | |||
|
2175 | 2256 | You can use :meth:`dict_key_matcher` instead. |
|
2176 | 2257 | """ |
|
2177 | 2258 | |
|
2178 | if self.__dict_key_regexps is not None: | |
|
2179 | regexps = self.__dict_key_regexps | |
|
2180 | else: | |
|
2181 | dict_key_re_fmt = r'''(?x) | |
|
2182 | ( # match dict-referring expression wrt greedy setting | |
|
2183 | %s | |
|
2184 | ) | |
|
2185 | \[ # open bracket | |
|
2186 | \s* # and optional whitespace | |
|
2187 | # Capture any number of str-like objects (e.g. "a", "b", 'c') | |
|
2188 | ((?:[uUbB]? # string prefix (r not handled) | |
|
2189 | (?: | |
|
2190 | '(?:[^']|(?<!\\)\\')*' | |
|
2191 | | | |
|
2192 | "(?:[^"]|(?<!\\)\\")*" | |
|
2193 | ) | |
|
2194 | \s*,\s* | |
|
2195 | )*) | |
|
2196 | ([uUbB]? # string prefix (r not handled) | |
|
2197 | (?: # unclosed string | |
|
2198 | '(?:[^']|(?<!\\)\\')* | |
|
2199 | | | |
|
2200 | "(?:[^"]|(?<!\\)\\")* | |
|
2201 | ) | |
|
2202 | )? | |
|
2203 | $ | |
|
2204 | ''' | |
|
2205 | regexps = self.__dict_key_regexps = { | |
|
2206 | False: re.compile(dict_key_re_fmt % r''' | |
|
2207 | # identifiers separated by . | |
|
2208 | (?!\d)\w+ | |
|
2209 | (?:\.(?!\d)\w+)* | |
|
2210 | '''), | |
|
2211 | True: re.compile(dict_key_re_fmt % ''' | |
|
2212 | .+ | |
|
2213 | ''') | |
|
2214 | } | |
|
2259 | # Short-circuit on closed dictionary (regular expression would | |
|
2260 | # not match anyway, but would take quite a while). | |
|
2261 | if self.text_until_cursor.strip().endswith(']'): | |
|
2262 | return [] | |
|
2215 | 2263 | |
|
2216 |
match = |
|
|
2264 | match = DICT_MATCHER_REGEX.search(self.text_until_cursor) | |
|
2217 | 2265 | |
|
2218 | 2266 | if match is None: |
|
2219 | 2267 | return [] |
|
2220 | 2268 | |
|
2221 |
expr, pr |
|
|
2222 | try: | |
|
2223 | obj = eval(expr, self.namespace) | |
|
2224 | except Exception: | |
|
2225 | try: | |
|
2226 | obj = eval(expr, self.global_namespace) | |
|
2227 | except Exception: | |
|
2269 | expr, prior_tuple_keys, key_prefix = match.groups() | |
|
2270 | ||
|
2271 | obj = self._evaluate_expr(expr) | |
|
2272 | ||
|
2273 | if obj is not_found: | |
|
2228 | 2274 |
|
|
2229 | 2275 | |
|
2230 | 2276 | keys = self._get_keys(obj) |
|
2231 | 2277 | if not keys: |
|
2232 | 2278 | return keys |
|
2233 | 2279 | |
|
2234 | extra_prefix = eval(prefix0) if prefix0 != '' else None | |
|
2280 | tuple_prefix = guarded_eval( | |
|
2281 | prior_tuple_keys, | |
|
2282 | EvaluationContext( | |
|
2283 | globals_=self.global_namespace, | |
|
2284 | locals_=self.namespace, | |
|
2285 | evaluation=self.evaluation, | |
|
2286 | in_subscript=True | |
|
2287 | ) | |
|
2288 | ) | |
|
2235 | 2289 | |
|
2236 |
closing_quote, token_offset, matches = match_dict_keys( |
|
|
2290 | closing_quote, token_offset, matches = match_dict_keys( | |
|
2291 | keys, | |
|
2292 | key_prefix, | |
|
2293 | self.splitter.delims, | |
|
2294 | extra_prefix=tuple_prefix | |
|
2295 | ) | |
|
2237 | 2296 | if not matches: |
|
2238 | 2297 | return matches |
|
2239 | 2298 | |
@@ -2242,7 +2301,7 b' class IPCompleter(Completer):' | |||
|
2242 | 2301 | # - the start of the key text |
|
2243 | 2302 | # - the start of the completion |
|
2244 | 2303 | text_start = len(self.text_until_cursor) - len(text) |
|
2245 | if prefix: | |
|
2304 | if key_prefix: | |
|
2246 | 2305 | key_start = match.start(3) |
|
2247 | 2306 | completion_start = key_start + token_offset |
|
2248 | 2307 | else: |
@@ -113,6 +113,17 b' def greedy_completion():' | |||
|
113 | 113 | |
|
114 | 114 | |
|
115 | 115 | @contextmanager |
|
116 | def evaluation_level(evaluation: str): | |
|
117 | ip = get_ipython() | |
|
118 | evaluation_original = ip.Completer.evaluation | |
|
119 | try: | |
|
120 | ip.Completer.evaluation = evaluation | |
|
121 | yield | |
|
122 | finally: | |
|
123 | ip.Completer.evaluation = evaluation_original | |
|
124 | ||
|
125 | ||
|
126 | @contextmanager | |
|
116 | 127 | def custom_matchers(matchers): |
|
117 | 128 | ip = get_ipython() |
|
118 | 129 | try: |
@@ -852,8 +863,6 b' class TestCompleter(unittest.TestCase):' | |||
|
852 | 863 | assert match_dict_keys(keys, '"', delims=delims) == ('"', 1, ["foo"]) |
|
853 | 864 | assert match_dict_keys(keys, '"f', delims=delims) == ('"', 1, ["foo"]) |
|
854 | 865 | |
|
855 | match_dict_keys | |
|
856 | ||
|
857 | 866 | def test_match_dict_keys_tuple(self): |
|
858 | 867 | """ |
|
859 | 868 | Test that match_dict_keys called with extra prefix works on a couple of use case, |
@@ -883,6 +892,11 b' class TestCompleter(unittest.TestCase):' | |||
|
883 | 892 | assert match_dict_keys(keys, "'foo", delims=delims, extra_prefix=('foo1', 'foo2', 'foo3')) == ("'", 1, ["foo4"]) |
|
884 | 893 | assert match_dict_keys(keys, "'foo", delims=delims, extra_prefix=('foo1', 'foo2', 'foo3', 'foo4')) == ("'", 1, []) |
|
885 | 894 | |
|
895 | keys = [("foo", 1111), ("foo", 2222), (3333, "bar"), (3333, 'test')] | |
|
896 | assert match_dict_keys(keys, "'", delims=delims, extra_prefix=("foo",)) == ("'", 1, ["1111", "2222"]) | |
|
897 | assert match_dict_keys(keys, "'", delims=delims, extra_prefix=(3333,)) == ("'", 1, ["bar", "test"]) | |
|
898 | assert match_dict_keys(keys, "'", delims=delims, extra_prefix=("3333",)) == ("'", 1, []) | |
|
899 | ||
|
886 | 900 | def test_dict_key_completion_string(self): |
|
887 | 901 | """Test dictionary key completion for string keys""" |
|
888 | 902 | ip = get_ipython() |
@@ -1050,6 +1064,7 b' class TestCompleter(unittest.TestCase):' | |||
|
1050 | 1064 | |
|
1051 | 1065 | ip.user_ns["C"] = C |
|
1052 | 1066 | ip.user_ns["get"] = lambda: d |
|
1067 | ip.user_ns["nested"] = {'x': d} | |
|
1053 | 1068 | |
|
1054 | 1069 | def assert_no_completion(**kwargs): |
|
1055 | 1070 | _, matches = complete(**kwargs) |
@@ -1075,6 +1090,13 b' class TestCompleter(unittest.TestCase):' | |||
|
1075 | 1090 | assert_completion(line_buffer="(d[") |
|
1076 | 1091 | assert_completion(line_buffer="C.data[") |
|
1077 | 1092 | |
|
1093 | # nested dict completion | |
|
1094 | assert_completion(line_buffer="nested['x'][") | |
|
1095 | ||
|
1096 | with evaluation_level('minimal'): | |
|
1097 | with pytest.raises(AssertionError): | |
|
1098 | assert_completion(line_buffer="nested['x'][") | |
|
1099 | ||
|
1078 | 1100 | # greedy flag |
|
1079 | 1101 | def assert_completion(**kwargs): |
|
1080 | 1102 | _, matches = complete(**kwargs) |
@@ -1162,12 +1184,21 b' class TestCompleter(unittest.TestCase):' | |||
|
1162 | 1184 | _, matches = complete(line_buffer="d['") |
|
1163 | 1185 | self.assertIn("my_head", matches) |
|
1164 | 1186 | self.assertIn("my_data", matches) |
|
1165 |
|
|
|
1166 | with greedy_completion(): | |
|
1187 | def completes_on_nested(): | |
|
1167 | 1188 | ip.user_ns["d"] = numpy.zeros(2, dtype=dt) |
|
1168 | 1189 | _, matches = complete(line_buffer="d[1]['my_head']['") |
|
1169 | 1190 | self.assertTrue(any(["my_dt" in m for m in matches])) |
|
1170 | 1191 | self.assertTrue(any(["my_df" in m for m in matches])) |
|
1192 | # complete on a nested level | |
|
1193 | with greedy_completion(): | |
|
1194 | completes_on_nested() | |
|
1195 | ||
|
1196 | with evaluation_level('limitted'): | |
|
1197 | completes_on_nested() | |
|
1198 | ||
|
1199 | with evaluation_level('minimal'): | |
|
1200 | with pytest.raises(AssertionError): | |
|
1201 | completes_on_nested() | |
|
1171 | 1202 | |
|
1172 | 1203 | @dec.skip_without("pandas") |
|
1173 | 1204 | def test_dataframe_key_completion(self): |
@@ -1180,6 +1211,17 b' class TestCompleter(unittest.TestCase):' | |||
|
1180 | 1211 | _, matches = complete(line_buffer="d['") |
|
1181 | 1212 | self.assertIn("hello", matches) |
|
1182 | 1213 | self.assertIn("world", matches) |
|
1214 | _, matches = complete(line_buffer="d.loc[:, '") | |
|
1215 | self.assertIn("hello", matches) | |
|
1216 | self.assertIn("world", matches) | |
|
1217 | _, matches = complete(line_buffer="d.loc[1:, '") | |
|
1218 | self.assertIn("hello", matches) | |
|
1219 | _, matches = complete(line_buffer="d.loc[1:1, '") | |
|
1220 | self.assertIn("hello", matches) | |
|
1221 | _, matches = complete(line_buffer="d.loc[1:1:-1, '") | |
|
1222 | self.assertIn("hello", matches) | |
|
1223 | _, matches = complete(line_buffer="d.loc[::, '") | |
|
1224 | self.assertIn("hello", matches) | |
|
1183 | 1225 | |
|
1184 | 1226 | def test_dict_key_completion_invalids(self): |
|
1185 | 1227 | """Smoke test cases dict key completion can't handle""" |
General Comments 0
You need to be logged in to leave comments.
Login now