Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions tools/onnx_subgraph/include/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,115 @@ template <> struct hash<NodeTensor>

} // namespace std

/**
* @brief Extracts the names and shapes of initializers from the ONNX graph.
*
* @param [in] graph The ONNX graph from which to extract initializers.
* @pre The ONNX graph should be valid and contain initializers.
* @post The names and shapes of the initializers are stored in an unordered set of NodeTensor
* objects.
* @exception None
* @return An unordered set of NodeTensor objects containing the names and shapes of the
* initializers.
*/
std::unordered_set<NodeTensor> getInitializer(const onnx::GraphProto &graph);

/**
* @brief Extracts the names and shapes of inputs, outputs, and value_info from the ONNX graph.
*
* @param [in] graph The ONNX graph from which to extract inputs, outputs, and value_info.
* @pre The ONNX graph should be valid and contain inputs, outputs, and value_info.
* @post The names and shapes of the inputs, outputs, and value_info are stored in an unordered
* set of NodeTensor objects.
* @exception None
* @return An unordered set of NodeTensor objects containing the names and shapes of the inputs,
* outputs, and value_info.
*/
std::unordered_set<NodeTensor> getIOvalue(const onnx::GraphProto &graph);

/**
* @brief Determines the input tensors of the graph that are not produced by any node in the
* graph.
*
* @param [in] g The ONNX GraphProto object representing the graph.
* @param [in] initializerNames A set of NodeTensor objects representing the initializers in the
* graph.
* @param [out] graphInputs A set of NodeTensor objects representing the input tensors of the
* graph.
* @pre The GraphProto object g should be valid and contain nodes with proper input and output
* lists.
* @post The graphInputs set will be populated with NodeTensor objects that are inputs to the
* graph.
* @exception None
* @return None
*/
void determineGraphInput(const onnx::GraphProto &g,
const std::unordered_set<NodeTensor> &initializerNames,
std::unordered_set<NodeTensor> &graphInputs);

/**
* @brief Determines the output tensors of the graph that are either outputs of the original
* graph or are used as inputs in other parts of the graph.
*
* @param [in] originalGraph The original ONNX GraphProto object representing the graph.
* @param [in] g The ONNX GraphProto object representing the graph to analyze.
* @param [in] allgraphInputs_1 A vector of sets of NodeTensor objects representing the first
* set of inputs to the graph.
* @param [in] allgraphInputs_2 A vector of sets of NodeTensor objects representing the second
* set of inputs to the graph.
* @param [out] graphOutputs A set of NodeTensor objects representing the output tensors of the
* graph.
* @pre The GraphProto objects originalGraph and g should be valid and contain nodes with
* proper input and output lists.
* @post The graphOutputs set will be populated with NodeTensor objects that are outputs of the
* graph.
* @exception None
* @return None
*/
void determineGraphOutput(const onnx::GraphProto &originalGraph, const onnx::GraphProto &g,
std::vector<std::unordered_set<NodeTensor>> &allgraphInputs_1,
std::vector<std::unordered_set<NodeTensor>> &allgraphInputs_2,
std::unordered_set<NodeTensor> &graphOutputs);

/**
* @brief Finds the name of the node that produces a specified output tensor in the given ONNX
* graph.
*
* @param [in] g The ONNX GraphProto object representing the graph.
* @param [in] outputTensorName The name of the output tensor to find the producing node for.
* @pre The GraphProto object g should be valid and contain nodes with proper input and output
* lists.
* @post None
* @exception None
* @return The name of the node that produces the specified output tensor, or an empty string if
* no such node is found.
*/
std::string findInputNode(const onnx::GraphProto &g, const std::string &outputTensorName);

/**
* @brief Collects the names of all nodes in the given ONNX graph.
*
* @param [in] graph The ONNX GraphProto object representing the graph.
* @pre The GraphProto object graph should be valid and contain nodes with proper names.
* @post None
* @exception None
* @return An unordered set containing the names of all nodes in the graph.
*/
std::unordered_set<std::string> collectNodeNames(const onnx::GraphProto &graph);

/**
* @brief Merges nodes from the source graph into the target graph.
*
* @param [in,out] targetGraph The ONNX GraphProto object to which nodes will be added.
* @param [in] sourceGraph The ONNX GraphProto object from which nodes will be copied.
* @pre Both GraphProto objects should be valid.
* @post Nodes from sourceGraph are added to targetGraph.
* @exception Exits the program with an error message if the number of nodes in targetGraph does not
* match the expected size after merging.
* @return None
*/
void mergeGraphs(onnx::GraphProto &targetGraph, onnx::GraphProto &sourceGraph);

/**
* @brief Loads an ONNX model from a file and returns the graph contained within.
*
Expand Down