diff --git a/tools/onnx_subgraph/single_vs_multiple_onnx.py b/tools/onnx_subgraph/single_vs_multiple_onnx.py index 7442fa520af..86d3723ec09 100644 --- a/tools/onnx_subgraph/single_vs_multiple_onnx.py +++ b/tools/onnx_subgraph/single_vs_multiple_onnx.py @@ -36,6 +36,54 @@ class ModelInference: def __init__(self, model_path, subgraphsiostxt_path): self.model_path = model_path self.subgraphsiostxt_path = subgraphsiostxt_path + self.sessions, self.sorted_file_paths = self.load_sessions() + + def load_sessions(self): + with open(self.subgraphsiostxt_path, 'r') as file: + content = file.read() + subgraph_order_map = {} + matches = re.findall(r'(\w+)subgraph(\d+): order(\d+)', content) + + for match in matches: + subgraph_type, subgraph_number, order = match + # lower_subgraph_type = subgraph_type.lower() + file_path = os.path.join(self.model_path, + f"{subgraph_type}subgraph{subgraph_number}.onnx") + if int(order) in subgraph_order_map: + subgraph_order_map[int(order)].append(file_path) + else: + subgraph_order_map[int(order)] = [file_path] + + sorted_file_paths = [] + for order in sorted(subgraph_order_map.keys()): + sorted_file_paths.extend(subgraph_order_map[order]) + + sessions = [ort.InferenceSession(model) for model in sorted_file_paths] + return sessions, sorted_file_paths + + def infer_multiple_onnx_models(self, + initial_input_data, + output_names_to_collect=None): + if output_names_to_collect is None: + return {} + input_data = initial_input_data + collected_outputs = {} + + for i, (session, + model_file) in enumerate(zip(self.sessions, self.sorted_file_paths)): + input_names = [inp.name for inp in session.get_inputs()] + output_names = [out.name for out in session.get_outputs()] + model_input_data = {name: input_data[name] for name in input_names} + outputs = session.run(None, model_input_data) + current_model_outputs = dict(zip(output_names, outputs)) + + for output_name in output_names_to_collect: + if output_name in current_model_outputs: + collected_outputs[output_name] = current_model_outputs[output_name] + + if i < len(self.sessions) - 1: + input_data.update(current_model_outputs) + return collected_outputs def infer_single_onnx_model(model_file, input_data): session = ort.InferenceSession(model_file) @@ -120,3 +168,11 @@ def prepare_initial_input_data(onnx_model_path, default_input_data): output_single = ModelInference.infer_single_onnx_model(args.single, initial_input_data) print("Single model inference completed!") + + # Retrieve all output names from the single model + output_names_list = list(output_single.keys()) + + # Perform inference using multiple split subgraph models + output_multiple = model_inference.infer_multiple_onnx_models( + initial_input_data, output_names_list) + print("Multiple subgraph inference completed!")