Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
793 changes: 439 additions & 354 deletions Cargo.lock

Large diffs are not rendered by default.

22 changes: 10 additions & 12 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,11 @@ opentelemetry_sdk = "0.28"
# egglog-core-relations = { path = "../egg-smol/core-relations" }
# egglog-ast = { path = "../egg-smol/egglog-ast" }
# egglog-reports = { path = "../egg-smol/egglog-reports" }
egglog = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-container-fn-bug", default-features = false }
egglog-ast = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-container-fn-bug" }
egglog-core-relations = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-container-fn-bug" }
egglog-reports = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-container-fn-bug" }
egglog-bridge = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-container-fn-bug" }


egglog = { git = "https://github.com/egraphs-good/egglog.git", rev = "2e5657b", default-features = false }
egglog-ast = { git = "https://github.com/egraphs-good/egglog.git", rev = "2e5657b" }
egglog-core-relations = { git = "https://github.com/egraphs-good/egglog.git", rev = "2e5657b" }
egglog-reports = { git = "https://github.com/egraphs-good/egglog.git", rev = "2e5657b" }
egglog-bridge = { git = "https://github.com/egraphs-good/egglog.git", rev = "2e5657b" }
egglog-experimental = { git = "https://github.com/egraphs-good/egglog-experimental", branch = "main", default-features = false }
egraph-serialize = { version = "0.3", features = ["serde", "graphviz"] }
serde_json = "1"
Expand All @@ -52,11 +50,11 @@ base64 = "0.22.1"
# egglog-reports = { path = "../egg-smol/egglog-reports" }
# egglog-bridge = { path = "../egg-smol/egglog-bridge" }

egglog = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-container-fn-bug" }
egglog-ast = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-container-fn-bug" }
egglog-core-relations = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-container-fn-bug" }
egglog-bridge = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-container-fn-bug" }
egglog-reports = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-container-fn-bug" }
egglog = { git = "https://github.com/saulshanabrook/egg-smol.git", rev = "2e5657b" }
egglog-ast = { git = "https://github.com/saulshanabrook/egg-smol.git", rev = "2e5657b" }
egglog-core-relations = { git = "https://github.com/saulshanabrook/egg-smol.git", rev = "2e5657b" }
egglog-bridge = { git = "https://github.com/saulshanabrook/egg-smol.git", rev = "2e5657b" }
egglog-reports = { git = "https://github.com/saulshanabrook/egg-smol.git", rev = "2e5657b" }

# enable debug symbols for easier profiling
[profile.release]
Expand Down
1 change: 1 addition & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ _This project uses semantic versioning_

## 13.1.0 (2026-03-25)

- Add Python-friendly `RunReport` wrapper that returns `CommandDecl` objects as rule keys instead of raw egglog s-expression strings, with pretty-printed Python syntax in `str()` output [#416](https://github.com/egraphs-good/egglog-python/pull/416)
Comment thread
kaeun97 marked this conversation as resolved.
Outdated
- Improve high-level Python ergonomics and docs [#397](https://github.com/egraphs-good/egglog-python/pull/397)
- Add `EGraph.freeze()`, returning a `FrozenEGraph` snapshot that can be pretty-printed back into replayable high-level Python actions for debugging and inspection.
- Add a variadic `EGraph(*actions, seminaive=True, save_egglog_string=False)` constructor so actions can be registered at construction time, and export `ActionLike` from `egglog` for typing code that works with `EGraph.register(...)` and the constructor.
Expand Down
5 changes: 4 additions & 1 deletion python/egglog/bindings.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -403,8 +403,11 @@ class Rewrite:
lhs: _Expr
rhs: _Expr
conditions: list[_Fact]
name: str

def __new__(cls, span: _Span, lhs: _Expr, rhs: _Expr, conditions: list[_Fact] = ...) -> Rewrite: ...
def __new__(
cls, span: _Span, lhs: _Expr, rhs: _Expr, conditions: list[_Fact] = ..., name: str = ...
) -> Rewrite: ...

@final
class RunConfig:
Expand Down
17 changes: 8 additions & 9 deletions python/egglog/egraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from .egraph_state import *
from .ipython_magic import IN_IPYTHON
from .pretty import pretty_decl
from .run_report import RunReport
from .runtime import *
from .thunk import *

Comment thread
kaeun97 marked this conversation as resolved.
Expand Down Expand Up @@ -953,36 +954,34 @@ def output(self) -> None:
raise NotImplementedError(msg)

@overload
def run(self, limit: int, /, *until: Fact, ruleset: Ruleset | None = None) -> bindings.RunReport: ...
def run(self, limit: int, /, *until: Fact, ruleset: Ruleset | None = None) -> RunReport: ...

@overload
def run(self, schedule: Schedule, /) -> bindings.RunReport: ...
def run(self, schedule: Schedule, /) -> RunReport: ...

@_TRACER.start_as_current_span("run")
def run(
self, limit_or_schedule: int | Schedule, /, *until: Fact, ruleset: Ruleset | None = None
) -> bindings.RunReport:
def run(self, limit_or_schedule: int | Schedule, /, *until: Fact, ruleset: Ruleset | None = None) -> RunReport:
"""
Run the egraph until the given limit or until the given facts are true.
"""
if isinstance(limit_or_schedule, int):
limit_or_schedule = run(ruleset, *until) * limit_or_schedule
return self._run_schedule(limit_or_schedule)

def _run_schedule(self, schedule: Schedule) -> bindings.RunReport:
def _run_schedule(self, schedule: Schedule) -> RunReport:
self._add_decls(schedule)
cmd = self._state.run_schedule_to_egg(schedule.schedule)
(command_output,) = self._run_program(cmd)
assert isinstance(command_output, bindings.RunScheduleOutput)
return command_output.report
return RunReport._from_bindings(command_output.report, self._state)

def stats(self) -> bindings.RunReport:
def stats(self) -> RunReport:
"""
Returns the overall run report for the egraph.
"""
(output,) = self._run_program(bindings.PrintOverallStatistics(span(1), None))
assert isinstance(output, bindings.OverallStatistics)
return output.report
return RunReport._from_bindings(output.report, self._state)

def check_bool(self, *facts: FactLike) -> bool:
"""
Expand Down
41 changes: 35 additions & 6 deletions python/egglog/egraph_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ class EGraphState:
type_ref_to_egg_sort: dict[JustTypeRef, str] = field(default_factory=dict)
egg_sort_to_type_ref: dict[str, JustTypeRef] = field(default_factory=dict)

egg_rule_to_command_decl: dict[str, CommandDecl] = field(default_factory=dict)
Comment thread
kaeun97 marked this conversation as resolved.
Outdated

# Cache of egg expressions for converting to egg
expr_to_egg_cache: dict[ExprDecl, bindings._Expr] = field(default_factory=dict)

Expand All @@ -86,6 +88,11 @@ class EGraphState:
# Counter for deterministic synthetic names assigned to unnamed functions.
unnamed_function_counter: int = 0

# Counter for numeric rule names
rule_name_counter: int = 0
# Mapping from numeric name (str) to command decl
rule_name_to_command_decl: dict[str, CommandDecl] = field(default_factory=dict)

def copy(self) -> EGraphState:
"""
Returns a copy of the state. The egraph reference is kept the same. Used for pushing/popping.
Expand All @@ -102,6 +109,8 @@ def copy(self) -> EGraphState:
cost_callables=self.cost_callables.copy(),
expr_to_let_counter=self.expr_to_let_counter,
unnamed_function_counter=self.unnamed_function_counter,
rule_name_counter=self.rule_name_counter,
rule_name_to_command_decl=self.rule_name_to_command_decl.copy(),
)

def _run_program(self, *commands: bindings._Command) -> list[bindings._CommandOutput]:
Expand Down Expand Up @@ -247,6 +256,17 @@ def _schedule_with_scheduler_to_egg( # noqa: C901, PLR0912
case _:
assert_never(schedule)

def translate_rule_key(self, egglog_key: str) -> CommandDecl | str:
Comment thread
kaeun97 marked this conversation as resolved.
Outdated
"""
Look up the original Python CommandDecl for an egglog rule key.
"""
clean_key = egglog_key.removesuffix("=>").removesuffix("<=")
if clean_key in self.rule_name_to_command_decl:
return self.rule_name_to_command_decl[clean_key]
if egglog_key in self.egg_rule_to_command_decl:
return self.egg_rule_to_command_decl[egglog_key]
return egglog_key
Comment thread
kaeun97 marked this conversation as resolved.
Outdated

def ruleset_to_egg(self, ident: Ident) -> None:
"""
Registers a ruleset if it's not already registered.
Expand Down Expand Up @@ -283,24 +303,33 @@ def command_to_egg(self, cmd: CommandDecl, ruleset: Ident) -> bindings._Command
return bindings.ActionCommand(action_egg)
case RewriteDecl(tp, lhs, rhs, conditions) | BiRewriteDecl(tp, lhs, rhs, conditions):
self.type_ref_to_egg(tp)
name = str(self.rule_name_counter)
self.rule_name_counter += 1
Comment thread
kaeun97 marked this conversation as resolved.
self.rule_name_to_command_decl[name] = cmd
rewrite = bindings.Rewrite(
span(),
self._expr_to_egg(lhs),
self._expr_to_egg(rhs),
[self.fact_to_egg(c) for c in conditions],
name,
)
return (
bindings.RewriteCommand(str(ruleset), rewrite, cmd.subsume)
if isinstance(cmd, RewriteDecl)
else bindings.BiRewriteCommand(str(ruleset), rewrite)
)
egg_cmd: bindings._Command
if isinstance(cmd, RewriteDecl):
egg_cmd = bindings.RewriteCommand(str(ruleset), rewrite, cmd.subsume)
else:
egg_cmd = bindings.BiRewriteCommand(str(ruleset), rewrite)
return egg_cmd
case RuleDecl(head, body, name):
if not name:
name = str(self.rule_name_counter)
self.rule_name_counter += 1
self.rule_name_to_command_decl[name] = cmd
return bindings.RuleCommand(
bindings.Rule(
span(),
[self.action_to_egg(a) for a in head],
[self.fact_to_egg(f) for f in body],
name or "",
name,
str(ruleset),
)
)
Expand Down
123 changes: 123 additions & 0 deletions python/egglog/run_report.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
from __future__ import annotations

from collections.abc import Callable
from dataclasses import dataclass, field
from datetime import timedelta

from . import bindings
from .declarations import CommandDecl, Declarations
from .egraph_state import EGraphState
from .pretty import pretty_decl


def _format_rule_key(decls: Declarations, key: CommandDecl | str) -> str:
if isinstance(key, str):
return key
return pretty_decl(decls, key)


@dataclass
class RuleReport:
plan: bindings.Plan | None
search_and_apply_time: timedelta
num_matches: int

@classmethod
def _from_bindings(cls, report: bindings.RuleReport) -> RuleReport:
return cls(
plan=report.plan,
search_and_apply_time=report.search_and_apply_time,
num_matches=report.num_matches,
)


@dataclass
class RuleSetReport:
_decls: Declarations = field(repr=False)
changed: bool = False
rule_reports: dict[CommandDecl | str, list[RuleReport]] = field(default_factory=dict)
search_and_apply_time: timedelta = field(default_factory=timedelta)
merge_time: timedelta = field(default_factory=timedelta)

@classmethod
def _from_bindings(
cls, report: bindings.RuleSetReport, translate_key: Callable[[str], CommandDecl | str], decls: Declarations
) -> RuleSetReport:
return cls(
_decls=decls,
changed=report.changed,
rule_reports={
translate_key(k): [RuleReport._from_bindings(rr) for rr in v] for k, v in report.rule_reports.items()
Comment thread
kaeun97 marked this conversation as resolved.
Outdated
},
search_and_apply_time=report.search_and_apply_time,
merge_time=report.merge_time,
)

def __repr__(self) -> str:
rule_reports_str = {_format_rule_key(self._decls, k): v for k, v in self.rule_reports.items()}
return (
f"RuleSetReport(changed={self.changed}, "
f"rule_reports={rule_reports_str}, "
f"search_and_apply_time={self.search_and_apply_time}, "
f"merge_time={self.merge_time})"
)


@dataclass
class IterationReport:
rule_set_report: RuleSetReport
rebuild_time: timedelta

@classmethod
def _from_bindings(
cls, report: bindings.IterationReport, translate_key: Callable[[str], CommandDecl | str], decls: Declarations
) -> IterationReport:
return cls(
rule_set_report=RuleSetReport._from_bindings(report.rule_set_report, translate_key, decls),
rebuild_time=report.rebuild_time,
)


@dataclass
class RunReport:
"""Python-friendly wrapper around bindings.RunReport."""

_decls: Declarations = field(repr=False)
iterations: list[IterationReport] = field(default_factory=list)
updated: bool = False
search_and_apply_time_per_rule: dict[CommandDecl | str, timedelta] = field(default_factory=dict)
num_matches_per_rule: dict[CommandDecl | str, int] = field(default_factory=dict)
Comment thread
kaeun97 marked this conversation as resolved.
Outdated
search_and_apply_time_per_ruleset: dict[str, timedelta] = field(default_factory=dict)
merge_time_per_ruleset: dict[str, timedelta] = field(default_factory=dict)
rebuild_time_per_ruleset: dict[str, timedelta] = field(default_factory=dict)

def __repr__(self) -> str:
time_per_rule = {_format_rule_key(self._decls, k): v for k, v in self.search_and_apply_time_per_rule.items()}
matches_per_rule = {_format_rule_key(self._decls, k): v for k, v in self.num_matches_per_rule.items()}
return (
f"RunReport(iterations={self.iterations}, "
f"updated={self.updated}, "
f"search_and_apply_time_per_rule={time_per_rule}, "
f"num_matches_per_rule={matches_per_rule}, "
f"search_and_apply_time_per_ruleset={self.search_and_apply_time_per_ruleset}, "
f"merge_time_per_ruleset={self.merge_time_per_ruleset}, "
f"rebuild_time_per_ruleset={self.rebuild_time_per_ruleset})"
)

@classmethod
def _from_bindings(cls, report: bindings.RunReport, state: EGraphState) -> RunReport:
return cls(
_decls=state.__egg_decls__,
iterations=[
IterationReport._from_bindings(it, state.translate_rule_key, state.__egg_decls__)
for it in report.iterations
],
updated=report.updated,
search_and_apply_time_per_rule={
state.translate_rule_key(k): v for k, v in report.search_and_apply_time_per_rule.items()
},
num_matches_per_rule={state.translate_rule_key(k): v for k, v in report.num_matches_per_rule.items()},
Comment thread
kaeun97 marked this conversation as resolved.
Outdated
search_and_apply_time_per_ruleset=report.search_and_apply_time_per_ruleset,
merge_time_per_ruleset=report.merge_time_per_ruleset,
rebuild_time_per_ruleset=report.rebuild_time_per_ruleset,
)
Loading
Loading