diff --git a/components/polylith/interface/parser.py b/components/polylith/interface/parser.py index ab9a68e3..f9c1da3c 100644 --- a/components/polylith/interface/parser.py +++ b/components/polylith/interface/parser.py @@ -1,10 +1,13 @@ import ast from functools import lru_cache from pathlib import Path -from typing import Set, Union +from typing import FrozenSet, List, Set, Union from polylith.imports import SYMBOLS, extract_api, list_imports, parse_module +PACKAGE_INTERFACE = "__init__.py" +ALL_STATEMENT = "__all__" + def target_names(t: ast.AST) -> Set[str]: if isinstance(t, ast.Name): @@ -53,13 +56,25 @@ def extract_public_variables(path: Path) -> Set[str]: def is_the_all_statement(target: ast.expr) -> bool: - return isinstance(target, ast.Name) and target.id == "__all__" + return isinstance(target, ast.Name) and target.id == ALL_STATEMENT def is_string_constant(expression: ast.AST) -> bool: return isinstance(expression, ast.Constant) and isinstance(expression.value, str) +def attribute_expr_to_parts(expr: ast.AST) -> List[str]: + if isinstance(expr, ast.Name): + return [expr.id] + + if isinstance(expr, ast.Attribute): + parent = attribute_expr_to_parts(expr.value) + + return [*parent, expr.attr] if parent else [] + + return [] + + def find_the_all_variable(statement: ast.stmt) -> Union[Set[str], None]: if not isinstance(statement, ast.Assign): return None @@ -76,12 +91,69 @@ def find_the_all_variable(statement: ast.stmt) -> Union[Set[str], None]: return {e.value for e in statement.value.elts if isinstance(e, ast.Constant)} -def extract_the_all_variable(path: Path) -> Set[str]: +def find_the_all_pointer(statement: ast.stmt) -> Union[str, None]: + if not isinstance(statement, ast.Assign): + return None + + if not any(is_the_all_statement(t) for t in statement.targets): + return None + + parts = attribute_expr_to_parts(statement.value) + + if not parts: + return None + + *module_path, rest = parts + + if rest != ALL_STATEMENT: + return None + + return ".".join(module_path) + + +def resolve_local_module_path(package_dir: Path, module_ref: str) -> Union[Path, None]: + parts = tuple(p for p in module_ref.split(".") if p) + + if not parts: + return None + + module_file = package_dir.joinpath(*parts).with_suffix(".py") + + if module_file.exists(): + return module_file + + module_init = package_dir.joinpath(*parts, PACKAGE_INTERFACE) + + return module_init if module_init.exists() else None + + +def _extract_the_all_variable(path: Path, visited: FrozenSet[Path]) -> Set[str]: + if path in visited: + return set() + + visited = visited | frozenset({path}) + tree = parse(path) - res = [find_the_all_variable(s) for s in tree.body] + literals = [find_the_all_variable(s) for s in tree.body] + literal = next((r for r in literals if r is not None), None) + + if literal is not None: + return literal + + pointers = (find_the_all_pointer(s) for s in tree.body) + pointer = next((p for p in pointers if p is not None), None) + + if not pointer: + return set() - return next((r for r in res if r is not None), set()) + resolved = resolve_local_module_path(path.parent, pointer) + + return _extract_the_all_variable(resolved, visited) if resolved else set() + + +def extract_the_all_variable(path: Path) -> Set[str]: + return _extract_the_all_variable(path, frozenset()) def extract_imported_api(path: Path) -> Set[str]: @@ -98,7 +170,7 @@ def fetch_api_for_path(path: Path) -> Set[str]: def fetch_api(paths: Set[Path]) -> dict: - interface_paths = [Path(p / "__init__.py") for p in paths] + interface_paths = [Path(p / PACKAGE_INTERFACE) for p in paths] interfaces = [p for p in interface_paths if p.exists()] diff --git a/projects/poetry_polylith_plugin/pyproject.toml b/projects/poetry_polylith_plugin/pyproject.toml index a95a6d2f..f2c02f91 100644 --- a/projects/poetry_polylith_plugin/pyproject.toml +++ b/projects/poetry_polylith_plugin/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "poetry-polylith-plugin" -version = "1.48.1" +version = "1.48.2" description = "A Poetry plugin that adds tooling support for the Polylith Architecture" authors = ["David Vujic"] homepage = "https://davidvujic.github.io/python-polylith-docs/" diff --git a/projects/polylith_cli/pyproject.toml b/projects/polylith_cli/pyproject.toml index f5de6732..6b6ced66 100644 --- a/projects/polylith_cli/pyproject.toml +++ b/projects/polylith_cli/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "polylith-cli" -version = "1.42.1" +version = "1.42.2" description = "Python tooling support for the Polylith Architecture" authors = ['David Vujic'] homepage = "https://davidvujic.github.io/python-polylith-docs/" diff --git a/test/components/polylith/interface/test_parse_api.py b/test/components/polylith/interface/test_parse_api.py index 954e4733..54c0869b 100644 --- a/test/components/polylith/interface/test_parse_api.py +++ b/test/components/polylith/interface/test_parse_api.py @@ -72,6 +72,40 @@ def test_extract_the_all_variable(monkeypatch) -> None: assert res == {"thing", "other", "message"} +def test_extract_the_all_variable_from_module_pointer(tmp_path: Path) -> None: + expected = "pub_func" + + parser.parse.cache_clear() + + package_dir = tmp_path / "comp" + package_dir.mkdir(parents=True) + + init = package_dir / "__init__.py" + core = package_dir / "core.py" + + init.write_text( + """ +from .core import *\ + + +__all__ = core.__all__ +""" + ) + + core.write_text( + f""" +__all__ = ["{expected}"] + + +def pub_func(): + pass +""" + ) + + assert parser.extract_the_all_variable(init) == {expected} + assert parser.fetch_api_for_path(init) == {expected} + + def test_fetch_api_for_path(monkeypatch) -> None: fn = partial(fake_parse, the_interface)