diff --git a/modules/dnn/src/layers/fully_connected_layer.cpp b/modules/dnn/src/layers/fully_connected_layer.cpp index 03349253c0..4746403504 100644 --- a/modules/dnn/src/layers/fully_connected_layer.cpp +++ b/modules/dnn/src/layers/fully_connected_layer.cpp @@ -116,7 +116,6 @@ public: CV_CheckEQ(inputs.size(), (size_t)2, ""); numOutput = inputs[1].back(); cAxis = inputs[0].size() - 1; - CV_CheckEQ(numOutput, inputs[0][cAxis - 1], ""); int dims = inputs[0].size(); CV_CheckEQ(inputs[1].size(), (size_t)dims, ""); CV_CheckGE(dims, 2, ""); diff --git a/modules/dnn/src/onnx/onnx_graph_simplifier.cpp b/modules/dnn/src/onnx/onnx_graph_simplifier.cpp index 61ef8b7da6..e8b237cab4 100644 --- a/modules/dnn/src/onnx/onnx_graph_simplifier.cpp +++ b/modules/dnn/src/onnx/onnx_graph_simplifier.cpp @@ -262,6 +262,24 @@ public: } }; +class ExpandSubgraph : public Subgraph +{ +public: + ExpandSubgraph() + { + int input = addNodeToMatch(""); + int values = addNodeToMatch(""); + int init = addNodeToMatch("ConstantOfShape", values); + int coeff = addNodeToMatch("Constant"); + int mul = addNodeToMatch("Mul", init, coeff); + int shape = addNodeToMatch("Constant"); + int condition = addNodeToMatch("Equal", shape, mul); + int where = addNodeToMatch("Where", condition, init, addNodeToMatch("Constant")); + addNodeToMatch("Expand", input, where); + setFusedNode("Expand", input, shape); + } +}; + class MulCastSubgraph : public Subgraph { public: @@ -459,6 +477,7 @@ void simplifySubgraphs(opencv_onnx::GraphProto& net) subgraphs.push_back(makePtr()); subgraphs.push_back(makePtr()); subgraphs.push_back(makePtr()); + subgraphs.push_back(makePtr()); simplifySubgraphs(Ptr(new ONNXGraphWrapper(net)), subgraphs); } diff --git a/modules/dnn/src/onnx/onnx_importer.cpp b/modules/dnn/src/onnx/onnx_importer.cpp index e65c7ac3e9..7d37b065ab 100644 --- a/modules/dnn/src/onnx/onnx_importer.cpp +++ b/modules/dnn/src/onnx/onnx_importer.cpp @@ -387,26 +387,42 @@ void ONNXImporter::populateNet(Net dstNet) layerParams.set("ceil_mode", layerParams.has("pad_mode")); layerParams.set("ave_pool_padded_area", framework_name == "pytorch"); } - else if (layer_type == "GlobalAveragePool" || layer_type == "GlobalMaxPool" || layer_type == "ReduceMean") + else if (layer_type == "GlobalAveragePool" || layer_type == "GlobalMaxPool" || + layer_type == "ReduceMean" || layer_type == "ReduceSum") { CV_Assert(node_proto.input_size() == 1); layerParams.type = "Pooling"; - layerParams.set("pool", layer_type == "GlobalMaxPool"? "MAX" : "AVE"); + String pool; + if (layer_type == "GlobalMaxPool") + pool = "MAX"; + else if (layer_type == "ReduceSum") + pool = "SUM"; + else + pool = "AVE"; + layerParams.set("pool", pool); layerParams.set("global_pooling", layer_type == "GlobalAveragePool" || layer_type == "GlobalMaxPool"); - - if (layer_type == "ReduceMean") + if (layer_type == "ReduceMean" || layer_type == "ReduceSum") { - if (layerParams.get("keepdims") == 0 || !layerParams.has("axes")) - CV_Error(Error::StsNotImplemented, "Unsupported mode of ReduceMean operation."); + if (!layerParams.has("axes")) + CV_Error(Error::StsNotImplemented, "Unsupported mode of " + layer_type + " operation."); MatShape inpShape = outShapes[node_proto.input(0)]; DictValue axes = layerParams.get("axes"); + bool keepdims = layerParams.get("keepdims"); + MatShape targetShape = inpShape; + for (int i = 0; i < axes.size(); i++) { + int axis = clamp(axes.get(i), inpShape.size()); + if (keepdims) { + targetShape[axis] = 1; + } else { + targetShape.erase(targetShape.begin() + axis); + } + } + if (inpShape.size() == 3 && axes.size() <= 2) { - int axis = axes.get(0); + int axis = clamp(axes.get(0), inpShape.size()); CV_CheckNE(axis, 0, ""); - outShapes[layerParams.name] = inpShape; - outShapes[layerParams.name][axis] = 1; LayerParams reshapeLp; reshapeLp.name = layerParams.name + "/reshape"; @@ -426,13 +442,12 @@ void ONNXImporter::populateNet(Net dstNet) avgLp.name = layerParams.name + "/avg"; avgLp.type = "Pooling"; CV_Assert(layer_id.find(avgLp.name) == layer_id.end()); - avgLp.set("pool", "ave"); + avgLp.set("pool", pool); if (axes.size() == 2) { - CV_CheckEQ(axes.get(0), 1, "Unsupported ReduceMean mode"); - CV_CheckEQ(axes.get(1), 2, "Unsupported ReduceMean mode"); + CV_CheckEQ(clamp(axes.get(0), inpShape.size()), 1, ("Unsupported " + layer_type + " mode").c_str()); + CV_CheckEQ(clamp(axes.get(1), inpShape.size()), 2, ("Unsupported " + layer_type + " mode").c_str()); avgLp.set("global_pooling", true); - outShapes[layerParams.name][axes.get(1)] = 1; } else { @@ -443,28 +458,33 @@ void ONNXImporter::populateNet(Net dstNet) node_proto.set_input(0, reshapeLp.name); node_proto.set_output(0, avgLp.name); addLayer(dstNet, avgLp, node_proto, layer_id, outShapes); - - layerParams.type = "Flatten"; - layerParams.set("axis", 0); - layerParams.set("end_axis", 1); - - node_proto.set_input(0, avgLp.name); - node_proto.set_output(0, layerParams.name); } else { if (inpShape.size() != 4 && inpShape.size() != 5) - CV_Error(Error::StsNotImplemented, "Unsupported input shape of reduce_mean operation."); + CV_Error(Error::StsNotImplemented, "Unsupported input shape of " + layer_type + " operation."); CV_Assert(axes.size() <= inpShape.size() - 2); std::vector kernel_size(inpShape.size() - 2, 1); for (int i = 0; i < axes.size(); i++) { - int axis = axes.get(i); + int axis = clamp(axes.get(i), inpShape.size()); CV_Assert_N(axis >= 2 + i, axis < inpShape.size()); kernel_size[axis - 2] = inpShape[axis]; } - layerParams.set("kernel_size", DictValue::arrayInt(&kernel_size[0], kernel_size.size())); + LayerParams poolLp = layerParams; + poolLp.name = layerParams.name + "/avg"; + CV_Assert(layer_id.find(poolLp.name) == layer_id.end()); + poolLp.set("kernel_size", DictValue::arrayInt(&kernel_size[0], kernel_size.size())); + + node_proto.set_output(0, poolLp.name); + addLayer(dstNet, poolLp, node_proto, layer_id, outShapes); } + + layerParams.type = "Reshape"; + layerParams.set("dim", DictValue::arrayInt(&targetShape[0], targetShape.size())); + + node_proto.set_input(0, node_proto.output(0)); + node_proto.set_output(0, layerParams.name); } } else if (layer_type == "Slice") @@ -1001,15 +1021,10 @@ void ONNXImporter::populateNet(Net dstNet) { Mat inp0 = getBlob(node_proto, constBlobs, 0); Mat inp1 = getBlob(node_proto, constBlobs, 1); - if (inp0.size != inp1.size) + if (inp0.size != inp1.size && inp1.total() != 1) CV_Error(Error::StsNotImplemented, "Constant multiply with different shapes"); - Mat out; - if (isDiv) - divide(inp0, inp1, out); - else - multiply(inp0, inp1, out); - + Mat out = isDiv ? inp0 / inp1 : inp0.mul(inp1); out = out.reshape(1, inp0.dims, inp0.size); out.dims = inp0.dims; // to workaround dims == 1 addConstant(layerParams.name, out, constBlobs, outShapes); @@ -1180,9 +1195,45 @@ void ONNXImporter::populateNet(Net dstNet) Mat newShapeMat = getBlob(node_proto, constBlobs, 1); MatShape targetShape(newShapeMat.ptr(), newShapeMat.ptr() + newShapeMat.total()); - shapeIt = outShapes.find(node_proto.input(0)); - CV_Assert(shapeIt != outShapes.end()); - MatShape inpShape = shapeIt->second; + MatShape inpShape; + bool haveVariables = constBlobs.find(node_proto.input(0)) == constBlobs.end(); + if (haveVariables) + { + shapeIt = outShapes.find(node_proto.input(0)); + CV_Assert(shapeIt != outShapes.end()); + inpShape = shapeIt->second; + } + else + { + inpShape = shape(getBlob(node_proto, constBlobs, 0)); + } + + String srcName = node_proto.input(0); + // Unsqueeze and repeat along new axis + if (targetShape.size() == inpShape.size() + 1) + { + for (int i = 0; i < targetShape.size(); i++) + { + if (targetShape[i] == -1 && i < inpShape.size()) + targetShape[i] = inpShape[i]; + else if (i < inpShape.size() && targetShape[i] != inpShape[i]) + inpShape.insert(inpShape.begin() + i, 1); + } + if (haveVariables) + { + LayerParams reshapeLp; + reshapeLp.name = layerParams.name + "/reshape"; + reshapeLp.type = "Reshape"; + CV_Assert(layer_id.find(reshapeLp.name) == layer_id.end()); + reshapeLp.set("dim", DictValue::arrayInt(&inpShape[0], inpShape.size())); + + opencv_onnx::NodeProto proto; + proto.add_input(node_proto.input(0)); + proto.add_output(reshapeLp.name); + addLayer(dstNet, reshapeLp, proto, layer_id, outShapes); + srcName = reshapeLp.name; + } + } CV_CheckEQ(inpShape.size(), targetShape.size(), "Unsupported Expand op with different dims"); std::vector broadcast_axes; @@ -1197,6 +1248,19 @@ void ONNXImporter::populateNet(Net dstNet) } } + if (!haveVariables) + { + if (broadcast_axes.size() != 1) + CV_Error(Error::StsNotImplemented, "Expand op doesn't support multiple axes for constant input"); + + Mat input = getBlob(node_proto, constBlobs, 0); + input = input.reshape(0, total(inpShape, 0, broadcast_axes[0])); + Mat output = cv::repeat(input, 1, targetShape[broadcast_axes[0]]); + output = output.reshape(0, targetShape); + addConstant(layerParams.name, output, constBlobs, outShapes); + continue; + } + if (broadcast_axes.size() == 2 && broadcast_axes[0] == broadcast_axes[1] - 1 && broadcast_axes[1] == inpShape.size() - 1) { @@ -1231,6 +1295,7 @@ void ONNXImporter::populateNet(Net dstNet) CV_Assert(layer_id.find(copyLP.name) == layer_id.end()); input_names.push_back(copyLP.name); + node_proto.set_input(0, srcName); node_proto.set_output(0, copyLP.name); addLayer(dstNet, copyLP, node_proto, layer_id, outShapes); } @@ -1241,6 +1306,7 @@ void ONNXImporter::populateNet(Net dstNet) } layerParams.set("axis", broadcast_axes[0]); layerParams.type = "Concat"; + node_proto.set_output(0, layerParams.name); } else CV_Error(Error::StsNotImplemented, "Unsupported Expand op"); diff --git a/modules/dnn/test/test_onnx_importer.cpp b/modules/dnn/test/test_onnx_importer.cpp index a317be71fb..25efcbb3ca 100644 --- a/modules/dnn/test/test_onnx_importer.cpp +++ b/modules/dnn/test/test_onnx_importer.cpp @@ -257,6 +257,11 @@ TEST_P(Test_ONNX_layers, ReduceMean) testONNXModels("reduce_mean_axis2"); } +TEST_P(Test_ONNX_layers, ReduceSum) +{ + testONNXModels("reduce_sum"); +} + TEST_P(Test_ONNX_layers, ReduceMean3D) { if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019 && target != DNN_TARGET_CPU) @@ -417,6 +422,7 @@ TEST_P(Test_ONNX_layers, Expand) { testONNXModels("expand_batch"); testONNXModels("expand_channels"); + testONNXModels("expand_neg_batch"); } TEST_P(Test_ONNX_layers, ExpandHW)