diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index da1a286cec..9594fa4f0a 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -22,6 +22,7 @@ import onnx import onnx.reference.ops import onnx_ir as ir +import onnx_ir.passes.common as ir_passes_common import onnxscript.utils.utils as utils from onnxscript._internal.tape_builder import BuilderBase, TapeBuilder @@ -1394,6 +1395,11 @@ def call(self, model: ir.Model) -> FoldConstantsResult: for function in model.functions.values(): # TODO(rama): Should we specialize functions? self.visit_function(function) + if self._modified: + # TapeBuilder may create values with names that clash with existing graph + # values when nodes are inserted via replace_nodes_and_values. + # NameFixPass ensures all value names are unique before returning. + ir_passes_common.NameFixPass()(model) return FoldConstantsResult(model, self._modified, self._state.symbolic_value_map) diff --git a/onnxscript/optimizer/_constant_folding_test.py b/onnxscript/optimizer/_constant_folding_test.py index 60e1284066..7080fdb753 100644 --- a/onnxscript/optimizer/_constant_folding_test.py +++ b/onnxscript/optimizer/_constant_folding_test.py @@ -797,5 +797,61 @@ def test_initializer_as_graph_output_is_not_removed(self): self.assertIn("z", output_names) +def _all_value_names_unique(model: ir.Model) -> bool: + """Return True if all named values in the top-level graph have unique names.""" + names = [] + for v in model.graph.inputs: + if v.name: + names.append(v.name) + for v in model.graph.initializers.values(): + if v.name: + names.append(v.name) + for node in model.graph: + for output in node.outputs: + if output.name: + names.append(output.name) + return len(names) == len(set(names)) + + +class NameClashAfterFoldTest(unittest.TestCase): + """Tests that fold_constants calls NameFixPass to deduplicate value names. + + TapeBuilder may assign names that collide with existing graph values when + new nodes are inserted via replace_nodes_and_values. NameFixPass, invoked + by FoldConstantsPass.call when the model was modified, resolves the + duplicates. + """ + + def test_fold_constants_deduplicates_names(self): + """Duplicate value names present alongside a constant-fold are fixed.""" + model = ir.from_onnx_text( + """ + + agraph (float[N] x) => (float[N] z) { + two = Constant () + four = Add(two, two) + extra = Relu(x) + z = Mul(extra, four) + } + """ + ) + + # Simulate the name clash that TapeBuilder can introduce: 'extra' (a + # non-folded node that survives) is given the same name as 'four' (the + # folded Add output) because NameAuthority does not check for conflicts + # when registering pre-named values inserted by TapeBuilder. + four_node = next(n for n in model.graph if n.op_type == "Add") + extra_node = next(n for n in model.graph if n.op_type == "Relu") + extra_node.outputs[0].name = four_node.outputs[0].name # inject clash + + result = _constant_folding.fold_constants(model) + + self.assertTrue(result.modified, "Folding must have modified the model") + self.assertTrue( + _all_value_names_unique(model), + "All value names must be unique after fold_constants", + ) + + if __name__ == "__main__": unittest.main() diff --git a/onnxscript/rewriter/_rewrite_rule.py b/onnxscript/rewriter/_rewrite_rule.py index 536c2c7117..d1cdf7c5dd 100644 --- a/onnxscript/rewriter/_rewrite_rule.py +++ b/onnxscript/rewriter/_rewrite_rule.py @@ -13,6 +13,8 @@ TypeVar, ) +import onnx_ir.passes.common as ir_passes_common + import onnxscript.optimizer import onnxscript.rewriter._basics as _basics import onnxscript.rewriter._context as _context @@ -835,6 +837,11 @@ def apply_to_model( ) if self.remove_unused_nodes: onnxscript.optimizer.remove_unused_nodes(model) + if count > 0: + # TapeBuilder may create values with names that clash with existing graph + # values when nodes are inserted via replace_nodes_and_values. + # NameFixPass ensures all value names are unique before returning. + ir_passes_common.NameFixPass()(model) return count def __iter__(self): diff --git a/onnxscript/rewriter/pattern_test.py b/onnxscript/rewriter/pattern_test.py index f296b5320c..b27e34f0e8 100644 --- a/onnxscript/rewriter/pattern_test.py +++ b/onnxscript/rewriter/pattern_test.py @@ -989,5 +989,70 @@ def test_pattern_builder_context(self): self.assertEqual(ops, ["Op1", "Op2", "Add", "Op3", "Mul"]) +def _all_value_names_unique(model: ir.Model) -> bool: + """Return True if all named values in the top-level graph have unique names.""" + names = [] + for v in model.graph.inputs: + if v.name: + names.append(v.name) + for v in model.graph.initializers.values(): + if v.name: + names.append(v.name) + for node in model.graph: + for output in node.outputs: + if output.name: + names.append(output.name) + return len(names) == len(set(names)) + + +class NameClashAfterRewriteTest(unittest.TestCase): + """Tests that apply_to_model calls NameFixPass to deduplicate value names. + + TapeBuilder may assign names that collide with existing graph values when + new nodes are inserted via replace_nodes_and_values. NameFixPass, invoked + by apply_to_model when at least one rewrite fires, resolves the duplicates. + """ + + def test_apply_to_model_deduplicates_names(self): + """Duplicate value names introduced alongside a rewrite are fixed.""" + model_proto = onnx.parser.parse_model( + """ + + agraph (float[N] x, float[N] y, float[N] p) => (float[N] z) + { + c1 = Constant() + t1 = Div(c1, x) + z1 = Mul(t1, y) + extra = Add(z1, p) + z = Identity(extra) + } + """ + ) + model = ir.serde.deserialize_model(model_proto) + + # Simulate the name clash that TapeBuilder can introduce: two values that + # survive the rewrite (not part of the matched pattern) end up sharing a + # name because NameAuthority does not check for conflicts when registering + # pre-named values from TapeBuilder. + extra_node = next(n for n in model.graph if n.op_type == "Add") + identity_node = next(n for n in model.graph if n.op_type == "Identity") + identity_node.outputs[0].name = extra_node.outputs[0].name # inject clash + + def reciprocal_mul_pattern(op, x, y): + return (1 / x) * y + + def div_replacement(op, x, y): + return op.Div(y, x) + + rule = pattern.RewriteRule(reciprocal_mul_pattern, div_replacement) + count = rule.apply_to_model(model) + + self.assertGreater(count, 0, "Rewrite rule must have fired to exercise the fix") + self.assertTrue( + _all_value_names_unique(model), + "All value names must be unique after apply_to_model", + ) + + if __name__ == "__main__": unittest.main() diff --git a/onnxscript/version_converter/_version_converter.py b/onnxscript/version_converter/_version_converter.py index 05830d47b4..99e30417d4 100644 --- a/onnxscript/version_converter/_version_converter.py +++ b/onnxscript/version_converter/_version_converter.py @@ -10,6 +10,7 @@ from typing import Callable, Sequence, Union import onnx_ir.convenience as ir_convenience +import onnx_ir.passes.common as ir_passes_common import onnxscript.utils.metadata_merger as metadata_merger from onnxscript import ir @@ -239,6 +240,7 @@ def groupnormalization_20_21(node: ir.Node, op): class _VersionConverter: def __init__(self, target_version: int): self._target_version = target_version + self._modified: bool = False # Default metadata merger: no merging should be needed; keep the first value. self._default_metadata_merger: metadata_merger.MetadataMerger = ( metadata_merger.MetadataMerger( @@ -269,6 +271,7 @@ def replace_node(self, node: ir.Node, replacement, root: ir.Graph | ir.Function) ir_convenience.replace_nodes_and_values( root, node, [node], replacement.new_nodes, node.outputs, replacement.new_outputs ) + self._modified = True def visit_attribute(self, attr: ir.Attr) -> None: if attr.is_ref(): @@ -341,6 +344,11 @@ def visit_model(self, model: ir.Model) -> None: self.visit_graph_or_function(function) _set_onnx_opset_version(function, self._target_version) _set_onnx_opset_version(model, self._target_version) + if self._modified: + # TapeBuilder may create values with names that clash with existing graph + # values when nodes are inserted via replace_nodes_and_values. + # NameFixPass ensures all value names are unique before returning. + ir_passes_common.NameFixPass()(model) def convert_version(model: ir.Model, target_version: int) -> None: diff --git a/onnxscript/version_converter/_version_converter_test.py b/onnxscript/version_converter/_version_converter_test.py index a37e8e262f..2635635557 100644 --- a/onnxscript/version_converter/_version_converter_test.py +++ b/onnxscript/version_converter/_version_converter_test.py @@ -538,5 +538,65 @@ def test_version_convert_compatible(self): version_converter.convert_version(model, target_version=target_version) +def _all_value_names_unique(model: ir.Model) -> bool: + """Return True if all named values in the top-level graph have unique names.""" + names = [] + for v in model.graph.inputs: + if v.name: + names.append(v.name) + for v in model.graph.initializers.values(): + if v.name: + names.append(v.name) + for node in model.graph: + for output in node.outputs: + if output.name: + names.append(output.name) + return len(names) == len(set(names)) + + +class NameClashAfterConversionTest(unittest.TestCase): + """Tests that convert_version calls NameFixPass to deduplicate value names. + + TapeBuilder may assign names that collide with existing graph values when + new nodes are inserted via replace_nodes_and_values. NameFixPass, invoked + inside _VersionConverter.visit_model when nodes were modified, resolves + the duplicates. + """ + + def test_convert_version_deduplicates_names(self): + """Duplicate value names present after conversion are fixed by NameFixPass.""" + model = ir.from_onnx_text( + """ + + agraph (float[4, 512, 512] input_x, float[4, 1024, 1024] input_y) => (float[4, 1024, 1024] output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + shape_b = Constant() + reshape_y = Reshape (input_x, shape_b) + gridsample = GridSample (reshape_x, reshape_y) + shape_c = Constant() + output = Reshape (gridsample, shape_c) + } + """ + ) + + # Simulate the name clash that TapeBuilder can introduce: two Constant + # node outputs (not touched by the GridSample adapter) receive the same + # name because NameAuthority does not check for conflicts when registering + # pre-named values inserted by TapeBuilder. + shape_a_output = model.graph.node(0).outputs[0] + shape_c_output = model.graph.node(5).outputs[0] + shape_c_output.name = shape_a_output.name # inject clash + + version_converter.convert_version(model, target_version=20) + + self.assertEqual(model.opset_imports[""], 20) + self.assertTrue( + _all_value_names_unique(model), + "All value names must be unique after convert_version", + ) + + if __name__ == "__main__": unittest.main()