Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 49 additions & 2 deletions tools/onnx_subgraph/single_vs_multiple_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,53 @@ def infer_single_onnx_model(model_file, input_data):
return output_dict


def prepare_initial_input_data(onnx_model_path, default_input_data):
"""
Prepares initial input data for inference.
Args:
onnx_model_path (str): Path to the ONNX model file.
default_input_data (dict): Dictionary containing default input data.
Returns:
dict: Dictionary with user-specified or default shaped and typed input data.
"""
session = ort.InferenceSession(onnx_model_path)
input_info = {input.name: input.shape for input in session.get_inputs()}

initial_input_data = {}
dtype_map = {'f': np.float32, 'i': np.int64}

for input_name, shape in input_info.items():
custom_shape_str = input(
f"Enter new shape for input '{input_name}' (comma-separated integers),\
or press Enter to use default: ")
custom_dtype_str = input(
f"Enter data type for input '{input_name}' ('f' for float32, 'i' for int64),\
or press Enter to use default: ")

if not custom_shape_str:
new_shape = default_input_data[input_name].shape
else:
try:
new_shape = [int(dim) for dim in custom_shape_str.split(',')]
except ValueError:
print("Invalid input, please ensure you enter comma-separated integers.")
continue

if not custom_dtype_str:
dtype = default_input_data[input_name].dtype
else:
dtype = dtype_map.get(custom_dtype_str.strip(), None)
if dtype is None:
print("Invalid data type, set as default, please enter 'f' or 'i'.")
dtype = default_input_data[input_name].dtype
continue
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q) if dtype is invalid contunue seems to skip input_name input. is it OK with it?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, thank you for suggestion, I add dtype default setting for exception case, please check


input_data = np.random.rand(*new_shape).astype(dtype)
initial_input_data[input_name] = input_data

return initial_input_data


if __name__ == "__main__":
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument('-s',
Expand All @@ -68,8 +115,8 @@ def infer_single_onnx_model(model_file, input_data):
default_input_data = {
"x": np.random.rand(1, 3, 256, 256).astype(np.float32),
}

initial_input_data = prepare_initial_input_data(args.single, default_input_data)
# Perform inference using a single ONNX model
output_single = ModelInference.infer_single_onnx_model(args.single,
default_input_data)
initial_input_data)
print("Single model inference completed!")
Loading