Merge pull request #20476 from TolyaTalamanov:at/support-unet-camvid-0001-segm-sample
[G-API] Support postprocessing for not argmaxed outputs * Support postprocessing for not argmaxed outputs * Fix typo * Add assert * Remove static cast * CamelCast to snake_case * Fix windows warning * Add static_cast to uint8_t * Add const to variables
This commit is contained in:
parent
0c2741f7ad
commit
24de676a64
@ -47,6 +47,53 @@ std::string get_weights_path(const std::string &model_path) {
|
||||
CV_Assert(ext == ".xml");
|
||||
return model_path.substr(0u, sz - EXT_LEN) + ".bin";
|
||||
}
|
||||
|
||||
void classesToColors(const cv::Mat &out_blob,
|
||||
cv::Mat &mask_img) {
|
||||
const int H = out_blob.size[0];
|
||||
const int W = out_blob.size[1];
|
||||
|
||||
mask_img.create(H, W, CV_8UC3);
|
||||
GAPI_Assert(out_blob.type() == CV_8UC1);
|
||||
const uint8_t* const classes = out_blob.ptr<uint8_t>();
|
||||
|
||||
for (int rowId = 0; rowId < H; ++rowId) {
|
||||
for (int colId = 0; colId < W; ++colId) {
|
||||
uint8_t class_id = classes[rowId * W + colId];
|
||||
mask_img.at<cv::Vec3b>(rowId, colId) =
|
||||
class_id < colors.size()
|
||||
? colors[class_id]
|
||||
: cv::Vec3b{0, 0, 0}; // NB: sample supports 20 classes
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void probsToClasses(const cv::Mat& probs, cv::Mat& classes) {
|
||||
const int C = probs.size[1];
|
||||
const int H = probs.size[2];
|
||||
const int W = probs.size[3];
|
||||
|
||||
classes.create(H, W, CV_8UC1);
|
||||
GAPI_Assert(probs.depth() == CV_32F);
|
||||
float* out_p = reinterpret_cast<float*>(probs.data);
|
||||
uint8_t* classes_p = reinterpret_cast<uint8_t*>(classes.data);
|
||||
|
||||
for (int h = 0; h < H; ++h) {
|
||||
for (int w = 0; w < W; ++w) {
|
||||
double max = 0;
|
||||
int class_id = 0;
|
||||
for (int c = 0; c < C; ++c) {
|
||||
int idx = c * H * W + h * W + w;
|
||||
if (out_p[idx] > max) {
|
||||
max = out_p[idx];
|
||||
class_id = c;
|
||||
}
|
||||
}
|
||||
classes_p[h * W + w] = static_cast<uint8_t>(class_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
namespace custom {
|
||||
@ -57,25 +104,21 @@ G_API_OP(PostProcessing, <cv::GMat(cv::GMat, cv::GMat)>, "sample.custom.post_pro
|
||||
};
|
||||
|
||||
GAPI_OCV_KERNEL(OCVPostProcessing, PostProcessing) {
|
||||
static void run(const cv::Mat &in, const cv::Mat &detected_classes, cv::Mat &out) {
|
||||
// This kernel constructs output image by class table and colors vector
|
||||
|
||||
// The semantic-segmentation-adas-0001 output a blob with the shape
|
||||
// [B, C=1, H=1024, W=2048]
|
||||
const int outHeight = 1024;
|
||||
const int outWidth = 2048;
|
||||
cv::Mat maskImg(outHeight, outWidth, CV_8UC3);
|
||||
const int* const classes = detected_classes.ptr<int>();
|
||||
for (int rowId = 0; rowId < outHeight; ++rowId) {
|
||||
for (int colId = 0; colId < outWidth; ++colId) {
|
||||
size_t classId = static_cast<size_t>(classes[rowId * outWidth + colId]);
|
||||
maskImg.at<cv::Vec3b>(rowId, colId) =
|
||||
classId < colors.size()
|
||||
? colors[classId]
|
||||
: cv::Vec3b{0, 0, 0}; // sample detects 20 classes
|
||||
}
|
||||
static void run(const cv::Mat &in, const cv::Mat &out_blob, cv::Mat &out) {
|
||||
cv::Mat classes;
|
||||
// NB: If output has more than single plane, it contains probabilities
|
||||
// otherwise class id.
|
||||
if (out_blob.size[1] > 1) {
|
||||
probsToClasses(out_blob, classes);
|
||||
} else {
|
||||
out_blob.convertTo(classes, CV_8UC1);
|
||||
classes = classes.reshape(1, out_blob.size[2]);
|
||||
}
|
||||
cv::resize(maskImg, out, in.size());
|
||||
|
||||
cv::Mat mask_img;
|
||||
classesToColors(classes, mask_img);
|
||||
|
||||
cv::resize(mask_img, out, in.size());
|
||||
const float blending = 0.3f;
|
||||
out = in * blending + out * (1 - blending);
|
||||
}
|
||||
@ -104,8 +147,8 @@ int main(int argc, char *argv[]) {
|
||||
|
||||
// Now build the graph
|
||||
cv::GMat in;
|
||||
cv::GMat detected_classes = cv::gapi::infer<SemSegmNet>(in);
|
||||
cv::GMat out = custom::PostProcessing::on(in, detected_classes);
|
||||
cv::GMat out_blob = cv::gapi::infer<SemSegmNet>(in);
|
||||
cv::GMat out = custom::PostProcessing::on(in, out_blob);
|
||||
|
||||
cv::GStreamingCompiled pipeline = cv::GComputation(cv::GIn(in), cv::GOut(out))
|
||||
.compileStreaming(cv::compile_args(kernels, networks));
|
||||
|
||||
Loading…
Reference in New Issue
Block a user