Merge pull request #21154 from pccvlab:MatMul_with_two_inputs
Add BatchMatMul layer support for tf_importer * two inputs * support batch_matmul * refactor: remove useless code * refactor: decrease nesting
This commit is contained in:
parent
c08954c18b
commit
17bc8565f6
@ -646,7 +646,7 @@ const TFImporter::DispatchMap TFImporter::buildDispatchMap()
|
||||
dispatch["Conv2D"] = dispatch["SpaceToBatchND"] = dispatch["DepthwiseConv2dNative"] =
|
||||
dispatch["Pad"] = dispatch["MirrorPad"] = dispatch["Conv3D"] = &TFImporter::parseConvolution;
|
||||
dispatch["BiasAdd"] = dispatch["Add"] = dispatch["AddV2"] = dispatch["Sub"] = dispatch["AddN"] = &TFImporter::parseBias;
|
||||
dispatch["MatMul"] = &TFImporter::parseMatMul;
|
||||
dispatch["MatMul"] = dispatch["BatchMatMul"] = &TFImporter::parseMatMul;
|
||||
dispatch["Reshape"] = &TFImporter::parseReshape;
|
||||
dispatch["Flatten"] = dispatch["Squeeze"] = &TFImporter::parseFlatten;
|
||||
dispatch["Transpose"] = &TFImporter::parseTranspose;
|
||||
@ -983,6 +983,24 @@ void TFImporter::parseMatMul(tensorflow::GraphDef& net, const tensorflow::NodeDe
|
||||
layerParams.set("bias_term", false);
|
||||
layerParams.blobs.resize(1);
|
||||
|
||||
bool hasConstBlob = false;
|
||||
for(int i = 0; i < layer.input_size(); i++) {
|
||||
if (value_id.find(layer.input(i)) != value_id.end())
|
||||
hasConstBlob = true;
|
||||
}
|
||||
if (!hasConstBlob)
|
||||
{
|
||||
layerParams.blobs.clear();
|
||||
int id = dstNet.addLayer(name, "InnerProduct", layerParams);
|
||||
layer_id[name] = id;
|
||||
|
||||
// two inputs
|
||||
for(int ii=0; ii<layer.input_size(); ii++){
|
||||
connect(layer_id, dstNet, parsePin(layer.input(ii)), id, ii);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
StrIntVector next_layers = getNextLayers(net, name, "BiasAdd"); // FIXIT Use layers fusion instead
|
||||
if (next_layers.empty())
|
||||
{
|
||||
|
||||
@ -660,6 +660,14 @@ TEST_P(Test_TensorFlow_layers, matmul)
|
||||
double l1 = target == DNN_TARGET_MYRIAD ? 6.1e-3 : default_l1;
|
||||
runTensorFlowNet("nhwc_reshape_matmul", false, l1);
|
||||
runTensorFlowNet("matmul_layout");
|
||||
runTensorFlowNet("two_inputs_matmul");
|
||||
}
|
||||
|
||||
TEST_P(Test_TensorFlow_layers, batch_matmul)
|
||||
{
|
||||
if (backend == DNN_BACKEND_OPENCV && target == DNN_TARGET_OPENCL_FP16)
|
||||
applyTestTag(CV_TEST_TAG_DNN_SKIP_OPENCL_FP16);
|
||||
runTensorFlowNet("batch_matmul");
|
||||
}
|
||||
|
||||
TEST_P(Test_TensorFlow_layers, reshape)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user