##// END OF EJS Templates
Implement guarded evaluation, replace greedy, implement:...
krassowski -
Show More
This diff has been collapsed as it changes many lines, (541 lines changed) Show them Hide them
@@ -0,0 +1,541 b''
1 from typing import Callable, Protocol, Set, Tuple, NamedTuple, Literal, Union
2 import collections
3 import sys
4 import ast
5 import types
6 from functools import cached_property
7 from dataclasses import dataclass, field
8
9
10 class HasGetItem(Protocol):
11 def __getitem__(self, key) -> None: ...
12
13
14 class InstancesHaveGetItem(Protocol):
15 def __call__(self) -> HasGetItem: ...
16
17
18 class HasGetAttr(Protocol):
19 def __getattr__(self, key) -> None: ...
20
21
22 class DoesNotHaveGetAttr(Protocol):
23 pass
24
25 # By default `__getattr__` is not explicitly implemented on most objects
26 MayHaveGetattr = Union[HasGetAttr, DoesNotHaveGetAttr]
27
28
29 def unbind_method(func: Callable) -> Union[Callable, None]:
30 """Get unbound method for given bound method.
31
32 Returns None if cannot get unbound method."""
33 owner = getattr(func, '__self__', None)
34 owner_class = type(owner)
35 name = getattr(func, '__name__', None)
36 instance_dict_overrides = getattr(owner, '__dict__', None)
37 if (
38 owner is not None
39 and
40 name
41 and
42 (
43 not instance_dict_overrides
44 or
45 (
46 instance_dict_overrides
47 and name not in instance_dict_overrides
48 )
49 )
50 ):
51 return getattr(owner_class, name)
52
53
54 @dataclass
55 class EvaluationPolicy:
56 allow_locals_access: bool = False
57 allow_globals_access: bool = False
58 allow_item_access: bool = False
59 allow_attr_access: bool = False
60 allow_builtins_access: bool = False
61 allow_any_calls: bool = False
62 allowed_calls: Set[Callable] = field(default_factory=set)
63
64 def can_get_item(self, value, item):
65 return self.allow_item_access
66
67 def can_get_attr(self, value, attr):
68 return self.allow_attr_access
69
70 def can_call(self, func):
71 if self.allow_any_calls:
72 return True
73
74 if func in self.allowed_calls:
75 return True
76
77 owner_method = unbind_method(func)
78 if owner_method and owner_method in self.allowed_calls:
79 return True
80
81 def has_original_dunder_external(value, module_name, access_path, method_name,):
82 try:
83 if module_name not in sys.modules:
84 return False
85 member_type = sys.modules[module_name]
86 for attr in access_path:
87 member_type = getattr(member_type, attr)
88 value_type = type(value)
89 if type(value) == member_type:
90 return True
91 if isinstance(value, member_type):
92 method = getattr(value_type, method_name, None)
93 member_method = getattr(member_type, method_name, None)
94 if member_method == method:
95 return True
96 except (AttributeError, KeyError):
97 return False
98
99
100 def has_original_dunder(
101 value,
102 allowed_types,
103 allowed_methods,
104 allowed_external,
105 method_name
106 ):
107 # note: Python ignores `__getattr__`/`__getitem__` on instances,
108 # we only need to check at class level
109 value_type = type(value)
110
111 # strict type check passes β†’ no need to check method
112 if value_type in allowed_types:
113 return True
114
115 method = getattr(value_type, method_name, None)
116
117 if not method:
118 return None
119
120 if method in allowed_methods:
121 return True
122
123 for module_name, *access_path in allowed_external:
124 if has_original_dunder_external(value, module_name, access_path, method_name):
125 return True
126
127 return False
128
129
130 @dataclass
131 class SelectivePolicy(EvaluationPolicy):
132 allowed_getitem: Set[HasGetItem] = field(default_factory=set)
133 allowed_getitem_external: Set[Tuple[str, ...]] = field(default_factory=set)
134 allowed_getattr: Set[MayHaveGetattr] = field(default_factory=set)
135 allowed_getattr_external: Set[Tuple[str, ...]] = field(default_factory=set)
136
137 def can_get_attr(self, value, attr):
138 has_original_attribute = has_original_dunder(
139 value,
140 allowed_types=self.allowed_getattr,
141 allowed_methods=self._getattribute_methods,
142 allowed_external=self.allowed_getattr_external,
143 method_name='__getattribute__'
144 )
145 has_original_attr = has_original_dunder(
146 value,
147 allowed_types=self.allowed_getattr,
148 allowed_methods=self._getattr_methods,
149 allowed_external=self.allowed_getattr_external,
150 method_name='__getattr__'
151 )
152 # Many objects do not have `__getattr__`, this is fine
153 if has_original_attr is None and has_original_attribute:
154 return True
155
156 # Accept objects without modifications to `__getattr__` and `__getattribute__`
157 return has_original_attr and has_original_attribute
158
159 def get_attr(self, value, attr):
160 if self.can_get_attr(value, attr):
161 return getattr(value, attr)
162
163
164 def can_get_item(self, value, item):
165 """Allow accessing `__getiitem__` of allow-listed instances unless it was not modified."""
166 return has_original_dunder(
167 value,
168 allowed_types=self.allowed_getitem,
169 allowed_methods=self._getitem_methods,
170 allowed_external=self.allowed_getitem_external,
171 method_name='__getitem__'
172 )
173
174 @cached_property
175 def _getitem_methods(self) -> Set[Callable]:
176 return self._safe_get_methods(
177 self.allowed_getitem,
178 '__getitem__'
179 )
180
181 @cached_property
182 def _getattr_methods(self) -> Set[Callable]:
183 return self._safe_get_methods(
184 self.allowed_getattr,
185 '__getattr__'
186 )
187
188 @cached_property
189 def _getattribute_methods(self) -> Set[Callable]:
190 return self._safe_get_methods(
191 self.allowed_getattr,
192 '__getattribute__'
193 )
194
195 def _safe_get_methods(self, classes, name) -> Set[Callable]:
196 return {
197 method
198 for class_ in classes
199 for method in [getattr(class_, name, None)]
200 if method
201 }
202
203
204 class DummyNamedTuple(NamedTuple):
205 pass
206
207
208 class EvaluationContext(NamedTuple):
209 locals_: dict
210 globals_: dict
211 evaluation: Literal['forbidden', 'minimal', 'limitted', 'unsafe', 'dangerous'] = 'forbidden'
212 in_subscript: bool = False
213
214
215 class IdentitySubscript:
216 def __getitem__(self, key):
217 return key
218
219 IDENTITY_SUBSCRIPT = IdentitySubscript()
220 SUBSCRIPT_MARKER = '__SUBSCRIPT_SENTINEL__'
221
222 class GuardRejection(ValueError):
223 pass
224
225
226 def guarded_eval(
227 code: str,
228 context: EvaluationContext
229 ):
230 locals_ = context.locals_
231
232 if context.evaluation == 'forbidden':
233 raise GuardRejection('Forbidden mode')
234
235 # note: not using `ast.literal_eval` as it does not implement
236 # getitem at all, for example it fails on simple `[0][1]`
237
238 if context.in_subscript:
239 # syntatic sugar for ellipsis (:) is only available in susbcripts
240 # so we need to trick the ast parser into thinking that we have
241 # a subscript, but we need to be able to later recognise that we did
242 # it so we can ignore the actual __getitem__ operation
243 if not code:
244 return tuple()
245 locals_ = locals_.copy()
246 locals_[SUBSCRIPT_MARKER] = IDENTITY_SUBSCRIPT
247 code = SUBSCRIPT_MARKER + '[' + code + ']'
248 context = EvaluationContext(**{
249 **context._asdict(),
250 **{'locals_': locals_}
251 })
252
253 if context.evaluation == 'dangerous':
254 return eval(code, context.globals_, context.locals_)
255
256 expression = ast.parse(code, mode='eval')
257
258 return eval_node(expression, context)
259
260 def eval_node(node: Union[ast.AST, None], context: EvaluationContext):
261 """
262 Evaluate AST node in provided context.
263
264 Applies evaluation restrictions defined in the context.
265
266 Currently does not support evaluation of functions with arguments.
267
268 Does not evaluate actions which always have side effects:
269 - class definitions (`class sth: ...`)
270 - function definitions (`def sth: ...`)
271 - variable assignments (`x = 1`)
272 - augumented assignments (`x += 1`)
273 - deletions (`del x`)
274
275 Does not evaluate operations which do not return values:
276 - assertions (`assert x`)
277 - pass (`pass`)
278 - imports (`import x`)
279 - control flow
280 - conditionals (`if x:`) except for terenary IfExp (`a if x else b`)
281 - loops (`for` and `while`)
282 - exception handling
283 """
284 policy = EVALUATION_POLICIES[context.evaluation]
285 if node is None:
286 return None
287 if isinstance(node, ast.Expression):
288 return eval_node(node.body, context)
289 if isinstance(node, ast.BinOp):
290 # TODO: add guards
291 left = eval_node(node.left, context)
292 right = eval_node(node.right, context)
293 if isinstance(node.op, ast.Add):
294 return left + right
295 if isinstance(node.op, ast.Sub):
296 return left - right
297 if isinstance(node.op, ast.Mult):
298 return left * right
299 if isinstance(node.op, ast.Div):
300 return left / right
301 if isinstance(node.op, ast.FloorDiv):
302 return left // right
303 if isinstance(node.op, ast.Mod):
304 return left % right
305 if isinstance(node.op, ast.Pow):
306 return left ** right
307 if isinstance(node.op, ast.LShift):
308 return left << right
309 if isinstance(node.op, ast.RShift):
310 return left >> right
311 if isinstance(node.op, ast.BitOr):
312 return left | right
313 if isinstance(node.op, ast.BitXor):
314 return left ^ right
315 if isinstance(node.op, ast.BitAnd):
316 return left & right
317 if isinstance(node.op, ast.MatMult):
318 return left @ right
319 if isinstance(node, ast.Constant):
320 return node.value
321 if isinstance(node, ast.Index):
322 return eval_node(node.value, context)
323 if isinstance(node, ast.Tuple):
324 return tuple(
325 eval_node(e, context)
326 for e in node.elts
327 )
328 if isinstance(node, ast.List):
329 return [
330 eval_node(e, context)
331 for e in node.elts
332 ]
333 if isinstance(node, ast.Set):
334 return {
335 eval_node(e, context)
336 for e in node.elts
337 }
338 if isinstance(node, ast.Dict):
339 return dict(zip(
340 [eval_node(k, context) for k in node.keys],
341 [eval_node(v, context) for v in node.values]
342 ))
343 if isinstance(node, ast.Slice):
344 return slice(
345 eval_node(node.lower, context),
346 eval_node(node.upper, context),
347 eval_node(node.step, context)
348 )
349 if isinstance(node, ast.ExtSlice):
350 return tuple([
351 eval_node(dim, context)
352 for dim in node.dims
353 ])
354 if isinstance(node, ast.UnaryOp):
355 # TODO: add guards
356 value = eval_node(node.operand, context)
357 if isinstance(node.op, ast.USub):
358 return -value
359 if isinstance(node.op, ast.UAdd):
360 return +value
361 if isinstance(node.op, ast.Invert):
362 return ~value
363 if isinstance(node.op, ast.Not):
364 return not value
365 raise ValueError('Unhandled unary operation:', node.op)
366 if isinstance(node, ast.Subscript):
367 value = eval_node(node.value, context)
368 slice_ = eval_node(node.slice, context)
369 if policy.can_get_item(value, slice_):
370 return value[slice_]
371 raise GuardRejection(
372 'Subscript access (`__getitem__`) for',
373 type(value), # not joined to avoid calling `repr`
374 f' not allowed in {context.evaluation} mode'
375 )
376 if isinstance(node, ast.Name):
377 if policy.allow_locals_access and node.id in context.locals_:
378 return context.locals_[node.id]
379 if policy.allow_globals_access and node.id in context.globals_:
380 return context.globals_[node.id]
381 if policy.allow_builtins_access and node.id in __builtins__:
382 return __builtins__[node.id]
383 if not policy.allow_globals_access and not policy.allow_locals_access:
384 raise GuardRejection(
385 f'Namespace access not allowed in {context.evaluation} mode'
386 )
387 else:
388 raise NameError(f'{node.id} not found in locals nor globals')
389 if isinstance(node, ast.Attribute):
390 value = eval_node(node.value, context)
391 if policy.can_get_attr(value, node.attr):
392 return getattr(value, node.attr)
393 raise GuardRejection(
394 'Attribute access (`__getattr__`) for',
395 type(value), # not joined to avoid calling `repr`
396 f'not allowed in {context.evaluation} mode'
397 )
398 if isinstance(node, ast.IfExp):
399 test = eval_node(node.test, context)
400 if test:
401 return eval_node(node.body, context)
402 else:
403 return eval_node(node.orelse, context)
404 if isinstance(node, ast.Call):
405 func = eval_node(node.func, context)
406 print(node.keywords)
407 if policy.can_call(func) and not node.keywords:
408 args = [
409 eval_node(arg, context)
410 for arg in node.args
411 ]
412 return func(*args)
413 raise GuardRejection(
414 'Call for',
415 func, # not joined to avoid calling `repr`
416 f'not allowed in {context.evaluation} mode'
417 )
418 raise ValueError('Unhandled node', node)
419
420
421 SUPPORTED_EXTERNAL_GETITEM = {
422 ('pandas', 'core', 'indexing', '_iLocIndexer'),
423 ('pandas', 'core', 'indexing', '_LocIndexer'),
424 ('pandas', 'DataFrame'),
425 ('pandas', 'Series'),
426 ('numpy', 'ndarray'),
427 ('numpy', 'void')
428 }
429
430 BUILTIN_GETITEM = {
431 dict,
432 str,
433 bytes,
434 list,
435 tuple,
436 collections.defaultdict,
437 collections.deque,
438 collections.OrderedDict,
439 collections.ChainMap,
440 collections.UserDict,
441 collections.UserList,
442 collections.UserString,
443 DummyNamedTuple,
444 IdentitySubscript
445 }
446
447
448 def _list_methods(cls, source=None):
449 """For use on immutable objects or with methods returning a copy"""
450 return [
451 getattr(cls, k)
452 for k in (source if source else dir(cls))
453 ]
454
455
456 dict_non_mutating_methods = ('copy', 'keys', 'values', 'items')
457 list_non_mutating_methods = ('copy', 'index', 'count')
458 set_non_mutating_methods = set(dir(set)) & set(dir(frozenset))
459
460
461 dict_keys = type({}.keys())
462 method_descriptor = type(list.copy)
463
464 ALLOWED_CALLS = {
465 bytes,
466 *_list_methods(bytes),
467 dict,
468 *_list_methods(dict, dict_non_mutating_methods),
469 dict_keys.isdisjoint,
470 list,
471 *_list_methods(list, list_non_mutating_methods),
472 set,
473 *_list_methods(set, set_non_mutating_methods),
474 frozenset,
475 *_list_methods(frozenset),
476 range,
477 str,
478 *_list_methods(str),
479 tuple,
480 *_list_methods(tuple),
481 collections.deque,
482 *_list_methods(collections.deque, list_non_mutating_methods),
483 collections.defaultdict,
484 *_list_methods(collections.defaultdict, dict_non_mutating_methods),
485 collections.OrderedDict,
486 *_list_methods(collections.OrderedDict, dict_non_mutating_methods),
487 collections.UserDict,
488 *_list_methods(collections.UserDict, dict_non_mutating_methods),
489 collections.UserList,
490 *_list_methods(collections.UserList, list_non_mutating_methods),
491 collections.UserString,
492 *_list_methods(collections.UserString, dir(str)),
493 collections.Counter,
494 *_list_methods(collections.Counter, dict_non_mutating_methods),
495 collections.Counter.elements,
496 collections.Counter.most_common
497 }
498
499 EVALUATION_POLICIES = {
500 'minimal': EvaluationPolicy(
501 allow_builtins_access=True,
502 allow_locals_access=False,
503 allow_globals_access=False,
504 allow_item_access=False,
505 allow_attr_access=False,
506 allowed_calls=set(),
507 allow_any_calls=False
508 ),
509 'limitted': SelectivePolicy(
510 # TODO:
511 # - should reject binary and unary operations if custom methods would be dispatched
512 allowed_getitem=BUILTIN_GETITEM,
513 allowed_getitem_external=SUPPORTED_EXTERNAL_GETITEM,
514 allowed_getattr={
515 *BUILTIN_GETITEM,
516 set,
517 frozenset,
518 object,
519 type, # `type` handles a lot of generic cases, e.g. numbers as in `int.real`.
520 dict_keys,
521 method_descriptor
522 },
523 allowed_getattr_external={
524 # pandas Series/Frame implements custom `__getattr__`
525 ('pandas', 'DataFrame'),
526 ('pandas', 'Series')
527 },
528 allow_builtins_access=True,
529 allow_locals_access=True,
530 allow_globals_access=True,
531 allowed_calls=ALLOWED_CALLS
532 ),
533 'unsafe': EvaluationPolicy(
534 allow_builtins_access=True,
535 allow_locals_access=True,
536 allow_globals_access=True,
537 allow_attr_access=True,
538 allow_item_access=True,
539 allow_any_calls=True
540 )
541 } No newline at end of file
@@ -0,0 +1,286 b''
1 from typing import NamedTuple
2 from IPython.core.guarded_eval import EvaluationContext, GuardRejection, guarded_eval, unbind_method
3 from IPython.testing import decorators as dec
4 import pytest
5
6
7 def limitted(**kwargs):
8 return EvaluationContext(
9 locals_=kwargs,
10 globals_={},
11 evaluation='limitted'
12 )
13
14
15 def unsafe(**kwargs):
16 return EvaluationContext(
17 locals_=kwargs,
18 globals_={},
19 evaluation='unsafe'
20 )
21
22 @dec.skip_without('pandas')
23 def test_pandas_series_iloc():
24 import pandas as pd
25 series = pd.Series([1], index=['a'])
26 context = limitted(data=series)
27 assert guarded_eval('data.iloc[0]', context) == 1
28
29
30 @dec.skip_without('pandas')
31 def test_pandas_series():
32 import pandas as pd
33 context = limitted(data=pd.Series([1], index=['a']))
34 assert guarded_eval('data["a"]', context) == 1
35 with pytest.raises(KeyError):
36 guarded_eval('data["c"]', context)
37
38
39 @dec.skip_without('pandas')
40 def test_pandas_bad_series():
41 import pandas as pd
42 class BadItemSeries(pd.Series):
43 def __getitem__(self, key):
44 return 'CUSTOM_ITEM'
45
46 class BadAttrSeries(pd.Series):
47 def __getattr__(self, key):
48 return 'CUSTOM_ATTR'
49
50 bad_series = BadItemSeries([1], index=['a'])
51 context = limitted(data=bad_series)
52
53 with pytest.raises(GuardRejection):
54 guarded_eval('data["a"]', context)
55 with pytest.raises(GuardRejection):
56 guarded_eval('data["c"]', context)
57
58 # note: here result is a bit unexpected because
59 # pandas `__getattr__` calls `__getitem__`;
60 # FIXME - special case to handle it?
61 assert guarded_eval('data.a', context) == 'CUSTOM_ITEM'
62
63 context = unsafe(data=bad_series)
64 assert guarded_eval('data["a"]', context) == 'CUSTOM_ITEM'
65
66 bad_attr_series = BadAttrSeries([1], index=['a'])
67 context = limitted(data=bad_attr_series)
68 assert guarded_eval('data["a"]', context) == 1
69 with pytest.raises(GuardRejection):
70 guarded_eval('data.a', context)
71
72
73 @dec.skip_without('pandas')
74 def test_pandas_dataframe_loc():
75 import pandas as pd
76 from pandas.testing import assert_series_equal
77 data = pd.DataFrame([{'a': 1}])
78 context = limitted(data=data)
79 assert_series_equal(
80 guarded_eval('data.loc[:, "a"]', context),
81 data['a']
82 )
83
84
85 def test_named_tuple():
86
87 class GoodNamedTuple(NamedTuple):
88 a: str
89 pass
90
91 class BadNamedTuple(NamedTuple):
92 a: str
93 def __getitem__(self, key):
94 return None
95
96 good = GoodNamedTuple(a='x')
97 bad = BadNamedTuple(a='x')
98
99 context = limitted(data=good)
100 assert guarded_eval('data[0]', context) == 'x'
101
102 context = limitted(data=bad)
103 with pytest.raises(GuardRejection):
104 guarded_eval('data[0]', context)
105
106
107 def test_dict():
108 context = limitted(
109 data={'a': 1, 'b': {'x': 2}, ('x', 'y'): 3}
110 )
111 assert guarded_eval('data["a"]', context) == 1
112 assert guarded_eval('data["b"]', context) == {'x': 2}
113 assert guarded_eval('data["b"]["x"]', context) == 2
114 assert guarded_eval('data["x", "y"]', context) == 3
115
116 assert guarded_eval('data.keys', context)
117
118
119 def test_set():
120 context = limitted(data={'a', 'b'})
121 assert guarded_eval('data.difference', context)
122
123
124 def test_list():
125 context = limitted(data=[1, 2, 3])
126 assert guarded_eval('data[1]', context) == 2
127 assert guarded_eval('data.copy', context)
128
129
130 def test_dict_literal():
131 context = limitted()
132 assert guarded_eval('{}', context) == {}
133 assert guarded_eval('{"a": 1}', context) == {"a": 1}
134
135
136 def test_list_literal():
137 context = limitted()
138 assert guarded_eval('[]', context) == []
139 assert guarded_eval('[1, "a"]', context) == [1, "a"]
140
141
142 def test_set_literal():
143 context = limitted()
144 assert guarded_eval('set()', context) == set()
145 assert guarded_eval('{"a"}', context) == {"a"}
146
147
148 def test_if_expression():
149 context = limitted()
150 assert guarded_eval('2 if True else 3', context) == 2
151 assert guarded_eval('4 if False else 5', context) == 5
152
153
154 def test_object():
155 obj = object()
156 context = limitted(obj=obj)
157 assert guarded_eval('obj.__dir__', context) == obj.__dir__
158
159
160 @pytest.mark.parametrize(
161 "code,expected",
162 [
163 [
164 'int.numerator',
165 int.numerator
166 ],
167 [
168 'float.is_integer',
169 float.is_integer
170 ],
171 [
172 'complex.real',
173 complex.real
174 ]
175 ]
176 )
177 def test_number_attributes(code, expected):
178 assert guarded_eval(code, limitted()) == expected
179
180
181 def test_method_descriptor():
182 context = limitted()
183 assert guarded_eval('list.copy.__name__', context) == 'copy'
184
185
186 @pytest.mark.parametrize(
187 "data,good,bad,expected",
188 [
189 [
190 [1, 2, 3],
191 'data.index(2)',
192 'data.append(4)',
193 1
194 ],
195 [
196 {'a': 1},
197 'data.keys().isdisjoint({})',
198 'data.update()',
199 True
200 ]
201 ]
202 )
203 def test_calls(data, good, bad, expected):
204 context = limitted(data=data)
205 assert guarded_eval(good, context) == expected
206
207 with pytest.raises(GuardRejection):
208 guarded_eval(bad, context)
209
210
211 @pytest.mark.parametrize(
212 "code,expected",
213 [
214 [
215 '(1\n+\n1)',
216 2
217 ],
218 [
219 'list(range(10))[-1:]',
220 [9]
221 ],
222 [
223 'list(range(20))[3:-2:3]',
224 [3, 6, 9, 12, 15]
225 ]
226 ]
227 )
228 def test_literals(code, expected):
229 context = limitted()
230 assert guarded_eval(code, context) == expected
231
232
233 def test_subscript():
234 context = EvaluationContext(
235 locals_={},
236 globals_={},
237 evaluation='limitted',
238 in_subscript=True
239 )
240 empty_slice = slice(None, None, None)
241 assert guarded_eval('', context) == tuple()
242 assert guarded_eval(':', context) == empty_slice
243 assert guarded_eval('1:2:3', context) == slice(1, 2, 3)
244 assert guarded_eval(':, "a"', context) == (empty_slice, "a")
245
246
247 def test_unbind_method():
248 class X(list):
249 def index(self, k):
250 return 'CUSTOM'
251 x = X()
252 assert unbind_method(x.index) is X.index
253 assert unbind_method([].index) is list.index
254
255
256 def test_assumption_instance_attr_do_not_matter():
257 """This is semi-specified in Python documentation.
258
259 However, since the specification says 'not guaranted
260 to work' rather than 'is forbidden to work', future
261 versions could invalidate this assumptions. This test
262 is meant to catch such a change if it ever comes true.
263 """
264 class T:
265 def __getitem__(self, k):
266 return 'a'
267 def __getattr__(self, k):
268 return 'a'
269 t = T()
270 t.__getitem__ = lambda f: 'b'
271 t.__getattr__ = lambda f: 'b'
272 assert t[1] == 'a'
273 assert t[1] == 'a'
274
275
276 def test_assumption_named_tuples_share_getitem():
277 """Check assumption on named tuples sharing __getitem__"""
278 from typing import NamedTuple
279
280 class A(NamedTuple):
281 pass
282
283 class B(NamedTuple):
284 pass
285
286 assert A.__getitem__ == B.__getitem__
@@ -190,6 +190,7 b' import time'
190 190 import unicodedata
191 191 import uuid
192 192 import warnings
193 from ast import literal_eval
193 194 from contextlib import contextmanager
194 195 from dataclasses import dataclass
195 196 from functools import cached_property, partial
@@ -212,6 +213,7 b' from typing import ('
212 213 Literal,
213 214 )
214 215
216 from IPython.core.guarded_eval import guarded_eval, EvaluationContext
215 217 from IPython.core.error import TryNext
216 218 from IPython.core.inputtransformer2 import ESC_MAGIC
217 219 from IPython.core.latex_symbols import latex_symbols, reverse_latex_symbol
@@ -296,6 +298,9 b' MATCHES_LIMIT = 500'
296 298 # Completion type reported when no type can be inferred.
297 299 _UNKNOWN_TYPE = "<unknown>"
298 300
301 # sentinel value to signal lack of a match
302 not_found = object()
303
299 304 class ProvisionalCompleterWarning(FutureWarning):
300 305 """
301 306 Exception raise by an experimental feature in this module.
@@ -902,12 +907,33 b' class CompletionSplitter(object):'
902 907
903 908 class Completer(Configurable):
904 909
905 greedy = Bool(False,
906 help="""Activate greedy completion
907 PENDING DEPRECATION. this is now mostly taken care of with Jedi.
910 greedy = Bool(
911 False,
912 help="""Activate greedy completion.
913
914 .. deprecated:: 8.8
915 Use :any:`evaluation` instead.
916
917 As of IPython 8.8 proxy for ``evaluation = 'unsafe'`` when set to ``True``,
918 and for ``'forbidden'`` when set to ``False``.
919 """,
920 ).tag(config=True)
908 921
909 This will enable completion on elements of lists, results of function calls, etc.,
910 but can be unsafe because the code is actually evaluated on TAB.
922 evaluation = Enum(
923 ('forbidden', 'minimal', 'limitted', 'unsafe', 'dangerous'),
924 default_value='limitted',
925 help="""Code evaluation under completion.
926
927 Successive options allow to enable more eager evaluation for more accurate completion suggestions,
928 including for nested dictionaries, nested lists, or even results of function calls. Setting `unsafe`
929 or higher can lead to evaluation of arbitrary user code on TAB with potentially dangerous side effects.
930
931 Allowed values are:
932 - `forbidden`: no evaluation at all
933 - `minimal`: evaluation of literals and access to built-in namespaces; no item/attribute evaluation nor access to locals/globals
934 - `limitted` (default): access to all namespaces, evaluation of hard-coded methods (``keys()``, ``__getattr__``, ``__getitems__``, etc) on allow-listed objects (e.g. ``dict``, ``list``, ``tuple``, ``pandas.Series``)
935 - `unsafe`: evaluation of all methods and function calls but not of syntax with side-effects like `del x`,
936 - `dangerous`: completely arbitrary evaluation
911 937 """,
912 938 ).tag(config=True)
913 939
@@ -1029,26 +1055,14 b' class Completer(Configurable):'
1029 1055 with a __getattr__ hook is evaluated.
1030 1056
1031 1057 """
1032
1033 # Another option, seems to work great. Catches things like ''.<tab>
1034 m = re.match(r"(\S+(\.\w+)*)\.(\w*)$", text)
1035
1036 if m:
1037 expr, attr = m.group(1, 3)
1038 elif self.greedy:
1039 1058 m2 = re.match(r"(.+)\.(\w*)$", self.line_buffer)
1040 1059 if not m2:
1041 1060 return []
1042 1061 expr, attr = m2.group(1,2)
1043 else:
1044 return []
1045 1062
1046 try:
1047 obj = eval(expr, self.namespace)
1048 except:
1049 try:
1050 obj = eval(expr, self.global_namespace)
1051 except:
1063 obj = self._evaluate_expr(expr)
1064
1065 if obj is not_found:
1052 1066 return []
1053 1067
1054 1068 if self.limit_to__all__ and hasattr(obj, '__all__'):
@@ -1068,8 +1082,32 b' class Completer(Configurable):'
1068 1082 pass
1069 1083 # Build match list to return
1070 1084 n = len(attr)
1071 return [u"%s.%s" % (expr, w) for w in words if w[:n] == attr ]
1085 return ["%s.%s" % (expr, w) for w in words if w[:n] == attr ]
1086
1072 1087
1088 def _evaluate_expr(self, expr):
1089 obj = not_found
1090 done = False
1091 while not done and expr:
1092 try:
1093 obj = guarded_eval(
1094 expr,
1095 EvaluationContext(
1096 globals_=self.global_namespace,
1097 locals_=self.namespace,
1098 evaluation=self.evaluation
1099 )
1100 )
1101 done = True
1102 except Exception as e:
1103 if self.debug:
1104 print('Evaluation exception', e)
1105 # trim the expression to remove any invalid prefix
1106 # e.g. user starts `(d[`, so we get `expr = '(d'`,
1107 # where parenthesis is not closed.
1108 # TODO: make this faster by reusing parts of the computation?
1109 expr = expr[1:]
1110 return obj
1073 1111
1074 1112 def get__all__entries(obj):
1075 1113 """returns the strings in the __all__ attribute"""
@@ -1081,8 +1119,8 b' def get__all__entries(obj):'
1081 1119 return [w for w in words if isinstance(w, str)]
1082 1120
1083 1121
1084 def match_dict_keys(keys: List[Union[str, bytes, Tuple[Union[str, bytes]]]], prefix: str, delims: str,
1085 extra_prefix: Optional[Tuple[str, bytes]]=None) -> Tuple[str, int, List[str]]:
1122 def match_dict_keys(keys: List[Union[str, bytes, Tuple[Union[str, bytes], ...]]], prefix: str, delims: str,
1123 extra_prefix: Optional[Tuple[Union[str, bytes], ...]]=None) -> Tuple[str, int, List[str]]:
1086 1124 """Used by dict_key_matches, matching the prefix to a list of keys
1087 1125
1088 1126 Parameters
@@ -1106,25 +1144,28 b' def match_dict_keys(keys: List[Union[str, bytes, Tuple[Union[str, bytes]]]], pre'
1106 1144
1107 1145 """
1108 1146 prefix_tuple = extra_prefix if extra_prefix else ()
1147
1109 1148 Nprefix = len(prefix_tuple)
1149 text_serializable_types = (str, bytes, int, float, slice)
1110 1150 def filter_prefix_tuple(key):
1111 1151 # Reject too short keys
1112 1152 if len(key) <= Nprefix:
1113 1153 return False
1114 # Reject keys with non str/bytes in it
1154 # Reject keys which cannot be serialised to text
1115 1155 for k in key:
1116 if not isinstance(k, (str, bytes)):
1156 if not isinstance(k, text_serializable_types):
1117 1157 return False
1118 1158 # Reject keys that do not match the prefix
1119 1159 for k, pt in zip(key, prefix_tuple):
1120 if k != pt:
1160 if k != pt and not isinstance(pt, slice):
1121 1161 return False
1122 1162 # All checks passed!
1123 1163 return True
1124 1164
1125 filtered_keys:List[Union[str,bytes]] = []
1165 filtered_keys: List[Union[str, bytes, int, float, slice]] = []
1166
1126 1167 def _add_to_filtered_keys(key):
1127 if isinstance(key, (str, bytes)):
1168 if isinstance(key, text_serializable_types):
1128 1169 filtered_keys.append(key)
1129 1170
1130 1171 for k in keys:
@@ -1140,7 +1181,7 b' def match_dict_keys(keys: List[Union[str, bytes, Tuple[Union[str, bytes]]]], pre'
1140 1181 assert quote_match is not None # silence mypy
1141 1182 quote = quote_match.group()
1142 1183 try:
1143 prefix_str = eval(prefix + quote, {})
1184 prefix_str = literal_eval(prefix + quote)
1144 1185 except Exception:
1145 1186 return '', 0, []
1146 1187
@@ -1152,15 +1193,16 b' def match_dict_keys(keys: List[Union[str, bytes, Tuple[Union[str, bytes]]]], pre'
1152 1193
1153 1194 matched:List[str] = []
1154 1195 for key in filtered_keys:
1196 str_key = key if isinstance(key, (str, bytes)) else str(key)
1155 1197 try:
1156 if not key.startswith(prefix_str):
1198 if not str_key.startswith(prefix_str):
1157 1199 continue
1158 1200 except (AttributeError, TypeError, UnicodeError):
1159 1201 # Python 3+ TypeError on b'a'.startswith('a') or vice-versa
1160 1202 continue
1161 1203
1162 1204 # reformat remainder of key to begin with prefix
1163 rem = key[len(prefix_str):]
1205 rem = str_key[len(prefix_str):]
1164 1206 # force repr wrapped in '
1165 1207 rem_repr = repr(rem + '"') if isinstance(rem, str) else repr(rem + b'"')
1166 1208 rem_repr = rem_repr[1 + rem_repr.index("'"):-2]
@@ -1237,11 +1279,14 b' def position_to_cursor(text:str, offset:int)->Tuple[int, int]:'
1237 1279 return line, col
1238 1280
1239 1281
1240 def _safe_isinstance(obj, module, class_name):
1282 def _safe_isinstance(obj, module, class_name, *attrs):
1241 1283 """Checks if obj is an instance of module.class_name if loaded
1242 1284 """
1243 return (module in sys.modules and
1244 isinstance(obj, getattr(import_module(module), class_name)))
1285 if module in sys.modules:
1286 m = sys.modules[module]
1287 for attr in [class_name, *attrs]:
1288 m = getattr(m, attr)
1289 return isinstance(obj, m)
1245 1290
1246 1291
1247 1292 @context_matcher()
@@ -1394,6 +1439,37 b' def _make_signature(completion)-> str:'
1394 1439 _CompleteResult = Dict[str, MatcherResult]
1395 1440
1396 1441
1442 DICT_MATCHER_REGEX = re.compile(r"""(?x)
1443 ( # match dict-referring - or any get item object - expression
1444 .+
1445 )
1446 \[ # open bracket
1447 \s* # and optional whitespace
1448 # Capture any number of serializable objects (e.g. "a", "b", 'c')
1449 # and slices
1450 ((?:[uUbB]? # string prefix (r not handled)
1451 (?:
1452 '(?:[^']|(?<!\\)\\')*'
1453 |
1454 "(?:[^"]|(?<!\\)\\")*"
1455 |
1456 # capture integers and slices
1457 (?:[-+]?\d+)?(?::(?:[-+]?\d+)?){0,2}
1458 )
1459 \s*,\s*
1460 )*)
1461 ([uUbB]? # string prefix (r not handled)
1462 (?: # unclosed string
1463 '(?:[^']|(?<!\\)\\')*
1464 |
1465 "(?:[^"]|(?<!\\)\\")*
1466 |
1467 (?:[-+]?\d+)
1468 )
1469 )?
1470 $
1471 """)
1472
1397 1473 def _convert_matcher_v1_result_to_v2(
1398 1474 matches: Sequence[str],
1399 1475 type: str,
@@ -1413,14 +1489,14 b' def _convert_matcher_v1_result_to_v2('
1413 1489 class IPCompleter(Completer):
1414 1490 """Extension of the completer class with IPython-specific features"""
1415 1491
1416 __dict_key_regexps: Optional[Dict[bool,Pattern]] = None
1417
1418 1492 @observe('greedy')
1419 1493 def _greedy_changed(self, change):
1420 1494 """update the splitter and readline delims when greedy is changed"""
1421 1495 if change['new']:
1496 self.evaluation = 'unsafe'
1422 1497 self.splitter.delims = GREEDY_DELIMS
1423 1498 else:
1499 self.evaluation = 'limitted'
1424 1500 self.splitter.delims = DELIMS
1425 1501
1426 1502 dict_keys_only = Bool(
@@ -2149,12 +2225,17 b' class IPCompleter(Completer):'
2149 2225 return method()
2150 2226
2151 2227 # Special case some common in-memory dict-like types
2152 if isinstance(obj, dict) or\
2153 _safe_isinstance(obj, 'pandas', 'DataFrame'):
2228 if (isinstance(obj, dict) or
2229 _safe_isinstance(obj, 'pandas', 'DataFrame')):
2154 2230 try:
2155 2231 return list(obj.keys())
2156 2232 except Exception:
2157 2233 return []
2234 elif _safe_isinstance(obj, 'pandas', 'core', 'indexing', '_LocIndexer'):
2235 try:
2236 return list(obj.obj.keys())
2237 except Exception:
2238 return []
2158 2239 elif _safe_isinstance(obj, 'numpy', 'ndarray') or\
2159 2240 _safe_isinstance(obj, 'numpy', 'void'):
2160 2241 return obj.dtype.names or []
@@ -2175,65 +2256,43 b' class IPCompleter(Completer):'
2175 2256 You can use :meth:`dict_key_matcher` instead.
2176 2257 """
2177 2258
2178 if self.__dict_key_regexps is not None:
2179 regexps = self.__dict_key_regexps
2180 else:
2181 dict_key_re_fmt = r'''(?x)
2182 ( # match dict-referring expression wrt greedy setting
2183 %s
2184 )
2185 \[ # open bracket
2186 \s* # and optional whitespace
2187 # Capture any number of str-like objects (e.g. "a", "b", 'c')
2188 ((?:[uUbB]? # string prefix (r not handled)
2189 (?:
2190 '(?:[^']|(?<!\\)\\')*'
2191 |
2192 "(?:[^"]|(?<!\\)\\")*"
2193 )
2194 \s*,\s*
2195 )*)
2196 ([uUbB]? # string prefix (r not handled)
2197 (?: # unclosed string
2198 '(?:[^']|(?<!\\)\\')*
2199 |
2200 "(?:[^"]|(?<!\\)\\")*
2201 )
2202 )?
2203 $
2204 '''
2205 regexps = self.__dict_key_regexps = {
2206 False: re.compile(dict_key_re_fmt % r'''
2207 # identifiers separated by .
2208 (?!\d)\w+
2209 (?:\.(?!\d)\w+)*
2210 '''),
2211 True: re.compile(dict_key_re_fmt % '''
2212 .+
2213 ''')
2214 }
2259 # Short-circuit on closed dictionary (regular expression would
2260 # not match anyway, but would take quite a while).
2261 if self.text_until_cursor.strip().endswith(']'):
2262 return []
2215 2263
2216 match = regexps[self.greedy].search(self.text_until_cursor)
2264 match = DICT_MATCHER_REGEX.search(self.text_until_cursor)
2217 2265
2218 2266 if match is None:
2219 2267 return []
2220 2268
2221 expr, prefix0, prefix = match.groups()
2222 try:
2223 obj = eval(expr, self.namespace)
2224 except Exception:
2225 try:
2226 obj = eval(expr, self.global_namespace)
2227 except Exception:
2269 expr, prior_tuple_keys, key_prefix = match.groups()
2270
2271 obj = self._evaluate_expr(expr)
2272
2273 if obj is not_found:
2228 2274 return []
2229 2275
2230 2276 keys = self._get_keys(obj)
2231 2277 if not keys:
2232 2278 return keys
2233 2279
2234 extra_prefix = eval(prefix0) if prefix0 != '' else None
2280 tuple_prefix = guarded_eval(
2281 prior_tuple_keys,
2282 EvaluationContext(
2283 globals_=self.global_namespace,
2284 locals_=self.namespace,
2285 evaluation=self.evaluation,
2286 in_subscript=True
2287 )
2288 )
2235 2289
2236 closing_quote, token_offset, matches = match_dict_keys(keys, prefix, self.splitter.delims, extra_prefix=extra_prefix)
2290 closing_quote, token_offset, matches = match_dict_keys(
2291 keys,
2292 key_prefix,
2293 self.splitter.delims,
2294 extra_prefix=tuple_prefix
2295 )
2237 2296 if not matches:
2238 2297 return matches
2239 2298
@@ -2242,7 +2301,7 b' class IPCompleter(Completer):'
2242 2301 # - the start of the key text
2243 2302 # - the start of the completion
2244 2303 text_start = len(self.text_until_cursor) - len(text)
2245 if prefix:
2304 if key_prefix:
2246 2305 key_start = match.start(3)
2247 2306 completion_start = key_start + token_offset
2248 2307 else:
@@ -113,6 +113,17 b' def greedy_completion():'
113 113
114 114
115 115 @contextmanager
116 def evaluation_level(evaluation: str):
117 ip = get_ipython()
118 evaluation_original = ip.Completer.evaluation
119 try:
120 ip.Completer.evaluation = evaluation
121 yield
122 finally:
123 ip.Completer.evaluation = evaluation_original
124
125
126 @contextmanager
116 127 def custom_matchers(matchers):
117 128 ip = get_ipython()
118 129 try:
@@ -852,8 +863,6 b' class TestCompleter(unittest.TestCase):'
852 863 assert match_dict_keys(keys, '"', delims=delims) == ('"', 1, ["foo"])
853 864 assert match_dict_keys(keys, '"f', delims=delims) == ('"', 1, ["foo"])
854 865
855 match_dict_keys
856
857 866 def test_match_dict_keys_tuple(self):
858 867 """
859 868 Test that match_dict_keys called with extra prefix works on a couple of use case,
@@ -883,6 +892,11 b' class TestCompleter(unittest.TestCase):'
883 892 assert match_dict_keys(keys, "'foo", delims=delims, extra_prefix=('foo1', 'foo2', 'foo3')) == ("'", 1, ["foo4"])
884 893 assert match_dict_keys(keys, "'foo", delims=delims, extra_prefix=('foo1', 'foo2', 'foo3', 'foo4')) == ("'", 1, [])
885 894
895 keys = [("foo", 1111), ("foo", 2222), (3333, "bar"), (3333, 'test')]
896 assert match_dict_keys(keys, "'", delims=delims, extra_prefix=("foo",)) == ("'", 1, ["1111", "2222"])
897 assert match_dict_keys(keys, "'", delims=delims, extra_prefix=(3333,)) == ("'", 1, ["bar", "test"])
898 assert match_dict_keys(keys, "'", delims=delims, extra_prefix=("3333",)) == ("'", 1, [])
899
886 900 def test_dict_key_completion_string(self):
887 901 """Test dictionary key completion for string keys"""
888 902 ip = get_ipython()
@@ -1050,6 +1064,7 b' class TestCompleter(unittest.TestCase):'
1050 1064
1051 1065 ip.user_ns["C"] = C
1052 1066 ip.user_ns["get"] = lambda: d
1067 ip.user_ns["nested"] = {'x': d}
1053 1068
1054 1069 def assert_no_completion(**kwargs):
1055 1070 _, matches = complete(**kwargs)
@@ -1075,6 +1090,13 b' class TestCompleter(unittest.TestCase):'
1075 1090 assert_completion(line_buffer="(d[")
1076 1091 assert_completion(line_buffer="C.data[")
1077 1092
1093 # nested dict completion
1094 assert_completion(line_buffer="nested['x'][")
1095
1096 with evaluation_level('minimal'):
1097 with pytest.raises(AssertionError):
1098 assert_completion(line_buffer="nested['x'][")
1099
1078 1100 # greedy flag
1079 1101 def assert_completion(**kwargs):
1080 1102 _, matches = complete(**kwargs)
@@ -1162,12 +1184,21 b' class TestCompleter(unittest.TestCase):'
1162 1184 _, matches = complete(line_buffer="d['")
1163 1185 self.assertIn("my_head", matches)
1164 1186 self.assertIn("my_data", matches)
1165 # complete on a nested level
1166 with greedy_completion():
1187 def completes_on_nested():
1167 1188 ip.user_ns["d"] = numpy.zeros(2, dtype=dt)
1168 1189 _, matches = complete(line_buffer="d[1]['my_head']['")
1169 1190 self.assertTrue(any(["my_dt" in m for m in matches]))
1170 1191 self.assertTrue(any(["my_df" in m for m in matches]))
1192 # complete on a nested level
1193 with greedy_completion():
1194 completes_on_nested()
1195
1196 with evaluation_level('limitted'):
1197 completes_on_nested()
1198
1199 with evaluation_level('minimal'):
1200 with pytest.raises(AssertionError):
1201 completes_on_nested()
1171 1202
1172 1203 @dec.skip_without("pandas")
1173 1204 def test_dataframe_key_completion(self):
@@ -1180,6 +1211,17 b' class TestCompleter(unittest.TestCase):'
1180 1211 _, matches = complete(line_buffer="d['")
1181 1212 self.assertIn("hello", matches)
1182 1213 self.assertIn("world", matches)
1214 _, matches = complete(line_buffer="d.loc[:, '")
1215 self.assertIn("hello", matches)
1216 self.assertIn("world", matches)
1217 _, matches = complete(line_buffer="d.loc[1:, '")
1218 self.assertIn("hello", matches)
1219 _, matches = complete(line_buffer="d.loc[1:1, '")
1220 self.assertIn("hello", matches)
1221 _, matches = complete(line_buffer="d.loc[1:1:-1, '")
1222 self.assertIn("hello", matches)
1223 _, matches = complete(line_buffer="d.loc[::, '")
1224 self.assertIn("hello", matches)
1183 1225
1184 1226 def test_dict_key_completion_invalids(self):
1185 1227 """Smoke test cases dict key completion can't handle"""
General Comments 0
You need to be logged in to leave comments. Login now