Skip to content

Commit 0953762

Browse files
authored
[tools/onnx-subgraph] add onnx process APIs definition (#15032)
add onnx process function definition, implementation will be added in the cpp file in next PR ONE-DCO-1.0-Signed-off-by: Youxin Chen <yx113.chen@samsung.com>
1 parent 444b28b commit 0953762

File tree

1 file changed

+109
-0
lines changed

1 file changed

+109
-0
lines changed

tools/onnx_subgraph/include/graph.h

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,115 @@ template <> struct hash<NodeTensor>
7676

7777
} // namespace std
7878

79+
/**
80+
* @brief Extracts the names and shapes of initializers from the ONNX graph.
81+
*
82+
* @param [in] graph The ONNX graph from which to extract initializers.
83+
* @pre The ONNX graph should be valid and contain initializers.
84+
* @post The names and shapes of the initializers are stored in an unordered set of NodeTensor
85+
* objects.
86+
* @exception None
87+
* @return An unordered set of NodeTensor objects containing the names and shapes of the
88+
* initializers.
89+
*/
90+
std::unordered_set<NodeTensor> getInitializer(const onnx::GraphProto &graph);
91+
92+
/**
93+
* @brief Extracts the names and shapes of inputs, outputs, and value_info from the ONNX graph.
94+
*
95+
* @param [in] graph The ONNX graph from which to extract inputs, outputs, and value_info.
96+
* @pre The ONNX graph should be valid and contain inputs, outputs, and value_info.
97+
* @post The names and shapes of the inputs, outputs, and value_info are stored in an unordered
98+
* set of NodeTensor objects.
99+
* @exception None
100+
* @return An unordered set of NodeTensor objects containing the names and shapes of the inputs,
101+
* outputs, and value_info.
102+
*/
103+
std::unordered_set<NodeTensor> getIOvalue(const onnx::GraphProto &graph);
104+
105+
/**
106+
* @brief Determines the input tensors of the graph that are not produced by any node in the
107+
* graph.
108+
*
109+
* @param [in] g The ONNX GraphProto object representing the graph.
110+
* @param [in] initializerNames A set of NodeTensor objects representing the initializers in the
111+
* graph.
112+
* @param [out] graphInputs A set of NodeTensor objects representing the input tensors of the
113+
* graph.
114+
* @pre The GraphProto object g should be valid and contain nodes with proper input and output
115+
* lists.
116+
* @post The graphInputs set will be populated with NodeTensor objects that are inputs to the
117+
* graph.
118+
* @exception None
119+
* @return None
120+
*/
121+
void determineGraphInput(const onnx::GraphProto &g,
122+
const std::unordered_set<NodeTensor> &initializerNames,
123+
std::unordered_set<NodeTensor> &graphInputs);
124+
125+
/**
126+
* @brief Determines the output tensors of the graph that are either outputs of the original
127+
* graph or are used as inputs in other parts of the graph.
128+
*
129+
* @param [in] originalGraph The original ONNX GraphProto object representing the graph.
130+
* @param [in] g The ONNX GraphProto object representing the graph to analyze.
131+
* @param [in] allgraphInputs_1 A vector of sets of NodeTensor objects representing the first
132+
* set of inputs to the graph.
133+
* @param [in] allgraphInputs_2 A vector of sets of NodeTensor objects representing the second
134+
* set of inputs to the graph.
135+
* @param [out] graphOutputs A set of NodeTensor objects representing the output tensors of the
136+
* graph.
137+
* @pre The GraphProto objects originalGraph and g should be valid and contain nodes with
138+
* proper input and output lists.
139+
* @post The graphOutputs set will be populated with NodeTensor objects that are outputs of the
140+
* graph.
141+
* @exception None
142+
* @return None
143+
*/
144+
void determineGraphOutput(const onnx::GraphProto &originalGraph, const onnx::GraphProto &g,
145+
std::vector<std::unordered_set<NodeTensor>> &allgraphInputs_1,
146+
std::vector<std::unordered_set<NodeTensor>> &allgraphInputs_2,
147+
std::unordered_set<NodeTensor> &graphOutputs);
148+
149+
/**
150+
* @brief Finds the name of the node that produces a specified output tensor in the given ONNX
151+
* graph.
152+
*
153+
* @param [in] g The ONNX GraphProto object representing the graph.
154+
* @param [in] outputTensorName The name of the output tensor to find the producing node for.
155+
* @pre The GraphProto object g should be valid and contain nodes with proper input and output
156+
* lists.
157+
* @post None
158+
* @exception None
159+
* @return The name of the node that produces the specified output tensor, or an empty string if
160+
* no such node is found.
161+
*/
162+
std::string findInputNode(const onnx::GraphProto &g, const std::string &outputTensorName);
163+
164+
/**
165+
* @brief Collects the names of all nodes in the given ONNX graph.
166+
*
167+
* @param [in] graph The ONNX GraphProto object representing the graph.
168+
* @pre The GraphProto object graph should be valid and contain nodes with proper names.
169+
* @post None
170+
* @exception None
171+
* @return An unordered set containing the names of all nodes in the graph.
172+
*/
173+
std::unordered_set<std::string> collectNodeNames(const onnx::GraphProto &graph);
174+
175+
/**
176+
* @brief Merges nodes from the source graph into the target graph.
177+
*
178+
* @param [in,out] targetGraph The ONNX GraphProto object to which nodes will be added.
179+
* @param [in] sourceGraph The ONNX GraphProto object from which nodes will be copied.
180+
* @pre Both GraphProto objects should be valid.
181+
* @post Nodes from sourceGraph are added to targetGraph.
182+
* @exception Exits the program with an error message if the number of nodes in targetGraph does not
183+
* match the expected size after merging.
184+
* @return None
185+
*/
186+
void mergeGraphs(onnx::GraphProto &targetGraph, onnx::GraphProto &sourceGraph);
187+
79188
/**
80189
* @brief Loads an ONNX model from a file and returns the graph contained within.
81190
*

0 commit comments

Comments
 (0)