diff --git a/modules/dnn/src/onnx/onnx_importer.cpp b/modules/dnn/src/onnx/onnx_importer.cpp index 0ca909597e..e08f7b0e11 100644 --- a/modules/dnn/src/onnx/onnx_importer.cpp +++ b/modules/dnn/src/onnx/onnx_importer.cpp @@ -431,9 +431,20 @@ void ONNXImporter::populateNet(Net dstNet) { bool isSub = layer_type == "Sub"; CV_CheckEQ(node_proto.input_size(), 2, ""); - if (layer_id.find(node_proto.input(1)) == layer_id.end()) + bool is_const_0 = layer_id.find(node_proto.input(0)) == layer_id.end(); + bool is_const_1 = layer_id.find(node_proto.input(1)) == layer_id.end(); + if (is_const_0 && is_const_1) { - Mat blob = getBlob(node_proto, constBlobs, 1); + Mat blob_0 = getBlob(node_proto, constBlobs, 0); + Mat blob_1 = getBlob(node_proto, constBlobs, 1); + CV_Assert(blob_0.size == blob_1.size); + Mat output = isSub ? (blob_0 - blob_1) : (blob_0 + blob_1); + constBlobs.insert(std::make_pair(layerParams.name, output)); + continue; + } + else if (is_const_0 || is_const_1) + { + Mat blob = getBlob(node_proto, constBlobs, is_const_0 ? 0 : 1); blob = blob.reshape(1, 1); if (blob.total() == 1) { layerParams.type = "Power"; @@ -808,6 +819,21 @@ void ONNXImporter::populateNet(Net dstNet) layerParams.set("end_axis", axis); layerParams.type = "Flatten"; } + else if (layer_type == "Flatten") + { + CV_CheckEQ(node_proto.input_size(), 1, ""); + if (constBlobs.find(node_proto.input(0)) != constBlobs.end()) + { + Mat input = getBlob(node_proto, constBlobs, 0); + int axis = clamp(layerParams.get("axis", 1), input.dims); + + std::vector out_size(&input.size[0], &input.size[0] + axis); + out_size.push_back(input.total(axis)); + Mat output = input.reshape(1, out_size); + constBlobs.insert(std::make_pair(layerParams.name, output)); + continue; + } + } else if (layer_type == "Unsqueeze") { CV_Assert(node_proto.input_size() == 1); @@ -896,6 +922,31 @@ void ONNXImporter::populateNet(Net dstNet) constBlobs.insert(std::make_pair(layerParams.name, shapeMat)); continue; } + else if (layer_type == "Cast") + { + if (constBlobs.find(node_proto.input(0)) != constBlobs.end()) + { + Mat blob = getBlob(node_proto, constBlobs, 0); + int type; + switch (layerParams.get("to")) + { + case opencv_onnx::TensorProto_DataType_FLOAT: type = CV_32F; break; + case opencv_onnx::TensorProto_DataType_UINT8: type = CV_8U; break; + case opencv_onnx::TensorProto_DataType_UINT16: type = CV_16U; break; + case opencv_onnx::TensorProto_DataType_FLOAT16: type = CV_16S; break; + case opencv_onnx::TensorProto_DataType_INT8: + case opencv_onnx::TensorProto_DataType_INT16: + case opencv_onnx::TensorProto_DataType_INT32: + case opencv_onnx::TensorProto_DataType_INT64: type = CV_32S; break; + default: type = blob.type(); + } + blob.convertTo(blob, type); + constBlobs.insert(std::make_pair(layerParams.name, blob)); + continue; + } + else + layerParams.type = "Identity"; + } else if (layer_type == "Gather") { CV_Assert(node_proto.input_size() == 2); diff --git a/modules/dnn/test/test_onnx_importer.cpp b/modules/dnn/test/test_onnx_importer.cpp index f284eed45b..769862d53d 100644 --- a/modules/dnn/test/test_onnx_importer.cpp +++ b/modules/dnn/test/test_onnx_importer.cpp @@ -187,6 +187,11 @@ TEST_P(Test_ONNX_layers, MaxPooling_Sigmoid) testONNXModels("maxpooling_sigmoid"); } +TEST_P(Test_ONNX_layers, Cast) +{ + testONNXModels("cast"); +} + TEST_P(Test_ONNX_layers, Concatenation) { if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019) @@ -377,6 +382,7 @@ TEST_P(Test_ONNX_layers, DynamicReshape) testONNXModels("dynamic_reshape"); testONNXModels("dynamic_reshape_opset_11"); testONNXModels("flatten_by_prod"); + testONNXModels("flatten_const"); } TEST_P(Test_ONNX_layers, Reshape)