diff --git a/contrib/import-checker.py b/contrib/import-checker.py --- a/contrib/import-checker.py +++ b/contrib/import-checker.py @@ -71,15 +71,31 @@ def list_stdlib_modules(): stdlib_modules = set(list_stdlib_modules()) -def imported_modules(source): +def imported_modules(source, ignore_nested=False): """Given the source of a file as a string, yield the names imported by that file. - >>> list(imported_modules( + Args: + source: The python source to examine as a string. + ignore_nested: If true, import statements that do not start in + column zero will be ignored. + + Returns: + A list of module names imported by the given source. + + >>> sorted(imported_modules( ... 'import foo ; from baz import bar; import foo.qux')) - ['foo', 'baz.bar', 'foo.qux'] + ['baz.bar', 'foo', 'foo.qux'] + >>> sorted(imported_modules( + ... '''import foo + ... def wat(): + ... import bar + ... ''', ignore_nested=True)) + ['foo'] """ for node in ast.walk(ast.parse(source)): + if ignore_nested and getattr(node, 'col_offset', 0) > 0: + continue if isinstance(node, ast.Import): for n in node.names: yield n.name @@ -171,7 +187,8 @@ def main(argv): f = open(source_path) modname = dotted_name_of_path(source_path) src = f.read() - used_imports[modname] = sorted(imported_modules(src)) + used_imports[modname] = sorted( + imported_modules(src, ignore_nested=True)) for error in verify_stdlib_on_own_line(src): any_errors = True print source_path, error