diff --git a/tools/onnx_subgraph/single_vs_multiple_onnx.py b/tools/onnx_subgraph/single_vs_multiple_onnx.py index 58a5cc65694..7442fa520af 100644 --- a/tools/onnx_subgraph/single_vs_multiple_onnx.py +++ b/tools/onnx_subgraph/single_vs_multiple_onnx.py @@ -45,6 +45,53 @@ def infer_single_onnx_model(model_file, input_data): return output_dict +def prepare_initial_input_data(onnx_model_path, default_input_data): + """ + Prepares initial input data for inference. + Args: + onnx_model_path (str): Path to the ONNX model file. + default_input_data (dict): Dictionary containing default input data. + Returns: + dict: Dictionary with user-specified or default shaped and typed input data. + """ + session = ort.InferenceSession(onnx_model_path) + input_info = {input.name: input.shape for input in session.get_inputs()} + + initial_input_data = {} + dtype_map = {'f': np.float32, 'i': np.int64} + + for input_name, shape in input_info.items(): + custom_shape_str = input( + f"Enter new shape for input '{input_name}' (comma-separated integers),\ + or press Enter to use default: ") + custom_dtype_str = input( + f"Enter data type for input '{input_name}' ('f' for float32, 'i' for int64),\ + or press Enter to use default: ") + + if not custom_shape_str: + new_shape = default_input_data[input_name].shape + else: + try: + new_shape = [int(dim) for dim in custom_shape_str.split(',')] + except ValueError: + print("Invalid input, please ensure you enter comma-separated integers.") + continue + + if not custom_dtype_str: + dtype = default_input_data[input_name].dtype + else: + dtype = dtype_map.get(custom_dtype_str.strip(), None) + if dtype is None: + print("Invalid data type, set as default, please enter 'f' or 'i'.") + dtype = default_input_data[input_name].dtype + continue + + input_data = np.random.rand(*new_shape).astype(dtype) + initial_input_data[input_name] = input_data + + return initial_input_data + + if __name__ == "__main__": arg_parser = argparse.ArgumentParser() arg_parser.add_argument('-s', @@ -68,8 +115,8 @@ def infer_single_onnx_model(model_file, input_data): default_input_data = { "x": np.random.rand(1, 3, 256, 256).astype(np.float32), } - + initial_input_data = prepare_initial_input_data(args.single, default_input_data) # Perform inference using a single ONNX model output_single = ModelInference.infer_single_onnx_model(args.single, - default_input_data) + initial_input_data) print("Single model inference completed!")