diff --git a/modules/dnn/src/layers/slice_layer.cpp b/modules/dnn/src/layers/slice_layer.cpp index 11dd4ea9e2..051098c744 100644 --- a/modules/dnn/src/layers/slice_layer.cpp +++ b/modules/dnn/src/layers/slice_layer.cpp @@ -174,16 +174,16 @@ public: for (int i = 0; i < outputs.size(); ++i) { CV_Assert(sliceRanges[i].size() <= inpShape.dims()); - // Clamp. - for (int j = 0; j < sliceRanges[i].size(); ++j) - { - sliceRanges[i][j] = clamp(sliceRanges[i][j], inpShape[j]); - } // Fill the rest of ranges. for (int j = sliceRanges[i].size(); j < inpShape.dims(); ++j) { sliceRanges[i].push_back(Range::all()); } + // Clamp. + for (int j = 0; j < sliceRanges[i].size(); ++j) + { + sliceRanges[i][j] = clamp(sliceRanges[i][j], inpShape[j]); + } } } diff --git a/modules/dnn/src/onnx/onnx_importer.cpp b/modules/dnn/src/onnx/onnx_importer.cpp index b7d289d202..fb98778929 100644 --- a/modules/dnn/src/onnx/onnx_importer.cpp +++ b/modules/dnn/src/onnx/onnx_importer.cpp @@ -401,6 +401,47 @@ void ONNXImporter::populateNet(Net dstNet) layerParams.set("pool", layer_type == "GlobalAveragePool" ? "AVE" : "MAX"); layerParams.set("global_pooling", true); } + else if (layer_type == "Slice") + { + if (layerParams.has("steps")) { + DictValue steps = layerParams.get("steps"); + for (int i = 0; i < steps.size(); ++i) { + if (steps.get(i) != 1) + CV_Error(Error::StsNotImplemented, + "Slice layer only supports steps = 1"); + } + } + + int axis = 0; + if (layerParams.has("axes")) { + DictValue axes = layerParams.get("axes"); + for (int i = 1; i < axes.size(); ++i) { + CV_Assert(axes.get(i - 1) == axes.get(i) - 1); + } + axis = axes.get(0); + } + layerParams.set("axis", axis); + + DictValue starts = layerParams.get("starts"); + DictValue ends = layerParams.get("ends"); + CV_Assert(starts.size() == ends.size()); + + std::vector begin; + std::vector end; + if (axis > 0) { + begin.resize(axis, 0); + end.resize(axis, -1); + } + + for (int i = 0; i < starts.size(); ++i) + { + begin.push_back(starts.get(i)); + int finish = ends.get(i); + end.push_back((finish < 0) ? --finish : finish); // numpy doesn't include last dim + } + layerParams.set("begin", DictValue::arrayInt(&begin[0], begin.size())); + layerParams.set("end", DictValue::arrayInt(&end[0], end.size())); + } else if (layer_type == "Add" || layer_type == "Sum") { if (layer_id.find(node_proto.input(1)) == layer_id.end()) diff --git a/modules/dnn/test/test_onnx_importer.cpp b/modules/dnn/test/test_onnx_importer.cpp index 9de4603ce5..6b6affc107 100644 --- a/modules/dnn/test/test_onnx_importer.cpp +++ b/modules/dnn/test/test_onnx_importer.cpp @@ -245,6 +245,11 @@ TEST_P(Test_ONNX_layers, Reshape) testONNXModels("unsqueeze"); } +TEST_P(Test_ONNX_layers, Slice) +{ + testONNXModels("slice"); +} + TEST_P(Test_ONNX_layers, Softmax) { testONNXModels("softmax");