diff --git a/modules/dnn/src/onnx/onnx_graph_simplifier.cpp b/modules/dnn/src/onnx/onnx_graph_simplifier.cpp index c6e54d6a92..091d2d4ae9 100644 --- a/modules/dnn/src/onnx/onnx_graph_simplifier.cpp +++ b/modules/dnn/src/onnx/onnx_graph_simplifier.cpp @@ -531,6 +531,35 @@ public: } }; +// softplus(x) = log(exp(x) + 1) +class SoftplusSubgraph: public Subgraph +{ +public: + SoftplusSubgraph() + { + int input = addNodeToMatch(""); + int exp = addNodeToMatch("Exp", input); + int addVal = addNodeToMatch(""); + int add = addNodeToMatch("Add", addVal, exp); + addNodeToMatch("Log", add); + setFusedNode("Softplus", input); + } +}; + +class SoftplusSubgraph2: public Subgraph +{ +public: + SoftplusSubgraph2() + { + int input = addNodeToMatch(""); + int exp = addNodeToMatch("Exp", input); + int addVal = addNodeToMatch(""); + int add = addNodeToMatch("Add", exp, addVal); + addNodeToMatch("Log", add); + setFusedNode("Softplus", input); + } +}; + class MulCastSubgraph : public Subgraph { public: @@ -734,6 +763,8 @@ void simplifySubgraphs(opencv_onnx::GraphProto& net) subgraphs.push_back(makePtr()); subgraphs.push_back(makePtr()); subgraphs.push_back(makePtr()); + subgraphs.push_back(makePtr()); + subgraphs.push_back(makePtr()); subgraphs.push_back(makePtr()); subgraphs.push_back(makePtr()); subgraphs.push_back(makePtr()); diff --git a/modules/dnn/test/test_onnx_importer.cpp b/modules/dnn/test/test_onnx_importer.cpp index 702a69da2a..9b231521c4 100644 --- a/modules/dnn/test/test_onnx_importer.cpp +++ b/modules/dnn/test/test_onnx_importer.cpp @@ -1325,6 +1325,7 @@ TEST_P(Test_ONNX_layers, ResizeOpset11_Torch1_6) TEST_P(Test_ONNX_layers, Mish) { testONNXModels("mish"); + testONNXModels("mish_no_softplus"); } TEST_P(Test_ONNX_layers, CalculatePads)