Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions tools/onnx_subgraph/single_vs_multiple_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,54 @@ class ModelInference:
def __init__(self, model_path, subgraphsiostxt_path):
self.model_path = model_path
self.subgraphsiostxt_path = subgraphsiostxt_path
self.sessions, self.sorted_file_paths = self.load_sessions()

def load_sessions(self):
with open(self.subgraphsiostxt_path, 'r') as file:
content = file.read()
subgraph_order_map = {}
matches = re.findall(r'(\w+)subgraph(\d+): order(\d+)', content)

for match in matches:
subgraph_type, subgraph_number, order = match
# lower_subgraph_type = subgraph_type.lower()
file_path = os.path.join(self.model_path,
f"{subgraph_type}subgraph{subgraph_number}.onnx")
if int(order) in subgraph_order_map:
subgraph_order_map[int(order)].append(file_path)
else:
subgraph_order_map[int(order)] = [file_path]

sorted_file_paths = []
for order in sorted(subgraph_order_map.keys()):
sorted_file_paths.extend(subgraph_order_map[order])

sessions = [ort.InferenceSession(model) for model in sorted_file_paths]
return sessions, sorted_file_paths

def infer_multiple_onnx_models(self,
initial_input_data,
output_names_to_collect=None):
if output_names_to_collect is None:
return {}
input_data = initial_input_data
collected_outputs = {}

for i, (session,
model_file) in enumerate(zip(self.sessions, self.sorted_file_paths)):
input_names = [inp.name for inp in session.get_inputs()]
output_names = [out.name for out in session.get_outputs()]
model_input_data = {name: input_data[name] for name in input_names}
outputs = session.run(None, model_input_data)
current_model_outputs = dict(zip(output_names, outputs))

for output_name in output_names_to_collect:
if output_name in current_model_outputs:
collected_outputs[output_name] = current_model_outputs[output_name]

if i < len(self.sessions) - 1:
input_data.update(current_model_outputs)
return collected_outputs

def infer_single_onnx_model(model_file, input_data):
session = ort.InferenceSession(model_file)
Expand Down Expand Up @@ -120,3 +168,11 @@ def prepare_initial_input_data(onnx_model_path, default_input_data):
output_single = ModelInference.infer_single_onnx_model(args.single,
initial_input_data)
print("Single model inference completed!")

# Retrieve all output names from the single model
output_names_list = list(output_single.keys())

# Perform inference using multiple split subgraph models
output_multiple = model_inference.infer_multiple_onnx_models(
initial_input_data, output_names_list)
print("Multiple subgraph inference completed!")
Copy link
Copy Markdown
Contributor

@seanshpark seanshpark Mar 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q) is infer_single_onnx_model for the "source" onnx model and infer_multiple_onnx_models for our "target" splitted multiple models ?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, we use this code to verify the splitted models inference result, then compare the output data and evaluate the accuracy with the "source" onnx model, there will be comparaing code in next PR

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, yes, I move the parameter exception checking out of loop now, thank you :)