support axis in concat layer ocl path
Signed-off-by: Li Peng <peng.li@intel.com>
This commit is contained in:
parent
07bec6bdcd
commit
85b1c4060c
@ -185,12 +185,13 @@ public:
|
||||
outs.getUMatVector(outputs);
|
||||
|
||||
int cAxis = clamp(axis, inputs[0].dims);
|
||||
if (!(cAxis == 1 && outputs[0].dims == 4 && !padding))
|
||||
if (padding)
|
||||
return false;
|
||||
|
||||
int bottom_concat_axis;
|
||||
int concat_size = inputs[0].size[2] * inputs[0].size[3];
|
||||
int top_concat_axis = outputs[0].size[1];
|
||||
int concat_size = total(shape(inputs[0]), cAxis + 1);
|
||||
int top_concat_axis = outputs[0].size[cAxis];
|
||||
int num_concats = total(shape(inputs[0]), 0, cAxis);
|
||||
int offset_concat_axis = 0;
|
||||
UMat& outMat = outputs[0];
|
||||
String buildopt = String("-DDtype=") + ocl::typeToStr(inputs[0].type()) + String(" ");
|
||||
@ -202,12 +203,12 @@ public:
|
||||
return false;
|
||||
|
||||
UMat& inpMat = inputs[i];
|
||||
bottom_concat_axis = inputs[i].size[1];
|
||||
bottom_concat_axis = inputs[i].size[cAxis];
|
||||
size_t nthreads = inputs[i].total();
|
||||
|
||||
kernel.set(0, (int)nthreads);
|
||||
kernel.set(1, ocl::KernelArg::PtrReadOnly(inpMat));
|
||||
kernel.set(2, (int)inputs[i].size[0]);
|
||||
kernel.set(2, (int)num_concats);
|
||||
kernel.set(3, (int)concat_size);
|
||||
kernel.set(4, (int)top_concat_axis);
|
||||
kernel.set(5, (int)bottom_concat_axis);
|
||||
|
||||
Loading…
Reference in New Issue
Block a user