#!/usr/bin/python3 # -*- coding: utf-8 -*- vim: set fileencoding=utf-8 : """Find cyclic Python imports.""" # To do: # - Produce dot files (and instructions for running dot) import argparse import ast from collections import defaultdict, namedtuple from functools import total_ordering import logging import os import re import sys from typing import ( Dict, List, ItemsView, Iterable, Iterator, NoReturn, Optional, Set, Tuple, ) KNOWN_EXTERNAL_MODULES = [ "__future__", "astropy", "atexit", "code", "collections", "copy", "Cython", "datetime", "dill", "distutils", "doctest", "eqtools", "errno", "fileinput", "glob", "h5py", "idlpy", "importlib", "inspect", "itertools", "lic_internal", "math", "matplotlib", "mpi4py", "multiprocessing", "numpy", "os", "pen", "pexpect", "pickle", "plotly", "pylab", "re", "retrying", "scipy", "shutil", "socket", "string", "struct", "subprocess", "sys", "tempfile", "_thread", "thread", "time", "tqdm", "unittest", "vtk", "warnings", "weakref", ] def main() -> None: args = parse_arguments() if args.recursive: python_files = list(find_python_files(args.path)) else: python_files = args.path graph = DependencyDict() pencil_python_directory = os.path.join( find_pencil_home(python_files[0]), "python" ) for python_file in python_files: deps = parse_imports(python_file, pencil_python_directory) if deps: graph[Module.from_file(python_file)] = deps # # Print all dependencies: # for file, deps in graph.items(): # print( # "{} → [{}]".format(str(file), ", ".join(str(d) for d in deps)) # ) logging.debug(graph) progress = ProgressIndicator(downsample_by=10000) # A cycle of length N will appear N times all_cycles = [ cycle for node in graph for cycle in dfs(graph, node, node, progress) ] progress.finish() # … so we deduplicate: cycles = set([c.canonicalize() for c in all_cycles]) if cycles: print_cycles(cycles) print("Found {} import cycles".format(len(cycles))) sys.exit(1) else: print("No import cycles found") def print_cycles(cycles: Iterable["Cycle"]) -> None: grouped_cycles = _group_by_length(cycles) for count in sorted(grouped_cycles.keys()): print("Cycles of length {}:".format(count)) for cycle in grouped_cycles[count]: print(" " + str(cycle)) def _group_by_length( cycles: Iterable["Cycle"] ) -> Dict[int, List["Cycle"]]: grouped_cycles = defaultdict(list) for cycle in cycles: grouped_cycles[len(cycle)].append(cycle) return grouped_cycles def parse_imports( filename: str, pencil_python_directory: str ) -> Set["Module"]: directory, _ = os.path.split(os.path.realpath(filename)) dependencies = set() # type: Set[Module] with open(filename) as source: syntax_tree = ast.parse(source.read(), filename) imports, imports_from = extract_imports(syntax_tree) for i_node in imports: dependencies.update( Module.from_import_node( i_node, directory, filename, pencil_python_directory ) ) for if_node in imports_from: dependencies.update( Module.from_import_from_node( if_node, directory, filename, pencil_python_directory ) ) return dependencies def extract_imports( syntax_tree: ast.Module ) -> Tuple[Iterator[ast.Import], Iterator[ast.ImportFrom]]: """Extract all Import and ImportFrom nodes from a Python AST. This walks the syntax tree recursively, so in addition to imports at the global module level, wa also get imports inside functions or conditional blocks. """ analyzer = ImportExtractor() analyzer.visit(syntax_tree) return analyzer.get_imports(), analyzer.get_imports_from() Line = namedtuple("Line", ["file", "number", "text"]) class ImportExtractor(ast.NodeVisitor): """An AST NodeVisitor that records all import/from...import entries.""" def __init__(self) -> None: self.imports = set() # type: Set[ast.Import] self.imports_from = set() # type: Set[ast.ImportFrom] def visit_Import(self, node: ast.Import) -> None: self.imports.add(node) self.generic_visit(node) def visit_ImportFrom(self, node: ast.ImportFrom) -> None: self.imports_from.add(node) self.generic_visit(node) def get_imports(self) -> Iterator[ast.Import]: yield from self.imports def get_imports_from(self) -> Iterator[ast.ImportFrom]: yield from self.imports_from class Package: """A Python package. E.g. pencil.math.derivatives A package is always represented by a __init__.py file in the package directory. """ def __init__(self, directory: str): self.directory = os.path.realpath(directory) if not (os.path.isdir(self.directory)): raise PackageDirectoryDoesNotExist(self.directory) def __eq__(self, other: object) -> bool: if isinstance(other, Package): return self.directory == other.directory else: return NotImplemented def __hash__(self) -> int: return hash(self.directory) def __repr__(self) -> str: return "Package {}".format(self.directory) @total_ordering class Module: """A Python module (essentially a .py file). For a package, point to the __init__.py file. """ def __init__(self, package: Package, name: str) -> None: self.package = package self.name = name self.file = self.find_module_file(package, name) @staticmethod def find_module_file(package: Package, module_name: str) -> str: directory = package.directory for cand in [ module_name + ".py", os.path.join(module_name, "__init__.py"), ]: candidate = os.path.join(directory, cand) if os.path.exists(candidate): return candidate raise Exception( "No module file found for package {}, module {}".format( package, module_name ) ) @staticmethod def from_file(filename: str) -> "Module": if not filename.endswith(".py"): die("File name {} has no .py suffix") package_dir, basename = os.path.split( re.sub(r"\.py$", "", filename) ) return Module(Package(package_dir), basename) @staticmethod def from_import_node( node: ast.Import, directory: str, filename: str, pencil_python_directory: str, ) -> Iterator["Module"]: """Extract Module's from an "import ..." line. There are two cases: 1. import pencil.module 2. import some.external.module In the first case, we search for the imported file(s) under ${PENCIL_HOME}/python/[pencil/] . In the second case, we ignore the imports, as there cannot be cycles involving external packages. Before this, however, we verify that the top package from which we import is a known external package. """ logging.debug( "from_import_node({}, {}, {})".format( _format_import_node(node), directory, pencil_python_directory, ) ) for alias in node.names: name = alias.name start = name.split(".")[0] if start == "pencil": yield Module._from_import( None, name, pencil_python_directory ) elif start in KNOWN_EXTERNAL_MODULES: # Absolute import of some non-pencil module logging.info( "Non-interesting import line ".format(name) ) else: logging.warning( "Unknown global import ".format(name) ) logging.warning( "Consider adding {} to KNOWN_EXTERNAL_MODULES.".format( start ) ) @staticmethod def from_import_from_node( node: ast.ImportFrom, directory: str, filename: str, pencil_python_directory: str, ) -> Iterator["Module"]: """Extract Module's from a "from ... import ..." line. There are three cases to distinguish: 1. from pencil.module import ... 2. from ..relative.package.path import ... 2a. from .. import ... 3. from some.external.module import ... In the first case, we search for the imported file(s) under ${PENCIL_HOME}/python/[pencil/] . In the second case, we search for the imported file(s) relative from the directory of the current file. In the third case, there cannot be cycles. """ logging.debug( "from_import_from_node({}, {}, {}, {})".format( _format_from_import_node(node), directory, pencil_python_directory, filename, ) ) for alias in node.names: name = alias.name import_from = node.module if import_from is None: # An import of the form # 'from . import this' or 'from .. import that' if node.level < 1: die( "Unexpected: no module, no level" " in {}, {}:{}".format( node, filename, node.lineno ) ) yield Module._from_import( import_from, name, directory, node.level ) else: if node.level == 0: yield from Module._from_absolute_import( import_from, name, pencil_python_directory ) else: # Relative import of the form # 'from .here import ...' or 'from ..here import ...' yield Module._from_import( import_from, name, directory, node.level ) @staticmethod def _from_absolute_import( import_from: str, name: str, pencil_python_directory: str ) -> Iterator["Module"]: """Absolute import of the form 'from os.path import isdir' or 'from pencil.io import snapshot' """ start = import_from.split(".")[0] if start == "pencil": # 'from pencil.module import ...' yield Module._from_import( import_from, name, pencil_python_directory ) elif start in KNOWN_EXTERNAL_MODULES: # 'from some.external.module import ...' logging.info( "Non-interesting import ".format( import_from, name ) ) else: logging.warning( "Unknown global import ".format( import_from, name ) ) logging.warning( "Consider adding {} to KNOWN_EXTERNAL_MODULES.".format( start ) ) @staticmethod def _from_import( import_from: Optional[str], importee: str, start_directory: str, level: int = 0, ) -> "Module": """Find a Module based on information from an import line. The Module represents either a 'module.py' file or a 'module/__init__.py' file. Finding it is complicated by the fact that the import line can describe be an import of a module from a package, or the import of a symbol from a module. Examples of import lines are 'import B', 'from A import B', 'from .A import B', 'from ..A import B', 'from . import B', 'from .. import B'. Arguments: import_from : The package name given between the 'from' and the 'import keywords, without leading dots. In the examples above, "A". importee : The imported module or symbol. In the examples above, "B". start_directory : Directory from where to search for the module file. level : The number of leading dots prefixing the package name. In the examples above, the levels are (0, 0, 1, 2, 1, 2), in order. """ logging.debug( "_from_import({}, {}, {}, {}".format( import_from, importee, start_directory, level ) ) if level > 0: subdir = os.path.join(".", *[".."] * (level - 1)) else: subdir = "." if import_from is not None: subdir = os.path.join(subdir, *import_from.split(".")) logging.debug( "_from_import: expanded {} → {}".format(import_from, subdir) ) directory = os.path.realpath( os.path.join(start_directory, subdir) ) parent_dir, last_part = os.path.split(directory) candidates = [(directory, importee), (parent_dir, last_part)] logging.debug( "Module._from_import(): candidates = {}".format(candidates) ) for package_dir, module_name in candidates: candidate_file = os.path.join( package_dir, module_name + ".py" ) init_file = os.path.join( package_dir, module_name, "__init__.py" ) # NB: In Python, a symbol provided by __init__.py wins over a # module.py file. So for example 'from pencil.sim import # is_sim_dir' will import the function 'is_sim_dir' defined in # pencil/sim/__init__.py, not the module is_sim_dir # represented by the file pencil/sim/is_sim_dir.py . # # Here, however, we turn this logic upside down. # Reason: To see what 'from pencil.sim import is_sim_dir' # resolves to, we need to know the symbols loaded in # pencil/sim/__init__.py -- but I don't think we can without # executing the code. # If we prefer pencil/sim/is_sim_dir.py if it exists, we might # get false positives from our cycle check, which is better # than false negatives. # And, yes, we should stop defining 'function_name' in a module # 'function_name.py'. if os.path.exists(candidate_file): logging.debug("Found: {}".format(candidate_file)) return Module(Package(package_dir), module_name) elif os.path.exists(init_file): logging.debug("Found: {}".format(init_file)) return Module(Package(package_dir), module_name) raise Exception( "Cannot resolve 'from {} import {}'".format( import_from, importee ) ) def __lt__(self, other: "Module") -> bool: return str(self) < str(other) def __eq__(self, other: object) -> bool: if isinstance(other, Module): return self.file == other.file else: return NotImplemented def __hash__(self) -> int: return hash(self.file) def __repr__(self) -> str: return "Module {} in {}/".format(self.name, self.package) def __str__(self) -> str: """Pretty-print module path.""" path = re.sub(r".*/python/pencil/", "pencil/", self.file) path = re.sub(r"/__init__.py$", "", path) return path class Cycle: def __init__(self, module_sequence: List[Module]) -> None: """Arguments: module_sequence: The sequence of modules that forms a cycle. E.g. the cycle A → B → C → A would be represented as [A, B, C]. """ self.modules = module_sequence if len(self.modules) > 1 and self.modules[-1] == self.modules[0]: raise Exception( "Cycle() arguments must not repeat first Module as last entry" ) def canonicalize(self) -> "Cycle": # Find index of "smallest" module: start_idx, _ = min(enumerate(self.modules), key=snd) return Cycle(self.modules[start_idx:] + self.modules[:start_idx]) def __len__(self) -> int: return len(self.modules) def __iter__(self) -> Iterator[Module]: return iter(self.modules) def __eq__(self, other: object) -> bool: if isinstance(other, Cycle): # Assume that __str()__ encapsulates everything relevant for # comparison return str(self) == str(other) else: return NotImplemented def __hash__(self) -> int: return hash(str(self)) def __str__(self) -> str: return "[" + ", ".join([str(m) for m in self.modules]) + "]" class DependencyDict: """A Dict[Module, Set[Module]] that returns set() for missing entries. We cannot use defaultdict, as that leads to RuntimeError: dictionary changed size during iteration """ def __init__(self) -> None: self.dict = dict() # type: Dict[Module, Set[Module]] def __setitem__(self, key: Module, item: Set[Module]) -> None: if key not in self.dict: self.dict[key] = set() self.dict[key].update(item) def __getitem__(self, key: Module) -> Set[Module]: if key in self.dict: return self.dict[key] else: return set() def __iter__(self) -> Iterator[Module]: return iter(self.dict) def items(self) -> ItemsView[Module, Set[Module]]: return self.dict.items() class ProgressIndicator: """A simple progress indicator for unspecified total work.""" def __init__(self, downsample_by: int, line_length: int = 80) -> None: """ Arguments: downsample_by : Only print a dot every N events. line_length : Print (up to) this many dots per line. """ self.sample_by = 10000 self.line_length = line_length self.count = self.sample_by self.column = 0 self.line = 0 self.line_indicator = "{}:".format(self.line) def tick(self) -> None: """Increase updat count (not necessarily visible).""" self.count -= 1 if self.count < 1: self.display_tick() self.count += self.sample_by def display_tick(self) -> None: """Show update.""" print(self.get_marker_char(), end="") self.column += 1 if self.column >= self.line_length: print() self.column = 0 self.line += 1 self.line_indicator = "{}:".format(self.line) else: sys.stdout.flush() def get_marker_char(self) -> str: if self.line_indicator: marker, self.line_indicator = ( self.line_indicator[0], self.line_indicator[1:], ) return marker else: return "." def finish(self) -> None: """Clean up""" print() def expand_dot_notation(name: str) -> str: """Expand notation like ..parent.sub to ../parent/sub .""" parts = [] # The first leading dot points to the current directory if name.startswith("."): parts.append(".") name = name[1:] while name.startswith("."): parts.append(os.path.pardir) name = name[1:] parts.extend(name.split(".")) return os.path.join(*parts) def find_python_files(root_dirs: Iterable[str]) -> Iterator[str]: """Recursively list all .py files under the given root directories.""" for root_dir in root_dirs: if not os.path.isdir(root_dir): die("Not a directory: {}".format(root_dir)) for root, directories, file_names in os.walk(root_dir): for file_name in file_names: if file_name.endswith(".py"): yield os.path.join(root, file_name) def find_pencil_home(python_file: str) -> str: """Find $PENCIL_HOME/python.""" path = os.path.realpath(python_file) while path: path, last = os.path.split(path) if _looks_like_pencil_home(path): return path die( "Could not find PENCIL_HOME searching upwards from {}".format( python_file ) ) def _looks_like_pencil_home(path: str) -> bool: """Does the given path contain some files typical for PENCIL_HOME?""" subpaths = [ "python/pencil", "samples", "license/GNU_public_license.txt", ] return all([os.path.exists(os.path.join(path, s)) for s in subpaths]) def dfs( graph: DependencyDict, start: Module, end: Module, progress: ProgressIndicator, ) -> Iterator[Cycle]: """Depth-first search algorithm for finding cycles in a graph. [From: https://stackoverflow.com/a/40834276] Arguments: graph : A dict {module => [included modules]} start : The start node for searching end : The end mode for searching """ fringe = [(start, [])] # type: List[Tuple[Module, List[Module]]] while fringe: progress.tick() state, path = fringe.pop() if path and state == end: yield Cycle(path) continue for next_state in graph[state]: if next_state in path: continue fringe.append((next_state, path + [next_state])) def snd(tup: Tuple[int, Module]) -> Module: """A type-specialized version of Haskell's 'snd' function.""" return tup[1] def parse_arguments() -> argparse.Namespace: parser = argparse.ArgumentParser(description=__doc__, add_help=False) parser.add_argument( "-r", "--recursive", action="store_true", help="Recursively find all .py files under all given PATHs.", ) parser.add_argument( "-h", "--help", action="store_true", help="Print this help text and exit.", ) levels = ("DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL") parser.add_argument( "-L", "--log-level", action="store", default="WARNING", choices=levels, help="Set the log level (one of {})".format(", ".join(levels)), ) parser.add_argument( "--debug", action="store_true", help="Set log-level to DEBUG." ) parser.add_argument( "path", nargs="*", help="In non-recursive mode: Python file(s) to parse for imports." " With -r|--recursive: root directory/ies to search recursively.", ) args = parser.parse_args() if args.debug: log_level = "DEBUG" else: log_level = args.log_level logging.basicConfig( level=log_level, format="%(levelname)s %(message)s" ) if args.help: parser.print_help() print( """ Examples: cyclic-imports python/pencil/sim/*.py cyclic-imports python/pencil/{math,read}/*.py cyclic-imports --recursive python/pencil/ cd $PENCIL_HOME/python; \\ find pencil -name \\*.py | xargs utils/cyclic-imports""" ) parser.exit() return args def _format_import_node(node: ast.Import) -> str: return "Import[{}]".format([a.name for a in node.names]) def _format_from_import_node(node: ast.ImportFrom) -> str: return "ImportFrom[module={}, names={}, level={}]".format( node.module, [a.name for a in node.names], node.level ) def die(*msg_lines: str) -> NoReturn: print("\n".join(msg_lines)) sys.exit(1) class AstImportException(Exception): def __init__( self, statement: ast.stmt, filename: str, msg: str ) -> None: super().__init__( "{}:{} {}".format(filename, statement.lineno, msg) ) class PackageDirectoryDoesNotExist(Exception): """We have tried to treat a non-existing path as package directory""" def __init__(self, directory: str) -> None: super().__init__( "Package directory '{}' does not exist".format(directory) ) if __name__ == "__main__": main()