enhance slice layer

refactor the code for parsing Slice layer
add test for Slice layer
let 'begin' and 'end' resize to dims
add opset message comment
This commit is contained in:
zoom 2022-09-22 14:40:39 +08:00
parent 04ebedb6f0
commit 4557971481
2 changed files with 83 additions and 49 deletions

View File

@ -1326,72 +1326,59 @@ void ONNXImporter::parseReduce(LayerParams& layerParams, const opencv_onnx::Node
void ONNXImporter::parseSlice(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
int axis = 0;
std::vector<int> begin;
std::vector<int> end;
MatShape inpShape = outShapes[node_proto.input(0)];
int dims = inpShape.size();
std::vector<int> begin(dims, 0);
std::vector<int> end(dims, INT_MAX);
std::vector<int> steps;
int inp_size = node_proto.input_size();
int axis = 0;
bool has_axes = false;
DictValue starts_, ends_, axes_, steps_;
// opset = 1
if (inp_size == 1)
{
if (layerParams.has("axes")) {
DictValue axes = layerParams.get("axes");
for (int i = 1; i < axes.size(); ++i) {
CV_Assert(axes.get<int>(i - 1) == axes.get<int>(i) - 1);
}
axis = axes.get<int>(0);
}
DictValue starts = layerParams.get("starts");
DictValue ends = layerParams.get("ends");
CV_Assert(starts.size() == ends.size());
if (axis > 0) {
CV_CheckLE(axis, 1024, "Slice layer can't have more than 1024 axes"); // arbitrary limit
begin.resize(axis, 0);
end.resize(axis, INT_MAX);
}
for (int i = 0; i < starts.size(); ++i)
starts_ = layerParams.get("starts");
ends_ = layerParams.get("ends");
CV_Assert(starts_.size() == ends_.size());
if (layerParams.has("axes"))
{
begin.push_back(starts.get<int>(i));
end.push_back(ends.get<int>(i));
axes_ = layerParams.get("axes");
CV_Assert(axes_.size() == starts_.size());
axis = axes_.getIntValue(0) < 0 ? axes_.getIntValue(0) + dims : axes_.getIntValue(0);
has_axes = true;
}
} else { // inp_size > 1
}
// opset > 1
else
{
CV_Assert(inp_size >= 3);
for (int i = 1; i < inp_size; i++) {
for (int i = 1; i < inp_size; ++i)
{
CV_Assert(constBlobs.find(node_proto.input(i)) != constBlobs.end());
}
Mat start_blob = getBlob(node_proto, 1);
Mat end_blob = getBlob(node_proto, 2);
Mat end_blob = getBlob(node_proto, 2);
CV_Assert(start_blob.total() == end_blob.total());
starts_ = DictValue::arrayInt(start_blob.begin<int>(), start_blob.total());
ends_ = DictValue::arrayInt(end_blob.begin<int>(), end_blob.total());
if (inp_size > 3) {
if (inp_size > 3)
{
Mat axes_blob = getBlob(node_proto, 3);
const int* axes = (int*)axes_blob.data;
for (int i = 1; i < axes_blob.total(); ++i) {
CV_Assert(axes[i - 1] == axes[i] - 1);
}
axis = axes[0];
CV_Assert(axes_blob.total() == start_blob.total());
axes_ = DictValue::arrayInt(axes_blob.begin<int>(), axes_blob.total());
axis = axes_.getIntValue(0) < 0 ? axes_.getIntValue(0) + dims : axes_.getIntValue(0);
has_axes = true;
}
const int* starts = start_blob.ptr<int>();
const int* ends = end_blob.ptr<int>();
if (axis > 0) {
begin.resize(axis, 0);
end.resize(axis, INT_MAX);
}
std::copy(starts, starts + start_blob.total(), std::back_inserter(begin));
std::copy(ends, ends + end_blob.total(), std::back_inserter(end));
if (inp_size == 5) {
CV_Assert(constBlobs.find(node_proto.input(4)) != constBlobs.end());
if (inp_size == 5)
{
Mat step_blob = getBlob(node_proto, 4);
const int* steps_ptr = step_blob.ptr<int>();
if (axis > 0)
steps.resize(axis, 1);
std::copy(steps_ptr, steps_ptr + step_blob.total(), std::back_inserter(steps));
CV_Assert(step_blob.total() == start_blob.total());
steps_ = DictValue::arrayInt(step_blob.begin<int>(), step_blob.total());
steps.resize(dims, 1);
// Very strange application for Slice op with tensor reversing.
// We just workaround it for 2d constants.
@ -1411,12 +1398,45 @@ void ONNXImporter::parseSlice(LayerParams& layerParams, const opencv_onnx::NodeP
}
}
}
if (!has_axes)
{
// make a default axes [0, 1, 2...]
Mat axes_tmp(1, starts_.size(), CV_32S);
std::iota(axes_tmp.begin<int>(), axes_tmp.end<int>(), 0);
axes_ = DictValue::arrayInt(axes_tmp.begin<int>(), axes_tmp.total());
}
int cur_axe;
std::vector<bool> flag(dims, false);
Mat axes(1, starts_.size(), CV_32S);
auto axes_ptr = axes.ptr<int>();
// resize begin and end
for (int i = 0; i < axes_.size(); ++i)
{
// dims should be added to the negative axes
cur_axe = axes_.getIntValue(i) < 0 ? axes_.getIntValue(i) + dims : axes_.getIntValue(i);
CV_CheckGE(cur_axe, 0, "Axes should be grater or equal to '-dims'.");
CV_CheckLT(cur_axe, dims, "Axes should be less than 'dim'.");
CV_CheckEQ(flag[cur_axe], false, "Axes shouldn't have duplicated values.");
flag[cur_axe] = true;
// change axis to the minimum axe
if (cur_axe < axis) axis = cur_axe;
axes_ptr[i] = cur_axe;
begin[cur_axe] = starts_.getIntValue(i);
end[cur_axe] = ends_.getIntValue(i);
}
layerParams.set("begin", DictValue::arrayInt(&begin[0], begin.size()));
layerParams.set("end", DictValue::arrayInt(&end[0], end.size()));
layerParams.set("axis", axis);
if (!steps.empty())
{
for (int i = 0; i < axes.total(); ++i)
steps[axes_ptr[i]] = steps_.getIntValue(i);
layerParams.set("steps", DictValue::arrayInt(&steps[0], steps.size()));
}
if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
{

View File

@ -1172,6 +1172,20 @@ TEST_P(Test_ONNX_layers, Slice_Steps_5DInput)
testONNXModels("slice_opset_11_steps_5d");
}
TEST_P(Test_ONNX_layers, Slice_Nonseq_Axes)
{
testONNXModels("slice_nonseq_axes");
testONNXModels("slice_nonseq_axes_steps");
testONNXModels("slice_nonseq_miss_axes_steps");
}
TEST_P(Test_ONNX_layers, Slice_Neg_Axes)
{
testONNXModels("slice_neg_axes");
testONNXModels("slice_neg_axes_steps");
testONNXModels("slice_neg_miss_axes_steps");
}
TEST_P(Test_ONNX_layers, Softmax)
{
testONNXModels("softmax");