From 17bc8565f6139298e6be56e91888937fc2e6d10a Mon Sep 17 00:00:00 2001 From: Gruhuang <56301098+Crayon-new@users.noreply.github.com> Date: Fri, 10 Dec 2021 19:44:27 +0800 Subject: [PATCH] Merge pull request #21154 from pccvlab:MatMul_with_two_inputs Add BatchMatMul layer support for tf_importer * two inputs * support batch_matmul * refactor: remove useless code * refactor: decrease nesting --- modules/dnn/src/tensorflow/tf_importer.cpp | 20 +++++++++++++++++++- modules/dnn/test/test_tf_importer.cpp | 8 ++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/modules/dnn/src/tensorflow/tf_importer.cpp b/modules/dnn/src/tensorflow/tf_importer.cpp index 9fb8f60b41..5fafa2b9d5 100644 --- a/modules/dnn/src/tensorflow/tf_importer.cpp +++ b/modules/dnn/src/tensorflow/tf_importer.cpp @@ -646,7 +646,7 @@ const TFImporter::DispatchMap TFImporter::buildDispatchMap() dispatch["Conv2D"] = dispatch["SpaceToBatchND"] = dispatch["DepthwiseConv2dNative"] = dispatch["Pad"] = dispatch["MirrorPad"] = dispatch["Conv3D"] = &TFImporter::parseConvolution; dispatch["BiasAdd"] = dispatch["Add"] = dispatch["AddV2"] = dispatch["Sub"] = dispatch["AddN"] = &TFImporter::parseBias; - dispatch["MatMul"] = &TFImporter::parseMatMul; + dispatch["MatMul"] = dispatch["BatchMatMul"] = &TFImporter::parseMatMul; dispatch["Reshape"] = &TFImporter::parseReshape; dispatch["Flatten"] = dispatch["Squeeze"] = &TFImporter::parseFlatten; dispatch["Transpose"] = &TFImporter::parseTranspose; @@ -983,6 +983,24 @@ void TFImporter::parseMatMul(tensorflow::GraphDef& net, const tensorflow::NodeDe layerParams.set("bias_term", false); layerParams.blobs.resize(1); + bool hasConstBlob = false; + for(int i = 0; i < layer.input_size(); i++) { + if (value_id.find(layer.input(i)) != value_id.end()) + hasConstBlob = true; + } + if (!hasConstBlob) + { + layerParams.blobs.clear(); + int id = dstNet.addLayer(name, "InnerProduct", layerParams); + layer_id[name] = id; + + // two inputs + for(int ii=0; ii