Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions tools/onnx_subgraph/single_vs_multiple_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,22 @@ def prepare_initial_input_data(onnx_model_path, default_input_data):
return initial_input_data


def compare_results(output_single, output_multiple):
"""
Compares the Mean Squared Error (MSE) between identically named outputs from
two inference result dictionaries.Ensures each output name is processed only once.
"""
all_keys = set(output_single.keys()).union(set(output_multiple.keys()))
for key in sorted(all_keys):
if key in output_single and key in output_multiple:
single_output = np.array(output_single[key])
multiple_output = np.array(output_multiple[key])
mse = np.mean((single_output - multiple_output)**2)
print(f"Output '{key}' MSE: {mse}")
else:
print(f"Output '{key}' is missing in one of the result sets.")


if __name__ == "__main__":
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument('-s',
Expand Down Expand Up @@ -176,3 +192,7 @@ def prepare_initial_input_data(onnx_model_path, default_input_data):
output_multiple = model_inference.infer_multiple_onnx_models(
initial_input_data, output_names_list)
print("Multiple subgraph inference completed!")

print("Comparing inference results between single ONNX model \
and multiple subgraphs...")
compare_results(output_single, output_multiple)
Loading