diff --git a/modules/dnn/src/layers/recurrent_layers.cpp b/modules/dnn/src/layers/recurrent_layers.cpp index a6715aefca..9088c13390 100644 --- a/modules/dnn/src/layers/recurrent_layers.cpp +++ b/modules/dnn/src/layers/recurrent_layers.cpp @@ -80,12 +80,31 @@ static void sigmoid(const Mat &src, Mat &dst) cv::pow(1 + dst, -1, dst); } +typedef void (*ActivationFunction)(const Mat &src, Mat &dst); +static ActivationFunction get_activation_function(const String& activation) { + // most used activations for PyTorch and TF : Tanh, Sigmoid + // if you need to support more optional activations use std::map instead + if (activation == "Tanh") + { + return tanh; + } + else if (activation == "Sigmoid") + { + return sigmoid; + } + else + { + CV_Error(Error::StsNotImplemented, + cv::format("Activation function [%s] for layer LSTM is not supported", activation.c_str())); + } +} + class LSTMLayerImpl CV_FINAL : public LSTMLayer { int numTimeStamps, numSamples; bool allocated; - MatShape outTailShape; //shape of single output sample + MatShape outTailShape; //shape of single output sample MatShape outTsShape; //shape of N output samples bool useTimestampDim; @@ -95,6 +114,10 @@ class LSTMLayerImpl CV_FINAL : public LSTMLayer bool reverse; // If true, go in negative direction along the time axis bool bidirectional; // If true, produces both forward and reversed directions along time axis + ActivationFunction f_activation; + ActivationFunction g_activation; + ActivationFunction h_activation; + public: LSTMLayerImpl(const LayerParams& params) @@ -145,6 +168,20 @@ public: reverse = params.get("reverse", false); CV_Assert(!reverse || !bidirectional); + // read activations + DictValue activations = params.get("activations", ""); + if (activations.size() == 1) // if activations wasn't specified use default + { + f_activation = sigmoid; + g_activation = tanh; + h_activation = tanh; + } else { + CV_Assert(activations.size() == 3); + f_activation = get_activation_function(activations.getStringValue(0)); + g_activation = get_activation_function(activations.getStringValue(1)); + h_activation = get_activation_function(activations.getStringValue(2)); + } + allocated = false; outTailShape.clear(); } @@ -339,15 +376,15 @@ public: Mat gatesIF = gates.colRange(0, 2*numOut); gemm(cInternal, blobs[5], 1, gateI, 1, gateI); gemm(cInternal, blobs[6], 1, gateF, 1, gateF); - sigmoid(gatesIF, gatesIF); + f_activation(gatesIF, gatesIF); } else { Mat gatesIFO = gates.colRange(0, 3*numOut); - sigmoid(gatesIFO, gatesIFO); + f_activation(gatesIFO, gatesIFO); } - tanh(gateG, gateG); + g_activation(gateG, gateG); //compute c_t multiply(gateF, cInternal, gateF); // f_t (*) c_{t-1} @@ -362,11 +399,11 @@ public: if (usePeephole) { gemm(cInternal, blobs[7], 1, gateO, 1, gateO); - sigmoid(gateO, gateO); + f_activation(gateO, gateO); } //compute h_t - tanh(cInternal, hInternal); + h_activation(cInternal, hInternal); multiply(gateO, hInternal, hInternal); //save results in output blobs diff --git a/modules/dnn/src/onnx/onnx_importer.cpp b/modules/dnn/src/onnx/onnx_importer.cpp index b833b2ea44..32b56278bd 100644 --- a/modules/dnn/src/onnx/onnx_importer.cpp +++ b/modules/dnn/src/onnx/onnx_importer.cpp @@ -244,6 +244,10 @@ static DictValue parse(const ::google::protobuf::RepeatedField< ::google::protob return DictValue::arrayInt(&dst[0], src.size()); } +static DictValue parseStr(const ::google::protobuf::RepeatedPtrField< ::std::string>& src) { + return DictValue::arrayString(src.begin(), static_cast(src.size())); +} + LayerParams ONNXImporter::getLayerParams(const opencv_onnx::NodeProto& node_proto) { LayerParams lp; @@ -301,6 +305,10 @@ LayerParams ONNXImporter::getLayerParams(const opencv_onnx::NodeProto& node_prot CV_Assert(attribute_proto.ints_size() == 1 || attribute_proto.ints_size() == 2 || attribute_proto.ints_size() == 3); lp.set("dilation", parse(attribute_proto.ints())); } + else if(attribute_name == "activations" && node_proto.op_type() == "LSTM") + { + lp.set(attribute_name, parseStr(attribute_proto.strings())); + } else if (attribute_proto.has_i()) { ::google::protobuf::int64 src = attribute_proto.i(); @@ -997,18 +1005,32 @@ void ONNXImporter::parseLSTM(LayerParams& layerParams, const opencv_onnx::NodePr lstmParams.name += "/lstm"; // https://pytorch.org/docs/stable/nn.html#lstm - CV_Assert(node_proto.input_size() == 7); + CV_Assert(node_proto.input_size() >= 7); Mat Wx = getBlob(node_proto, 1); Mat Wh = getBlob(node_proto, 2); Mat b = getBlob(node_proto, 3); - Mat h0 = getBlob(node_proto, 5); - Mat c0 = getBlob(node_proto, 6); - - b = b.reshape(1, b.size[0]); const int numHidden = lstmParams.get("hidden_size"); const int numDirs = Wx.size[0]; // Is 1 for forward only and 2 for bidirectional LSTM. const int numFeatures = Wx.size[2]; + + Mat h0, c0; + if (!node_proto.input(5).empty()) { + h0 = getBlob(node_proto, 5); + h0 = h0.reshape(1, h0.size[0] * h0.size[1]); + } else { + // initial_h attribute can be empty in case of keras2onnx producer. fill it with zeros + h0 = Mat::zeros(numDirs * numFeatures, numHidden, CV_32FC1); + } + if (!node_proto.input(6).empty()) { + c0 = getBlob(node_proto, 6); + c0 = c0.reshape(1, c0.size[0] * c0.size[1]); + } else { + // initial_c attribute can be empty in case of keras2onnx producer. fill it with zeros + c0 = Mat::zeros(numDirs * numFeatures, numHidden, CV_32FC1); + } + + b = b.reshape(1, b.size[0]); Mat bx = b.colRange(0, b.cols / 2); Mat bh = b.colRange(b.cols / 2, b.cols); b = bx + bh; @@ -1036,8 +1058,7 @@ void ONNXImporter::parseLSTM(LayerParams& layerParams, const opencv_onnx::NodePr } Wx = Wx.reshape(1, Wx.size[0] * Wx.size[1]); Wh = Wh.reshape(1, Wh.size[0] * Wh.size[1]); - h0 = h0.reshape(1, h0.size[0] * h0.size[1]); - c0 = c0.reshape(1, c0.size[0] * c0.size[1]); + lstmParams.blobs.resize(5); lstmParams.blobs[0] = Wh; @@ -1045,6 +1066,9 @@ void ONNXImporter::parseLSTM(LayerParams& layerParams, const opencv_onnx::NodePr lstmParams.blobs[2] = b; lstmParams.blobs[3] = h0; lstmParams.blobs[4] = c0; + + // read direction attribute + lstmParams.set("reverse", lstmParams.get("direction", "") == "reverse"); lstmParams.set("bidirectional", lstmParams.get("direction", "") == "bidirectional"); node_proto.set_output(0, lstmParams.name); // set different name so output shapes will be registered on that name diff --git a/modules/dnn/test/test_onnx_importer.cpp b/modules/dnn/test/test_onnx_importer.cpp index 05f77730af..a446a37c79 100644 --- a/modules/dnn/test/test_onnx_importer.cpp +++ b/modules/dnn/test/test_onnx_importer.cpp @@ -665,6 +665,11 @@ TEST_P(Test_ONNX_layers, Split_EltwiseMax) testONNXModels("split_max"); } +TEST_P(Test_ONNX_layers, LSTM_Activations) +{ + testONNXModels("lstm_cntk_tanh", pb, 0, 0, false, false); +} + TEST_P(Test_ONNX_layers, LSTM) { testONNXModels("lstm", npy, 0, 0, false, false);