From e8aadae76f5fbe8cf8d8bcabfac259460dd00d7f Mon Sep 17 00:00:00 2001 From: chenyx113 Date: Thu, 27 Feb 2025 21:15:26 +0800 Subject: [PATCH 1/3] [tools/onnx-subgraph] support arbitrary input shape for onnx 1. fixed input shape as default value 2. user can set the input shape dynamically with guide ONE-DCO-1.0-Signed-off-by: Youxin Chen --- .../onnx_subgraph/single_vs_multiple_onnx.py | 52 +++++++++++++++++-- 1 file changed, 49 insertions(+), 3 deletions(-) diff --git a/tools/onnx_subgraph/single_vs_multiple_onnx.py b/tools/onnx_subgraph/single_vs_multiple_onnx.py index 58a5cc65694..df4132a6ba8 100644 --- a/tools/onnx_subgraph/single_vs_multiple_onnx.py +++ b/tools/onnx_subgraph/single_vs_multiple_onnx.py @@ -30,7 +30,7 @@ class ModelInference: Description: Subgraphsiostxt_path is a txt file that describes the structure of the model graph and is used to get input/output node names.The model_path contains paths to multiple onnx files. - The load_sessions function will sort the onnx models in the model_path according to the + The load_sessions function will sort the onnx models in the model_path according to the order specified in subgraphsiostxt_path. """ def __init__(self, model_path, subgraphsiostxt_path): @@ -45,6 +45,52 @@ def infer_single_onnx_model(model_file, input_data): return output_dict +def prepare_initial_input_data(onnx_model_path, default_input_data): + """ + Prepares initial input data for inference. + Args: + onnx_model_path (str): Path to the ONNX model file. + default_input_data (dict): Dictionary containing default input data. + Returns: + dict: Dictionary with user-specified or default shaped and typed input data. + """ + session = ort.InferenceSession(onnx_model_path) + input_info = {input.name: input.shape for input in session.get_inputs()} + + initial_input_data = {} + dtype_map = {'f': np.float32, 'i': np.int64} + + for input_name, shape in input_info.items(): + custom_shape_str = input( + f"Enter new shape for input '{input_name}' (comma-separated integers), or press Enter to use default: " + ) + custom_dtype_str = input( + f"Enter data type for input '{input_name}' ('f' for float32, 'i' for int64), or press Enter to use default: " + ) + + if not custom_shape_str: + new_shape = default_input_data[input_name].shape + else: + try: + new_shape = [int(dim) for dim in custom_shape_str.split(',')] + except ValueError: + print("Invalid input, please ensure you enter comma-separated integers.") + continue + + if not custom_dtype_str: + dtype = default_input_data[input_name].dtype + else: + dtype = dtype_map.get(custom_dtype_str.strip(), None) + if dtype is None: + print("Invalid data type, please enter 'f' or 'i'.") + continue + + input_data = np.random.rand(*new_shape).astype(dtype) + initial_input_data[input_name] = input_data + + return initial_input_data + + if __name__ == "__main__": arg_parser = argparse.ArgumentParser() arg_parser.add_argument('-s', @@ -68,8 +114,8 @@ def infer_single_onnx_model(model_file, input_data): default_input_data = { "x": np.random.rand(1, 3, 256, 256).astype(np.float32), } - + initial_input_data = prepare_initial_input_data(args.single, default_input_data) # Perform inference using a single ONNX model output_single = ModelInference.infer_single_onnx_model(args.single, - default_input_data) + initial_input_data) print("Single model inference completed!") From a872bf0e5cf07d36d07b8ac1a299c166e484883b Mon Sep 17 00:00:00 2001 From: chenyx113 Date: Tue, 4 Mar 2025 10:57:21 +0800 Subject: [PATCH 2/3] Update single_vs_multiple_onnx.py update code as review, split too long lines, and add dtype default setting for exception case --- tools/onnx_subgraph/single_vs_multiple_onnx.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tools/onnx_subgraph/single_vs_multiple_onnx.py b/tools/onnx_subgraph/single_vs_multiple_onnx.py index df4132a6ba8..3927f1ceef4 100644 --- a/tools/onnx_subgraph/single_vs_multiple_onnx.py +++ b/tools/onnx_subgraph/single_vs_multiple_onnx.py @@ -30,7 +30,7 @@ class ModelInference: Description: Subgraphsiostxt_path is a txt file that describes the structure of the model graph and is used to get input/output node names.The model_path contains paths to multiple onnx files. - The load_sessions function will sort the onnx models in the model_path according to the + The load_sessions function will sort the onnx models in the model_path according to the order specified in subgraphsiostxt_path. """ def __init__(self, model_path, subgraphsiostxt_path): @@ -62,10 +62,12 @@ def prepare_initial_input_data(onnx_model_path, default_input_data): for input_name, shape in input_info.items(): custom_shape_str = input( - f"Enter new shape for input '{input_name}' (comma-separated integers), or press Enter to use default: " + f"Enter new shape for input '{input_name}' (comma-separated integers), + or press Enter to use default: " ) custom_dtype_str = input( - f"Enter data type for input '{input_name}' ('f' for float32, 'i' for int64), or press Enter to use default: " + f"Enter data type for input '{input_name}' ('f' for float32, 'i' for int64), + or press Enter to use default: " ) if not custom_shape_str: @@ -82,7 +84,8 @@ def prepare_initial_input_data(onnx_model_path, default_input_data): else: dtype = dtype_map.get(custom_dtype_str.strip(), None) if dtype is None: - print("Invalid data type, please enter 'f' or 'i'.") + print("Invalid data type, set as default, please enter 'f' or 'i'.") + dtype = default_input_data[input_name].dtype continue input_data = np.random.rand(*new_shape).astype(dtype) From 401d7fdde43e6d78cc30e23ac9a6513d410b0cfc Mon Sep 17 00:00:00 2001 From: chenyx113 Date: Tue, 4 Mar 2025 11:25:19 +0800 Subject: [PATCH 3/3] Update single_vs_multiple_onnx.py fix format checking issue --- tools/onnx_subgraph/single_vs_multiple_onnx.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/tools/onnx_subgraph/single_vs_multiple_onnx.py b/tools/onnx_subgraph/single_vs_multiple_onnx.py index 3927f1ceef4..7442fa520af 100644 --- a/tools/onnx_subgraph/single_vs_multiple_onnx.py +++ b/tools/onnx_subgraph/single_vs_multiple_onnx.py @@ -62,13 +62,11 @@ def prepare_initial_input_data(onnx_model_path, default_input_data): for input_name, shape in input_info.items(): custom_shape_str = input( - f"Enter new shape for input '{input_name}' (comma-separated integers), - or press Enter to use default: " - ) + f"Enter new shape for input '{input_name}' (comma-separated integers),\ + or press Enter to use default: ") custom_dtype_str = input( - f"Enter data type for input '{input_name}' ('f' for float32, 'i' for int64), - or press Enter to use default: " - ) + f"Enter data type for input '{input_name}' ('f' for float32, 'i' for int64),\ + or press Enter to use default: ") if not custom_shape_str: new_shape = default_input_data[input_name].shape