Update tutorials. A new cv::dnn::readNet function

This commit is contained in:
Dmitry Kurtaev
2018-03-03 19:29:37 +03:00
parent 8e4fe30db6
commit f2440ceae6
10 changed files with 149 additions and 155 deletions
+27 -32
View File
@@ -2,8 +2,9 @@
#include <iostream>
#include <sstream>
#include <opencv2/opencv.hpp>
#include <opencv2/dnn.hpp>
#include <opencv2/imgproc.hpp>
#include <opencv2/highgui.hpp>
const char* keys =
"{ help h | | Print help message. }"
@@ -33,8 +34,6 @@ using namespace dnn;
std::vector<std::string> classes;
Net readNet(const std::string& model, const std::string& config = "", const std::string& framework = "");
int main(int argc, char** argv)
{
CommandLineParser parser(argc, argv, keys);
@@ -49,6 +48,11 @@ int main(int argc, char** argv)
bool swapRB = parser.get<bool>("rgb");
int inpWidth = parser.get<int>("width");
int inpHeight = parser.get<int>("height");
String model = parser.get<String>("model");
String config = parser.get<String>("config");
String framework = parser.get<String>("framework");
int backendId = parser.get<int>("backend");
int targetId = parser.get<int>("target");
// Parse mean values.
Scalar mean;
@@ -77,22 +81,24 @@ int main(int argc, char** argv)
}
}
// Load a model.
CV_Assert(parser.has("model"));
Net net = readNet(parser.get<String>("model"), parser.get<String>("config"), parser.get<String>("framework"));
net.setPreferableBackend(parser.get<int>("backend"));
net.setPreferableTarget(parser.get<int>("target"));
//! [Read and initialize network]
Net net = readNet(model, config, framework);
net.setPreferableBackend(backendId);
net.setPreferableTarget(targetId);
//! [Read and initialize network]
// Create a window
static const std::string kWinName = "Deep learning image classification in OpenCV";
namedWindow(kWinName, WINDOW_NORMAL);
// Open a video file or an image file or a camera stream.
//! [Open a video file or an image file or a camera stream]
VideoCapture cap;
if (parser.has("input"))
cap.open(parser.get<String>("input"));
else
cap.open(0);
//! [Open a video file or an image file or a camera stream]
// Process frames.
Mat frame, blob;
@@ -105,24 +111,29 @@ int main(int argc, char** argv)
break;
}
// Create a 4D blob from a frame.
//! [Create a 4D blob from a frame]
blobFromImage(frame, blob, scale, Size(inpWidth, inpHeight), mean, swapRB, false);
//! [Create a 4D blob from a frame]
// Run a model.
//! [Set input blob]
net.setInput(blob);
Mat out = net.forward();
out = out.reshape(1, 1);
//! [Set input blob]
//! [Make forward pass]
Mat prob = net.forward();
//! [Make forward pass]
// Get a class with a highest score.
//! [Get a class with a highest score]
Point classIdPoint;
double confidence;
minMaxLoc(out, 0, &confidence, 0, &classIdPoint);
minMaxLoc(prob.reshape(1, 1), 0, &confidence, 0, &classIdPoint);
int classId = classIdPoint.x;
//! [Get a class with a highest score]
// Put efficiency information.
std::vector<double> layersTimes;
double t = net.getPerfProfile(layersTimes);
std::string label = format("Inference time: %.2f", t * 1000 / getTickFrequency());
double freq = getTickFrequency() / 1000;
double t = net.getPerfProfile(layersTimes) / freq;
std::string label = format("Inference time: %.2f ms", t);
putText(frame, label, Point(0, 15), FONT_HERSHEY_SIMPLEX, 0.5, Scalar(0, 255, 0));
// Print predicted class.
@@ -135,19 +146,3 @@ int main(int argc, char** argv)
}
return 0;
}
Net readNet(const std::string& model, const std::string& config, const std::string& framework)
{
std::string modelExt = model.substr(model.rfind('.'));
if (framework == "caffe" || modelExt == ".caffemodel")
return readNetFromCaffe(config, model);
else if (framework == "tensorflow" || modelExt == ".pb")
return readNetFromTensorflow(model, config);
else if (framework == "torch" || modelExt == ".t7" || modelExt == ".net")
return readNetFromTorch(model);
else if (framework == "darknet" || modelExt == ".weights")
return readNetFromDarknet(config, model);
else
CV_Error(Error::StsError, "Cannot determine an origin framework of model from file " + model);
return Net();
}