-
Notifications
You must be signed in to change notification settings - Fork 178
[tools/onnx-subgraph] add multi subgraphs inference code #14769
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -36,6 +36,53 @@ class ModelInference: | |
| def __init__(self, model_path, subgraphsiostxt_path): | ||
| self.model_path = model_path | ||
| self.subgraphsiostxt_path = subgraphsiostxt_path | ||
| self.sessions, self.sorted_file_paths = self.load_sessions() | ||
|
|
||
| def load_sessions(self): | ||
| with open(self.subgraphsiostxt_path, 'r') as file: | ||
| content = file.read() | ||
| subgraph_order_map = {} | ||
| matches = re.findall(r'(\w+)subgraph(\d+): order(\d+)', content) | ||
|
|
||
| for match in matches: | ||
| subgraph_type, subgraph_number, order = match | ||
| # lower_subgraph_type = subgraph_type.lower() | ||
| file_path = os.path.join(self.model_path, | ||
| f"{subgraph_type}subgraph{subgraph_number}.onnx") | ||
| if int(order) in subgraph_order_map: | ||
| subgraph_order_map[int(order)].append(file_path) | ||
| else: | ||
| subgraph_order_map[int(order)] = [file_path] | ||
|
|
||
| sorted_file_paths = [] | ||
| for order in sorted(subgraph_order_map.keys()): | ||
| sorted_file_paths.extend(subgraph_order_map[order]) | ||
|
|
||
| sessions = [ort.InferenceSession(model) for model in sorted_file_paths] | ||
| return sessions, sorted_file_paths | ||
|
|
||
| def infer_multiple_onnx_models(self, | ||
| initial_input_data, | ||
| output_names_to_collect=None): | ||
| input_data = initial_input_data | ||
| collected_outputs = {} | ||
|
|
||
| for i, (session, | ||
| model_file) in enumerate(zip(self.sessions, self.sorted_file_paths)): | ||
| input_names = [inp.name for inp in session.get_inputs()] | ||
| output_names = [out.name for out in session.get_outputs()] | ||
| model_input_data = {name: input_data[name] for name in input_names} | ||
| outputs = session.run(None, model_input_data) | ||
| current_model_outputs = dict(zip(output_names, outputs)) | ||
| if output_names_to_collect is not None: | ||
| for output_name in output_names_to_collect: | ||
| if output_name in current_model_outputs: | ||
| collected_outputs[output_name] = current_model_outputs[ | ||
| output_name] | ||
|
|
||
| if i < len(self.sessions) - 1: | ||
| input_data.update(current_model_outputs) | ||
| return collected_outputs | ||
|
|
||
| def infer_single_onnx_model(model_file, input_data): | ||
| session = ort.InferenceSession(model_file) | ||
|
|
@@ -116,7 +163,16 @@ def prepare_initial_input_data(onnx_model_path, default_input_data): | |
| "x": np.random.rand(1, 3, 256, 256).astype(np.float32), | ||
| } | ||
| initial_input_data = prepare_initial_input_data(args.single, default_input_data) | ||
|
|
||
|
||
| # Perform inference using a single ONNX model | ||
| output_single = ModelInference.infer_single_onnx_model(args.single, | ||
| initial_input_data) | ||
| print("Single model inference completed!") | ||
|
|
||
| # Retrieve all output names from the single model | ||
| output_names_list = list(output_single.keys()) | ||
|
|
||
| # Perform inference using multiple split subgraph models | ||
| output_multiple = model_inference.infer_multiple_onnx_models( | ||
| initial_input_data, output_names_list) | ||
| print("Multiple subgraph inference completed!") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Q) is
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, we use this code to verify the splitted models inference result, then compare the output data and evaluate the accuracy with the "source" onnx model, there will be comparaing code in next PR
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh, yes, I move the parameter exception checking out of loop now, thank you :) |
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can hoist this
ifline aboveforloop.when
collected_outputsisNone, return is{}.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated change doesn't relfect may comment correctly.
please read again.