Added Mish ONNX subgraph
This commit is contained in:
parent
235e648bf5
commit
af9597f454
@ -314,6 +314,19 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class MishSubgraph : public Subgraph
|
||||
{
|
||||
public:
|
||||
MishSubgraph()
|
||||
{
|
||||
int input = addNodeToMatch("");
|
||||
int softplus = addNodeToMatch("Softplus", input);
|
||||
int tanh = addNodeToMatch("Tanh", softplus);
|
||||
addNodeToMatch("Mul", input, tanh);
|
||||
setFusedNode("Mish", input);
|
||||
}
|
||||
};
|
||||
|
||||
class MulCastSubgraph : public Subgraph
|
||||
{
|
||||
public:
|
||||
@ -512,6 +525,7 @@ void simplifySubgraphs(opencv_onnx::GraphProto& net)
|
||||
subgraphs.push_back(makePtr<BatchNormalizationSubgraph1>());
|
||||
subgraphs.push_back(makePtr<BatchNormalizationSubgraph2>());
|
||||
subgraphs.push_back(makePtr<ExpandSubgraph>());
|
||||
subgraphs.push_back(makePtr<MishSubgraph>());
|
||||
|
||||
simplifySubgraphs(Ptr<ImportGraphWrapper>(new ONNXGraphWrapper(net)), subgraphs);
|
||||
}
|
||||
|
||||
@ -660,6 +660,11 @@ TEST_P(Test_ONNX_layers, ResizeOpset11_Torch1_6)
|
||||
testONNXModels("resize_opset11_torch1.6");
|
||||
}
|
||||
|
||||
TEST_P(Test_ONNX_layers, Mish)
|
||||
{
|
||||
testONNXModels("mish");
|
||||
}
|
||||
|
||||
TEST_P(Test_ONNX_layers, Conv1d)
|
||||
{
|
||||
testONNXModels("conv1d");
|
||||
|
||||
Loading…
Reference in New Issue
Block a user