Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ _This project uses semantic versioning_

## 13.1.0 (2026-03-25)

- Add Python-friendly `RunReport` wrapper that returns `CommandDecl` objects as rule keys instead of raw egglog s-expression strings, with pretty-printed Python syntax in `str()` output [#416](https://github.com/egraphs-good/egglog-python/pull/416)
Comment thread
kaeun97 marked this conversation as resolved.
Outdated
- Improve high-level Python ergonomics and docs [#397](https://github.com/egraphs-good/egglog-python/pull/397)
- Add `EGraph.freeze()`, returning a `FrozenEGraph` snapshot that can be pretty-printed back into replayable high-level Python actions for debugging and inspection.
- Add a variadic `EGraph(*actions, seminaive=True, save_egglog_string=False)` constructor so actions can be registered at construction time, and export `ActionLike` from `egglog` for typing code that works with `EGraph.register(...)` and the constructor.
Expand Down
17 changes: 8 additions & 9 deletions python/egglog/egraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from .egraph_state import *
from .ipython_magic import IN_IPYTHON
from .pretty import pretty_decl
from .run_report import RunReport
from .runtime import *
from .thunk import *

Comment thread
kaeun97 marked this conversation as resolved.
Expand Down Expand Up @@ -953,36 +954,34 @@ def output(self) -> None:
raise NotImplementedError(msg)

@overload
def run(self, limit: int, /, *until: Fact, ruleset: Ruleset | None = None) -> bindings.RunReport: ...
def run(self, limit: int, /, *until: Fact, ruleset: Ruleset | None = None) -> RunReport: ...

@overload
def run(self, schedule: Schedule, /) -> bindings.RunReport: ...
def run(self, schedule: Schedule, /) -> RunReport: ...

@_TRACER.start_as_current_span("run")
def run(
self, limit_or_schedule: int | Schedule, /, *until: Fact, ruleset: Ruleset | None = None
) -> bindings.RunReport:
def run(self, limit_or_schedule: int | Schedule, /, *until: Fact, ruleset: Ruleset | None = None) -> RunReport:
"""
Run the egraph until the given limit or until the given facts are true.
"""
if isinstance(limit_or_schedule, int):
limit_or_schedule = run(ruleset, *until) * limit_or_schedule
return self._run_schedule(limit_or_schedule)

def _run_schedule(self, schedule: Schedule) -> bindings.RunReport:
def _run_schedule(self, schedule: Schedule) -> RunReport:
self._add_decls(schedule)
cmd = self._state.run_schedule_to_egg(schedule.schedule)
(command_output,) = self._run_program(cmd)
assert isinstance(command_output, bindings.RunScheduleOutput)
return command_output.report
return RunReport._from_bindings(command_output.report, self._state)

def stats(self) -> bindings.RunReport:
def stats(self) -> RunReport:
"""
Returns the overall run report for the egraph.
"""
(output,) = self._run_program(bindings.PrintOverallStatistics(span(1), None))
assert isinstance(output, bindings.OverallStatistics)
return output.report
return RunReport._from_bindings(output.report, self._state)

def check_bool(self, *facts: FactLike) -> bool:
"""
Expand Down
37 changes: 31 additions & 6 deletions python/egglog/egraph_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ def _normalize_global_let_name(name: str) -> str:
return name if name.startswith("$") else f"${name}"


def _normalize_rule_key(key: str) -> str:
"""Normalize an egglog rule string for consistent matching."""
key = key.replace("'", '"')
return re.sub(r"\s+", " ", key).strip()


@dataclass
class EGraphState:
"""
Expand Down Expand Up @@ -76,6 +82,8 @@ class EGraphState:
type_ref_to_egg_sort: dict[JustTypeRef, str] = field(default_factory=dict)
egg_sort_to_type_ref: dict[str, JustTypeRef] = field(default_factory=dict)

egg_rule_to_command_decl: dict[str, CommandDecl] = field(default_factory=dict)
Comment thread
kaeun97 marked this conversation as resolved.
Outdated

# Cache of egg expressions for converting to egg
expr_to_egg_cache: dict[ExprDecl, bindings._Expr] = field(default_factory=dict)

Expand Down Expand Up @@ -247,6 +255,13 @@ def _schedule_with_scheduler_to_egg( # noqa: C901, PLR0912
case _:
assert_never(schedule)

def translate_rule_key(self, egglog_key: str) -> CommandDecl:
"""
Look up the original Python CommandDecl for an egglog rule key.
"""
normalized = _normalize_rule_key(egglog_key)
return self.egg_rule_to_command_decl[normalized]

def ruleset_to_egg(self, ident: Ident) -> None:
"""
Registers a ruleset if it's not already registered.
Expand Down Expand Up @@ -289,13 +304,19 @@ def command_to_egg(self, cmd: CommandDecl, ruleset: Ident) -> bindings._Command
self._expr_to_egg(rhs),
[self.fact_to_egg(c) for c in conditions],
)
return (
bindings.RewriteCommand(str(ruleset), rewrite, cmd.subsume)
if isinstance(cmd, RewriteDecl)
else bindings.BiRewriteCommand(str(ruleset), rewrite)
)
if isinstance(cmd, RewriteDecl):
egg_cmd = bindings.RewriteCommand(str(ruleset), rewrite, cmd.subsume)
else:
egg_cmd = bindings.BiRewriteCommand(str(ruleset), rewrite)

normalized = _normalize_rule_key(str(egg_cmd))
self.egg_rule_to_command_decl[normalized] = cmd
if isinstance(cmd, BiRewriteDecl):
self.egg_rule_to_command_decl[normalized + "=>"] = cmd
self.egg_rule_to_command_decl[normalized + "<="] = cmd
Comment thread
kaeun97 marked this conversation as resolved.
Outdated
return egg_cmd
case RuleDecl(head, body, name):
return bindings.RuleCommand(
egg_cmd = bindings.RuleCommand(
bindings.Rule(
span(),
[self.action_to_egg(a) for a in head],
Expand All @@ -304,6 +325,10 @@ def command_to_egg(self, cmd: CommandDecl, ruleset: Ident) -> bindings._Command
str(ruleset),
)
)
self.egg_rule_to_command_decl[_normalize_rule_key(str(egg_cmd))] = cmd
if name:
self.egg_rule_to_command_decl[name] = cmd
Comment thread
kaeun97 marked this conversation as resolved.
Outdated
return egg_cmd
# TODO: Replace with just constants value and looking at REF of function
case DefaultRewriteDecl(ref, expr, subsume):
sig = self.__egg_decls__.get_callable_decl(ref).signature
Expand Down
121 changes: 121 additions & 0 deletions python/egglog/run_report.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from __future__ import annotations

from collections.abc import Callable
from dataclasses import dataclass, field
from datetime import timedelta

from . import bindings
from .declarations import CommandDecl, Declarations
from .egraph_state import EGraphState
from .pretty import pretty_decl


def _format_rule_key(decls: Declarations, key: CommandDecl) -> str:
return pretty_decl(decls, key)
Comment thread
kaeun97 marked this conversation as resolved.
Outdated


@dataclass
class RuleReport:
plan: bindings.Plan | None
search_and_apply_time: timedelta
num_matches: int

@classmethod
def _from_bindings(cls, report: bindings.RuleReport) -> RuleReport:
return cls(
plan=report.plan,
search_and_apply_time=report.search_and_apply_time,
num_matches=report.num_matches,
)


@dataclass
class RuleSetReport:
changed: bool
rule_reports: dict[CommandDecl, list[RuleReport]]
search_and_apply_time: timedelta
merge_time: timedelta
_decls: Declarations = field(repr=False, default=None)

@classmethod
def _from_bindings(
cls, report: bindings.RuleSetReport, translate_key: Callable[[str], CommandDecl], decls: Declarations
) -> RuleSetReport:
return cls(
changed=report.changed,
rule_reports={
translate_key(k): [RuleReport._from_bindings(rr) for rr in v] for k, v in report.rule_reports.items()
Comment thread
kaeun97 marked this conversation as resolved.
Outdated
},
search_and_apply_time=report.search_and_apply_time,
merge_time=report.merge_time,
_decls=decls,
)

def __repr__(self) -> str:
rule_reports_str = {_format_rule_key(self._decls, k): v for k, v in self.rule_reports.items()}
return (
f"RuleSetReport(changed={self.changed}, "
f"rule_reports={rule_reports_str}, "
f"search_and_apply_time={self.search_and_apply_time}, "
f"merge_time={self.merge_time})"
)


@dataclass
class IterationReport:
rule_set_report: RuleSetReport
rebuild_time: timedelta

@classmethod
def _from_bindings(
cls, report: bindings.IterationReport, translate_key: Callable[[str], CommandDecl], decls: Declarations
) -> IterationReport:
return cls(
rule_set_report=RuleSetReport._from_bindings(report.rule_set_report, translate_key, decls),
rebuild_time=report.rebuild_time,
)


@dataclass
class RunReport:
"""Python-friendly wrapper around bindings.RunReport."""

iterations: list[IterationReport]
updated: bool
search_and_apply_time_per_rule: dict[CommandDecl, timedelta]
num_matches_per_rule: dict[CommandDecl, int]
search_and_apply_time_per_ruleset: dict[str, timedelta]
merge_time_per_ruleset: dict[str, timedelta]
rebuild_time_per_ruleset: dict[str, timedelta]
Comment thread
kaeun97 marked this conversation as resolved.
Outdated
_decls: Declarations = field(repr=False, default=None)

def __repr__(self) -> str:
time_per_rule = {_format_rule_key(self._decls, k): v for k, v in self.search_and_apply_time_per_rule.items()}
matches_per_rule = {_format_rule_key(self._decls, k): v for k, v in self.num_matches_per_rule.items()}
return (
f"RunReport(iterations={self.iterations}, "
f"updated={self.updated}, "
f"search_and_apply_time_per_rule={time_per_rule}, "
f"num_matches_per_rule={matches_per_rule}, "
f"search_and_apply_time_per_ruleset={self.search_and_apply_time_per_ruleset}, "
f"merge_time_per_ruleset={self.merge_time_per_ruleset}, "
f"rebuild_time_per_ruleset={self.rebuild_time_per_ruleset})"
)

@classmethod
def _from_bindings(cls, report: bindings.RunReport, state: EGraphState) -> RunReport:
return cls(
iterations=[
IterationReport._from_bindings(it, state.translate_rule_key, state.__egg_decls__)
for it in report.iterations
],
updated=report.updated,
search_and_apply_time_per_rule={
state.translate_rule_key(k): v for k, v in report.search_and_apply_time_per_rule.items()
},
num_matches_per_rule={state.translate_rule_key(k): v for k, v in report.num_matches_per_rule.items()},
Comment thread
kaeun97 marked this conversation as resolved.
Outdated
search_and_apply_time_per_ruleset=report.search_and_apply_time_per_ruleset,
merge_time_per_ruleset=report.merge_time_per_ruleset,
rebuild_time_per_ruleset=report.rebuild_time_per_ruleset,
_decls=state.__egg_decls__,
)
Loading
Loading