From 7cee8b4a7c20002dfeb1ba00b7a0eecbe7ba60d3 Mon Sep 17 00:00:00 2001 From: Youxin Chen Date: Mon, 7 Apr 2025 18:33:23 +0800 Subject: [PATCH] [tools/onnx-subgraph] add onnx process functions implementation add more onnx process functions implementation ONE-DCO-1.0-Signed-off-by: Youxin Chen --- tools/onnx_subgraph/src/lib/graph.cpp | 114 ++++++++++++++++++++++++++ 1 file changed, 114 insertions(+) diff --git a/tools/onnx_subgraph/src/lib/graph.cpp b/tools/onnx_subgraph/src/lib/graph.cpp index bef970990fb..fc943876cb2 100644 --- a/tools/onnx_subgraph/src/lib/graph.cpp +++ b/tools/onnx_subgraph/src/lib/graph.cpp @@ -141,6 +141,120 @@ void determineGraphInput(const onnx::GraphProto &g, } } +void determineGraphOutput(const onnx::GraphProto &originalGraph, const onnx::GraphProto &g, + std::vector> &allgraphInputs_1, + std::vector> &allgraphInputs_2, + std::unordered_set &graphOutputs) +{ + auto allgraphInputs = allgraphInputs_1; + allgraphInputs.insert(allgraphInputs.end(), allgraphInputs_2.begin(), allgraphInputs_2.end()); + + for (const auto &node : g.node()) + { + const auto &outputs = node.output(); + + for (const auto &output : outputs) + { + int flag = 0; + + for (auto value_info : originalGraph.output()) + { + if (value_info.name() == output) + { + NodeTensor nt; + nt.name = value_info.name(); + std::cout << nt.name << std::endl; + std::vector shape; + + for (const auto &dim : value_info.type().tensor_type().shape().dim()) + { + shape.push_back(dim.dim_value()); + } + + nt.shape = shape; + graphOutputs.insert(nt); + flag = 1; + + break; + } + } + + if (flag) + { + continue; + } + + for (size_t i = 0; i < allgraphInputs.size(); i++) + { + for (auto &input : allgraphInputs[i]) + { + if (input.name == output) + { + graphOutputs.insert(input); + flag = 1; + + break; + } + } + + if (flag) + { + break; + } + } + } + } +} + +std::string findInputNode(const onnx::GraphProto &g, const std::string &outputTensorName) +{ + std::string node_name = ""; + + for (const auto &node : g.node()) + { + for (const auto &output : node.output()) + { + if (output == outputTensorName) + { + node_name = node.name(); + } + } + } + + return node_name; +} + +std::unordered_set collectNodeNames(const onnx::GraphProto &graph) +{ + std::unordered_set nodeNames; + + for (const auto &node : graph.node()) + { + nodeNames.insert(node.name()); + } + + return nodeNames; +} + +void mergeGraphs(onnx::GraphProto &targetGraph, onnx::GraphProto &sourceGraph) +{ + std::cout << "size before merged: " << targetGraph.node_size() << "+" << sourceGraph.node_size() + << std::endl; + int size_before = targetGraph.node_size() + sourceGraph.node_size(); + + for (const auto &node : sourceGraph.node()) + { + *targetGraph.add_node() = node; + } + + std::cout << "size after merged: " << targetGraph.node_size() << std::endl; + if (size_before != targetGraph.node_size()) + { + std::cout << "error in mergeGraphs" << std::endl; + std::exit(-1); + } +} + onnx::GraphProto GetGraphFromOnnx(std::string &path) { onnx::ModelProto model;