-
Notifications
You must be signed in to change notification settings - Fork 178
[tools/onnx-subgraph] support arbitrary input shape for onnx #14757
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -30,7 +30,7 @@ class ModelInference: | |
| Description: | ||
| Subgraphsiostxt_path is a txt file that describes the structure of the model graph and | ||
| is used to get input/output node names.The model_path contains paths to multiple onnx files. | ||
| The load_sessions function will sort the onnx models in the model_path according to the | ||
| The load_sessions function will sort the onnx models in the model_path according to the | ||
| order specified in subgraphsiostxt_path. | ||
| """ | ||
| def __init__(self, model_path, subgraphsiostxt_path): | ||
|
|
@@ -45,6 +45,52 @@ def infer_single_onnx_model(model_file, input_data): | |
| return output_dict | ||
|
|
||
|
|
||
| def prepare_initial_input_data(onnx_model_path, default_input_data): | ||
| """ | ||
| Prepares initial input data for inference. | ||
| Args: | ||
| onnx_model_path (str): Path to the ONNX model file. | ||
| default_input_data (dict): Dictionary containing default input data. | ||
| Returns: | ||
| dict: Dictionary with user-specified or default shaped and typed input data. | ||
| """ | ||
| session = ort.InferenceSession(onnx_model_path) | ||
| input_info = {input.name: input.shape for input in session.get_inputs()} | ||
|
|
||
| initial_input_data = {} | ||
| dtype_map = {'f': np.float32, 'i': np.int64} | ||
|
|
||
| for input_name, shape in input_info.items(): | ||
| custom_shape_str = input( | ||
| f"Enter new shape for input '{input_name}' (comma-separated integers), or press Enter to use default: " | ||
| ) | ||
| custom_dtype_str = input( | ||
| f"Enter data type for input '{input_name}' ('f' for float32, 'i' for int64), or press Enter to use default: " | ||
|
||
| ) | ||
|
|
||
| if not custom_shape_str: | ||
| new_shape = default_input_data[input_name].shape | ||
| else: | ||
| try: | ||
| new_shape = [int(dim) for dim in custom_shape_str.split(',')] | ||
| except ValueError: | ||
| print("Invalid input, please ensure you enter comma-separated integers.") | ||
| continue | ||
|
|
||
| if not custom_dtype_str: | ||
| dtype = default_input_data[input_name].dtype | ||
| else: | ||
| dtype = dtype_map.get(custom_dtype_str.strip(), None) | ||
| if dtype is None: | ||
| print("Invalid data type, please enter 'f' or 'i'.") | ||
| continue | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Q) if
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, thank you for suggestion, I add dtype default setting for exception case, please check |
||
|
|
||
| input_data = np.random.rand(*new_shape).astype(dtype) | ||
| initial_input_data[input_name] = input_data | ||
|
|
||
| return initial_input_data | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| arg_parser = argparse.ArgumentParser() | ||
| arg_parser.add_argument('-s', | ||
|
|
@@ -68,8 +114,8 @@ def infer_single_onnx_model(model_file, input_data): | |
| default_input_data = { | ||
| "x": np.random.rand(1, 3, 256, 256).astype(np.float32), | ||
| } | ||
|
|
||
| initial_input_data = prepare_initial_input_data(args.single, default_input_data) | ||
| # Perform inference using a single ONNX model | ||
| output_single = ModelInference.infer_single_onnx_model(args.single, | ||
| default_input_data) | ||
| initial_input_data) | ||
| print("Single model inference completed!") | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
plz split this to another PR. this change is not related to your commit message.