diff --git a/IPython/core/completer.py b/IPython/core/completer.py index 2dff9ef..a497f12 100644 --- a/IPython/core/completer.py +++ b/IPython/core/completer.py @@ -190,6 +190,7 @@ import time import unicodedata import uuid import warnings +from ast import literal_eval from contextlib import contextmanager from dataclasses import dataclass from functools import cached_property, partial @@ -212,6 +213,7 @@ from typing import ( Literal, ) +from IPython.core.guarded_eval import guarded_eval, EvaluationContext from IPython.core.error import TryNext from IPython.core.inputtransformer2 import ESC_MAGIC from IPython.core.latex_symbols import latex_symbols, reverse_latex_symbol @@ -296,6 +298,9 @@ MATCHES_LIMIT = 500 # Completion type reported when no type can be inferred. _UNKNOWN_TYPE = "" +# sentinel value to signal lack of a match +not_found = object() + class ProvisionalCompleterWarning(FutureWarning): """ Exception raise by an experimental feature in this module. @@ -902,12 +907,33 @@ class CompletionSplitter(object): class Completer(Configurable): - greedy = Bool(False, - help="""Activate greedy completion - PENDING DEPRECATION. this is now mostly taken care of with Jedi. + greedy = Bool( + False, + help="""Activate greedy completion. + + .. deprecated:: 8.8 + Use :any:`evaluation` instead. + + As of IPython 8.8 proxy for ``evaluation = 'unsafe'`` when set to ``True``, + and for ``'forbidden'`` when set to ``False``. + """, + ).tag(config=True) - This will enable completion on elements of lists, results of function calls, etc., - but can be unsafe because the code is actually evaluated on TAB. + evaluation = Enum( + ('forbidden', 'minimal', 'limitted', 'unsafe', 'dangerous'), + default_value='limitted', + help="""Code evaluation under completion. + + Successive options allow to enable more eager evaluation for more accurate completion suggestions, + including for nested dictionaries, nested lists, or even results of function calls. Setting `unsafe` + or higher can lead to evaluation of arbitrary user code on TAB with potentially dangerous side effects. + + Allowed values are: + - `forbidden`: no evaluation at all + - `minimal`: evaluation of literals and access to built-in namespaces; no item/attribute evaluation nor access to locals/globals + - `limitted` (default): access to all namespaces, evaluation of hard-coded methods (``keys()``, ``__getattr__``, ``__getitems__``, etc) on allow-listed objects (e.g. ``dict``, ``list``, ``tuple``, ``pandas.Series``) + - `unsafe`: evaluation of all methods and function calls but not of syntax with side-effects like `del x`, + - `dangerous`: completely arbitrary evaluation """, ).tag(config=True) @@ -1029,28 +1055,16 @@ class Completer(Configurable): with a __getattr__ hook is evaluated. """ + m2 = re.match(r"(.+)\.(\w*)$", self.line_buffer) + if not m2: + return [] + expr, attr = m2.group(1,2) - # Another option, seems to work great. Catches things like ''. - m = re.match(r"(\S+(\.\w+)*)\.(\w*)$", text) + obj = self._evaluate_expr(expr) - if m: - expr, attr = m.group(1, 3) - elif self.greedy: - m2 = re.match(r"(.+)\.(\w*)$", self.line_buffer) - if not m2: - return [] - expr, attr = m2.group(1,2) - else: + if obj is not_found: return [] - try: - obj = eval(expr, self.namespace) - except: - try: - obj = eval(expr, self.global_namespace) - except: - return [] - if self.limit_to__all__ and hasattr(obj, '__all__'): words = get__all__entries(obj) else: @@ -1068,9 +1082,33 @@ class Completer(Configurable): pass # Build match list to return n = len(attr) - return [u"%s.%s" % (expr, w) for w in words if w[:n] == attr ] + return ["%s.%s" % (expr, w) for w in words if w[:n] == attr ] + def _evaluate_expr(self, expr): + obj = not_found + done = False + while not done and expr: + try: + obj = guarded_eval( + expr, + EvaluationContext( + globals_=self.global_namespace, + locals_=self.namespace, + evaluation=self.evaluation + ) + ) + done = True + except Exception as e: + if self.debug: + print('Evaluation exception', e) + # trim the expression to remove any invalid prefix + # e.g. user starts `(d[`, so we get `expr = '(d'`, + # where parenthesis is not closed. + # TODO: make this faster by reusing parts of the computation? + expr = expr[1:] + return obj + def get__all__entries(obj): """returns the strings in the __all__ attribute""" try: @@ -1081,8 +1119,8 @@ def get__all__entries(obj): return [w for w in words if isinstance(w, str)] -def match_dict_keys(keys: List[Union[str, bytes, Tuple[Union[str, bytes]]]], prefix: str, delims: str, - extra_prefix: Optional[Tuple[str, bytes]]=None) -> Tuple[str, int, List[str]]: +def match_dict_keys(keys: List[Union[str, bytes, Tuple[Union[str, bytes], ...]]], prefix: str, delims: str, + extra_prefix: Optional[Tuple[Union[str, bytes], ...]]=None) -> Tuple[str, int, List[str]]: """Used by dict_key_matches, matching the prefix to a list of keys Parameters @@ -1106,25 +1144,28 @@ def match_dict_keys(keys: List[Union[str, bytes, Tuple[Union[str, bytes]]]], pre """ prefix_tuple = extra_prefix if extra_prefix else () + Nprefix = len(prefix_tuple) + text_serializable_types = (str, bytes, int, float, slice) def filter_prefix_tuple(key): # Reject too short keys if len(key) <= Nprefix: return False - # Reject keys with non str/bytes in it + # Reject keys which cannot be serialised to text for k in key: - if not isinstance(k, (str, bytes)): + if not isinstance(k, text_serializable_types): return False # Reject keys that do not match the prefix for k, pt in zip(key, prefix_tuple): - if k != pt: + if k != pt and not isinstance(pt, slice): return False # All checks passed! return True - filtered_keys:List[Union[str,bytes]] = [] + filtered_keys: List[Union[str, bytes, int, float, slice]] = [] + def _add_to_filtered_keys(key): - if isinstance(key, (str, bytes)): + if isinstance(key, text_serializable_types): filtered_keys.append(key) for k in keys: @@ -1140,7 +1181,7 @@ def match_dict_keys(keys: List[Union[str, bytes, Tuple[Union[str, bytes]]]], pre assert quote_match is not None # silence mypy quote = quote_match.group() try: - prefix_str = eval(prefix + quote, {}) + prefix_str = literal_eval(prefix + quote) except Exception: return '', 0, [] @@ -1150,17 +1191,18 @@ def match_dict_keys(keys: List[Union[str, bytes, Tuple[Union[str, bytes]]]], pre token_start = token_match.start() token_prefix = token_match.group() - matched:List[str] = [] + matched: List[str] = [] for key in filtered_keys: + str_key = key if isinstance(key, (str, bytes)) else str(key) try: - if not key.startswith(prefix_str): + if not str_key.startswith(prefix_str): continue except (AttributeError, TypeError, UnicodeError): # Python 3+ TypeError on b'a'.startswith('a') or vice-versa continue # reformat remainder of key to begin with prefix - rem = key[len(prefix_str):] + rem = str_key[len(prefix_str):] # force repr wrapped in ' rem_repr = repr(rem + '"') if isinstance(rem, str) else repr(rem + b'"') rem_repr = rem_repr[1 + rem_repr.index("'"):-2] @@ -1237,11 +1279,14 @@ def position_to_cursor(text:str, offset:int)->Tuple[int, int]: return line, col -def _safe_isinstance(obj, module, class_name): +def _safe_isinstance(obj, module, class_name, *attrs): """Checks if obj is an instance of module.class_name if loaded """ - return (module in sys.modules and - isinstance(obj, getattr(import_module(module), class_name))) + if module in sys.modules: + m = sys.modules[module] + for attr in [class_name, *attrs]: + m = getattr(m, attr) + return isinstance(obj, m) @context_matcher() @@ -1394,6 +1439,37 @@ def _make_signature(completion)-> str: _CompleteResult = Dict[str, MatcherResult] +DICT_MATCHER_REGEX = re.compile(r"""(?x) +( # match dict-referring - or any get item object - expression + .+ +) +\[ # open bracket +\s* # and optional whitespace +# Capture any number of serializable objects (e.g. "a", "b", 'c') +# and slices +((?:[uUbB]? # string prefix (r not handled) + (?: + '(?:[^']|(? None: ... + + +class InstancesHaveGetItem(Protocol): + def __call__(self) -> HasGetItem: ... + + +class HasGetAttr(Protocol): + def __getattr__(self, key) -> None: ... + + +class DoesNotHaveGetAttr(Protocol): + pass + +# By default `__getattr__` is not explicitly implemented on most objects +MayHaveGetattr = Union[HasGetAttr, DoesNotHaveGetAttr] + + +def unbind_method(func: Callable) -> Union[Callable, None]: + """Get unbound method for given bound method. + + Returns None if cannot get unbound method.""" + owner = getattr(func, '__self__', None) + owner_class = type(owner) + name = getattr(func, '__name__', None) + instance_dict_overrides = getattr(owner, '__dict__', None) + if ( + owner is not None + and + name + and + ( + not instance_dict_overrides + or + ( + instance_dict_overrides + and name not in instance_dict_overrides + ) + ) + ): + return getattr(owner_class, name) + + +@dataclass +class EvaluationPolicy: + allow_locals_access: bool = False + allow_globals_access: bool = False + allow_item_access: bool = False + allow_attr_access: bool = False + allow_builtins_access: bool = False + allow_any_calls: bool = False + allowed_calls: Set[Callable] = field(default_factory=set) + + def can_get_item(self, value, item): + return self.allow_item_access + + def can_get_attr(self, value, attr): + return self.allow_attr_access + + def can_call(self, func): + if self.allow_any_calls: + return True + + if func in self.allowed_calls: + return True + + owner_method = unbind_method(func) + if owner_method and owner_method in self.allowed_calls: + return True + +def has_original_dunder_external(value, module_name, access_path, method_name,): + try: + if module_name not in sys.modules: + return False + member_type = sys.modules[module_name] + for attr in access_path: + member_type = getattr(member_type, attr) + value_type = type(value) + if type(value) == member_type: + return True + if isinstance(value, member_type): + method = getattr(value_type, method_name, None) + member_method = getattr(member_type, method_name, None) + if member_method == method: + return True + except (AttributeError, KeyError): + return False + + +def has_original_dunder( + value, + allowed_types, + allowed_methods, + allowed_external, + method_name +): + # note: Python ignores `__getattr__`/`__getitem__` on instances, + # we only need to check at class level + value_type = type(value) + + # strict type check passes → no need to check method + if value_type in allowed_types: + return True + + method = getattr(value_type, method_name, None) + + if not method: + return None + + if method in allowed_methods: + return True + + for module_name, *access_path in allowed_external: + if has_original_dunder_external(value, module_name, access_path, method_name): + return True + + return False + + +@dataclass +class SelectivePolicy(EvaluationPolicy): + allowed_getitem: Set[HasGetItem] = field(default_factory=set) + allowed_getitem_external: Set[Tuple[str, ...]] = field(default_factory=set) + allowed_getattr: Set[MayHaveGetattr] = field(default_factory=set) + allowed_getattr_external: Set[Tuple[str, ...]] = field(default_factory=set) + + def can_get_attr(self, value, attr): + has_original_attribute = has_original_dunder( + value, + allowed_types=self.allowed_getattr, + allowed_methods=self._getattribute_methods, + allowed_external=self.allowed_getattr_external, + method_name='__getattribute__' + ) + has_original_attr = has_original_dunder( + value, + allowed_types=self.allowed_getattr, + allowed_methods=self._getattr_methods, + allowed_external=self.allowed_getattr_external, + method_name='__getattr__' + ) + # Many objects do not have `__getattr__`, this is fine + if has_original_attr is None and has_original_attribute: + return True + + # Accept objects without modifications to `__getattr__` and `__getattribute__` + return has_original_attr and has_original_attribute + + def get_attr(self, value, attr): + if self.can_get_attr(value, attr): + return getattr(value, attr) + + + def can_get_item(self, value, item): + """Allow accessing `__getiitem__` of allow-listed instances unless it was not modified.""" + return has_original_dunder( + value, + allowed_types=self.allowed_getitem, + allowed_methods=self._getitem_methods, + allowed_external=self.allowed_getitem_external, + method_name='__getitem__' + ) + + @cached_property + def _getitem_methods(self) -> Set[Callable]: + return self._safe_get_methods( + self.allowed_getitem, + '__getitem__' + ) + + @cached_property + def _getattr_methods(self) -> Set[Callable]: + return self._safe_get_methods( + self.allowed_getattr, + '__getattr__' + ) + + @cached_property + def _getattribute_methods(self) -> Set[Callable]: + return self._safe_get_methods( + self.allowed_getattr, + '__getattribute__' + ) + + def _safe_get_methods(self, classes, name) -> Set[Callable]: + return { + method + for class_ in classes + for method in [getattr(class_, name, None)] + if method + } + + +class DummyNamedTuple(NamedTuple): + pass + + +class EvaluationContext(NamedTuple): + locals_: dict + globals_: dict + evaluation: Literal['forbidden', 'minimal', 'limitted', 'unsafe', 'dangerous'] = 'forbidden' + in_subscript: bool = False + + +class IdentitySubscript: + def __getitem__(self, key): + return key + +IDENTITY_SUBSCRIPT = IdentitySubscript() +SUBSCRIPT_MARKER = '__SUBSCRIPT_SENTINEL__' + +class GuardRejection(ValueError): + pass + + +def guarded_eval( + code: str, + context: EvaluationContext +): + locals_ = context.locals_ + + if context.evaluation == 'forbidden': + raise GuardRejection('Forbidden mode') + + # note: not using `ast.literal_eval` as it does not implement + # getitem at all, for example it fails on simple `[0][1]` + + if context.in_subscript: + # syntatic sugar for ellipsis (:) is only available in susbcripts + # so we need to trick the ast parser into thinking that we have + # a subscript, but we need to be able to later recognise that we did + # it so we can ignore the actual __getitem__ operation + if not code: + return tuple() + locals_ = locals_.copy() + locals_[SUBSCRIPT_MARKER] = IDENTITY_SUBSCRIPT + code = SUBSCRIPT_MARKER + '[' + code + ']' + context = EvaluationContext(**{ + **context._asdict(), + **{'locals_': locals_} + }) + + if context.evaluation == 'dangerous': + return eval(code, context.globals_, context.locals_) + + expression = ast.parse(code, mode='eval') + + return eval_node(expression, context) + +def eval_node(node: Union[ast.AST, None], context: EvaluationContext): + """ + Evaluate AST node in provided context. + + Applies evaluation restrictions defined in the context. + + Currently does not support evaluation of functions with arguments. + + Does not evaluate actions which always have side effects: + - class definitions (`class sth: ...`) + - function definitions (`def sth: ...`) + - variable assignments (`x = 1`) + - augumented assignments (`x += 1`) + - deletions (`del x`) + + Does not evaluate operations which do not return values: + - assertions (`assert x`) + - pass (`pass`) + - imports (`import x`) + - control flow + - conditionals (`if x:`) except for terenary IfExp (`a if x else b`) + - loops (`for` and `while`) + - exception handling + """ + policy = EVALUATION_POLICIES[context.evaluation] + if node is None: + return None + if isinstance(node, ast.Expression): + return eval_node(node.body, context) + if isinstance(node, ast.BinOp): + # TODO: add guards + left = eval_node(node.left, context) + right = eval_node(node.right, context) + if isinstance(node.op, ast.Add): + return left + right + if isinstance(node.op, ast.Sub): + return left - right + if isinstance(node.op, ast.Mult): + return left * right + if isinstance(node.op, ast.Div): + return left / right + if isinstance(node.op, ast.FloorDiv): + return left // right + if isinstance(node.op, ast.Mod): + return left % right + if isinstance(node.op, ast.Pow): + return left ** right + if isinstance(node.op, ast.LShift): + return left << right + if isinstance(node.op, ast.RShift): + return left >> right + if isinstance(node.op, ast.BitOr): + return left | right + if isinstance(node.op, ast.BitXor): + return left ^ right + if isinstance(node.op, ast.BitAnd): + return left & right + if isinstance(node.op, ast.MatMult): + return left @ right + if isinstance(node, ast.Constant): + return node.value + if isinstance(node, ast.Index): + return eval_node(node.value, context) + if isinstance(node, ast.Tuple): + return tuple( + eval_node(e, context) + for e in node.elts + ) + if isinstance(node, ast.List): + return [ + eval_node(e, context) + for e in node.elts + ] + if isinstance(node, ast.Set): + return { + eval_node(e, context) + for e in node.elts + } + if isinstance(node, ast.Dict): + return dict(zip( + [eval_node(k, context) for k in node.keys], + [eval_node(v, context) for v in node.values] + )) + if isinstance(node, ast.Slice): + return slice( + eval_node(node.lower, context), + eval_node(node.upper, context), + eval_node(node.step, context) + ) + if isinstance(node, ast.ExtSlice): + return tuple([ + eval_node(dim, context) + for dim in node.dims + ]) + if isinstance(node, ast.UnaryOp): + # TODO: add guards + value = eval_node(node.operand, context) + if isinstance(node.op, ast.USub): + return -value + if isinstance(node.op, ast.UAdd): + return +value + if isinstance(node.op, ast.Invert): + return ~value + if isinstance(node.op, ast.Not): + return not value + raise ValueError('Unhandled unary operation:', node.op) + if isinstance(node, ast.Subscript): + value = eval_node(node.value, context) + slice_ = eval_node(node.slice, context) + if policy.can_get_item(value, slice_): + return value[slice_] + raise GuardRejection( + 'Subscript access (`__getitem__`) for', + type(value), # not joined to avoid calling `repr` + f' not allowed in {context.evaluation} mode' + ) + if isinstance(node, ast.Name): + if policy.allow_locals_access and node.id in context.locals_: + return context.locals_[node.id] + if policy.allow_globals_access and node.id in context.globals_: + return context.globals_[node.id] + if policy.allow_builtins_access and node.id in __builtins__: + return __builtins__[node.id] + if not policy.allow_globals_access and not policy.allow_locals_access: + raise GuardRejection( + f'Namespace access not allowed in {context.evaluation} mode' + ) + else: + raise NameError(f'{node.id} not found in locals nor globals') + if isinstance(node, ast.Attribute): + value = eval_node(node.value, context) + if policy.can_get_attr(value, node.attr): + return getattr(value, node.attr) + raise GuardRejection( + 'Attribute access (`__getattr__`) for', + type(value), # not joined to avoid calling `repr` + f'not allowed in {context.evaluation} mode' + ) + if isinstance(node, ast.IfExp): + test = eval_node(node.test, context) + if test: + return eval_node(node.body, context) + else: + return eval_node(node.orelse, context) + if isinstance(node, ast.Call): + func = eval_node(node.func, context) + print(node.keywords) + if policy.can_call(func) and not node.keywords: + args = [ + eval_node(arg, context) + for arg in node.args + ] + return func(*args) + raise GuardRejection( + 'Call for', + func, # not joined to avoid calling `repr` + f'not allowed in {context.evaluation} mode' + ) + raise ValueError('Unhandled node', node) + + +SUPPORTED_EXTERNAL_GETITEM = { + ('pandas', 'core', 'indexing', '_iLocIndexer'), + ('pandas', 'core', 'indexing', '_LocIndexer'), + ('pandas', 'DataFrame'), + ('pandas', 'Series'), + ('numpy', 'ndarray'), + ('numpy', 'void') +} + +BUILTIN_GETITEM = { + dict, + str, + bytes, + list, + tuple, + collections.defaultdict, + collections.deque, + collections.OrderedDict, + collections.ChainMap, + collections.UserDict, + collections.UserList, + collections.UserString, + DummyNamedTuple, + IdentitySubscript +} + + +def _list_methods(cls, source=None): + """For use on immutable objects or with methods returning a copy""" + return [ + getattr(cls, k) + for k in (source if source else dir(cls)) + ] + + +dict_non_mutating_methods = ('copy', 'keys', 'values', 'items') +list_non_mutating_methods = ('copy', 'index', 'count') +set_non_mutating_methods = set(dir(set)) & set(dir(frozenset)) + + +dict_keys = type({}.keys()) +method_descriptor = type(list.copy) + +ALLOWED_CALLS = { + bytes, + *_list_methods(bytes), + dict, + *_list_methods(dict, dict_non_mutating_methods), + dict_keys.isdisjoint, + list, + *_list_methods(list, list_non_mutating_methods), + set, + *_list_methods(set, set_non_mutating_methods), + frozenset, + *_list_methods(frozenset), + range, + str, + *_list_methods(str), + tuple, + *_list_methods(tuple), + collections.deque, + *_list_methods(collections.deque, list_non_mutating_methods), + collections.defaultdict, + *_list_methods(collections.defaultdict, dict_non_mutating_methods), + collections.OrderedDict, + *_list_methods(collections.OrderedDict, dict_non_mutating_methods), + collections.UserDict, + *_list_methods(collections.UserDict, dict_non_mutating_methods), + collections.UserList, + *_list_methods(collections.UserList, list_non_mutating_methods), + collections.UserString, + *_list_methods(collections.UserString, dir(str)), + collections.Counter, + *_list_methods(collections.Counter, dict_non_mutating_methods), + collections.Counter.elements, + collections.Counter.most_common +} + +EVALUATION_POLICIES = { + 'minimal': EvaluationPolicy( + allow_builtins_access=True, + allow_locals_access=False, + allow_globals_access=False, + allow_item_access=False, + allow_attr_access=False, + allowed_calls=set(), + allow_any_calls=False + ), + 'limitted': SelectivePolicy( + # TODO: + # - should reject binary and unary operations if custom methods would be dispatched + allowed_getitem=BUILTIN_GETITEM, + allowed_getitem_external=SUPPORTED_EXTERNAL_GETITEM, + allowed_getattr={ + *BUILTIN_GETITEM, + set, + frozenset, + object, + type, # `type` handles a lot of generic cases, e.g. numbers as in `int.real`. + dict_keys, + method_descriptor + }, + allowed_getattr_external={ + # pandas Series/Frame implements custom `__getattr__` + ('pandas', 'DataFrame'), + ('pandas', 'Series') + }, + allow_builtins_access=True, + allow_locals_access=True, + allow_globals_access=True, + allowed_calls=ALLOWED_CALLS + ), + 'unsafe': EvaluationPolicy( + allow_builtins_access=True, + allow_locals_access=True, + allow_globals_access=True, + allow_attr_access=True, + allow_item_access=True, + allow_any_calls=True + ) +} \ No newline at end of file diff --git a/IPython/core/tests/test_completer.py b/IPython/core/tests/test_completer.py index 98ec814..7a99a26 100644 --- a/IPython/core/tests/test_completer.py +++ b/IPython/core/tests/test_completer.py @@ -113,6 +113,17 @@ def greedy_completion(): @contextmanager +def evaluation_level(evaluation: str): + ip = get_ipython() + evaluation_original = ip.Completer.evaluation + try: + ip.Completer.evaluation = evaluation + yield + finally: + ip.Completer.evaluation = evaluation_original + + +@contextmanager def custom_matchers(matchers): ip = get_ipython() try: @@ -522,10 +533,10 @@ class TestCompleter(unittest.TestCase): def test_greedy_completions(self): """ - Test the capability of the Greedy completer. + Test the capability of the Greedy completer. Most of the test here does not really show off the greedy completer, for proof - each of the text below now pass with Jedi. The greedy completer is capable of more. + each of the text below now pass with Jedi. The greedy completer is capable of more. See the :any:`test_dict_key_completion_contexts` @@ -852,15 +863,13 @@ class TestCompleter(unittest.TestCase): assert match_dict_keys(keys, '"', delims=delims) == ('"', 1, ["foo"]) assert match_dict_keys(keys, '"f', delims=delims) == ('"', 1, ["foo"]) - match_dict_keys - def test_match_dict_keys_tuple(self): """ Test that match_dict_keys called with extra prefix works on a couple of use case, does return what expected, and does not crash. """ delims = " \t\n`!@#$^&*()=+[{]}\\|;:'\",<>?" - + keys = [("foo", "bar"), ("foo", "oof"), ("foo", b"bar"), ('other', 'test')] # Completion on first key == "foo" @@ -883,6 +892,11 @@ class TestCompleter(unittest.TestCase): assert match_dict_keys(keys, "'foo", delims=delims, extra_prefix=('foo1', 'foo2', 'foo3')) == ("'", 1, ["foo4"]) assert match_dict_keys(keys, "'foo", delims=delims, extra_prefix=('foo1', 'foo2', 'foo3', 'foo4')) == ("'", 1, []) + keys = [("foo", 1111), ("foo", 2222), (3333, "bar"), (3333, 'test')] + assert match_dict_keys(keys, "'", delims=delims, extra_prefix=("foo",)) == ("'", 1, ["1111", "2222"]) + assert match_dict_keys(keys, "'", delims=delims, extra_prefix=(3333,)) == ("'", 1, ["bar", "test"]) + assert match_dict_keys(keys, "'", delims=delims, extra_prefix=("3333",)) == ("'", 1, []) + def test_dict_key_completion_string(self): """Test dictionary key completion for string keys""" ip = get_ipython() @@ -1050,6 +1064,7 @@ class TestCompleter(unittest.TestCase): ip.user_ns["C"] = C ip.user_ns["get"] = lambda: d + ip.user_ns["nested"] = {'x': d} def assert_no_completion(**kwargs): _, matches = complete(**kwargs) @@ -1075,6 +1090,13 @@ class TestCompleter(unittest.TestCase): assert_completion(line_buffer="(d[") assert_completion(line_buffer="C.data[") + # nested dict completion + assert_completion(line_buffer="nested['x'][") + + with evaluation_level('minimal'): + with pytest.raises(AssertionError): + assert_completion(line_buffer="nested['x'][") + # greedy flag def assert_completion(**kwargs): _, matches = complete(**kwargs) @@ -1162,12 +1184,21 @@ class TestCompleter(unittest.TestCase): _, matches = complete(line_buffer="d['") self.assertIn("my_head", matches) self.assertIn("my_data", matches) - # complete on a nested level - with greedy_completion(): + def completes_on_nested(): ip.user_ns["d"] = numpy.zeros(2, dtype=dt) _, matches = complete(line_buffer="d[1]['my_head']['") self.assertTrue(any(["my_dt" in m for m in matches])) self.assertTrue(any(["my_df" in m for m in matches])) + # complete on a nested level + with greedy_completion(): + completes_on_nested() + + with evaluation_level('limitted'): + completes_on_nested() + + with evaluation_level('minimal'): + with pytest.raises(AssertionError): + completes_on_nested() @dec.skip_without("pandas") def test_dataframe_key_completion(self): @@ -1180,6 +1211,17 @@ class TestCompleter(unittest.TestCase): _, matches = complete(line_buffer="d['") self.assertIn("hello", matches) self.assertIn("world", matches) + _, matches = complete(line_buffer="d.loc[:, '") + self.assertIn("hello", matches) + self.assertIn("world", matches) + _, matches = complete(line_buffer="d.loc[1:, '") + self.assertIn("hello", matches) + _, matches = complete(line_buffer="d.loc[1:1, '") + self.assertIn("hello", matches) + _, matches = complete(line_buffer="d.loc[1:1:-1, '") + self.assertIn("hello", matches) + _, matches = complete(line_buffer="d.loc[::, '") + self.assertIn("hello", matches) def test_dict_key_completion_invalids(self): """Smoke test cases dict key completion can't handle""" diff --git a/IPython/core/tests/test_guarded_eval.py b/IPython/core/tests/test_guarded_eval.py new file mode 100644 index 0000000..5c89a68 --- /dev/null +++ b/IPython/core/tests/test_guarded_eval.py @@ -0,0 +1,286 @@ +from typing import NamedTuple +from IPython.core.guarded_eval import EvaluationContext, GuardRejection, guarded_eval, unbind_method +from IPython.testing import decorators as dec +import pytest + + +def limitted(**kwargs): + return EvaluationContext( + locals_=kwargs, + globals_={}, + evaluation='limitted' + ) + + +def unsafe(**kwargs): + return EvaluationContext( + locals_=kwargs, + globals_={}, + evaluation='unsafe' + ) + +@dec.skip_without('pandas') +def test_pandas_series_iloc(): + import pandas as pd + series = pd.Series([1], index=['a']) + context = limitted(data=series) + assert guarded_eval('data.iloc[0]', context) == 1 + + +@dec.skip_without('pandas') +def test_pandas_series(): + import pandas as pd + context = limitted(data=pd.Series([1], index=['a'])) + assert guarded_eval('data["a"]', context) == 1 + with pytest.raises(KeyError): + guarded_eval('data["c"]', context) + + +@dec.skip_without('pandas') +def test_pandas_bad_series(): + import pandas as pd + class BadItemSeries(pd.Series): + def __getitem__(self, key): + return 'CUSTOM_ITEM' + + class BadAttrSeries(pd.Series): + def __getattr__(self, key): + return 'CUSTOM_ATTR' + + bad_series = BadItemSeries([1], index=['a']) + context = limitted(data=bad_series) + + with pytest.raises(GuardRejection): + guarded_eval('data["a"]', context) + with pytest.raises(GuardRejection): + guarded_eval('data["c"]', context) + + # note: here result is a bit unexpected because + # pandas `__getattr__` calls `__getitem__`; + # FIXME - special case to handle it? + assert guarded_eval('data.a', context) == 'CUSTOM_ITEM' + + context = unsafe(data=bad_series) + assert guarded_eval('data["a"]', context) == 'CUSTOM_ITEM' + + bad_attr_series = BadAttrSeries([1], index=['a']) + context = limitted(data=bad_attr_series) + assert guarded_eval('data["a"]', context) == 1 + with pytest.raises(GuardRejection): + guarded_eval('data.a', context) + + +@dec.skip_without('pandas') +def test_pandas_dataframe_loc(): + import pandas as pd + from pandas.testing import assert_series_equal + data = pd.DataFrame([{'a': 1}]) + context = limitted(data=data) + assert_series_equal( + guarded_eval('data.loc[:, "a"]', context), + data['a'] + ) + + +def test_named_tuple(): + + class GoodNamedTuple(NamedTuple): + a: str + pass + + class BadNamedTuple(NamedTuple): + a: str + def __getitem__(self, key): + return None + + good = GoodNamedTuple(a='x') + bad = BadNamedTuple(a='x') + + context = limitted(data=good) + assert guarded_eval('data[0]', context) == 'x' + + context = limitted(data=bad) + with pytest.raises(GuardRejection): + guarded_eval('data[0]', context) + + +def test_dict(): + context = limitted( + data={'a': 1, 'b': {'x': 2}, ('x', 'y'): 3} + ) + assert guarded_eval('data["a"]', context) == 1 + assert guarded_eval('data["b"]', context) == {'x': 2} + assert guarded_eval('data["b"]["x"]', context) == 2 + assert guarded_eval('data["x", "y"]', context) == 3 + + assert guarded_eval('data.keys', context) + + +def test_set(): + context = limitted(data={'a', 'b'}) + assert guarded_eval('data.difference', context) + + +def test_list(): + context = limitted(data=[1, 2, 3]) + assert guarded_eval('data[1]', context) == 2 + assert guarded_eval('data.copy', context) + + +def test_dict_literal(): + context = limitted() + assert guarded_eval('{}', context) == {} + assert guarded_eval('{"a": 1}', context) == {"a": 1} + + +def test_list_literal(): + context = limitted() + assert guarded_eval('[]', context) == [] + assert guarded_eval('[1, "a"]', context) == [1, "a"] + + +def test_set_literal(): + context = limitted() + assert guarded_eval('set()', context) == set() + assert guarded_eval('{"a"}', context) == {"a"} + + +def test_if_expression(): + context = limitted() + assert guarded_eval('2 if True else 3', context) == 2 + assert guarded_eval('4 if False else 5', context) == 5 + + +def test_object(): + obj = object() + context = limitted(obj=obj) + assert guarded_eval('obj.__dir__', context) == obj.__dir__ + + +@pytest.mark.parametrize( + "code,expected", + [ + [ + 'int.numerator', + int.numerator + ], + [ + 'float.is_integer', + float.is_integer + ], + [ + 'complex.real', + complex.real + ] + ] +) +def test_number_attributes(code, expected): + assert guarded_eval(code, limitted()) == expected + + +def test_method_descriptor(): + context = limitted() + assert guarded_eval('list.copy.__name__', context) == 'copy' + + +@pytest.mark.parametrize( + "data,good,bad,expected", + [ + [ + [1, 2, 3], + 'data.index(2)', + 'data.append(4)', + 1 + ], + [ + {'a': 1}, + 'data.keys().isdisjoint({})', + 'data.update()', + True + ] + ] +) +def test_calls(data, good, bad, expected): + context = limitted(data=data) + assert guarded_eval(good, context) == expected + + with pytest.raises(GuardRejection): + guarded_eval(bad, context) + + +@pytest.mark.parametrize( + "code,expected", + [ + [ + '(1\n+\n1)', + 2 + ], + [ + 'list(range(10))[-1:]', + [9] + ], + [ + 'list(range(20))[3:-2:3]', + [3, 6, 9, 12, 15] + ] + ] +) +def test_literals(code, expected): + context = limitted() + assert guarded_eval(code, context) == expected + + +def test_subscript(): + context = EvaluationContext( + locals_={}, + globals_={}, + evaluation='limitted', + in_subscript=True + ) + empty_slice = slice(None, None, None) + assert guarded_eval('', context) == tuple() + assert guarded_eval(':', context) == empty_slice + assert guarded_eval('1:2:3', context) == slice(1, 2, 3) + assert guarded_eval(':, "a"', context) == (empty_slice, "a") + + +def test_unbind_method(): + class X(list): + def index(self, k): + return 'CUSTOM' + x = X() + assert unbind_method(x.index) is X.index + assert unbind_method([].index) is list.index + + +def test_assumption_instance_attr_do_not_matter(): + """This is semi-specified in Python documentation. + + However, since the specification says 'not guaranted + to work' rather than 'is forbidden to work', future + versions could invalidate this assumptions. This test + is meant to catch such a change if it ever comes true. + """ + class T: + def __getitem__(self, k): + return 'a' + def __getattr__(self, k): + return 'a' + t = T() + t.__getitem__ = lambda f: 'b' + t.__getattr__ = lambda f: 'b' + assert t[1] == 'a' + assert t[1] == 'a' + + +def test_assumption_named_tuples_share_getitem(): + """Check assumption on named tuples sharing __getitem__""" + from typing import NamedTuple + + class A(NamedTuple): + pass + + class B(NamedTuple): + pass + + assert A.__getitem__ == B.__getitem__