Bidirectional LSTM
This commit is contained in:
@@ -93,6 +93,7 @@ class LSTMLayerImpl CV_FINAL : public LSTMLayer
|
||||
float forgetBias, cellClip;
|
||||
bool useCellClip, usePeephole;
|
||||
bool reverse; // If true, go in negative direction along the time axis
|
||||
bool bidirectional; // If true, produces both forward and reversed directions along time axis
|
||||
|
||||
public:
|
||||
|
||||
@@ -101,6 +102,7 @@ public:
|
||||
{
|
||||
setParamsFrom(params);
|
||||
|
||||
bidirectional = params.get<bool>("bidirectional", false);
|
||||
if (!blobs.empty())
|
||||
{
|
||||
CV_Assert(blobs.size() >= 3);
|
||||
@@ -113,7 +115,7 @@ public:
|
||||
CV_CheckEQ(Wh.dims, 2, "");
|
||||
CV_CheckEQ(Wx.dims, 2, "");
|
||||
CV_CheckEQ(Wh.rows, Wx.rows, "");
|
||||
CV_CheckEQ(Wh.rows, 4*Wh.cols, "");
|
||||
CV_CheckEQ(Wh.rows, (1 + static_cast<int>(bidirectional))*4*Wh.cols, "");
|
||||
CV_CheckEQ(Wh.rows, (int)bias.total(), "");
|
||||
CV_Assert(Wh.type() == Wx.type() && Wx.type() == bias.type());
|
||||
|
||||
@@ -136,6 +138,7 @@ public:
|
||||
useCellClip = params.get<bool>("use_cell_clip", false);
|
||||
usePeephole = params.get<bool>("use_peephole", false);
|
||||
reverse = params.get<bool>("reverse", false);
|
||||
CV_Assert(!reverse || !bidirectional);
|
||||
|
||||
allocated = false;
|
||||
outTailShape.clear();
|
||||
@@ -207,6 +210,7 @@ public:
|
||||
|
||||
outResShape.push_back(_numSamples);
|
||||
outResShape.insert(outResShape.end(), outTailShape_.begin(), outTailShape_.end());
|
||||
outResShape.back() *= (1 + static_cast<int>(bidirectional));
|
||||
|
||||
size_t noutputs = produceCellOutput ? 2 : 1;
|
||||
outputs.assign(noutputs, outResShape);
|
||||
@@ -253,6 +257,7 @@ public:
|
||||
outTsShape.clear();
|
||||
outTsShape.push_back(numSamples);
|
||||
outTsShape.insert(outTsShape.end(), outTailShape.begin(), outTailShape.end());
|
||||
outTsShape.back() *= (1 + static_cast<int>(bidirectional));
|
||||
|
||||
allocated = true;
|
||||
}
|
||||
@@ -273,91 +278,96 @@ public:
|
||||
outputs_arr.getMatVector(output);
|
||||
internals_arr.getMatVector(internals);
|
||||
|
||||
const Mat &Wh = blobs[0];
|
||||
const Mat &Wx = blobs[1];
|
||||
const Mat &bias = blobs[2];
|
||||
|
||||
int numOut = Wh.size[1];
|
||||
|
||||
Mat hInternal = internals[0], cInternal = internals[1],
|
||||
dummyOnes = internals[2], gates = internals[3];
|
||||
hInternal.setTo(0.);
|
||||
cInternal.setTo(0.);
|
||||
dummyOnes.setTo(1.);
|
||||
|
||||
int numSamplesTotal = numTimeStamps*numSamples;
|
||||
Mat xTs = input[0].reshape(1, numSamplesTotal);
|
||||
|
||||
Mat hOutTs = output[0].reshape(1, numSamplesTotal);
|
||||
Mat cOutTs = produceCellOutput ? output[1].reshape(1, numSamplesTotal) : Mat();
|
||||
|
||||
int tsStart, tsEnd, tsInc;
|
||||
if (reverse) {
|
||||
tsStart = numTimeStamps - 1;
|
||||
tsEnd = -1;
|
||||
tsInc = -1;
|
||||
}
|
||||
else {
|
||||
tsStart = 0;
|
||||
tsEnd = numTimeStamps;
|
||||
tsInc = 1;
|
||||
}
|
||||
for (int ts = tsStart; ts != tsEnd; ts += tsInc)
|
||||
const int numDirs = 1 + static_cast<int>(bidirectional);
|
||||
for (int i = 0; i < numDirs; ++i)
|
||||
{
|
||||
Range curRowRange(ts*numSamples, (ts + 1)*numSamples);
|
||||
Mat xCurr = xTs.rowRange(curRowRange);
|
||||
const Mat &Wh = blobs[0].rowRange(i * blobs[0].rows / numDirs, (i + 1) * blobs[0].rows / numDirs);
|
||||
const Mat &Wx = blobs[1].rowRange(i * blobs[1].rows / numDirs, (i + 1) * blobs[1].rows / numDirs);
|
||||
const Mat &bias = blobs[2].colRange(i * blobs[2].cols / numDirs, (i + 1) * blobs[2].cols / numDirs);
|
||||
|
||||
gemm(xCurr, Wx, 1, gates, 0, gates, GEMM_2_T); // Wx * x_t
|
||||
gemm(hInternal, Wh, 1, gates, 1, gates, GEMM_2_T); //+Wh * h_{t-1}
|
||||
gemm(dummyOnes, bias, 1, gates, 1, gates); //+b
|
||||
int numOut = Wh.size[1];
|
||||
|
||||
Mat gateI = gates.colRange(0*numOut, 1*numOut);
|
||||
Mat gateF = gates.colRange(1*numOut, 2*numOut);
|
||||
Mat gateO = gates.colRange(2*numOut, 3*numOut);
|
||||
Mat gateG = gates.colRange(3*numOut, 4*numOut);
|
||||
Mat hInternal = internals[0], cInternal = internals[1],
|
||||
dummyOnes = internals[2], gates = internals[3];
|
||||
hInternal.setTo(0.);
|
||||
cInternal.setTo(0.);
|
||||
dummyOnes.setTo(1.);
|
||||
|
||||
if (forgetBias)
|
||||
add(gateF, forgetBias, gateF);
|
||||
int numSamplesTotal = numTimeStamps*numSamples;
|
||||
Mat xTs = input[0].reshape(1, numSamplesTotal);
|
||||
|
||||
if (usePeephole)
|
||||
{
|
||||
Mat gatesIF = gates.colRange(0, 2*numOut);
|
||||
gemm(cInternal, blobs[3], 1, gateI, 1, gateI);
|
||||
gemm(cInternal, blobs[4], 1, gateF, 1, gateF);
|
||||
sigmoid(gatesIF, gatesIF);
|
||||
Mat hOutTs = output[0].reshape(1, numSamplesTotal);
|
||||
hOutTs = hOutTs.colRange(i * hOutTs.cols / numDirs, (i + 1) * hOutTs.cols / numDirs);
|
||||
Mat cOutTs = produceCellOutput ? output[1].reshape(1, numSamplesTotal) : Mat();
|
||||
|
||||
int tsStart, tsEnd, tsInc;
|
||||
if (reverse || i == 1) {
|
||||
tsStart = numTimeStamps - 1;
|
||||
tsEnd = -1;
|
||||
tsInc = -1;
|
||||
}
|
||||
else
|
||||
{
|
||||
Mat gatesIFO = gates.colRange(0, 3*numOut);
|
||||
sigmoid(gatesIFO, gatesIFO);
|
||||
else {
|
||||
tsStart = 0;
|
||||
tsEnd = numTimeStamps;
|
||||
tsInc = 1;
|
||||
}
|
||||
|
||||
tanh(gateG, gateG);
|
||||
|
||||
//compute c_t
|
||||
multiply(gateF, cInternal, gateF); // f_t (*) c_{t-1}
|
||||
multiply(gateI, gateG, gateI); // i_t (*) g_t
|
||||
add(gateF, gateI, cInternal); // c_t = f_t (*) c_{t-1} + i_t (*) g_t
|
||||
|
||||
if (useCellClip)
|
||||
for (int ts = tsStart; ts != tsEnd; ts += tsInc)
|
||||
{
|
||||
min(cInternal, cellClip, cInternal);
|
||||
max(cInternal, -cellClip, cInternal);
|
||||
}
|
||||
if (usePeephole)
|
||||
{
|
||||
gemm(cInternal, blobs[5], 1, gateO, 1, gateO);
|
||||
sigmoid(gateO, gateO);
|
||||
}
|
||||
Range curRowRange(ts*numSamples, (ts + 1)*numSamples);
|
||||
Mat xCurr = xTs.rowRange(curRowRange);
|
||||
|
||||
//compute h_t
|
||||
tanh(cInternal, hInternal);
|
||||
multiply(gateO, hInternal, hInternal);
|
||||
gemm(xCurr, Wx, 1, gates, 0, gates, GEMM_2_T); // Wx * x_t
|
||||
gemm(hInternal, Wh, 1, gates, 1, gates, GEMM_2_T); //+Wh * h_{t-1}
|
||||
gemm(dummyOnes, bias, 1, gates, 1, gates); //+b
|
||||
|
||||
//save results in output blobs
|
||||
hInternal.copyTo(hOutTs.rowRange(curRowRange));
|
||||
if (produceCellOutput)
|
||||
cInternal.copyTo(cOutTs.rowRange(curRowRange));
|
||||
Mat gateI = gates.colRange(0*numOut, 1*numOut);
|
||||
Mat gateF = gates.colRange(1*numOut, 2*numOut);
|
||||
Mat gateO = gates.colRange(2*numOut, 3*numOut);
|
||||
Mat gateG = gates.colRange(3*numOut, 4*numOut);
|
||||
|
||||
if (forgetBias)
|
||||
add(gateF, forgetBias, gateF);
|
||||
|
||||
if (usePeephole)
|
||||
{
|
||||
Mat gatesIF = gates.colRange(0, 2*numOut);
|
||||
gemm(cInternal, blobs[3], 1, gateI, 1, gateI);
|
||||
gemm(cInternal, blobs[4], 1, gateF, 1, gateF);
|
||||
sigmoid(gatesIF, gatesIF);
|
||||
}
|
||||
else
|
||||
{
|
||||
Mat gatesIFO = gates.colRange(0, 3*numOut);
|
||||
sigmoid(gatesIFO, gatesIFO);
|
||||
}
|
||||
|
||||
tanh(gateG, gateG);
|
||||
|
||||
//compute c_t
|
||||
multiply(gateF, cInternal, gateF); // f_t (*) c_{t-1}
|
||||
multiply(gateI, gateG, gateI); // i_t (*) g_t
|
||||
add(gateF, gateI, cInternal); // c_t = f_t (*) c_{t-1} + i_t (*) g_t
|
||||
|
||||
if (useCellClip)
|
||||
{
|
||||
min(cInternal, cellClip, cInternal);
|
||||
max(cInternal, -cellClip, cInternal);
|
||||
}
|
||||
if (usePeephole)
|
||||
{
|
||||
gemm(cInternal, blobs[5], 1, gateO, 1, gateO);
|
||||
sigmoid(gateO, gateO);
|
||||
}
|
||||
|
||||
//compute h_t
|
||||
tanh(cInternal, hInternal);
|
||||
multiply(gateO, hInternal, hInternal);
|
||||
|
||||
//save results in output blobs
|
||||
hInternal.copyTo(hOutTs.rowRange(curRowRange));
|
||||
if (produceCellOutput)
|
||||
cInternal.copyTo(cOutTs.rowRange(curRowRange));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -630,37 +630,44 @@ void ONNXImporter::populateNet(Net dstNet)
|
||||
Mat Wx = getBlob(node_proto, constBlobs, 1);
|
||||
Mat Wh = getBlob(node_proto, constBlobs, 2);
|
||||
Mat b = getBlob(node_proto, constBlobs, 3);
|
||||
b = b.reshape(1, b.size[0]);
|
||||
|
||||
const int numHidden = lstmParams.get<int>("hidden_size");
|
||||
|
||||
Wx = Wx.reshape(1, Wx.size[1]);
|
||||
Wh = Wh.reshape(1, Wh.size[1]);
|
||||
b = b.reshape(1, 2);
|
||||
reduce(b, b, 0, REDUCE_SUM);
|
||||
const int numDirs = Wx.size[0]; // Is 1 for forward only and 2 for bidirectional LSTM.
|
||||
const int numFeatures = Wx.size[2];
|
||||
Mat bx = b.colRange(0, b.cols / 2);
|
||||
Mat bh = b.colRange(b.cols / 2, b.cols);
|
||||
b = bx + bh;
|
||||
|
||||
// IFGO->IGFO
|
||||
float* WxData = (float*)Wx.data;
|
||||
float* WhData = (float*)Wh.data;
|
||||
float* biasData = (float*)b.data;
|
||||
for (int j = 0; j < numHidden; ++j)
|
||||
for (int k = 0; k < numDirs; ++k)
|
||||
{
|
||||
for (int i = 0; i < Wx.cols; ++i)
|
||||
float* WxData = Wx.ptr<float>(k);
|
||||
float* WhData = Wh.ptr<float>(k);
|
||||
float* biasData = b.ptr<float>(k);
|
||||
for (int j = 0; j < numHidden; ++j)
|
||||
{
|
||||
std::swap(WxData[(numHidden + j) * Wx.cols + i],
|
||||
WxData[(numHidden * 2 + j) * Wx.cols + i]);
|
||||
for (int i = 0; i < numFeatures; ++i)
|
||||
{
|
||||
std::swap(WxData[(numHidden + j) * numFeatures + i],
|
||||
WxData[(numHidden * 2 + j) * numFeatures + i]);
|
||||
}
|
||||
for (int i = 0; i < numHidden; ++i)
|
||||
{
|
||||
std::swap(WhData[(numHidden + j) * numHidden + i],
|
||||
WhData[(numHidden * 2 + j) * numHidden + i]);
|
||||
}
|
||||
std::swap(biasData[numHidden + j], biasData[numHidden * 2 + j]);
|
||||
}
|
||||
for (int i = 0; i < Wh.cols; ++i)
|
||||
{
|
||||
std::swap(WhData[(numHidden + j) * Wh.cols + i],
|
||||
WhData[(numHidden * 2 + j) * Wh.cols + i]);
|
||||
}
|
||||
std::swap(biasData[numHidden + j], biasData[numHidden * 2 + j]);
|
||||
}
|
||||
Wx = Wx.reshape(1, Wx.size[0] * Wx.size[1]);
|
||||
Wh = Wh.reshape(1, Wh.size[0] * Wh.size[1]);
|
||||
|
||||
lstmParams.blobs.resize(3);
|
||||
lstmParams.blobs[0] = Wh;
|
||||
lstmParams.blobs[1] = Wx;
|
||||
lstmParams.blobs[2] = b;
|
||||
lstmParams.set("bidirectional", lstmParams.get<String>("direction", "") == "bidirectional");
|
||||
|
||||
node_proto.set_output(0, lstmParams.name); // set different name so output shapes will be registered on that name
|
||||
addLayer(dstNet, lstmParams, node_proto, layer_id, outShapes);
|
||||
|
||||
@@ -456,6 +456,11 @@ TEST_P(Test_ONNX_layers, LSTM)
|
||||
testONNXModels("lstm");
|
||||
}
|
||||
|
||||
TEST_P(Test_ONNX_layers, LSTM_bidirectional)
|
||||
{
|
||||
testONNXModels("lstm_bidirectional");
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(/*nothing*/, Test_ONNX_layers, dnnBackendsAndTargets());
|
||||
|
||||
class Test_ONNX_nets : public Test_ONNX_layers
|
||||
|
||||
Reference in New Issue
Block a user