Merge pull request #16735 from l-bat:flatten_const_onnx
* Supported Flatten for constant nodes * Added default axis * Refactoring * Refactoring * Added cast layer * Fix comments * Add Cast for layers
This commit is contained in:
parent
0e6ce50131
commit
2645ee90ca
@ -431,9 +431,20 @@ void ONNXImporter::populateNet(Net dstNet)
|
||||
{
|
||||
bool isSub = layer_type == "Sub";
|
||||
CV_CheckEQ(node_proto.input_size(), 2, "");
|
||||
if (layer_id.find(node_proto.input(1)) == layer_id.end())
|
||||
bool is_const_0 = layer_id.find(node_proto.input(0)) == layer_id.end();
|
||||
bool is_const_1 = layer_id.find(node_proto.input(1)) == layer_id.end();
|
||||
if (is_const_0 && is_const_1)
|
||||
{
|
||||
Mat blob = getBlob(node_proto, constBlobs, 1);
|
||||
Mat blob_0 = getBlob(node_proto, constBlobs, 0);
|
||||
Mat blob_1 = getBlob(node_proto, constBlobs, 1);
|
||||
CV_Assert(blob_0.size == blob_1.size);
|
||||
Mat output = isSub ? (blob_0 - blob_1) : (blob_0 + blob_1);
|
||||
constBlobs.insert(std::make_pair(layerParams.name, output));
|
||||
continue;
|
||||
}
|
||||
else if (is_const_0 || is_const_1)
|
||||
{
|
||||
Mat blob = getBlob(node_proto, constBlobs, is_const_0 ? 0 : 1);
|
||||
blob = blob.reshape(1, 1);
|
||||
if (blob.total() == 1) {
|
||||
layerParams.type = "Power";
|
||||
@ -808,6 +819,21 @@ void ONNXImporter::populateNet(Net dstNet)
|
||||
layerParams.set("end_axis", axis);
|
||||
layerParams.type = "Flatten";
|
||||
}
|
||||
else if (layer_type == "Flatten")
|
||||
{
|
||||
CV_CheckEQ(node_proto.input_size(), 1, "");
|
||||
if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
|
||||
{
|
||||
Mat input = getBlob(node_proto, constBlobs, 0);
|
||||
int axis = clamp(layerParams.get<int>("axis", 1), input.dims);
|
||||
|
||||
std::vector<int> out_size(&input.size[0], &input.size[0] + axis);
|
||||
out_size.push_back(input.total(axis));
|
||||
Mat output = input.reshape(1, out_size);
|
||||
constBlobs.insert(std::make_pair(layerParams.name, output));
|
||||
continue;
|
||||
}
|
||||
}
|
||||
else if (layer_type == "Unsqueeze")
|
||||
{
|
||||
CV_Assert(node_proto.input_size() == 1);
|
||||
@ -896,6 +922,31 @@ void ONNXImporter::populateNet(Net dstNet)
|
||||
constBlobs.insert(std::make_pair(layerParams.name, shapeMat));
|
||||
continue;
|
||||
}
|
||||
else if (layer_type == "Cast")
|
||||
{
|
||||
if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
|
||||
{
|
||||
Mat blob = getBlob(node_proto, constBlobs, 0);
|
||||
int type;
|
||||
switch (layerParams.get<int>("to"))
|
||||
{
|
||||
case opencv_onnx::TensorProto_DataType_FLOAT: type = CV_32F; break;
|
||||
case opencv_onnx::TensorProto_DataType_UINT8: type = CV_8U; break;
|
||||
case opencv_onnx::TensorProto_DataType_UINT16: type = CV_16U; break;
|
||||
case opencv_onnx::TensorProto_DataType_FLOAT16: type = CV_16S; break;
|
||||
case opencv_onnx::TensorProto_DataType_INT8:
|
||||
case opencv_onnx::TensorProto_DataType_INT16:
|
||||
case opencv_onnx::TensorProto_DataType_INT32:
|
||||
case opencv_onnx::TensorProto_DataType_INT64: type = CV_32S; break;
|
||||
default: type = blob.type();
|
||||
}
|
||||
blob.convertTo(blob, type);
|
||||
constBlobs.insert(std::make_pair(layerParams.name, blob));
|
||||
continue;
|
||||
}
|
||||
else
|
||||
layerParams.type = "Identity";
|
||||
}
|
||||
else if (layer_type == "Gather")
|
||||
{
|
||||
CV_Assert(node_proto.input_size() == 2);
|
||||
|
||||
@ -187,6 +187,11 @@ TEST_P(Test_ONNX_layers, MaxPooling_Sigmoid)
|
||||
testONNXModels("maxpooling_sigmoid");
|
||||
}
|
||||
|
||||
TEST_P(Test_ONNX_layers, Cast)
|
||||
{
|
||||
testONNXModels("cast");
|
||||
}
|
||||
|
||||
TEST_P(Test_ONNX_layers, Concatenation)
|
||||
{
|
||||
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019)
|
||||
@ -377,6 +382,7 @@ TEST_P(Test_ONNX_layers, DynamicReshape)
|
||||
testONNXModels("dynamic_reshape");
|
||||
testONNXModels("dynamic_reshape_opset_11");
|
||||
testONNXModels("flatten_by_prod");
|
||||
testONNXModels("flatten_const");
|
||||
}
|
||||
|
||||
TEST_P(Test_ONNX_layers, Reshape)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user