Merge pull request #22554 from WanliZhong:slice_axes_no_seq
DNN: Let Slice layer support non-sequential and negative axes
This commit is contained in:
commit
96844b0ca5
@ -1299,72 +1299,59 @@ void ONNXImporter::parseReduce(LayerParams& layerParams, const opencv_onnx::Node
|
||||
|
||||
void ONNXImporter::parseSlice(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
|
||||
{
|
||||
int axis = 0;
|
||||
std::vector<int> begin;
|
||||
std::vector<int> end;
|
||||
MatShape inpShape = outShapes[node_proto.input(0)];
|
||||
int dims = inpShape.size();
|
||||
std::vector<int> begin(dims, 0);
|
||||
std::vector<int> end(dims, INT_MAX);
|
||||
std::vector<int> steps;
|
||||
int inp_size = node_proto.input_size();
|
||||
int axis = 0;
|
||||
bool has_axes = false;
|
||||
DictValue starts_, ends_, axes_, steps_;
|
||||
|
||||
// opset = 1
|
||||
if (inp_size == 1)
|
||||
{
|
||||
if (layerParams.has("axes")) {
|
||||
DictValue axes = layerParams.get("axes");
|
||||
for (int i = 1; i < axes.size(); ++i) {
|
||||
CV_Assert(axes.get<int>(i - 1) == axes.get<int>(i) - 1);
|
||||
}
|
||||
axis = axes.get<int>(0);
|
||||
}
|
||||
|
||||
DictValue starts = layerParams.get("starts");
|
||||
DictValue ends = layerParams.get("ends");
|
||||
CV_Assert(starts.size() == ends.size());
|
||||
|
||||
if (axis > 0) {
|
||||
CV_CheckLE(axis, 1024, "Slice layer can't have more than 1024 axes"); // arbitrary limit
|
||||
begin.resize(axis, 0);
|
||||
end.resize(axis, INT_MAX);
|
||||
}
|
||||
for (int i = 0; i < starts.size(); ++i)
|
||||
starts_ = layerParams.get("starts");
|
||||
ends_ = layerParams.get("ends");
|
||||
CV_Assert(starts_.size() == ends_.size());
|
||||
if (layerParams.has("axes"))
|
||||
{
|
||||
begin.push_back(starts.get<int>(i));
|
||||
end.push_back(ends.get<int>(i));
|
||||
axes_ = layerParams.get("axes");
|
||||
CV_Assert(axes_.size() == starts_.size());
|
||||
axis = axes_.getIntValue(0) < 0 ? axes_.getIntValue(0) + dims : axes_.getIntValue(0);
|
||||
has_axes = true;
|
||||
}
|
||||
} else { // inp_size > 1
|
||||
}
|
||||
// opset > 1
|
||||
else
|
||||
{
|
||||
CV_Assert(inp_size >= 3);
|
||||
for (int i = 1; i < inp_size; i++) {
|
||||
for (int i = 1; i < inp_size; ++i)
|
||||
{
|
||||
CV_Assert(constBlobs.find(node_proto.input(i)) != constBlobs.end());
|
||||
}
|
||||
Mat start_blob = getBlob(node_proto, 1);
|
||||
Mat end_blob = getBlob(node_proto, 2);
|
||||
Mat end_blob = getBlob(node_proto, 2);
|
||||
CV_Assert(start_blob.total() == end_blob.total());
|
||||
starts_ = DictValue::arrayInt(start_blob.begin<int>(), start_blob.total());
|
||||
ends_ = DictValue::arrayInt(end_blob.begin<int>(), end_blob.total());
|
||||
|
||||
if (inp_size > 3) {
|
||||
if (inp_size > 3)
|
||||
{
|
||||
Mat axes_blob = getBlob(node_proto, 3);
|
||||
const int* axes = (int*)axes_blob.data;
|
||||
for (int i = 1; i < axes_blob.total(); ++i) {
|
||||
CV_Assert(axes[i - 1] == axes[i] - 1);
|
||||
}
|
||||
axis = axes[0];
|
||||
CV_Assert(axes_blob.total() == start_blob.total());
|
||||
axes_ = DictValue::arrayInt(axes_blob.begin<int>(), axes_blob.total());
|
||||
axis = axes_.getIntValue(0) < 0 ? axes_.getIntValue(0) + dims : axes_.getIntValue(0);
|
||||
has_axes = true;
|
||||
}
|
||||
|
||||
const int* starts = start_blob.ptr<int>();
|
||||
const int* ends = end_blob.ptr<int>();
|
||||
if (axis > 0) {
|
||||
begin.resize(axis, 0);
|
||||
end.resize(axis, INT_MAX);
|
||||
}
|
||||
std::copy(starts, starts + start_blob.total(), std::back_inserter(begin));
|
||||
std::copy(ends, ends + end_blob.total(), std::back_inserter(end));
|
||||
|
||||
if (inp_size == 5) {
|
||||
CV_Assert(constBlobs.find(node_proto.input(4)) != constBlobs.end());
|
||||
if (inp_size == 5)
|
||||
{
|
||||
Mat step_blob = getBlob(node_proto, 4);
|
||||
const int* steps_ptr = step_blob.ptr<int>();
|
||||
|
||||
if (axis > 0)
|
||||
steps.resize(axis, 1);
|
||||
|
||||
std::copy(steps_ptr, steps_ptr + step_blob.total(), std::back_inserter(steps));
|
||||
CV_Assert(step_blob.total() == start_blob.total());
|
||||
steps_ = DictValue::arrayInt(step_blob.begin<int>(), step_blob.total());
|
||||
steps.resize(dims, 1);
|
||||
|
||||
// Very strange application for Slice op with tensor reversing.
|
||||
// We just workaround it for 2d constants.
|
||||
@ -1384,12 +1371,45 @@ void ONNXImporter::parseSlice(LayerParams& layerParams, const opencv_onnx::NodeP
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!has_axes)
|
||||
{
|
||||
// make a default axes [0, 1, 2...]
|
||||
Mat axes_tmp(1, starts_.size(), CV_32S);
|
||||
std::iota(axes_tmp.begin<int>(), axes_tmp.end<int>(), 0);
|
||||
axes_ = DictValue::arrayInt(axes_tmp.begin<int>(), axes_tmp.total());
|
||||
}
|
||||
|
||||
int cur_axe;
|
||||
std::vector<bool> flag(dims, false);
|
||||
Mat axes(1, starts_.size(), CV_32S);
|
||||
auto axes_ptr = axes.ptr<int>();
|
||||
// resize begin and end
|
||||
for (int i = 0; i < axes_.size(); ++i)
|
||||
{
|
||||
// dims should be added to the negative axes
|
||||
cur_axe = axes_.getIntValue(i) < 0 ? axes_.getIntValue(i) + dims : axes_.getIntValue(i);
|
||||
CV_CheckGE(cur_axe, 0, "Axes should be grater or equal to '-dims'.");
|
||||
CV_CheckLT(cur_axe, dims, "Axes should be less than 'dim'.");
|
||||
CV_CheckEQ(flag[cur_axe], false, "Axes shouldn't have duplicated values.");
|
||||
flag[cur_axe] = true;
|
||||
// change axis to the minimum axe
|
||||
if (cur_axe < axis) axis = cur_axe;
|
||||
axes_ptr[i] = cur_axe;
|
||||
begin[cur_axe] = starts_.getIntValue(i);
|
||||
end[cur_axe] = ends_.getIntValue(i);
|
||||
}
|
||||
|
||||
layerParams.set("begin", DictValue::arrayInt(&begin[0], begin.size()));
|
||||
layerParams.set("end", DictValue::arrayInt(&end[0], end.size()));
|
||||
layerParams.set("axis", axis);
|
||||
|
||||
if (!steps.empty())
|
||||
{
|
||||
for (int i = 0; i < axes.total(); ++i)
|
||||
steps[axes_ptr[i]] = steps_.getIntValue(i);
|
||||
layerParams.set("steps", DictValue::arrayInt(&steps[0], steps.size()));
|
||||
}
|
||||
|
||||
if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
|
||||
{
|
||||
|
||||
@ -1172,6 +1172,20 @@ TEST_P(Test_ONNX_layers, Slice_Steps_5DInput)
|
||||
testONNXModels("slice_opset_11_steps_5d");
|
||||
}
|
||||
|
||||
TEST_P(Test_ONNX_layers, Slice_Nonseq_Axes)
|
||||
{
|
||||
testONNXModels("slice_nonseq_axes");
|
||||
testONNXModels("slice_nonseq_axes_steps");
|
||||
testONNXModels("slice_nonseq_miss_axes_steps");
|
||||
}
|
||||
|
||||
TEST_P(Test_ONNX_layers, Slice_Neg_Axes)
|
||||
{
|
||||
testONNXModels("slice_neg_axes");
|
||||
testONNXModels("slice_neg_axes_steps");
|
||||
testONNXModels("slice_neg_miss_axes_steps");
|
||||
}
|
||||
|
||||
TEST_P(Test_ONNX_layers, Softmax)
|
||||
{
|
||||
testONNXModels("softmax");
|
||||
|
||||
Loading…
Reference in New Issue
Block a user