forked from Samsung/ONE
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsingle_vs_multiple_onnx.py
More file actions
122 lines (106 loc) · 4.88 KB
/
single_vs_multiple_onnx.py
File metadata and controls
122 lines (106 loc) · 4.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import onnxruntime as ort
import numpy as np
import os
import re
import argparse
class ModelInference:
"""
This class is used to infer multiple onnx models.
Parameters:
model_path: Path to the model files.
subgraphsiostxt_path: Path to the txt file that describes the structure of the model graph.
Output:
outputs[0]: Inference result from the model.
Description:
Subgraphsiostxt_path is a txt file that describes the structure of the model graph and
is used to get input/output node names.The model_path contains paths to multiple onnx files.
The load_sessions function will sort the onnx models in the model_path according to the
order specified in subgraphsiostxt_path.
"""
def __init__(self, model_path, subgraphsiostxt_path):
self.model_path = model_path
self.subgraphsiostxt_path = subgraphsiostxt_path
def infer_single_onnx_model(model_file, input_data):
session = ort.InferenceSession(model_file)
outputs = session.run(None, input_data)
output_names = [output.name for output in session.get_outputs()]
output_dict = {name: output for name, output in zip(output_names, outputs)}
return output_dict
def prepare_initial_input_data(onnx_model_path, default_input_data):
"""
Prepares initial input data for inference.
Args:
onnx_model_path (str): Path to the ONNX model file.
default_input_data (dict): Dictionary containing default input data.
Returns:
dict: Dictionary with user-specified or default shaped and typed input data.
"""
session = ort.InferenceSession(onnx_model_path)
input_info = {input.name: input.shape for input in session.get_inputs()}
initial_input_data = {}
dtype_map = {'f': np.float32, 'i': np.int64}
for input_name, shape in input_info.items():
custom_shape_str = input(
f"Enter new shape for input '{input_name}' (comma-separated integers),\
or press Enter to use default: ")
custom_dtype_str = input(
f"Enter data type for input '{input_name}' ('f' for float32, 'i' for int64),\
or press Enter to use default: ")
if not custom_shape_str:
new_shape = default_input_data[input_name].shape
else:
try:
new_shape = [int(dim) for dim in custom_shape_str.split(',')]
except ValueError:
print("Invalid input, please ensure you enter comma-separated integers.")
continue
if not custom_dtype_str:
dtype = default_input_data[input_name].dtype
else:
dtype = dtype_map.get(custom_dtype_str.strip(), None)
if dtype is None:
print("Invalid data type, set as default, please enter 'f' or 'i'.")
dtype = default_input_data[input_name].dtype
continue
input_data = np.random.rand(*new_shape).astype(dtype)
initial_input_data[input_name] = input_data
return initial_input_data
if __name__ == "__main__":
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument('-s',
'--single',
default='./resnet-test.onnx',
help="set single ONNX model path")
arg_parser.add_argument('-m',
'--multi',
default='./subgraphs/',
help="set split subgraph models path")
arg_parser.add_argument('-n',
'--node',
default='./scripts/subgraphs_ios.txt',
help="set subgraphs node i/o information")
args = arg_parser.parse_args()
# Initialize ModelInference instance for inference
model_inference = ModelInference(args.multi, args.node)
# Default input data dictionary
default_input_data = {
"x": np.random.rand(1, 3, 256, 256).astype(np.float32),
}
initial_input_data = prepare_initial_input_data(args.single, default_input_data)
# Perform inference using a single ONNX model
output_single = ModelInference.infer_single_onnx_model(args.single,
initial_input_data)
print("Single model inference completed!")