enhance slice layer
refactor the code for parsing Slice layer add test for Slice layer let 'begin' and 'end' resize to dims add opset message comment
This commit is contained in:
parent
04ebedb6f0
commit
4557971481
@ -1326,72 +1326,59 @@ void ONNXImporter::parseReduce(LayerParams& layerParams, const opencv_onnx::Node
|
||||
|
||||
void ONNXImporter::parseSlice(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
|
||||
{
|
||||
int axis = 0;
|
||||
std::vector<int> begin;
|
||||
std::vector<int> end;
|
||||
MatShape inpShape = outShapes[node_proto.input(0)];
|
||||
int dims = inpShape.size();
|
||||
std::vector<int> begin(dims, 0);
|
||||
std::vector<int> end(dims, INT_MAX);
|
||||
std::vector<int> steps;
|
||||
int inp_size = node_proto.input_size();
|
||||
int axis = 0;
|
||||
bool has_axes = false;
|
||||
DictValue starts_, ends_, axes_, steps_;
|
||||
|
||||
// opset = 1
|
||||
if (inp_size == 1)
|
||||
{
|
||||
if (layerParams.has("axes")) {
|
||||
DictValue axes = layerParams.get("axes");
|
||||
for (int i = 1; i < axes.size(); ++i) {
|
||||
CV_Assert(axes.get<int>(i - 1) == axes.get<int>(i) - 1);
|
||||
}
|
||||
axis = axes.get<int>(0);
|
||||
}
|
||||
|
||||
DictValue starts = layerParams.get("starts");
|
||||
DictValue ends = layerParams.get("ends");
|
||||
CV_Assert(starts.size() == ends.size());
|
||||
|
||||
if (axis > 0) {
|
||||
CV_CheckLE(axis, 1024, "Slice layer can't have more than 1024 axes"); // arbitrary limit
|
||||
begin.resize(axis, 0);
|
||||
end.resize(axis, INT_MAX);
|
||||
}
|
||||
for (int i = 0; i < starts.size(); ++i)
|
||||
starts_ = layerParams.get("starts");
|
||||
ends_ = layerParams.get("ends");
|
||||
CV_Assert(starts_.size() == ends_.size());
|
||||
if (layerParams.has("axes"))
|
||||
{
|
||||
begin.push_back(starts.get<int>(i));
|
||||
end.push_back(ends.get<int>(i));
|
||||
axes_ = layerParams.get("axes");
|
||||
CV_Assert(axes_.size() == starts_.size());
|
||||
axis = axes_.getIntValue(0) < 0 ? axes_.getIntValue(0) + dims : axes_.getIntValue(0);
|
||||
has_axes = true;
|
||||
}
|
||||
} else { // inp_size > 1
|
||||
}
|
||||
// opset > 1
|
||||
else
|
||||
{
|
||||
CV_Assert(inp_size >= 3);
|
||||
for (int i = 1; i < inp_size; i++) {
|
||||
for (int i = 1; i < inp_size; ++i)
|
||||
{
|
||||
CV_Assert(constBlobs.find(node_proto.input(i)) != constBlobs.end());
|
||||
}
|
||||
Mat start_blob = getBlob(node_proto, 1);
|
||||
Mat end_blob = getBlob(node_proto, 2);
|
||||
Mat end_blob = getBlob(node_proto, 2);
|
||||
CV_Assert(start_blob.total() == end_blob.total());
|
||||
starts_ = DictValue::arrayInt(start_blob.begin<int>(), start_blob.total());
|
||||
ends_ = DictValue::arrayInt(end_blob.begin<int>(), end_blob.total());
|
||||
|
||||
if (inp_size > 3) {
|
||||
if (inp_size > 3)
|
||||
{
|
||||
Mat axes_blob = getBlob(node_proto, 3);
|
||||
const int* axes = (int*)axes_blob.data;
|
||||
for (int i = 1; i < axes_blob.total(); ++i) {
|
||||
CV_Assert(axes[i - 1] == axes[i] - 1);
|
||||
}
|
||||
axis = axes[0];
|
||||
CV_Assert(axes_blob.total() == start_blob.total());
|
||||
axes_ = DictValue::arrayInt(axes_blob.begin<int>(), axes_blob.total());
|
||||
axis = axes_.getIntValue(0) < 0 ? axes_.getIntValue(0) + dims : axes_.getIntValue(0);
|
||||
has_axes = true;
|
||||
}
|
||||
|
||||
const int* starts = start_blob.ptr<int>();
|
||||
const int* ends = end_blob.ptr<int>();
|
||||
if (axis > 0) {
|
||||
begin.resize(axis, 0);
|
||||
end.resize(axis, INT_MAX);
|
||||
}
|
||||
std::copy(starts, starts + start_blob.total(), std::back_inserter(begin));
|
||||
std::copy(ends, ends + end_blob.total(), std::back_inserter(end));
|
||||
|
||||
if (inp_size == 5) {
|
||||
CV_Assert(constBlobs.find(node_proto.input(4)) != constBlobs.end());
|
||||
if (inp_size == 5)
|
||||
{
|
||||
Mat step_blob = getBlob(node_proto, 4);
|
||||
const int* steps_ptr = step_blob.ptr<int>();
|
||||
|
||||
if (axis > 0)
|
||||
steps.resize(axis, 1);
|
||||
|
||||
std::copy(steps_ptr, steps_ptr + step_blob.total(), std::back_inserter(steps));
|
||||
CV_Assert(step_blob.total() == start_blob.total());
|
||||
steps_ = DictValue::arrayInt(step_blob.begin<int>(), step_blob.total());
|
||||
steps.resize(dims, 1);
|
||||
|
||||
// Very strange application for Slice op with tensor reversing.
|
||||
// We just workaround it for 2d constants.
|
||||
@ -1411,12 +1398,45 @@ void ONNXImporter::parseSlice(LayerParams& layerParams, const opencv_onnx::NodeP
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!has_axes)
|
||||
{
|
||||
// make a default axes [0, 1, 2...]
|
||||
Mat axes_tmp(1, starts_.size(), CV_32S);
|
||||
std::iota(axes_tmp.begin<int>(), axes_tmp.end<int>(), 0);
|
||||
axes_ = DictValue::arrayInt(axes_tmp.begin<int>(), axes_tmp.total());
|
||||
}
|
||||
|
||||
int cur_axe;
|
||||
std::vector<bool> flag(dims, false);
|
||||
Mat axes(1, starts_.size(), CV_32S);
|
||||
auto axes_ptr = axes.ptr<int>();
|
||||
// resize begin and end
|
||||
for (int i = 0; i < axes_.size(); ++i)
|
||||
{
|
||||
// dims should be added to the negative axes
|
||||
cur_axe = axes_.getIntValue(i) < 0 ? axes_.getIntValue(i) + dims : axes_.getIntValue(i);
|
||||
CV_CheckGE(cur_axe, 0, "Axes should be grater or equal to '-dims'.");
|
||||
CV_CheckLT(cur_axe, dims, "Axes should be less than 'dim'.");
|
||||
CV_CheckEQ(flag[cur_axe], false, "Axes shouldn't have duplicated values.");
|
||||
flag[cur_axe] = true;
|
||||
// change axis to the minimum axe
|
||||
if (cur_axe < axis) axis = cur_axe;
|
||||
axes_ptr[i] = cur_axe;
|
||||
begin[cur_axe] = starts_.getIntValue(i);
|
||||
end[cur_axe] = ends_.getIntValue(i);
|
||||
}
|
||||
|
||||
layerParams.set("begin", DictValue::arrayInt(&begin[0], begin.size()));
|
||||
layerParams.set("end", DictValue::arrayInt(&end[0], end.size()));
|
||||
layerParams.set("axis", axis);
|
||||
|
||||
if (!steps.empty())
|
||||
{
|
||||
for (int i = 0; i < axes.total(); ++i)
|
||||
steps[axes_ptr[i]] = steps_.getIntValue(i);
|
||||
layerParams.set("steps", DictValue::arrayInt(&steps[0], steps.size()));
|
||||
}
|
||||
|
||||
if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
|
||||
{
|
||||
|
||||
@ -1172,6 +1172,20 @@ TEST_P(Test_ONNX_layers, Slice_Steps_5DInput)
|
||||
testONNXModels("slice_opset_11_steps_5d");
|
||||
}
|
||||
|
||||
TEST_P(Test_ONNX_layers, Slice_Nonseq_Axes)
|
||||
{
|
||||
testONNXModels("slice_nonseq_axes");
|
||||
testONNXModels("slice_nonseq_axes_steps");
|
||||
testONNXModels("slice_nonseq_miss_axes_steps");
|
||||
}
|
||||
|
||||
TEST_P(Test_ONNX_layers, Slice_Neg_Axes)
|
||||
{
|
||||
testONNXModels("slice_neg_axes");
|
||||
testONNXModels("slice_neg_axes_steps");
|
||||
testONNXModels("slice_neg_miss_axes_steps");
|
||||
}
|
||||
|
||||
TEST_P(Test_ONNX_layers, Softmax)
|
||||
{
|
||||
testONNXModels("softmax");
|
||||
|
||||
Loading…
Reference in New Issue
Block a user