Merge pull request #14697 from l-bat:Slice_ONNX
* Support Slice layer in ONNX importer * Add IE support * Fix ONNX importer * Fix Slice
This commit is contained in:
parent
254f88f805
commit
5e80191d27
@ -174,16 +174,16 @@ public:
|
||||
for (int i = 0; i < outputs.size(); ++i)
|
||||
{
|
||||
CV_Assert(sliceRanges[i].size() <= inpShape.dims());
|
||||
// Clamp.
|
||||
for (int j = 0; j < sliceRanges[i].size(); ++j)
|
||||
{
|
||||
sliceRanges[i][j] = clamp(sliceRanges[i][j], inpShape[j]);
|
||||
}
|
||||
// Fill the rest of ranges.
|
||||
for (int j = sliceRanges[i].size(); j < inpShape.dims(); ++j)
|
||||
{
|
||||
sliceRanges[i].push_back(Range::all());
|
||||
}
|
||||
// Clamp.
|
||||
for (int j = 0; j < sliceRanges[i].size(); ++j)
|
||||
{
|
||||
sliceRanges[i][j] = clamp(sliceRanges[i][j], inpShape[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -401,6 +401,47 @@ void ONNXImporter::populateNet(Net dstNet)
|
||||
layerParams.set("pool", layer_type == "GlobalAveragePool" ? "AVE" : "MAX");
|
||||
layerParams.set("global_pooling", true);
|
||||
}
|
||||
else if (layer_type == "Slice")
|
||||
{
|
||||
if (layerParams.has("steps")) {
|
||||
DictValue steps = layerParams.get("steps");
|
||||
for (int i = 0; i < steps.size(); ++i) {
|
||||
if (steps.get<int>(i) != 1)
|
||||
CV_Error(Error::StsNotImplemented,
|
||||
"Slice layer only supports steps = 1");
|
||||
}
|
||||
}
|
||||
|
||||
int axis = 0;
|
||||
if (layerParams.has("axes")) {
|
||||
DictValue axes = layerParams.get("axes");
|
||||
for (int i = 1; i < axes.size(); ++i) {
|
||||
CV_Assert(axes.get<int>(i - 1) == axes.get<int>(i) - 1);
|
||||
}
|
||||
axis = axes.get<int>(0);
|
||||
}
|
||||
layerParams.set("axis", axis);
|
||||
|
||||
DictValue starts = layerParams.get("starts");
|
||||
DictValue ends = layerParams.get("ends");
|
||||
CV_Assert(starts.size() == ends.size());
|
||||
|
||||
std::vector<int> begin;
|
||||
std::vector<int> end;
|
||||
if (axis > 0) {
|
||||
begin.resize(axis, 0);
|
||||
end.resize(axis, -1);
|
||||
}
|
||||
|
||||
for (int i = 0; i < starts.size(); ++i)
|
||||
{
|
||||
begin.push_back(starts.get<int>(i));
|
||||
int finish = ends.get<int>(i);
|
||||
end.push_back((finish < 0) ? --finish : finish); // numpy doesn't include last dim
|
||||
}
|
||||
layerParams.set("begin", DictValue::arrayInt(&begin[0], begin.size()));
|
||||
layerParams.set("end", DictValue::arrayInt(&end[0], end.size()));
|
||||
}
|
||||
else if (layer_type == "Add" || layer_type == "Sum")
|
||||
{
|
||||
if (layer_id.find(node_proto.input(1)) == layer_id.end())
|
||||
|
||||
@ -245,6 +245,11 @@ TEST_P(Test_ONNX_layers, Reshape)
|
||||
testONNXModels("unsqueeze");
|
||||
}
|
||||
|
||||
TEST_P(Test_ONNX_layers, Slice)
|
||||
{
|
||||
testONNXModels("slice");
|
||||
}
|
||||
|
||||
TEST_P(Test_ONNX_layers, Softmax)
|
||||
{
|
||||
testONNXModels("softmax");
|
||||
|
||||
Loading…
Reference in New Issue
Block a user