From 9085b933d81ffd12fbdabffe1ffc69a33580e631 Mon Sep 17 00:00:00 2001 From: Zihao Mu Date: Tue, 5 Oct 2021 00:37:38 +0800 Subject: [PATCH] Merge pull request #20702 from zihaomu:tf_expand_dim_layer Add ExpandDims layer of tf_importer.cpp * Add ExpandDims to tf_importer. * add -1 expand test case. * Support different dimensions of input. * Compatible with 5-dimensional NDHWC data * Code align * support 3-dim input. * 3-dim bug fixed. * fixing error of code format. --- modules/dnn/src/tensorflow/tf_importer.cpp | 134 ++++++++++++++++++++- modules/dnn/test/test_tf_importer.cpp | 13 ++ 2 files changed, 144 insertions(+), 3 deletions(-) diff --git a/modules/dnn/src/tensorflow/tf_importer.cpp b/modules/dnn/src/tensorflow/tf_importer.cpp index 813407dbd8..9a9cd4692f 100644 --- a/modules/dnn/src/tensorflow/tf_importer.cpp +++ b/modules/dnn/src/tensorflow/tf_importer.cpp @@ -555,7 +555,7 @@ protected: std::map layer_id; private: - void addPermuteLayer(const int* order, const std::string& permName, Pin& inpId); + void addPermuteLayer(const int* order, const std::string& permName, Pin& inpId, int orderSize = 4); void setPadding(LayerParams &layerParams, const tensorflow::NodeDef &layer, std::string& inputName, float value = 0.); friend class TFLayerHandler; @@ -595,6 +595,7 @@ private: void parseClipByValue (tensorflow::GraphDef& net, const tensorflow::NodeDef& layer, LayerParams& layerParams); void parseLeakyRelu (tensorflow::GraphDef& net, const tensorflow::NodeDef& layer, LayerParams& layerParams); void parseActivation (tensorflow::GraphDef& net, const tensorflow::NodeDef& layer, LayerParams& layerParams); + void parseExpandDims (tensorflow::GraphDef& net, const tensorflow::NodeDef& layer, LayerParams& layerParams); void parseCustomLayer (tensorflow::GraphDef& net, const tensorflow::NodeDef& layer, LayerParams& layerParams); }; @@ -672,6 +673,7 @@ const TFImporter::DispatchMap TFImporter::buildDispatchMap() dispatch["LeakyRelu"] = &TFImporter::parseLeakyRelu; dispatch["Abs"] = dispatch["Tanh"] = dispatch["Sigmoid"] = dispatch["Relu"] = dispatch["Elu"] = dispatch["Exp"] = dispatch["Identity"] = dispatch["Relu6"] = &TFImporter::parseActivation; + dispatch["ExpandDims"] = &TFImporter::parseExpandDims; return dispatch; } @@ -1113,6 +1115,123 @@ void TFImporter::parseReshape(tensorflow::GraphDef& net, const tensorflow::NodeD } } +void TFImporter::parseExpandDims(tensorflow::GraphDef& net, const tensorflow::NodeDef& layer, LayerParams& layerParams) +{ + const std::string& name = layer.name(); + const int num_inputs = layer.input_size(); + + CV_Assert(!netInputShapes.empty()); + + CV_CheckGT(num_inputs, 0, ""); + Pin inpId = parsePin(layer.input(0)); + DataLayout inpLayout = getDataLayout(layer.input(0), data_layouts); + + // Get input shape + std::vector inShape_, outShape_; + int inpIdindex = layer_id.find(inpId.name)->second; + + dstNet.getLayerShapes(netInputShapes, inpIdindex, inShape_, outShape_); + MatShape inpShape = outShape_[0]; + std::vector outShape = inpShape; + + int outShapeSize = outShape.size(); + + CV_Assert(inpShape.size() >= 1); + // 2nd blob is dims tensor + int axis = getConstBlob(layer, value_id, 1).int_val().Get(0); + + // Convert negative numbers to positive numbers, axis can be in range [-(D+1), D]. + if(axis < 0) + { + axis = inpShape.size() + axis + 1; + } + + CV_Assert(0 <= axis && axis <= inpShape.size()); + + // After ExpendDims, 3-dim data will become 4-dim data, and OpenCV retains 4-dim data as NCHW data layout. + // Convert OpenCV's NHC to NCH first. + if(outShapeSize == 3) + { + // If axis equal to outShapeSize, that mean we expand in Channel dimmension, and do not add permuteLayer. + if(axis != outShapeSize) + { + int order[] = {0, 2, 1}; // From OpenCV's NHC to NCH. + addPermuteLayer(order, name + "/nch", inpId, 3); + + std::swap(outShape[1], outShape[2]); + } + axis = (axis != 0)?(axis % outShapeSize + 1):2; + } + + if(inpShape.size() == 4) + { + if(axis == inpShape.size()) + { + int order[] = {0, 2, 3, 1}; // From OpenCV's NCHW to NHWC. + addPermuteLayer(order, name + "/nhwc", inpId); + + // Convert shape From OpenCV's NCHW to NHWC. + if(inpLayout == DATA_LAYOUT_NHWC) + { + std::swap(outShape[1], outShape[2]); + std::swap(outShape[2], outShape[3]); + } + } + if(inpLayout == DATA_LAYOUT_NHWC || inpLayout == DATA_LAYOUT_NCHW) + { + // toNCHW + axis = (axis != 0)?(axis % outShapeSize + 1):0; + } + } + + // After ExpendDims, 5-dim data will become 6-dim data, and OpenCV retains 6-dim data as original data layout. + // Convert OpenCV's NCDHW to NDHWC first. + if (inpShape.size() == 5 && (inpLayout == DATA_LAYOUT_NDHWC || inpLayout == DATA_LAYOUT_UNKNOWN)) + { + int order[] = {0, 2, 3, 4, 1}; // From OpenCV's NCDHW to NDHWC. + addPermuteLayer(order, name + "/ndhwc", inpId, 5); + + // Convert shape From OpenCV's NCDHW to NDHWC. + if(inpLayout == DATA_LAYOUT_NDHWC) + { + std::swap(outShape[1], outShape[2]); + std::swap(outShape[2], outShape[3]); + std::swap(outShape[3], outShape[4]); + } + } + + outShape.insert(outShape.begin() + axis, 1); + outShapeSize += 1; + + // From OpenCV's NCDHW to NDHWC. + if((inpLayout != DATA_LAYOUT_NHWC && inpLayout != DATA_LAYOUT_NCHW) && outShapeSize == 5) + { + for(int i = 1; i < outShapeSize - 1; i++) + { + std::swap(outShape[outShapeSize - i - 1], outShape[outShapeSize - i]); + } + } + + layerParams.set("dim", DictValue::arrayInt(&outShape[0], outShape.size())); + int id = dstNet.addLayer(name, "Reshape", layerParams); + layer_id[name] = id; + + connect(layer_id, dstNet, inpId, id, 0); + + if(outShapeSize == 5) + { + data_layouts[name] = DATA_LAYOUT_NDHWC; + } + else if(outShapeSize == 4) + { + data_layouts[name] = DATA_LAYOUT_NCHW; + } + else + { + data_layouts[name] = inpLayout; + } +} + // "Flatten" "Squeeze" void TFImporter::parseFlatten(tensorflow::GraphDef& net, const tensorflow::NodeDef& layer, LayerParams& layerParams) { @@ -1419,6 +1538,15 @@ void TFImporter::parsePlaceholder(tensorflow::GraphDef& net, const tensorflow::N if (dims[0] == -1) // It's OK to have undetermined batch size dims[0] = 1; } + + if (dims.size() == 5 && predictedLayout == DATA_LAYOUT_NDHWC) + { + std::swap(dims[3], dims[4]); // NDHWC->NDHCW + std::swap(dims[2], dims[3]); // NDHCW->NDCHW + std::swap(dims[1], dims[2]); // NDCHW->NCDHW + if (dims[0] == -1) // It's OK to have undetermined batch size + dims[0] = 1; + } bool hasNeg = false; for (int i = 0; i < dims.size() && !hasNeg; ++i) { @@ -2882,10 +3010,10 @@ void TFImporter::populateNet() CV_LOG_DEBUG(NULL, (DNN_DIAGNOSTICS_RUN? "DNN/TF: diagnostic run completed!" : "DNN/TF: import completed!")); } -void TFImporter::addPermuteLayer(const int* order, const std::string& permName, Pin& inpId) +void TFImporter::addPermuteLayer(const int* order, const std::string& permName, Pin& inpId, int orderSize) { LayerParams permLP; - permLP.set("order", DictValue::arrayInt(order, 4)); + permLP.set("order", DictValue::arrayInt(order, orderSize)); CV_Assert(layer_id.find(permName) == layer_id.end()); int permId = dstNet.addLayer(permName, "Permute", permLP); layer_id[permName] = permId; diff --git a/modules/dnn/test/test_tf_importer.cpp b/modules/dnn/test/test_tf_importer.cpp index eec1a628f1..8e7f969480 100644 --- a/modules/dnn/test/test_tf_importer.cpp +++ b/modules/dnn/test/test_tf_importer.cpp @@ -593,6 +593,19 @@ TEST_P(Test_TensorFlow_layers, BiasAdd) runTensorFlowNet("bias_add_1"); } +TEST_P(Test_TensorFlow_layers, ExpandDims) +{ +#if defined(INF_ENGINE_RELEASE) && INF_ENGINE_VER_MAJOR_GE(2019010000) + if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019 && target == DNN_TARGET_MYRIAD + && getInferenceEngineVPUType() == CV_DNN_INFERENCE_ENGINE_VPU_TYPE_MYRIAD_X + ) + applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_MYRIAD_X, CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER, CV_TEST_TAG_DNN_SKIP_IE_VERSION); +#endif + + runTensorFlowNet("expand_dims_1"); + runTensorFlowNet("expand_dims_2"); +} + // TODO: fix it and add to l2_normalize TEST_P(Test_TensorFlow_layers, l2_normalize_3d) {