Remove Switch and Merge nodes from TensorFlow networks
This commit is contained in:
parent
df1f62b34c
commit
ec41a4897a
@ -10,6 +10,7 @@
|
||||
#ifdef HAVE_PROTOBUF
|
||||
|
||||
#include "tf_graph_simplifier.hpp"
|
||||
#include <queue>
|
||||
|
||||
namespace cv { namespace dnn {
|
||||
CV__DNN_EXPERIMENTAL_NS_BEGIN
|
||||
@ -883,7 +884,6 @@ void sortByExecutionOrder(tensorflow::GraphDef& net)
|
||||
nodesToAdd.pop_back();
|
||||
|
||||
permIds.push_back(nodeToAdd);
|
||||
// std::cout << net.node(nodeToAdd).name() << '\n';
|
||||
|
||||
for (int i = 0; i < edges[nodeToAdd].size(); ++i)
|
||||
{
|
||||
@ -902,6 +902,85 @@ void sortByExecutionOrder(tensorflow::GraphDef& net)
|
||||
permute(net.mutable_node(), permIds);
|
||||
}
|
||||
|
||||
// Remove training switches (Switch and Merge nodes and corresponding subgraphs).
|
||||
void removePhaseSwitches(tensorflow::GraphDef& net)
|
||||
{
|
||||
std::vector<int> nodesToRemove;
|
||||
std::map<std::string, int> nodesMap;
|
||||
std::map<std::string, int>::iterator nodesMapIt;
|
||||
std::queue<int> mergeOpSubgraphNodes;
|
||||
for (int i = 0; i < net.node_size(); ++i)
|
||||
{
|
||||
const tensorflow::NodeDef& node = net.node(i);
|
||||
nodesMap.insert(std::make_pair(node.name(), i));
|
||||
if (node.op() == "Switch" || node.op() == "Merge")
|
||||
{
|
||||
CV_Assert(node.input_size() > 0);
|
||||
// Replace consumers' inputs.
|
||||
for (int j = 0; j < net.node_size(); ++j)
|
||||
{
|
||||
tensorflow::NodeDef* consumer = net.mutable_node(j);
|
||||
for (int k = 0; k < consumer->input_size(); ++k)
|
||||
{
|
||||
std::string inpName = consumer->input(k);
|
||||
inpName = inpName.substr(0, inpName.rfind(':'));
|
||||
if (inpName == node.name())
|
||||
{
|
||||
consumer->set_input(k, node.input(0));
|
||||
}
|
||||
}
|
||||
}
|
||||
nodesToRemove.push_back(i);
|
||||
if (node.op() == "Merge")
|
||||
mergeOpSubgraphNodes.push(i);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int> numConsumers(net.node_size(), 0);
|
||||
for (int i = 0; i < net.node_size(); ++i)
|
||||
{
|
||||
const tensorflow::NodeDef& node = net.node(i);
|
||||
for (int j = 0; j < node.input_size(); ++j)
|
||||
{
|
||||
std::string inpName = node.input(j);
|
||||
inpName = inpName.substr(1 + (int)inpName.find('^'), inpName.rfind(':'));
|
||||
nodesMapIt = nodesMap.find(inpName);
|
||||
CV_Assert(nodesMapIt != nodesMap.end());
|
||||
numConsumers[nodesMapIt->second] += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Remove subgraphs of unused nodes which are terminated by Merge nodes.
|
||||
while (!mergeOpSubgraphNodes.empty())
|
||||
{
|
||||
const tensorflow::NodeDef& node = net.node(mergeOpSubgraphNodes.front());
|
||||
mergeOpSubgraphNodes.pop();
|
||||
for (int i = 0; i < node.input_size(); ++i)
|
||||
{
|
||||
std::string inpName = node.input(i);
|
||||
inpName = inpName.substr(1 + (int)inpName.find('^'), inpName.rfind(':'));
|
||||
nodesMapIt = nodesMap.find(inpName);
|
||||
CV_Assert(nodesMapIt != nodesMap.end());
|
||||
|
||||
int inpNodeId = nodesMapIt->second;
|
||||
if (numConsumers[inpNodeId] == 1)
|
||||
{
|
||||
mergeOpSubgraphNodes.push(inpNodeId);
|
||||
nodesToRemove.push_back(inpNodeId);
|
||||
}
|
||||
else if (numConsumers[inpNodeId] > 0)
|
||||
numConsumers[inpNodeId] -= 1;
|
||||
}
|
||||
}
|
||||
std::sort(nodesToRemove.begin(), nodesToRemove.end());
|
||||
for (int i = nodesToRemove.size() - 1; i >= 0; --i)
|
||||
{
|
||||
if (nodesToRemove[i] < net.node_size()) // Ids might be repeated.
|
||||
net.mutable_node()->DeleteSubrange(nodesToRemove[i], 1);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
CV__DNN_EXPERIMENTAL_NS_END
|
||||
}} // namespace dnn, namespace cv
|
||||
|
||||
|
||||
@ -27,6 +27,8 @@ void releaseTensor(tensorflow::TensorProto* tensor);
|
||||
|
||||
void sortByExecutionOrder(tensorflow::GraphDef& net);
|
||||
|
||||
void removePhaseSwitches(tensorflow::GraphDef& net);
|
||||
|
||||
CV__DNN_EXPERIMENTAL_NS_END
|
||||
}} // namespace dnn, namespace cv
|
||||
|
||||
|
||||
@ -657,6 +657,9 @@ static int predictOutputDataLayout(const tensorflow::GraphDef& net,
|
||||
|
||||
void TFImporter::populateNet(Net dstNet)
|
||||
{
|
||||
if (!netTxt.ByteSize())
|
||||
removePhaseSwitches(netBin);
|
||||
|
||||
RemoveIdentityOps(netBin);
|
||||
RemoveIdentityOps(netTxt);
|
||||
|
||||
|
||||
@ -185,6 +185,16 @@ TEST_P(Test_TensorFlow_layers, batch_norm)
|
||||
runTensorFlowNet("mvn_batch_norm_1x1");
|
||||
}
|
||||
|
||||
TEST_P(Test_TensorFlow_layers, slim_batch_norm)
|
||||
{
|
||||
if (backend == DNN_BACKEND_INFERENCE_ENGINE)
|
||||
throw SkipTestException("Test is disabled for DLIE");
|
||||
// Output values range: [-40.0597, 207.827]
|
||||
double l1 = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD) ? 0.041 : default_l1;
|
||||
double lInf = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD) ? 0.33 : default_lInf;
|
||||
runTensorFlowNet("slim_batch_norm", false, l1, lInf);
|
||||
}
|
||||
|
||||
TEST_P(Test_TensorFlow_layers, pooling)
|
||||
{
|
||||
runTensorFlowNet("max_pool_even");
|
||||
|
||||
Loading…
Reference in New Issue
Block a user