diff --git a/src/api/optimize.py b/src/api/optimize.py index ff744fa7d..5ec2a4a21 100644 --- a/src/api/optimize.py +++ b/src/api/optimize.py @@ -6,7 +6,7 @@ # -------------------------------------------------------------------- import symtable -from collections.abc import Generator +from collections.abc import Callable, Generator from typing import Any, NamedTuple import src.api.check as chk @@ -19,7 +19,8 @@ from src.api.constants import CLASS, CONVENTION, SCOPE, TYPE from src.api.debug import __DEBUG__ from src.api.errmsg import warning_not_used -from src.ast import Ast, NodeVisitor +from src.ast import Ast +from src.ast.visitor import GenericNodeVisitor from src.symbols import sym as symbols from src.symbols.id_ import ref @@ -32,12 +33,13 @@ class ToVisit(NamedTuple): obj: symbols.SYMBOL -class GenericVisitor(NodeVisitor): +class GenericVisitor(GenericNodeVisitor[ToVisit]): """A slightly different visitor, that just traverses an AST, but does not return a translation of it. Used to examine the AST or do transformations """ - node_type = ToVisit + def __init__(self): + super().__init__(ToVisit) @property def O_LEVEL(self): @@ -58,18 +60,22 @@ def TYPE(type_): assert TYPE.is_valid(type_) return gl.SYMBOL_TABLE.basic_types[type_] - def visit(self, node): + def visit(self, node) -> ToVisit | Generator[ToVisit | None, None, None] | None: return super().visit(ToVisit(node)) - def _visit(self, node: ToVisit): + def _visit(self, node: ToVisit) -> Generator[Ast | None, Any, None] | None: if node.obj is None: return None __DEBUG__(f"Optimizer: Visiting node {node.obj!s}[{node.obj.token}]", 1) - meth = getattr(self, f"visit_{node.obj.token}", self.generic_visit) + meth: Callable[[Ast], Generator[Ast | None, Any, None]] = getattr( + self, + f"visit_{node.obj.token}", + self.generic_visit, + ) return meth(node.obj) - def generic_visit(self, node: Ast) -> Generator[Ast | None, Any, None]: + def generic_visit(self, node: ToVisit) -> Generator[ToVisit | None, None, None]: for i, child in enumerate(node.children): node.children[i] = yield self.visit(child) @@ -88,6 +94,29 @@ def _visit(self, node: ToVisit): self.visited.add(node.obj) return super()._visit(node) + def filter_inorder( + self, + node, + filter_func: Callable[[Any], bool], + child_selector: Callable[[Ast], bool] = lambda x: True, + ) -> Generator[Ast, None, None]: + """Visit the tree inorder, but only those that return true for filter_func and visiting children which + return true for child_selector. + """ + visited = set() + stack = [node] + while stack: + node = stack.pop() + if node in visited: + continue + + visited.add(node) + if filter_func(node): + yield self.visit(node) + + if isinstance(node, Ast) and child_selector(node): + stack.extend(node.children[::-1]) + class UnreachableCodeVisitor(UniqueVisitor): """Visitor to optimize unreachable code (and prune it).""" diff --git a/src/arch/z80/visitor/function_translator.py b/src/arch/z80/visitor/function_translator.py index 143a77d4e..88dcc2bdf 100644 --- a/src/arch/z80/visitor/function_translator.py +++ b/src/arch/z80/visitor/function_translator.py @@ -24,9 +24,9 @@ class FunctionTranslator(Translator): REQUIRES = backend.REQUIRES def __init__(self, backend: Backend, function_list: list[symbols.ID]): + super().__init__(backend) if function_list is None: function_list = [] - super().__init__(backend) assert isinstance(function_list, list) assert all(x.token == "FUNCTION" for x in function_list) diff --git a/src/arch/z80/visitor/translator_inst_visitor.py b/src/arch/z80/visitor/translator_inst_visitor.py index cc218d4c9..306c66512 100644 --- a/src/arch/z80/visitor/translator_inst_visitor.py +++ b/src/arch/z80/visitor/translator_inst_visitor.py @@ -16,6 +16,7 @@ class TranslatorInstVisitor(NodeVisitor): def __init__(self, backend: Backend): + super().__init__() self.backend = backend def emit(self, *args: str) -> None: diff --git a/src/ast/__init__.py b/src/ast/__init__.py index 9edeeea6b..221568150 100644 --- a/src/ast/__init__.py +++ b/src/ast/__init__.py @@ -5,12 +5,11 @@ # See https://www.gnu.org/licenses/agpl-3.0.html for details. # -------------------------------------------------------------------- -from .ast import Ast, NodeVisitor, types +from .ast import Ast, NodeVisitor from .tree import Tree __all__ = ( "Ast", "NodeVisitor", "Tree", - "types", ) diff --git a/src/ast/ast.py b/src/ast/ast.py index 2fb960f46..8e05b96e8 100644 --- a/src/ast/ast.py +++ b/src/ast/ast.py @@ -5,11 +5,12 @@ # See https://www.gnu.org/licenses/agpl-3.0.html for details. # -------------------------------------------------------------------- -import types -from collections.abc import Callable -from typing import Any +from typing import Final from .tree import Tree +from .visitor import GenericNodeVisitor + +__all__: Final[tuple[str, ...]] = "Ast", "NodeVisitor" # ---------------------------------------------------------------------- @@ -25,49 +26,6 @@ def token(self): return self.__class__ -class NodeVisitor: - node_type: type = Ast - - def visit(self, node): - stack = [node] - last_result = None - - while stack: - try: - last = stack[-1] - if isinstance(last, types.GeneratorType): - stack.append(last.send(last_result)) - last_result = None - elif isinstance(last, self.node_type): - stack.append(self._visit(stack.pop())) - else: - last_result = stack.pop() - except StopIteration: - stack.pop() - - return last_result - - def _visit(self, node): - meth = getattr(self, f"visit_{node.token}", self.generic_visit) - return meth(node) - - def generic_visit(self, node: Ast): - raise RuntimeError(f"No visit_{node.token}() method defined") - - def filter_inorder( - self, node, filter_func: Callable[[Any], bool], child_selector: Callable[[Ast], bool] = lambda x: True - ): - """Visit the tree inorder, but only those that return true for filter_func and visiting children which - return true for child_selector. - """ - visited = set() - stack = [node] - while stack: - node = stack.pop() - if node in visited: - continue - visited.add(node) - if filter_func(node): - yield self.visit(node) - if isinstance(node, Ast) and child_selector(node): - stack.extend(node.children[::-1]) +class NodeVisitor(GenericNodeVisitor[Ast]): + def __init__(self): + super().__init__(Ast) diff --git a/src/ast/exceptions.py b/src/ast/exceptions.py new file mode 100644 index 000000000..23102a633 --- /dev/null +++ b/src/ast/exceptions.py @@ -0,0 +1,18 @@ +from typing import Final + +from src.api.exception import Error + +__all__: Final[tuple[str, ...]] = ("NotAnAstError",) + + +class NotAnAstError(Error): + """Thrown when the "pointer" is not + an AST, but another thing. + """ + + def __init__(self, instance): + self.instance = instance + self.msg = "Object '%s' is not an Ast instance" % str(instance) + + def __str__(self): + return self.msg diff --git a/src/ast/tree.py b/src/ast/tree.py index 389f9b648..d65e384e3 100644 --- a/src/ast/tree.py +++ b/src/ast/tree.py @@ -11,22 +11,7 @@ from collections.abc import Iterable, Iterator from typing import Any -from src.api.exception import Error - -__all__ = "ChildrenList", "NotAnAstError", "Tree" - - -class NotAnAstError(Error): - """Thrown when the "pointer" is not - an AST, but another thing. - """ - - def __init__(self, instance): - self.instance = instance - self.msg = "Object '%s' is not an Ast instance" % str(instance) - - def __str__(self): - return self.msg +__all__ = "ChildrenList", "Tree" class Tree: diff --git a/src/ast/visitor.py b/src/ast/visitor.py new file mode 100644 index 000000000..0a433e479 --- /dev/null +++ b/src/ast/visitor.py @@ -0,0 +1,39 @@ +__doc__ = "Implements a generic visitor class for Trees" + +from collections.abc import Generator +from types import GeneratorType +from typing import Generic, TypeVar + +_T = TypeVar("_T") + + +class GenericNodeVisitor(Generic[_T]): + def __init__(self, type_: type[_T]): + self.node_type: type[_T] = type_ + + def visit(self, node: _T) -> _T | Generator[_T | None, None, None] | None: + stack: list[_T | GeneratorType] = [node] + last_result: _T | None = None + + while stack: + try: + stack_top = stack[-1] + if isinstance(stack_top, GeneratorType): + stack.append(stack_top.send(last_result)) + last_result = None + elif isinstance(stack_top, self.node_type): + stack.pop() + stack.append(self._visit(stack_top)) + else: + last_result = stack.pop() + except StopIteration: + stack.pop() + + return last_result + + def _visit(self, node): + meth = getattr(self, f"visit_{node.token}", self.generic_visit) + return meth(node) + + def generic_visit(self, node: _T) -> Generator[_T | None, None, None]: + raise RuntimeError(f"No visit_{node.token}() method defined") diff --git a/src/zxbasm/expr.py b/src/zxbasm/expr.py index b0a779f8e..4c9444acc 100644 --- a/src/zxbasm/expr.py +++ b/src/zxbasm/expr.py @@ -7,7 +7,7 @@ from src.api.errmsg import error from src.ast import Ast -from src.ast.tree import NotAnAstError +from src.ast.exceptions import NotAnAstError from src.zxbasm.label import Label