Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 50 additions & 3 deletions tools/onnx_subgraph/single_vs_multiple_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class ModelInference:
Description:
Subgraphsiostxt_path is a txt file that describes the structure of the model graph and
is used to get input/output node names.The model_path contains paths to multiple onnx files.
The load_sessions function will sort the onnx models in the model_path according to the
The load_sessions function will sort the onnx models in the model_path according to the
order specified in subgraphsiostxt_path.
"""
def __init__(self, model_path, subgraphsiostxt_path):
Expand All @@ -45,6 +45,52 @@ def infer_single_onnx_model(model_file, input_data):
return output_dict


def prepare_initial_input_data(onnx_model_path, default_input_data):
"""
Prepares initial input data for inference.
Args:
onnx_model_path (str): Path to the ONNX model file.
default_input_data (dict): Dictionary containing default input data.
Returns:
dict: Dictionary with user-specified or default shaped and typed input data.
"""
session = ort.InferenceSession(onnx_model_path)
input_info = {input.name: input.shape for input in session.get_inputs()}

initial_input_data = {}
dtype_map = {'f': np.float32, 'i': np.int64}

for input_name, shape in input_info.items():
custom_shape_str = input(
f"Enter new shape for input '{input_name}' (comma-separated integers), or press Enter to use default: "
)
custom_dtype_str = input(
f"Enter data type for input '{input_name}' ('f' for float32, 'i' for int64), or press Enter to use default: "
)

if not custom_shape_str:
new_shape = default_input_data[input_name].shape
else:
try:
new_shape = [int(dim) for dim in custom_shape_str.split(',')]
except ValueError:
print("Invalid input, please ensure you enter comma-separated integers.")
continue

if not custom_dtype_str:
dtype = default_input_data[input_name].dtype
else:
dtype = dtype_map.get(custom_dtype_str.strip(), None)
if dtype is None:
print("Invalid data type, please enter 'f' or 'i'.")
continue

input_data = np.random.rand(*new_shape).astype(dtype)
initial_input_data[input_name] = input_data

return initial_input_data


if __name__ == "__main__":
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument('-s',
Expand All @@ -68,8 +114,9 @@ def infer_single_onnx_model(model_file, input_data):
default_input_data = {
"x": np.random.rand(1, 3, 256, 256).astype(np.float32),
}

initial_input_data = prepare_initial_input_data(args.single,
default_input_data)
# Perform inference using a single ONNX model
output_single = ModelInference.infer_single_onnx_model(args.single,
default_input_data)
initial_input_data)
print("Single model inference completed!")
Loading