diff --git a/src/ast/desugar.rs b/src/ast/desugar.rs index 7642c944b..bf6a6e5fd 100644 --- a/src/ast/desugar.rs +++ b/src/ast/desugar.rs @@ -118,7 +118,12 @@ pub(crate) fn desugar_command( res } Command::Rewrite(ruleset, rewrite, subsume) => { - desugar_rewrite(ruleset, rule_name, rewrite, subsume, parser) + let resolved_name = if rewrite.name.is_empty() { + rule_name + } else { + rewrite.name.clone() + }; + desugar_rewrite(ruleset, resolved_name, rewrite, subsume, parser) } Command::BiRewrite(ruleset, rewrite) => { desugar_birewrite(ruleset, rule_name, rewrite, parser) @@ -352,22 +357,34 @@ fn desugar_birewrite( parser: &mut Parser, ) -> Vec { let span = rewrite.span.clone(); + let rewrite_name = if rewrite.name.is_empty() { + name + } else { + rewrite.name.clone() + }; let rw2 = Rewrite { span, lhs: rewrite.rhs.clone(), rhs: rewrite.lhs.clone(), conditions: rewrite.conditions.clone(), + name: rewrite_name.clone(), }; - desugar_rewrite(ruleset.clone(), format!("{name}=>"), rewrite, false, parser) - .into_iter() - .chain(desugar_rewrite( - ruleset, - format!("{name}<="), - rw2, - false, - parser, - )) - .collect() + desugar_rewrite( + ruleset.clone(), + format!("{rewrite_name}=>"), + rewrite, + false, + parser, + ) + .into_iter() + .chain(desugar_rewrite( + ruleset, + format!("{rewrite_name}<="), + rw2, + false, + parser, + )) + .collect() } /// Desugar relation by making a new sort and a constructor for it. diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 26370ef41..4257ee409 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -1480,6 +1480,7 @@ pub struct GenericRewrite { pub lhs: GenericExpr, pub rhs: GenericExpr, pub conditions: Vec>, + pub name: String, } impl GenericRewrite @@ -1506,6 +1507,7 @@ where .into_iter() .map(|fact| fact.map_symbols(head, leaf)) .collect(), + name: self.name, } } @@ -1808,6 +1810,7 @@ where .into_iter() .map(|fact| fact.visit_exprs(f)) .collect(), + name: rewrite.name, }, subsume, ), @@ -1822,6 +1825,7 @@ where .into_iter() .map(|fact| fact.visit_exprs(f)) .collect(), + name: rewrite.name, }, ), GenericCommand::Action(action) => GenericCommand::Action(action.visit_exprs(f)), diff --git a/src/ast/parse.rs b/src/ast/parse.rs index 3386bdac7..cfe9d24fe 100644 --- a/src/ast/parse.rs +++ b/src/ast/parse.rs @@ -503,6 +503,7 @@ impl Parser { let mut ruleset = String::new(); let mut conditions = Vec::new(); let mut subsume = false; + let mut name = String::new(); for option in self.parse_options(rest)? { match option { (":ruleset", [r]) => ruleset = r.expect_atom("ruleset name")?, @@ -514,6 +515,7 @@ impl Parser { Self::parse_fact, )? } + (":name", [s]) => name = s.expect_string("rule name")?, _ => return error!(span, "could not parse rewrite options"), } } @@ -525,6 +527,7 @@ impl Parser { lhs, rhs, conditions, + name, }, subsume, )] @@ -538,6 +541,7 @@ impl Parser { let mut ruleset = String::new(); let mut conditions = Vec::new(); + let mut name = String::new(); for option in self.parse_options(rest)? { match option { (":ruleset", [r]) => ruleset = r.expect_atom("ruleset name")?, @@ -548,6 +552,7 @@ impl Parser { Self::parse_fact, )? } + (":name", [s]) => name = s.expect_string("rule name")?, _ => return error!(span, "could not parse birewrite options"), } } @@ -559,6 +564,7 @@ impl Parser { lhs, rhs, conditions, + name, }, )] } diff --git a/tests/integration_test.rs b/tests/integration_test.rs index 00939a27f..1dfb00f88 100644 --- a/tests/integration_test.rs +++ b/tests/integration_test.rs @@ -995,3 +995,106 @@ fn eqsat_basic_term_encoding_roundtrip() { .parse_and_run_program(None, &text_thrice) .expect("final program should execute successfully"); } + +#[test] +fn rewrite_name_basic() { + let mut egraph = EGraph::default(); + egraph + .parse_and_run_program( + None, + " + (datatype Math (Num i64) (Mul Math Math)) + (rewrite (Mul a (Num 1)) a :name \"mul-identity\") + (let $x (Mul (Num 42) (Num 1))) + (run 3) + (check (= $x (Num 42))) + ", + ) + .unwrap(); +} + +#[test] +fn rewrite_name_birewrite() { + let mut egraph = EGraph::default(); + egraph + .parse_and_run_program( + None, + " + (datatype Math (Num i64) (Add Math Math)) + (birewrite (Add a b) (Add b a) :name \"add-comm\") + (let $x (Add (Num 1) (Num 2))) + (run 3) + (check (= $x (Add (Num 2) (Num 1)))) + ", + ) + .unwrap(); +} + +#[test] +fn rewrite_name_desugars_correctly() { + let mut egraph = EGraph::default(); + let desugared = egraph + .resolve_program( + None, + " + (datatype Math (Num i64) (Mul Math Math) (Add Math Math)) + (rewrite (Mul a (Num 1)) a :name \"mul-identity\") + (birewrite (Add a b) (Add b a) :name \"add-comm\") + ", + ) + .unwrap(); + + let joined: String = desugared + .iter() + .map(|cmd| format!("{cmd}")) + .collect::>() + .join("\n"); + + assert!( + joined.contains("mul-identity"), + "expected 'mul-identity' in:\n{joined}" + ); + assert!( + joined.contains("add-comm=>"), + "expected 'add-comm=>' in:\n{joined}" + ); + assert!( + joined.contains("add-comm<="), + "expected 'add-comm<=' in:\n{joined}" + ); +} + +#[test] +fn rewrite_without_name_still_works() { + let mut egraph = EGraph::default(); + egraph + .parse_and_run_program( + None, + " + (datatype Math (Num i64) (Add Math Math)) + (rewrite (Add a b) (Add b a)) + (let $x (Add (Num 1) (Num 2))) + (run 3) + (check (= $x (Add (Num 2) (Num 1)))) + ", + ) + .unwrap(); +} + +#[test] +fn rewrite_name_with_ruleset() { + let mut egraph = EGraph::default(); + egraph + .parse_and_run_program( + None, + " + (datatype Math (Num i64) (Mul Math Math)) + (ruleset my-rules) + (rewrite (Mul a (Num 0)) (Num 0) :name \"mul-zero\" :ruleset my-rules) + (let $x (Mul (Num 99) (Num 0))) + (run my-rules 3) + (check (= $x (Num 0))) + ", + ) + .unwrap(); +}