Merge pull request #17570 from HannibalAPE:text_det_recog_demo

[GSoC] High Level API and Samples for Scene Text Detection and Recognition

* APIs and samples for scene text detection and recognition

* update APIs and tutorial for Text Detection and Recognition

* API updates:
(1) put decodeType into struct Voc
(2) optimize the post-processing of DB

* sample update:
(1) add transformation into scene_text_spotting.cpp
(2) modify text_detection.cpp with API update

* update tutorial

* simplify text recognition API
update tutorial

* update impl usage in recognize() and detect()

* dnn: refactoring public API of TextRecognitionModel/TextDetectionModel

* update provided models
update opencv.bib

* dnn: adjust text rectangle angle

* remove points ordering operation in model.cpp

* update gts of DB test in test_model.cpp

* dnn: ensure to keep text rectangle angle

- avoid 90/180 degree turns

* dnn(text): use quadrangle result in TextDetectionModel API

* dnn: update Text Detection API
(1) keep points' order consistent with (bl, tl, tr, br) in unclip
(2) update contourScore with boundingRect
This commit is contained in:
Wenqing Zhang
2020-12-04 02:47:40 +08:00
committed by GitHub
parent 5ecf693774
commit 22d64ae08f
19 changed files with 2339 additions and 181 deletions
+92 -177
View File
@@ -2,22 +2,23 @@
Text detection model: https://github.com/argman/EAST
Download link: https://www.dropbox.com/s/r2ingd0l3zt8hxs/frozen_east_text_detection.tar.gz?dl=1
CRNN Text recognition model taken from here: https://github.com/meijieru/crnn.pytorch
Text recognition models can be downloaded directly here:
Download link: https://drive.google.com/drive/folders/1cTbQ3nuZG-EKWak6emD_s8_hHXWz7lAr?usp=sharing
and doc/tutorials/dnn/dnn_text_spotting/dnn_text_spotting.markdown
How to convert from pb to onnx:
Using classes from here: https://github.com/meijieru/crnn.pytorch/blob/master/models/crnn.py
More converted onnx text recognition models can be downloaded directly here:
Download link: https://drive.google.com/drive/folders/1cTbQ3nuZG-EKWak6emD_s8_hHXWz7lAr?usp=sharing
And these models taken from here:https://github.com/clovaai/deep-text-recognition-benchmark
import torch
from models.crnn import CRNN
model = CRNN(32, 1, 37, 256)
model.load_state_dict(torch.load('crnn.pth'))
dummy_input = torch.randn(1, 1, 32, 100)
torch.onnx.export(model, dummy_input, "crnn.onnx", verbose=True)
For more information, please refer to doc/tutorials/dnn/dnn_text_spotting/dnn_text_spotting.markdown and doc/tutorials/dnn/dnn_OCR/dnn_OCR.markdown
*/
#include <iostream>
#include <fstream>
#include <opencv2/imgproc.hpp>
#include <opencv2/highgui.hpp>
@@ -27,21 +28,20 @@ using namespace cv;
using namespace cv::dnn;
const char* keys =
"{ help h | | Print help message. }"
"{ input i | | Path to input image or video file. Skip this argument to capture frames from a camera.}"
"{ model m | | Path to a binary .pb file contains trained detector network.}"
"{ ocr | | Path to a binary .pb or .onnx file contains trained recognition network.}"
"{ width | 320 | Preprocess input image by resizing to a specific width. It should be multiple by 32. }"
"{ height | 320 | Preprocess input image by resizing to a specific height. It should be multiple by 32. }"
"{ thr | 0.5 | Confidence threshold. }"
"{ nms | 0.4 | Non-maximum suppression threshold. }";
"{ help h | | Print help message. }"
"{ input i | | Path to input image or video file. Skip this argument to capture frames from a camera.}"
"{ detModel dmp | | Path to a binary .pb file contains trained detector network.}"
"{ width | 320 | Preprocess input image by resizing to a specific width. It should be multiple by 32. }"
"{ height | 320 | Preprocess input image by resizing to a specific height. It should be multiple by 32. }"
"{ thr | 0.5 | Confidence threshold. }"
"{ nms | 0.4 | Non-maximum suppression threshold. }"
"{ recModel rmp | | Path to a binary .onnx file contains trained CRNN text recognition model. "
"Download links are provided in doc/tutorials/dnn/dnn_text_spotting/dnn_text_spotting.markdown}"
"{ RGBInput rgb |0| 0: imread with flags=IMREAD_GRAYSCALE; 1: imread with flags=IMREAD_COLOR. }"
"{ vocabularyPath vp | alphabet_36.txt | Path to benchmarks for evaluation. "
"Download links are provided in doc/tutorials/dnn/dnn_text_spotting/dnn_text_spotting.markdown}";
void decodeBoundingBoxes(const Mat& scores, const Mat& geometry, float scoreThresh,
std::vector<RotatedRect>& detections, std::vector<float>& confidences);
void fourPointsTransform(const Mat& frame, Point2f vertices[4], Mat& result);
void decodeText(const Mat& scores, std::string& text);
void fourPointsTransform(const Mat& frame, const Point2f vertices[], Mat& result);
int main(int argc, char** argv)
{
@@ -57,10 +57,12 @@ int main(int argc, char** argv)
float confThreshold = parser.get<float>("thr");
float nmsThreshold = parser.get<float>("nms");
int inpWidth = parser.get<int>("width");
int inpHeight = parser.get<int>("height");
String modelDecoder = parser.get<String>("model");
String modelRecognition = parser.get<String>("ocr");
int width = parser.get<int>("width");
int height = parser.get<int>("height");
int imreadRGB = parser.get<int>("RGBInput");
String detModelPath = parser.get<String>("detModel");
String recModelPath = parser.get<String>("recModel");
String vocPath = parser.get<String>("vocabularyPath");
if (!parser.check())
{
@@ -68,14 +70,39 @@ int main(int argc, char** argv)
return 1;
}
CV_Assert(!modelDecoder.empty());
// Load networks.
Net detector = readNet(modelDecoder);
Net recognizer;
CV_Assert(!detModelPath.empty() && !recModelPath.empty());
TextDetectionModel_EAST detector(detModelPath);
detector.setConfidenceThreshold(confThreshold)
.setNMSThreshold(nmsThreshold);
if (!modelRecognition.empty())
recognizer = readNet(modelRecognition);
TextRecognitionModel recognizer(recModelPath);
// Load vocabulary
CV_Assert(!vocPath.empty());
std::ifstream vocFile;
vocFile.open(samples::findFile(vocPath));
CV_Assert(vocFile.is_open());
String vocLine;
std::vector<String> vocabulary;
while (std::getline(vocFile, vocLine)) {
vocabulary.push_back(vocLine);
}
recognizer.setVocabulary(vocabulary);
recognizer.setDecodeType("CTC-greedy");
// Parameters for Recognition
double recScale = 1.0 / 127.5;
Scalar recMean = Scalar(127.5, 127.5, 127.5);
Size recInputSize = Size(100, 32);
recognizer.setInputParams(recScale, recInputSize, recMean);
// Parameters for Detection
double detScale = 1.0;
Size detInputSize = Size(width, height);
Scalar detMean = Scalar(123.68, 116.78, 103.94);
bool swapRB = true;
detector.setInputParams(detScale, detInputSize, detMean, swapRB);
// Open a video file or an image file or a camera stream.
VideoCapture cap;
@@ -83,15 +110,8 @@ int main(int argc, char** argv)
CV_Assert(openSuccess);
static const std::string kWinName = "EAST: An Efficient and Accurate Scene Text Detector";
namedWindow(kWinName, WINDOW_NORMAL);
std::vector<Mat> outs;
std::vector<String> outNames(2);
outNames[0] = "feature_fusion/Conv_7/Sigmoid";
outNames[1] = "feature_fusion/concat_3";
Mat frame, blob;
TickMeter tickMeter;
Mat frame;
while (waitKey(1) < 0)
{
cap >> frame;
@@ -101,162 +121,57 @@ int main(int argc, char** argv)
break;
}
blobFromImage(frame, blob, 1.0, Size(inpWidth, inpHeight), Scalar(123.68, 116.78, 103.94), true, false);
detector.setInput(blob);
tickMeter.start();
detector.forward(outs, outNames);
tickMeter.stop();
std::cout << frame.size << std::endl;
Mat scores = outs[0];
Mat geometry = outs[1];
// Detection
std::vector< std::vector<Point> > detResults;
detector.detect(frame, detResults);
// Decode predicted bounding boxes.
std::vector<RotatedRect> boxes;
std::vector<float> confidences;
decodeBoundingBoxes(scores, geometry, confThreshold, boxes, confidences);
// Apply non-maximum suppression procedure.
std::vector<int> indices;
NMSBoxes(boxes, confidences, confThreshold, nmsThreshold, indices);
Point2f ratio((float)frame.cols / inpWidth, (float)frame.rows / inpHeight);
// Render text.
for (size_t i = 0; i < indices.size(); ++i)
{
RotatedRect& box = boxes[indices[i]];
Point2f vertices[4];
box.points(vertices);
for (int j = 0; j < 4; ++j)
{
vertices[j].x *= ratio.x;
vertices[j].y *= ratio.y;
if (detResults.size() > 0) {
// Text Recognition
Mat recInput;
if (!imreadRGB) {
cvtColor(frame, recInput, cv::COLOR_BGR2GRAY);
} else {
recInput = frame;
}
if (!modelRecognition.empty())
std::vector< std::vector<Point> > contours;
for (uint i = 0; i < detResults.size(); i++)
{
const auto& quadrangle = detResults[i];
CV_CheckEQ(quadrangle.size(), (size_t)4, "");
contours.emplace_back(quadrangle);
std::vector<Point2f> quadrangle_2f;
for (int j = 0; j < 4; j++)
quadrangle_2f.emplace_back(quadrangle[j]);
Mat cropped;
fourPointsTransform(frame, vertices, cropped);
fourPointsTransform(recInput, &quadrangle_2f[0], cropped);
cvtColor(cropped, cropped, cv::COLOR_BGR2GRAY);
std::string recognitionResult = recognizer.recognize(cropped);
std::cout << i << ": '" << recognitionResult << "'" << std::endl;
Mat blobCrop = blobFromImage(cropped, 1.0/127.5, Size(), Scalar::all(127.5));
recognizer.setInput(blobCrop);
tickMeter.start();
Mat result = recognizer.forward();
tickMeter.stop();
std::string wordRecognized = "";
decodeText(result, wordRecognized);
putText(frame, wordRecognized, vertices[1], FONT_HERSHEY_SIMPLEX, 1.5, Scalar(0, 0, 255));
putText(frame, recognitionResult, quadrangle[3], FONT_HERSHEY_SIMPLEX, 1.5, Scalar(0, 0, 255), 2);
}
for (int j = 0; j < 4; ++j)
line(frame, vertices[j], vertices[(j + 1) % 4], Scalar(0, 255, 0), 1);
polylines(frame, contours, true, Scalar(0, 255, 0), 2);
}
// Put efficiency information.
std::string label = format("Inference time: %.2f ms", tickMeter.getTimeMilli());
putText(frame, label, Point(0, 15), FONT_HERSHEY_SIMPLEX, 0.5, Scalar(0, 255, 0));
imshow(kWinName, frame);
tickMeter.reset();
}
return 0;
}
void decodeBoundingBoxes(const Mat& scores, const Mat& geometry, float scoreThresh,
std::vector<RotatedRect>& detections, std::vector<float>& confidences)
{
detections.clear();
CV_Assert(scores.dims == 4); CV_Assert(geometry.dims == 4); CV_Assert(scores.size[0] == 1);
CV_Assert(geometry.size[0] == 1); CV_Assert(scores.size[1] == 1); CV_Assert(geometry.size[1] == 5);
CV_Assert(scores.size[2] == geometry.size[2]); CV_Assert(scores.size[3] == geometry.size[3]);
const int height = scores.size[2];
const int width = scores.size[3];
for (int y = 0; y < height; ++y)
{
const float* scoresData = scores.ptr<float>(0, 0, y);
const float* x0_data = geometry.ptr<float>(0, 0, y);
const float* x1_data = geometry.ptr<float>(0, 1, y);
const float* x2_data = geometry.ptr<float>(0, 2, y);
const float* x3_data = geometry.ptr<float>(0, 3, y);
const float* anglesData = geometry.ptr<float>(0, 4, y);
for (int x = 0; x < width; ++x)
{
float score = scoresData[x];
if (score < scoreThresh)
continue;
// Decode a prediction.
// Multiple by 4 because feature maps are 4 time less than input image.
float offsetX = x * 4.0f, offsetY = y * 4.0f;
float angle = anglesData[x];
float cosA = std::cos(angle);
float sinA = std::sin(angle);
float h = x0_data[x] + x2_data[x];
float w = x1_data[x] + x3_data[x];
Point2f offset(offsetX + cosA * x1_data[x] + sinA * x2_data[x],
offsetY - sinA * x1_data[x] + cosA * x2_data[x]);
Point2f p1 = Point2f(-sinA * h, -cosA * h) + offset;
Point2f p3 = Point2f(-cosA * w, sinA * w) + offset;
RotatedRect r(0.5f * (p1 + p3), Size2f(w, h), -angle * 180.0f / (float)CV_PI);
detections.push_back(r);
confidences.push_back(score);
}
}
}
void fourPointsTransform(const Mat& frame, Point2f vertices[4], Mat& result)
void fourPointsTransform(const Mat& frame, const Point2f vertices[], Mat& result)
{
const Size outputSize = Size(100, 32);
Point2f targetVertices[4] = {Point(0, outputSize.height - 1),
Point(0, 0), Point(outputSize.width - 1, 0),
Point(outputSize.width - 1, outputSize.height - 1),
};
Point2f targetVertices[4] = {
Point(0, outputSize.height - 1),
Point(0, 0), Point(outputSize.width - 1, 0),
Point(outputSize.width - 1, outputSize.height - 1)
};
Mat rotationMatrix = getPerspectiveTransform(vertices, targetVertices);
warpPerspective(frame, result, rotationMatrix, outputSize);
}
void decodeText(const Mat& scores, std::string& text)
{
static const std::string alphabet = "0123456789abcdefghijklmnopqrstuvwxyz";
Mat scoresMat = scores.reshape(1, scores.size[0]);
std::vector<char> elements;
elements.reserve(scores.size[0]);
for (int rowIndex = 0; rowIndex < scoresMat.rows; ++rowIndex)
{
Point p;
minMaxLoc(scoresMat.row(rowIndex), 0, 0, 0, &p);
if (p.x > 0 && static_cast<size_t>(p.x) <= alphabet.size())
{
elements.push_back(alphabet[p.x - 1]);
}
else
{
elements.push_back('-');
}
}
if (elements.size() > 0 && elements[0] != '-')
text += elements[0];
for (size_t elementIndex = 1; elementIndex < elements.size(); ++elementIndex)
{
if (elementIndex > 0 && elements[elementIndex] != '-' &&
elements[elementIndex - 1] != elements[elementIndex])
{
text += elements[elementIndex];
}
}
}