Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 49 additions & 3 deletions tools/onnx_subgraph/single_vs_multiple_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class ModelInference:
Description:
Subgraphsiostxt_path is a txt file that describes the structure of the model graph and
is used to get input/output node names.The model_path contains paths to multiple onnx files.
The load_sessions function will sort the onnx models in the model_path according to the
The load_sessions function will sort the onnx models in the model_path according to the
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

plz split this to another PR. this change is not related to your commit message.

order specified in subgraphsiostxt_path.
"""
def __init__(self, model_path, subgraphsiostxt_path):
Expand All @@ -45,6 +45,52 @@ def infer_single_onnx_model(model_file, input_data):
return output_dict


def prepare_initial_input_data(onnx_model_path, default_input_data):
"""
Prepares initial input data for inference.
Args:
onnx_model_path (str): Path to the ONNX model file.
default_input_data (dict): Dictionary containing default input data.
Returns:
dict: Dictionary with user-specified or default shaped and typed input data.
"""
session = ort.InferenceSession(onnx_model_path)
input_info = {input.name: input.shape for input in session.get_inputs()}

initial_input_data = {}
dtype_map = {'f': np.float32, 'i': np.int64}

for input_name, shape in input_info.items():
custom_shape_str = input(
f"Enter new shape for input '{input_name}' (comma-separated integers), or press Enter to use default: "
)
custom_dtype_str = input(
f"Enter data type for input '{input_name}' ('f' for float32, 'i' for int64), or press Enter to use default: "
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would be nice if you coult split lines so that it comes in 100 cols.

)

if not custom_shape_str:
new_shape = default_input_data[input_name].shape
else:
try:
new_shape = [int(dim) for dim in custom_shape_str.split(',')]
except ValueError:
print("Invalid input, please ensure you enter comma-separated integers.")
continue

if not custom_dtype_str:
dtype = default_input_data[input_name].dtype
else:
dtype = dtype_map.get(custom_dtype_str.strip(), None)
if dtype is None:
print("Invalid data type, please enter 'f' or 'i'.")
continue
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q) if dtype is invalid contunue seems to skip input_name input. is it OK with it?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, thank you for suggestion, I add dtype default setting for exception case, please check


input_data = np.random.rand(*new_shape).astype(dtype)
initial_input_data[input_name] = input_data

return initial_input_data


if __name__ == "__main__":
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument('-s',
Expand All @@ -68,8 +114,8 @@ def infer_single_onnx_model(model_file, input_data):
default_input_data = {
"x": np.random.rand(1, 3, 256, 256).astype(np.float32),
}

initial_input_data = prepare_initial_input_data(args.single, default_input_data)
# Perform inference using a single ONNX model
output_single = ModelInference.infer_single_onnx_model(args.single,
default_input_data)
initial_input_data)
print("Single model inference completed!")
Loading