Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 37 additions & 8 deletions src/api/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# --------------------------------------------------------------------

import symtable
from collections.abc import Generator
from collections.abc import Callable, Generator
from typing import Any, NamedTuple

import src.api.check as chk
Expand All @@ -19,7 +19,8 @@
from src.api.constants import CLASS, CONVENTION, SCOPE, TYPE
from src.api.debug import __DEBUG__
from src.api.errmsg import warning_not_used
from src.ast import Ast, NodeVisitor
from src.ast import Ast
from src.ast.visitor import GenericNodeVisitor
from src.symbols import sym as symbols
from src.symbols.id_ import ref

Expand All @@ -32,12 +33,13 @@ class ToVisit(NamedTuple):
obj: symbols.SYMBOL


class GenericVisitor(NodeVisitor):
class GenericVisitor(GenericNodeVisitor[ToVisit]):
"""A slightly different visitor, that just traverses an AST, but does not return
a translation of it. Used to examine the AST or do transformations
"""

node_type = ToVisit
def __init__(self):
super().__init__(ToVisit)

@property
def O_LEVEL(self):
Expand All @@ -58,18 +60,22 @@ def TYPE(type_):
assert TYPE.is_valid(type_)
return gl.SYMBOL_TABLE.basic_types[type_]

def visit(self, node):
def visit(self, node) -> ToVisit | Generator[ToVisit | None, None, None] | None:
return super().visit(ToVisit(node))

def _visit(self, node: ToVisit):
def _visit(self, node: ToVisit) -> Generator[Ast | None, Any, None] | None:
if node.obj is None:
return None

__DEBUG__(f"Optimizer: Visiting node {node.obj!s}[{node.obj.token}]", 1)
meth = getattr(self, f"visit_{node.obj.token}", self.generic_visit)
meth: Callable[[Ast], Generator[Ast | None, Any, None]] = getattr(
self,
f"visit_{node.obj.token}",
self.generic_visit,
)
return meth(node.obj)

def generic_visit(self, node: Ast) -> Generator[Ast | None, Any, None]:
def generic_visit(self, node: ToVisit) -> Generator[ToVisit | None, None, None]:
for i, child in enumerate(node.children):
node.children[i] = yield self.visit(child)

Expand All @@ -88,6 +94,29 @@ def _visit(self, node: ToVisit):
self.visited.add(node.obj)
return super()._visit(node)

def filter_inorder(
self,
node,
filter_func: Callable[[Any], bool],
child_selector: Callable[[Ast], bool] = lambda x: True,
) -> Generator[Ast, None, None]:
"""Visit the tree inorder, but only those that return true for filter_func and visiting children which
return true for child_selector.
"""
visited = set()
stack = [node]
while stack:
node = stack.pop()
if node in visited:
continue

visited.add(node)
if filter_func(node):
yield self.visit(node)

if isinstance(node, Ast) and child_selector(node):
stack.extend(node.children[::-1])


class UnreachableCodeVisitor(UniqueVisitor):
"""Visitor to optimize unreachable code (and prune it)."""
Expand Down
2 changes: 1 addition & 1 deletion src/arch/z80/visitor/function_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ class FunctionTranslator(Translator):
REQUIRES = backend.REQUIRES

def __init__(self, backend: Backend, function_list: list[symbols.ID]):
super().__init__(backend)
if function_list is None:
function_list = []
super().__init__(backend)

assert isinstance(function_list, list)
assert all(x.token == "FUNCTION" for x in function_list)
Expand Down
1 change: 1 addition & 0 deletions src/arch/z80/visitor/translator_inst_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

class TranslatorInstVisitor(NodeVisitor):
def __init__(self, backend: Backend):
super().__init__()
self.backend = backend

def emit(self, *args: str) -> None:
Expand Down
3 changes: 1 addition & 2 deletions src/ast/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@
# See https://www.gnu.org/licenses/agpl-3.0.html for details.
# --------------------------------------------------------------------

from .ast import Ast, NodeVisitor, types
from .ast import Ast, NodeVisitor
from .tree import Tree

__all__ = (
"Ast",
"NodeVisitor",
"Tree",
"types",
)
56 changes: 7 additions & 49 deletions src/ast/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
# See https://www.gnu.org/licenses/agpl-3.0.html for details.
# --------------------------------------------------------------------

import types
from collections.abc import Callable
from typing import Any
from typing import Final

from .tree import Tree
from .visitor import GenericNodeVisitor

__all__: Final[tuple[str, ...]] = "Ast", "NodeVisitor"


# ----------------------------------------------------------------------
Expand All @@ -25,49 +26,6 @@ def token(self):
return self.__class__


class NodeVisitor:
node_type: type = Ast

def visit(self, node):
stack = [node]
last_result = None

while stack:
try:
last = stack[-1]
if isinstance(last, types.GeneratorType):
stack.append(last.send(last_result))
last_result = None
elif isinstance(last, self.node_type):
stack.append(self._visit(stack.pop()))
else:
last_result = stack.pop()
except StopIteration:
stack.pop()

return last_result

def _visit(self, node):
meth = getattr(self, f"visit_{node.token}", self.generic_visit)
return meth(node)

def generic_visit(self, node: Ast):
raise RuntimeError(f"No visit_{node.token}() method defined")

def filter_inorder(
self, node, filter_func: Callable[[Any], bool], child_selector: Callable[[Ast], bool] = lambda x: True
):
"""Visit the tree inorder, but only those that return true for filter_func and visiting children which
return true for child_selector.
"""
visited = set()
stack = [node]
while stack:
node = stack.pop()
if node in visited:
continue
visited.add(node)
if filter_func(node):
yield self.visit(node)
if isinstance(node, Ast) and child_selector(node):
stack.extend(node.children[::-1])
class NodeVisitor(GenericNodeVisitor[Ast]):
def __init__(self):
super().__init__(Ast)
18 changes: 18 additions & 0 deletions src/ast/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from typing import Final

from src.api.exception import Error

__all__: Final[tuple[str, ...]] = ("NotAnAstError",)


class NotAnAstError(Error):
"""Thrown when the "pointer" is not
an AST, but another thing.
"""

def __init__(self, instance):
self.instance = instance
self.msg = "Object '%s' is not an Ast instance" % str(instance)

def __str__(self):
return self.msg
17 changes: 1 addition & 16 deletions src/ast/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,7 @@
from collections.abc import Iterable, Iterator
from typing import Any

from src.api.exception import Error

__all__ = "ChildrenList", "NotAnAstError", "Tree"


class NotAnAstError(Error):
"""Thrown when the "pointer" is not
an AST, but another thing.
"""

def __init__(self, instance):
self.instance = instance
self.msg = "Object '%s' is not an Ast instance" % str(instance)

def __str__(self):
return self.msg
__all__ = "ChildrenList", "Tree"


class Tree:
Expand Down
39 changes: 39 additions & 0 deletions src/ast/visitor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
__doc__ = "Implements a generic visitor class for Trees"

from collections.abc import Generator
from types import GeneratorType
from typing import Generic, TypeVar

_T = TypeVar("_T")


class GenericNodeVisitor(Generic[_T]):
def __init__(self, type_: type[_T]):
self.node_type: type[_T] = type_

def visit(self, node: _T) -> _T | Generator[_T | None, None, None] | None:
stack: list[_T | GeneratorType] = [node]
last_result: _T | None = None

while stack:
try:
stack_top = stack[-1]
if isinstance(stack_top, GeneratorType):
stack.append(stack_top.send(last_result))
last_result = None
elif isinstance(stack_top, self.node_type):
stack.pop()
stack.append(self._visit(stack_top))
else:
last_result = stack.pop()
except StopIteration:
stack.pop()

return last_result

def _visit(self, node):
meth = getattr(self, f"visit_{node.token}", self.generic_visit)
return meth(node)

def generic_visit(self, node: _T) -> Generator[_T | None, None, None]:
raise RuntimeError(f"No visit_{node.token}() method defined")
2 changes: 1 addition & 1 deletion src/zxbasm/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from src.api.errmsg import error
from src.ast import Ast
from src.ast.tree import NotAnAstError
from src.ast.exceptions import NotAnAstError
from src.zxbasm.label import Label


Expand Down
Loading