Skip to content

Commit bdbcbb5

Browse files
authored
[tools/onnx-subgraph] mse comparing of src model and multi sub models (#14793)
mse comparing of src model and multi sub models ONE-DCO-1.0-Signed-off-by: Youxin Chen <[email protected]>
1 parent c83857a commit bdbcbb5

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

tools/onnx_subgraph/single_vs_multiple_onnx.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,22 @@ def prepare_initial_input_data(onnx_model_path, default_input_data):
140140
return initial_input_data
141141

142142

143+
def compare_results(output_single, output_multiple):
144+
"""
145+
Compares the Mean Squared Error (MSE) between identically named outputs from
146+
two inference result dictionaries.Ensures each output name is processed only once.
147+
"""
148+
all_keys = set(output_single.keys()).union(set(output_multiple.keys()))
149+
for key in sorted(all_keys):
150+
if key in output_single and key in output_multiple:
151+
single_output = np.array(output_single[key])
152+
multiple_output = np.array(output_multiple[key])
153+
mse = np.mean((single_output - multiple_output)**2)
154+
print(f"Output '{key}' MSE: {mse}")
155+
else:
156+
print(f"Output '{key}' is missing in one of the result sets.")
157+
158+
143159
if __name__ == "__main__":
144160
arg_parser = argparse.ArgumentParser()
145161
arg_parser.add_argument('-s',
@@ -176,3 +192,7 @@ def prepare_initial_input_data(onnx_model_path, default_input_data):
176192
output_multiple = model_inference.infer_multiple_onnx_models(
177193
initial_input_data, output_names_list)
178194
print("Multiple subgraph inference completed!")
195+
196+
print("Comparing inference results between single ONNX model \
197+
and multiple subgraphs...")
198+
compare_results(output_single, output_multiple)

0 commit comments

Comments
 (0)