Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions python/egglog/egraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from .egraph_state import *
from .ipython_magic import IN_IPYTHON
from .pretty import pretty_decl
from .run_report import PrettyRunReport
from .runtime import *
from .thunk import *

Comment thread
kaeun97 marked this conversation as resolved.
Expand Down Expand Up @@ -953,36 +954,36 @@ def output(self) -> None:
raise NotImplementedError(msg)

@overload
def run(self, limit: int, /, *until: Fact, ruleset: Ruleset | None = None) -> bindings.RunReport: ...
def run(self, limit: int, /, *until: Fact, ruleset: Ruleset | None = None) -> PrettyRunReport: ...
Comment thread
kaeun97 marked this conversation as resolved.
Outdated

@overload
def run(self, schedule: Schedule, /) -> bindings.RunReport: ...
def run(self, schedule: Schedule, /) -> PrettyRunReport: ...

@_TRACER.start_as_current_span("run")
def run(
self, limit_or_schedule: int | Schedule, /, *until: Fact, ruleset: Ruleset | None = None
) -> bindings.RunReport:
) -> PrettyRunReport:
"""
Run the egraph until the given limit or until the given facts are true.
"""
if isinstance(limit_or_schedule, int):
limit_or_schedule = run(ruleset, *until) * limit_or_schedule
return self._run_schedule(limit_or_schedule)

def _run_schedule(self, schedule: Schedule) -> bindings.RunReport:
def _run_schedule(self, schedule: Schedule) -> PrettyRunReport:
self._add_decls(schedule)
cmd = self._state.run_schedule_to_egg(schedule.schedule)
(command_output,) = self._run_program(cmd)
assert isinstance(command_output, bindings.RunScheduleOutput)
return command_output.report
return PrettyRunReport.from_bindings(command_output.report, self._state)

def stats(self) -> bindings.RunReport:
def stats(self) -> PrettyRunReport:
"""
Returns the overall run report for the egraph.
"""
(output,) = self._run_program(bindings.PrintOverallStatistics(span(1), None))
assert isinstance(output, bindings.OverallStatistics)
return output.report
return PrettyRunReport.from_bindings(output.report, self._state)

def check_bool(self, *facts: FactLike) -> bool:
"""
Expand Down
26 changes: 20 additions & 6 deletions python/egglog/egraph_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ class EGraphState:
type_ref_to_egg_sort: dict[JustTypeRef, str] = field(default_factory=dict)
egg_sort_to_type_ref: dict[str, JustTypeRef] = field(default_factory=dict)

egg_rule_to_command_decl: dict[str, CommandDecl] = field(default_factory=dict)
Comment thread
kaeun97 marked this conversation as resolved.
Outdated

# Cache of egg expressions for converting to egg
expr_to_egg_cache: dict[ExprDecl, bindings._Expr] = field(default_factory=dict)

Expand Down Expand Up @@ -247,6 +249,14 @@ def _schedule_with_scheduler_to_egg( # noqa: C901, PLR0912
case _:
assert_never(schedule)

def translate_rule_key(self, egglog_key: str) -> str:
"""
Translate an egglog rule name to its Python representation.
"""
if egglog_key in self.egg_rule_to_command_decl:
return pretty_decl(self.__egg_decls__, self.egg_rule_to_command_decl[egglog_key])
return egglog_key
Comment thread
kaeun97 marked this conversation as resolved.
Outdated

def ruleset_to_egg(self, ident: Ident) -> None:
"""
Registers a ruleset if it's not already registered.
Expand Down Expand Up @@ -289,13 +299,15 @@ def command_to_egg(self, cmd: CommandDecl, ruleset: Ident) -> bindings._Command
self._expr_to_egg(rhs),
[self.fact_to_egg(c) for c in conditions],
)
return (
bindings.RewriteCommand(str(ruleset), rewrite, cmd.subsume)
if isinstance(cmd, RewriteDecl)
else bindings.BiRewriteCommand(str(ruleset), rewrite)
)
if isinstance(cmd, RewriteDecl):
egg_cmd = bindings.RewriteCommand(str(ruleset), rewrite, cmd.subsume)
else:
egg_cmd = bindings.BiRewriteCommand(str(ruleset), rewrite)

self.egg_rule_to_command_decl[str(egg_cmd)] = cmd
return egg_cmd
case RuleDecl(head, body, name):
return bindings.RuleCommand(
egg_cmd = bindings.RuleCommand(
bindings.Rule(
span(),
[self.action_to_egg(a) for a in head],
Expand All @@ -304,6 +316,8 @@ def command_to_egg(self, cmd: CommandDecl, ruleset: Ident) -> bindings._Command
str(ruleset),
)
)
self.egg_rule_to_command_decl[str(egg_cmd)] = cmd
return egg_cmd
# TODO: Replace with just constants value and looking at REF of function
case DefaultRewriteDecl(ref, expr, subsume):
sig = self.__egg_decls__.get_callable_decl(ref).signature
Expand Down
86 changes: 86 additions & 0 deletions python/egglog/run_report.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from __future__ import annotations

from dataclasses import dataclass
from datetime import timedelta

from . import bindings
from .egraph_state import EGraphState


@dataclass
class PrettyRuleReport:
plan: bindings.Plan | None
search_and_apply_time: timedelta
num_matches: int

@classmethod
def from_bindings(cls, report: bindings.RuleReport) -> PrettyRuleReport:
return cls(
plan=report.plan,
search_and_apply_time=report.search_and_apply_time,
num_matches=report.num_matches,
)


@dataclass
class PrettyRuleSetReport:
changed: bool
rule_reports: dict[str, list[PrettyRuleReport]]
Comment thread
kaeun97 marked this conversation as resolved.
Outdated
search_and_apply_time: timedelta
merge_time: timedelta

@classmethod
def from_bindings(cls, report: bindings.RuleSetReport, translate_key: callable) -> PrettyRuleSetReport:
Comment thread
kaeun97 marked this conversation as resolved.
Outdated
return cls(
changed=report.changed,
rule_reports={
translate_key(k): [PrettyRuleReport.from_bindings(rr) for rr in v]
for k, v in report.rule_reports.items()
},
search_and_apply_time=report.search_and_apply_time,
merge_time=report.merge_time,
)


@dataclass
class PrettyIterationReport:
rule_set_report: PrettyRuleSetReport
rebuild_time: timedelta

@classmethod
def from_bindings(cls, report: bindings.IterationReport, translate_key: callable) -> PrettyIterationReport:
return cls(
rule_set_report=PrettyRuleSetReport.from_bindings(report.rule_set_report, translate_key),
rebuild_time=report.rebuild_time,
)


@dataclass
class PrettyRunReport:
"""Python-friendly wrapper around bindings.RunReport."""

iterations: list[PrettyIterationReport]
updated: bool
search_and_apply_time_per_rule: dict[str, timedelta]
num_matches_per_rule: dict[str, int]
search_and_apply_time_per_ruleset: dict[str, timedelta]
merge_time_per_ruleset: dict[str, timedelta]
rebuild_time_per_ruleset: dict[str, timedelta]
Comment thread
kaeun97 marked this conversation as resolved.
Outdated

@classmethod
def from_bindings(cls, report: bindings.RunReport, state: EGraphState) -> PrettyRunReport:
Comment thread
kaeun97 marked this conversation as resolved.
Outdated
return cls(
iterations=[PrettyIterationReport.from_bindings(it, state.translate_rule_key) for it in report.iterations],
updated=report.updated,
search_and_apply_time_per_rule={
state.translate_rule_key(k): v for k, v in report.search_and_apply_time_per_rule.items()
},
num_matches_per_rule={state.translate_rule_key(k): v for k, v in report.num_matches_per_rule.items()},
Comment thread
kaeun97 marked this conversation as resolved.
Outdated
search_and_apply_time_per_ruleset={
state.translate_rule_key(k): v for k, v in report.search_and_apply_time_per_ruleset.items()
},
merge_time_per_ruleset={state.translate_rule_key(k): v for k, v in report.merge_time_per_ruleset.items()},
rebuild_time_per_ruleset={
state.translate_rule_key(k): v for k, v in report.rebuild_time_per_ruleset.items()
},
)
200 changes: 200 additions & 0 deletions python/tests/test_run_report.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
# mypy: disable-error-code="empty-body"
from __future__ import annotations

from datetime import timedelta

from egglog import *


class TestPrettyRunReport:
def _setup_simple_egraph(self):
egraph = EGraph()

class Num(Expr):
def __init__(self, n: i64Like) -> None: ...
def __add__(self, other: Num) -> Num: ...

x, y = vars_("x y", Num)
egraph.register(rewrite(x + y).to(y + x))
egraph.register(Num(1) + Num(2))
return egraph

def test_run_returns_pretty_report(self):
egraph = self._setup_simple_egraph()
report = egraph.run(10)
assert type(report).__name__ == "PrettyRunReport"

def test_stats_returns_pretty_report(self):
egraph = self._setup_simple_egraph()
egraph.run(10)
report = egraph.stats()
assert type(report).__name__ == "PrettyRunReport"

def test_rule_names_translated_in_top_level_dicts(self):
egraph = self._setup_simple_egraph()
report = egraph.run(10)

for key in report.search_and_apply_time_per_rule:
assert "rewrite" in key, f"Expected Python rewrite syntax, got: {key}"
assert "__main__" not in key, f"Key should not contain mangled egglog names: {key}"

for key in report.num_matches_per_rule:
assert "rewrite" in key, f"Expected Python rewrite syntax, got: {key}"
assert "__main__" not in key, f"Key should not contain mangled egglog names: {key}"

def test_rule_names_translated_in_iterations(self):
egraph = self._setup_simple_egraph()
report = egraph.run(10)

assert len(report.iterations) > 0
for iteration in report.iterations:
for key in iteration.rule_set_report.rule_reports:
assert "__main__" not in key, f"Iteration rule key not translated: {key}"
assert "rewrite" in key, f"Expected Python rewrite syntax, got: {key}"

def test_updated_field(self):
egraph = self._setup_simple_egraph()
report = egraph.run(10)
assert isinstance(report.updated, bool)
assert report.updated is True

def test_num_matches(self):
egraph = self._setup_simple_egraph()
report = egraph.run(10)

total_matches = sum(report.num_matches_per_rule.values())
assert total_matches > 0

def test_timedelta_types(self):
egraph = self._setup_simple_egraph()
report = egraph.run(10)

for v in report.search_and_apply_time_per_rule.values():
assert isinstance(v, timedelta)
for v in report.search_and_apply_time_per_ruleset.values():
assert isinstance(v, timedelta)
for v in report.merge_time_per_ruleset.values():
assert isinstance(v, timedelta)
for v in report.rebuild_time_per_ruleset.values():
assert isinstance(v, timedelta)

def test_iteration_reports_are_pretty(self):
egraph = self._setup_simple_egraph()
report = egraph.run(10)

for it in report.iterations:
assert type(it).__name__ == "PrettyIterationReport"
assert type(it.rule_set_report).__name__ == "PrettyRuleSetReport"
for rule_reports in it.rule_set_report.rule_reports.values():
for rr in rule_reports:
assert type(rr).__name__ == "PrettyRuleReport"

def test_str_no_egglog_sexprs(self):
egraph = self._setup_simple_egraph()
report = egraph.run(10)
output = str(report)

assert "(rewrite" not in output, f"str() still contains egglog s-expressions:\n{output}"
assert "__main__" not in output, f"str() still contains mangled names:\n{output}"

def test_multiple_rules(self):
egraph = EGraph()

class Math(Expr):
def __init__(self, value: i64Like) -> None: ...
def __add__(self, other: Math) -> Math: ...
def __mul__(self, other: Math) -> Math: ...

a, b = vars_("a b", Math)
egraph.register(
rewrite(a + b).to(b + a),
rewrite(a * b).to(b * a),
)
egraph.register(Math(1) + Math(2), Math(3) * Math(4))
report = egraph.run(10)

# should have two distinct translated rule keys
rule_keys = list(report.search_and_apply_time_per_rule.keys())
assert len(rule_keys) == 2
for key in rule_keys:
assert "__main__" not in key, f"Key not translated: {key}"

def test_empty_run(self):
egraph = EGraph()
report = egraph.run(1)
assert type(report).__name__ == "PrettyRunReport"
assert isinstance(report.updated, bool)

def test_named_rule(self):
egraph = EGraph()

class Num(Expr):
def __init__(self, n: i64Like) -> None: ...
def __add__(self, other: Num) -> Num: ...

x, y = vars_("x y", Num)
egraph.register(rule(x + y, name="comm").then(union(x + y).with_(y + x)))
egraph.register(Num(1) + Num(2))
report = egraph.run(10)

output = str(report)
assert "__main__" not in output, f"str() still contains mangled names:\n{output}"

def test_unnamed_rule_decl(self):
egraph = EGraph()

class Num(Expr):
def __init__(self, n: i64Like) -> None: ...
def __add__(self, other: Num) -> Num: ...

x, y = vars_("x y", Num)
egraph.register(rule(x + y).then(union(x + y).with_(y + x)))
egraph.register(Num(1) + Num(2))
report = egraph.run(10)

output = str(report)
assert "__main__" not in output, f"Unnamed RuleDecl key not translated:\n{output}"
# Should contain Python rule() syntax somewhere in the keys
rule_keys = list(report.search_and_apply_time_per_rule.keys())
assert len(rule_keys) > 0
for key in rule_keys:
assert "__main__" not in key, f"RuleDecl key not translated: {key}"

def test_birewrite_decl(self):
egraph = EGraph()

class Num(Expr):
def __init__(self, n: i64Like) -> None: ...
def __add__(self, other: Num) -> Num: ...
def __mul__(self, other: Num) -> Num: ...

x, y = vars_("x y", Num)
egraph.register(birewrite(x + y).to(y + x))
egraph.register(Num(1) + Num(2))
report = egraph.run(10)

output = str(report)
Comment thread
kaeun97 marked this conversation as resolved.
Outdated
assert "__main__" not in output, f"BiRewriteDecl key not translated:\n{output}"
rule_keys = list(report.search_and_apply_time_per_rule.keys())
assert len(rule_keys) > 0
for key in rule_keys:
assert "__main__" not in key, f"BiRewriteDecl key not translated: {key}"
assert "birewrite" in key, f"Expected birewrite() syntax, got: {key}"

def test_rewrite_decl(self):
egraph = EGraph()

class Num(Expr):
def __init__(self, n: i64Like) -> None: ...
def __add__(self, other: Num) -> Num: ...

x, y = vars_("x y", Num)
egraph.register(rewrite(x + y).to(y + x))
egraph.register(Num(1) + Num(2))
report = egraph.run(10)

rule_keys = list(report.search_and_apply_time_per_rule.keys())
assert len(rule_keys) == 1
key = rule_keys[0]
assert "rewrite" in key, f"Expected rewrite() syntax, got: {key}"
assert "__main__" not in key, f"RewriteDecl key not translated: {key}"
Loading