let MatMul can work when both two inputs are const
This commit is contained in:
parent
bc6544c0bc
commit
5044af69d1
@ -2037,9 +2037,25 @@ void ONNXImporter::parseMatMul(LayerParams& layerParams, const opencv_onnx::Node
|
|||||||
CV_Assert(node_proto.input_size() == 2);
|
CV_Assert(node_proto.input_size() == 2);
|
||||||
layerParams.type = "InnerProduct";
|
layerParams.type = "InnerProduct";
|
||||||
layerParams.set("bias_term", false);
|
layerParams.set("bias_term", false);
|
||||||
CV_Assert(constBlobs.find(node_proto.input(0)) == constBlobs.end());
|
int firstInpDims, secondInpDims;
|
||||||
int firstInpDims = outShapes[node_proto.input(0)].size();
|
|
||||||
int secondInpDims;
|
if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
|
||||||
|
{
|
||||||
|
Mat blob = getBlob(node_proto, 0);
|
||||||
|
firstInpDims = blob.dims;
|
||||||
|
LayerParams constParams;
|
||||||
|
constParams.name = layerParams.name + "/const_0";
|
||||||
|
constParams.type = "Const";
|
||||||
|
constParams.blobs.push_back(blob);
|
||||||
|
|
||||||
|
opencv_onnx::NodeProto tmpProto;
|
||||||
|
tmpProto.add_output(constParams.name);
|
||||||
|
addLayer(constParams, tmpProto);
|
||||||
|
|
||||||
|
node_proto.set_input(0, constParams.name);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
firstInpDims = outShapes[node_proto.input(0)].size();
|
||||||
|
|
||||||
if (constBlobs.find(node_proto.input(1)) != constBlobs.end())
|
if (constBlobs.find(node_proto.input(1)) != constBlobs.end())
|
||||||
{
|
{
|
||||||
@ -2053,7 +2069,7 @@ void ONNXImporter::parseMatMul(LayerParams& layerParams, const opencv_onnx::Node
|
|||||||
else
|
else
|
||||||
{
|
{
|
||||||
LayerParams constParams;
|
LayerParams constParams;
|
||||||
constParams.name = layerParams.name + "/const";
|
constParams.name = layerParams.name + "/const_1";
|
||||||
constParams.type = "Const";
|
constParams.type = "Const";
|
||||||
constParams.blobs.push_back(blob);
|
constParams.blobs.push_back(blob);
|
||||||
|
|
||||||
@ -2063,9 +2079,10 @@ void ONNXImporter::parseMatMul(LayerParams& layerParams, const opencv_onnx::Node
|
|||||||
|
|
||||||
node_proto.set_input(1, constParams.name);
|
node_proto.set_input(1, constParams.name);
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
secondInpDims = outShapes[node_proto.input(1)].size();
|
|
||||||
}
|
}
|
||||||
|
else
|
||||||
|
secondInpDims = outShapes[node_proto.input(1)].size();
|
||||||
|
|
||||||
layerParams.set("axis", firstInpDims - secondInpDims + 1);
|
layerParams.set("axis", firstInpDims - secondInpDims + 1);
|
||||||
addLayer(layerParams, node_proto);
|
addLayer(layerParams, node_proto);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -961,6 +961,8 @@ TEST_P(Test_ONNX_layers, MatMul_init)
|
|||||||
testONNXModels("matmul_2d_init");
|
testONNXModels("matmul_2d_init");
|
||||||
testONNXModels("matmul_3d_init");
|
testONNXModels("matmul_3d_init");
|
||||||
testONNXModels("matmul_4d_init");
|
testONNXModels("matmul_4d_init");
|
||||||
|
|
||||||
|
testONNXModels("matmul_init_2");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(Test_ONNX_layers, MatMulAdd)
|
TEST_P(Test_ONNX_layers, MatMulAdd)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user