From f24343456214949cad0c6efed04a0aff2a77aaa1 2013-10-09 20:50:51 From: Martín Gaitán Date: 2013-10-09 20:50:51 Subject: [PATCH] return an specific error for non python --- diff --git a/IPython/core/magics/code.py b/IPython/core/magics/code.py index 6d1fce5..76cc522 100644 --- a/IPython/core/magics/code.py +++ b/IPython/core/magics/code.py @@ -30,7 +30,7 @@ from IPython.testing.skipdoctest import skip_doctest from IPython.utils import py3compat from IPython.utils.contexts import preserve_keys from IPython.utils.path import get_py_filename, unquote_filename -from IPython.utils.warn import warn +from IPython.utils.warn import warn, error from IPython.utils.text import get_text_list #----------------------------------------------------------------------------- @@ -99,11 +99,9 @@ def extract_symbols(code, symbols): (["class A: pass", "def b(): return 42"], ['z']) """ symbols = symbols.split(',') - try: - py_code = ast.parse(code) - except SyntaxError: - # ignores non python code - return [], symbols + + # this will raise SyntaxError if code isn't valid Python + py_code = ast.parse(code) marks = [(getattr(s, 'name', None), s.lineno) for s in py_code.body] code = code.split('\n') @@ -303,7 +301,13 @@ class CodeMagics(Magics): contents = self.shell.find_user_code(args) if 's' in opts: - blocks, not_found = extract_symbols(contents, opts['s']) + try: + blocks, not_found = extract_symbols(contents, opts['s']) + except SyntaxError: + # non python code + error("Unable to parse the input as valid Python code") + return + if len(not_found) == 1: warn('The symbol `%s` was not found' % not_found[0]) elif len(not_found) > 1: diff --git a/IPython/core/tests/test_magic.py b/IPython/core/tests/test_magic.py index fbf34a5..77c9808 100644 --- a/IPython/core/tests/test_magic.py +++ b/IPython/core/tests/test_magic.py @@ -74,12 +74,13 @@ def test_extract_symbols(): nt.assert_equal(code.extract_symbols(source, symbols), exp) -def test_extract_symbols_ignores_non_python_code(): +def test_extract_symbols_raises_exception_with_non_python_code(): source = ("=begin A Ruby program :)=end\n" "def hello\n" "puts 'Hello world'\n" "end") - nt.assert_equal(code.extract_symbols(source, "hello"), ([], ['hello'])) + with nt.assert_raises(SyntaxError): + code.extract_symbols(source, "hello") def test_rehashx():