From 85b1c4060cc3a5082dfa4245d3da8e180cec4c13 Mon Sep 17 00:00:00 2001 From: Li Peng Date: Tue, 28 Nov 2017 23:40:46 +0800 Subject: [PATCH] support axis in concat layer ocl path Signed-off-by: Li Peng --- modules/dnn/src/layers/concat_layer.cpp | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/modules/dnn/src/layers/concat_layer.cpp b/modules/dnn/src/layers/concat_layer.cpp index e51e1f7824..e49f22db2c 100644 --- a/modules/dnn/src/layers/concat_layer.cpp +++ b/modules/dnn/src/layers/concat_layer.cpp @@ -185,12 +185,13 @@ public: outs.getUMatVector(outputs); int cAxis = clamp(axis, inputs[0].dims); - if (!(cAxis == 1 && outputs[0].dims == 4 && !padding)) + if (padding) return false; int bottom_concat_axis; - int concat_size = inputs[0].size[2] * inputs[0].size[3]; - int top_concat_axis = outputs[0].size[1]; + int concat_size = total(shape(inputs[0]), cAxis + 1); + int top_concat_axis = outputs[0].size[cAxis]; + int num_concats = total(shape(inputs[0]), 0, cAxis); int offset_concat_axis = 0; UMat& outMat = outputs[0]; String buildopt = String("-DDtype=") + ocl::typeToStr(inputs[0].type()) + String(" "); @@ -202,12 +203,12 @@ public: return false; UMat& inpMat = inputs[i]; - bottom_concat_axis = inputs[i].size[1]; + bottom_concat_axis = inputs[i].size[cAxis]; size_t nthreads = inputs[i].total(); kernel.set(0, (int)nthreads); kernel.set(1, ocl::KernelArg::PtrReadOnly(inpMat)); - kernel.set(2, (int)inputs[i].size[0]); + kernel.set(2, (int)num_concats); kernel.set(3, (int)concat_size); kernel.set(4, (int)top_concat_axis); kernel.set(5, (int)bottom_concat_axis);