diff --git a/IPython/core/magics/code.py b/IPython/core/magics/code.py index ae9c45b..1e4a060 100644 --- a/IPython/core/magics/code.py +++ b/IPython/core/magics/code.py @@ -31,6 +31,7 @@ from IPython.utils import py3compat from IPython.utils.contexts import preserve_keys from IPython.utils.path import get_py_filename, unquote_filename from IPython.utils.warn import warn +from IPython.utils.text import get_text_list #----------------------------------------------------------------------------- # Magic implementation classes @@ -48,6 +49,7 @@ range_re = re.compile(r""" (?P\d+)?)? $""", re.VERBOSE) + def extract_code_ranges(ranges_str): """Turn a string of range for %%load into 2-tuples of (start, stop) ready to use as a slice of the content splitted by lines. @@ -80,18 +82,21 @@ def extract_code_ranges(ranges_str): @skip_doctest def extract_symbols(code, symbols): """ - Return a list of code fragments for each symbol parsed from code - For example, suppose code is a string:: + Return a tuple (blocks, not_found) + where ``blocks`` is a list of code fragments + for each symbol parsed from code, and ``not_found`` are + symbols not found in the code. + + For example:: - a = 10 + >>> code = '''a = 10 def b(): return 42 - class A: pass + class A: pass''' >>> extract_symbols(code, 'A,b') - - ["class A: pass", "def b(): return 42"] + (["class A: pass", "def b(): return 42"], []) """ try: py_code = ast.parse(code) @@ -115,12 +120,15 @@ def extract_symbols(code, symbols): # fill a list with chunks of codes for each symbol blocks = [] + not_found = [] for symbol in symbols.split(','): if symbol in symbols_lines: start, end = symbols_lines[symbol] blocks.append('\n'.join(code[start:end]) + '\n') + else: + not_found.append(symbol) - return blocks + return blocks, not_found class InteractivelyDefined(Exception): @@ -289,7 +297,15 @@ class CodeMagics(Magics): contents = self.shell.find_user_code(args) if 's' in opts: - contents = '\n'.join(extract_symbols(contents, opts['s'])) + blocks, not_found = extract_symbols(contents, opts['s']) + if len(not_found) == 1: + warn('The symbol `%s` was not found' % not_found[0]) + elif len(not_found) > 1: + warn('The symbols %s were not found' % get_text_list(not_found, + wrap_item_with='`') + ) + + contents = '\n'.join(blocks) if 'r' in opts: ranges = opts['r'].replace(',', ' ') diff --git a/IPython/core/tests/test_magic.py b/IPython/core/tests/test_magic.py index 71345d6..cb457e9 100644 --- a/IPython/core/tests/test_magic.py +++ b/IPython/core/tests/test_magic.py @@ -64,12 +64,12 @@ def test_extract_code_ranges(): def test_extract_symbols(): source = """import foo\na = 10\ndef b():\n return 42\n\n\nclass A: pass\n\n\n""" symbols_args = ["a", "b", "A", "A,b", "A,a", "z"] - expected = [[], - ["def b():\n return 42\n"], - ["class A: pass\n"], - ["class A: pass\n", "def b():\n return 42\n"], - ["class A: pass\n"], - []] + expected = [([], ['a']), + (["def b():\n return 42\n"], []), + (["class A: pass\n"], []), + (["class A: pass\n", "def b():\n return 42\n"], []), + (["class A: pass\n"], ['a']), + ([], ['z'])] for symbols, exp in zip(symbols_args, expected): nt.assert_equal(code.extract_symbols(source, symbols), exp)