Update get memory shapes
This commit is contained in:
parent
4625337179
commit
9ed372b297
@ -149,11 +149,12 @@ public:
|
||||
out.push_back(outputs[0].size[i]);
|
||||
}
|
||||
kernel_size.resize(out.size());
|
||||
int diff_size = isGlobalPooling.size() - kernel_size.size();
|
||||
for (int i = 0; i < kernel_size.size(); i++)
|
||||
{
|
||||
if (isGlobalPooling[i + diff_size])
|
||||
kernel_size[i] = inp[i];
|
||||
int pool_idx = isGlobalPooling.size() - 1 - i;
|
||||
int kernel_idx = kernel_size.size() - 1 - i;
|
||||
if (isGlobalPooling[pool_idx])
|
||||
kernel_size[kernel_idx] = inp[kernel_idx];
|
||||
}
|
||||
kernel = Size(kernel_size[1], kernel_size[0]);
|
||||
|
||||
@ -1001,20 +1002,27 @@ virtual Ptr<BackendNode> initNgraph(const std::vector<Ptr<BackendWrapper> >& inp
|
||||
std::vector<int> inpShape(inputs[0].begin() + 2, inputs[0].end());
|
||||
std::vector<int> outShape(inputs[0].begin(), inputs[0].begin() + 2);
|
||||
|
||||
if (globalPooling)
|
||||
std::vector<size_t> local_kernel = kernel_size.empty() ?
|
||||
std::vector<size_t>(inpShape.begin(), inpShape.end()) : kernel_size;
|
||||
|
||||
for (int i = 0; i < local_kernel.size(); i++)
|
||||
{
|
||||
outShape.push_back(1);
|
||||
outShape.push_back(1);
|
||||
int pool_idx = isGlobalPooling.size() - 1 - i;
|
||||
int kernel_idx = local_kernel.size() - 1 - i;
|
||||
if (isGlobalPooling[pool_idx])
|
||||
local_kernel[kernel_idx] = inpShape[kernel_idx];
|
||||
}
|
||||
else if (type == ROI || type == PSROI)
|
||||
|
||||
|
||||
if (type == ROI || type == PSROI)
|
||||
{
|
||||
outShape.push_back(pooledSize.height);
|
||||
outShape.push_back(pooledSize.width);
|
||||
}
|
||||
else if (padMode.empty())
|
||||
{
|
||||
for (int i = 0; i < kernel_size.size(); i++) {
|
||||
float dst = (float)(inpShape[i] + pads_begin[i] + pads_end[i] - kernel_size[i]) / strides[i];
|
||||
for (int i = 0; i < local_kernel.size(); i++) {
|
||||
float dst = (float)(inpShape[i] + pads_begin[i] + pads_end[i] - local_kernel[i]) / strides[i];
|
||||
outShape.push_back(1 + (ceilMode ? ceil(dst) : floor(dst)));
|
||||
}
|
||||
|
||||
@ -1029,7 +1037,7 @@ virtual Ptr<BackendNode> initNgraph(const std::vector<Ptr<BackendWrapper> >& inp
|
||||
}
|
||||
else
|
||||
{
|
||||
getConvPoolOutParams(inpShape, kernel_size, strides, padMode, std::vector<size_t>(kernel_size.size(), 1), outShape);
|
||||
getConvPoolOutParams(inpShape, local_kernel, strides, padMode, std::vector<size_t>(local_kernel.size(), 1), outShape);
|
||||
}
|
||||
if (type == ROI)
|
||||
{
|
||||
@ -1044,13 +1052,6 @@ virtual Ptr<BackendNode> initNgraph(const std::vector<Ptr<BackendWrapper> >& inp
|
||||
outShape[1] = psRoiOutChannels;
|
||||
}
|
||||
|
||||
int diff_size = isGlobalPooling.size() - (outShape.size() - 2);
|
||||
for (int i = 2; i < outShape.size(); i++)
|
||||
{
|
||||
if (isGlobalPooling[i - 2 + diff_size])
|
||||
outShape[i] = 1;
|
||||
}
|
||||
|
||||
int numOutputs = requiredOutputs ? requiredOutputs : (type == MAX ? 2 : 1);
|
||||
CV_Assert(numOutputs == 1 || (numOutputs == 2 && type == MAX));
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user