Added Mish ONNX subgraph

This commit is contained in:
Liubov Batanina 2021-01-15 14:01:48 +03:00
parent 235e648bf5
commit af9597f454
2 changed files with 19 additions and 0 deletions

View File

@ -314,6 +314,19 @@ public:
}
};
class MishSubgraph : public Subgraph
{
public:
MishSubgraph()
{
int input = addNodeToMatch("");
int softplus = addNodeToMatch("Softplus", input);
int tanh = addNodeToMatch("Tanh", softplus);
addNodeToMatch("Mul", input, tanh);
setFusedNode("Mish", input);
}
};
class MulCastSubgraph : public Subgraph
{
public:
@ -512,6 +525,7 @@ void simplifySubgraphs(opencv_onnx::GraphProto& net)
subgraphs.push_back(makePtr<BatchNormalizationSubgraph1>());
subgraphs.push_back(makePtr<BatchNormalizationSubgraph2>());
subgraphs.push_back(makePtr<ExpandSubgraph>());
subgraphs.push_back(makePtr<MishSubgraph>());
simplifySubgraphs(Ptr<ImportGraphWrapper>(new ONNXGraphWrapper(net)), subgraphs);
}

View File

@ -660,6 +660,11 @@ TEST_P(Test_ONNX_layers, ResizeOpset11_Torch1_6)
testONNXModels("resize_opset11_torch1.6");
}
TEST_P(Test_ONNX_layers, Mish)
{
testONNXModels("mish");
}
TEST_P(Test_ONNX_layers, Conv1d)
{
testONNXModels("conv1d");