Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 114 additions & 0 deletions tools/onnx_subgraph/src/lib/graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,120 @@ void determineGraphInput(const onnx::GraphProto &g,
}
}

void determineGraphOutput(const onnx::GraphProto &originalGraph, const onnx::GraphProto &g,
std::vector<std::unordered_set<NodeTensor>> &allgraphInputs_1,
std::vector<std::unordered_set<NodeTensor>> &allgraphInputs_2,
std::unordered_set<NodeTensor> &graphOutputs)
{
auto allgraphInputs = allgraphInputs_1;
allgraphInputs.insert(allgraphInputs.end(), allgraphInputs_2.begin(), allgraphInputs_2.end());

for (const auto &node : g.node())
{
const auto &outputs = node.output();

for (const auto &output : outputs)
{
int flag = 0;

for (auto value_info : originalGraph.output())
{
if (value_info.name() == output)
{
NodeTensor nt;
nt.name = value_info.name();
std::cout << nt.name << std::endl;
std::vector<int64_t> shape;

for (const auto &dim : value_info.type().tensor_type().shape().dim())
{
shape.push_back(dim.dim_value());
}

nt.shape = shape;
graphOutputs.insert(nt);
flag = 1;

break;
}
}

if (flag)
{
continue;
}

for (size_t i = 0; i < allgraphInputs.size(); i++)
{
for (auto &input : allgraphInputs[i])
{
if (input.name == output)
{
graphOutputs.insert(input);
flag = 1;

break;
}
}

if (flag)
{
break;
}
}
}
}
}

std::string findInputNode(const onnx::GraphProto &g, const std::string &outputTensorName)
{
std::string node_name = "";

for (const auto &node : g.node())
{
for (const auto &output : node.output())
{
if (output == outputTensorName)
{
node_name = node.name();
}
}
}

return node_name;
}

std::unordered_set<std::string> collectNodeNames(const onnx::GraphProto &graph)
{
std::unordered_set<std::string> nodeNames;

for (const auto &node : graph.node())
{
nodeNames.insert(node.name());
}

return nodeNames;
}

void mergeGraphs(onnx::GraphProto &targetGraph, onnx::GraphProto &sourceGraph)
{
std::cout << "size before merged: " << targetGraph.node_size() << "+" << sourceGraph.node_size()
<< std::endl;
int size_before = targetGraph.node_size() + sourceGraph.node_size();

for (const auto &node : sourceGraph.node())
{
*targetGraph.add_node() = node;
}

std::cout << "size after merged: " << targetGraph.node_size() << std::endl;
if (size_before != targetGraph.node_size())
{
std::cout << "error in mergeGraphs" << std::endl;
std::exit(-1);
}
}

onnx::GraphProto GetGraphFromOnnx(std::string &path)
{
onnx::ModelProto model;
Expand Down