From bb9a6b503a749d570d792435ada8fffb2af6f379 Mon Sep 17 00:00:00 2001 From: Jose Rodriguez Date: Wed, 11 Feb 2026 16:53:46 +0100 Subject: [PATCH 1/2] refact: move Exception to its own class --- src/ast/ast.py | 20 +++++++++++++------- src/ast/exceptions.py | 18 ++++++++++++++++++ src/ast/tree.py | 17 +---------------- src/zxbasm/expr.py | 2 +- 4 files changed, 33 insertions(+), 24 deletions(-) create mode 100644 src/ast/exceptions.py diff --git a/src/ast/ast.py b/src/ast/ast.py index 2fb960f46..566b5074a 100644 --- a/src/ast/ast.py +++ b/src/ast/ast.py @@ -7,10 +7,12 @@ import types from collections.abc import Callable -from typing import Any +from typing import Any, Final from .tree import Tree +__all__: Final[tuple[str, ...]] = "Ast", "NodeVisitor" + # ---------------------------------------------------------------------- # Abstract Syntax Tree class @@ -34,12 +36,13 @@ def visit(self, node): while stack: try: - last = stack[-1] - if isinstance(last, types.GeneratorType): - stack.append(last.send(last_result)) + stack_top = stack[-1] + if isinstance(stack_top, types.GeneratorType): + stack.append(stack_top.send(last_result)) last_result = None - elif isinstance(last, self.node_type): - stack.append(self._visit(stack.pop())) + elif isinstance(stack_top, self.node_type): + stack.pop() + stack.append(self._visit(stack_top)) else: last_result = stack.pop() except StopIteration: @@ -55,7 +58,10 @@ def generic_visit(self, node: Ast): raise RuntimeError(f"No visit_{node.token}() method defined") def filter_inorder( - self, node, filter_func: Callable[[Any], bool], child_selector: Callable[[Ast], bool] = lambda x: True + self, + node, + filter_func: Callable[[Any], bool], + child_selector: Callable[[Ast], bool] = lambda x: True, ): """Visit the tree inorder, but only those that return true for filter_func and visiting children which return true for child_selector. diff --git a/src/ast/exceptions.py b/src/ast/exceptions.py new file mode 100644 index 000000000..23102a633 --- /dev/null +++ b/src/ast/exceptions.py @@ -0,0 +1,18 @@ +from typing import Final + +from src.api.exception import Error + +__all__: Final[tuple[str, ...]] = ("NotAnAstError",) + + +class NotAnAstError(Error): + """Thrown when the "pointer" is not + an AST, but another thing. + """ + + def __init__(self, instance): + self.instance = instance + self.msg = "Object '%s' is not an Ast instance" % str(instance) + + def __str__(self): + return self.msg diff --git a/src/ast/tree.py b/src/ast/tree.py index 389f9b648..d65e384e3 100644 --- a/src/ast/tree.py +++ b/src/ast/tree.py @@ -11,22 +11,7 @@ from collections.abc import Iterable, Iterator from typing import Any -from src.api.exception import Error - -__all__ = "ChildrenList", "NotAnAstError", "Tree" - - -class NotAnAstError(Error): - """Thrown when the "pointer" is not - an AST, but another thing. - """ - - def __init__(self, instance): - self.instance = instance - self.msg = "Object '%s' is not an Ast instance" % str(instance) - - def __str__(self): - return self.msg +__all__ = "ChildrenList", "Tree" class Tree: diff --git a/src/zxbasm/expr.py b/src/zxbasm/expr.py index b0a779f8e..4c9444acc 100644 --- a/src/zxbasm/expr.py +++ b/src/zxbasm/expr.py @@ -7,7 +7,7 @@ from src.api.errmsg import error from src.ast import Ast -from src.ast.tree import NotAnAstError +from src.ast.exceptions import NotAnAstError from src.zxbasm.label import Label From 477f3c433f3bef9b6d6e453e1c0aa3d4cb8b774a Mon Sep 17 00:00:00 2001 From: Jose Rodriguez Date: Wed, 11 Feb 2026 19:42:21 +0100 Subject: [PATCH 2/2] refact: Use a GenericVisitor class This class is then subclassed by the others. --- src/api/optimize.py | 45 +++++++++++--- src/arch/z80/visitor/function_translator.py | 2 +- .../z80/visitor/translator_inst_visitor.py | 1 + src/ast/__init__.py | 3 +- src/ast/ast.py | 58 ++----------------- src/ast/visitor.py | 39 +++++++++++++ 6 files changed, 84 insertions(+), 64 deletions(-) create mode 100644 src/ast/visitor.py diff --git a/src/api/optimize.py b/src/api/optimize.py index ff744fa7d..5ec2a4a21 100644 --- a/src/api/optimize.py +++ b/src/api/optimize.py @@ -6,7 +6,7 @@ # -------------------------------------------------------------------- import symtable -from collections.abc import Generator +from collections.abc import Callable, Generator from typing import Any, NamedTuple import src.api.check as chk @@ -19,7 +19,8 @@ from src.api.constants import CLASS, CONVENTION, SCOPE, TYPE from src.api.debug import __DEBUG__ from src.api.errmsg import warning_not_used -from src.ast import Ast, NodeVisitor +from src.ast import Ast +from src.ast.visitor import GenericNodeVisitor from src.symbols import sym as symbols from src.symbols.id_ import ref @@ -32,12 +33,13 @@ class ToVisit(NamedTuple): obj: symbols.SYMBOL -class GenericVisitor(NodeVisitor): +class GenericVisitor(GenericNodeVisitor[ToVisit]): """A slightly different visitor, that just traverses an AST, but does not return a translation of it. Used to examine the AST or do transformations """ - node_type = ToVisit + def __init__(self): + super().__init__(ToVisit) @property def O_LEVEL(self): @@ -58,18 +60,22 @@ def TYPE(type_): assert TYPE.is_valid(type_) return gl.SYMBOL_TABLE.basic_types[type_] - def visit(self, node): + def visit(self, node) -> ToVisit | Generator[ToVisit | None, None, None] | None: return super().visit(ToVisit(node)) - def _visit(self, node: ToVisit): + def _visit(self, node: ToVisit) -> Generator[Ast | None, Any, None] | None: if node.obj is None: return None __DEBUG__(f"Optimizer: Visiting node {node.obj!s}[{node.obj.token}]", 1) - meth = getattr(self, f"visit_{node.obj.token}", self.generic_visit) + meth: Callable[[Ast], Generator[Ast | None, Any, None]] = getattr( + self, + f"visit_{node.obj.token}", + self.generic_visit, + ) return meth(node.obj) - def generic_visit(self, node: Ast) -> Generator[Ast | None, Any, None]: + def generic_visit(self, node: ToVisit) -> Generator[ToVisit | None, None, None]: for i, child in enumerate(node.children): node.children[i] = yield self.visit(child) @@ -88,6 +94,29 @@ def _visit(self, node: ToVisit): self.visited.add(node.obj) return super()._visit(node) + def filter_inorder( + self, + node, + filter_func: Callable[[Any], bool], + child_selector: Callable[[Ast], bool] = lambda x: True, + ) -> Generator[Ast, None, None]: + """Visit the tree inorder, but only those that return true for filter_func and visiting children which + return true for child_selector. + """ + visited = set() + stack = [node] + while stack: + node = stack.pop() + if node in visited: + continue + + visited.add(node) + if filter_func(node): + yield self.visit(node) + + if isinstance(node, Ast) and child_selector(node): + stack.extend(node.children[::-1]) + class UnreachableCodeVisitor(UniqueVisitor): """Visitor to optimize unreachable code (and prune it).""" diff --git a/src/arch/z80/visitor/function_translator.py b/src/arch/z80/visitor/function_translator.py index 143a77d4e..88dcc2bdf 100644 --- a/src/arch/z80/visitor/function_translator.py +++ b/src/arch/z80/visitor/function_translator.py @@ -24,9 +24,9 @@ class FunctionTranslator(Translator): REQUIRES = backend.REQUIRES def __init__(self, backend: Backend, function_list: list[symbols.ID]): + super().__init__(backend) if function_list is None: function_list = [] - super().__init__(backend) assert isinstance(function_list, list) assert all(x.token == "FUNCTION" for x in function_list) diff --git a/src/arch/z80/visitor/translator_inst_visitor.py b/src/arch/z80/visitor/translator_inst_visitor.py index cc218d4c9..306c66512 100644 --- a/src/arch/z80/visitor/translator_inst_visitor.py +++ b/src/arch/z80/visitor/translator_inst_visitor.py @@ -16,6 +16,7 @@ class TranslatorInstVisitor(NodeVisitor): def __init__(self, backend: Backend): + super().__init__() self.backend = backend def emit(self, *args: str) -> None: diff --git a/src/ast/__init__.py b/src/ast/__init__.py index 9edeeea6b..221568150 100644 --- a/src/ast/__init__.py +++ b/src/ast/__init__.py @@ -5,12 +5,11 @@ # See https://www.gnu.org/licenses/agpl-3.0.html for details. # -------------------------------------------------------------------- -from .ast import Ast, NodeVisitor, types +from .ast import Ast, NodeVisitor from .tree import Tree __all__ = ( "Ast", "NodeVisitor", "Tree", - "types", ) diff --git a/src/ast/ast.py b/src/ast/ast.py index 566b5074a..8e05b96e8 100644 --- a/src/ast/ast.py +++ b/src/ast/ast.py @@ -5,11 +5,10 @@ # See https://www.gnu.org/licenses/agpl-3.0.html for details. # -------------------------------------------------------------------- -import types -from collections.abc import Callable -from typing import Any, Final +from typing import Final from .tree import Tree +from .visitor import GenericNodeVisitor __all__: Final[tuple[str, ...]] = "Ast", "NodeVisitor" @@ -27,53 +26,6 @@ def token(self): return self.__class__ -class NodeVisitor: - node_type: type = Ast - - def visit(self, node): - stack = [node] - last_result = None - - while stack: - try: - stack_top = stack[-1] - if isinstance(stack_top, types.GeneratorType): - stack.append(stack_top.send(last_result)) - last_result = None - elif isinstance(stack_top, self.node_type): - stack.pop() - stack.append(self._visit(stack_top)) - else: - last_result = stack.pop() - except StopIteration: - stack.pop() - - return last_result - - def _visit(self, node): - meth = getattr(self, f"visit_{node.token}", self.generic_visit) - return meth(node) - - def generic_visit(self, node: Ast): - raise RuntimeError(f"No visit_{node.token}() method defined") - - def filter_inorder( - self, - node, - filter_func: Callable[[Any], bool], - child_selector: Callable[[Ast], bool] = lambda x: True, - ): - """Visit the tree inorder, but only those that return true for filter_func and visiting children which - return true for child_selector. - """ - visited = set() - stack = [node] - while stack: - node = stack.pop() - if node in visited: - continue - visited.add(node) - if filter_func(node): - yield self.visit(node) - if isinstance(node, Ast) and child_selector(node): - stack.extend(node.children[::-1]) +class NodeVisitor(GenericNodeVisitor[Ast]): + def __init__(self): + super().__init__(Ast) diff --git a/src/ast/visitor.py b/src/ast/visitor.py new file mode 100644 index 000000000..0a433e479 --- /dev/null +++ b/src/ast/visitor.py @@ -0,0 +1,39 @@ +__doc__ = "Implements a generic visitor class for Trees" + +from collections.abc import Generator +from types import GeneratorType +from typing import Generic, TypeVar + +_T = TypeVar("_T") + + +class GenericNodeVisitor(Generic[_T]): + def __init__(self, type_: type[_T]): + self.node_type: type[_T] = type_ + + def visit(self, node: _T) -> _T | Generator[_T | None, None, None] | None: + stack: list[_T | GeneratorType] = [node] + last_result: _T | None = None + + while stack: + try: + stack_top = stack[-1] + if isinstance(stack_top, GeneratorType): + stack.append(stack_top.send(last_result)) + last_result = None + elif isinstance(stack_top, self.node_type): + stack.pop() + stack.append(self._visit(stack_top)) + else: + last_result = stack.pop() + except StopIteration: + stack.pop() + + return last_result + + def _visit(self, node): + meth = getattr(self, f"visit_{node.token}", self.generic_visit) + return meth(node) + + def generic_visit(self, node: _T) -> Generator[_T | None, None, None]: + raise RuntimeError(f"No visit_{node.token}() method defined")