Update tutorials. A new cv::dnn::readNet function
This commit is contained in:
@@ -2,8 +2,9 @@
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include <opencv2/opencv.hpp>
|
||||
#include <opencv2/dnn.hpp>
|
||||
#include <opencv2/imgproc.hpp>
|
||||
#include <opencv2/highgui.hpp>
|
||||
|
||||
const char* keys =
|
||||
"{ help h | | Print help message. }"
|
||||
@@ -33,8 +34,6 @@ using namespace dnn;
|
||||
|
||||
std::vector<std::string> classes;
|
||||
|
||||
Net readNet(const std::string& model, const std::string& config = "", const std::string& framework = "");
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
CommandLineParser parser(argc, argv, keys);
|
||||
@@ -49,6 +48,11 @@ int main(int argc, char** argv)
|
||||
bool swapRB = parser.get<bool>("rgb");
|
||||
int inpWidth = parser.get<int>("width");
|
||||
int inpHeight = parser.get<int>("height");
|
||||
String model = parser.get<String>("model");
|
||||
String config = parser.get<String>("config");
|
||||
String framework = parser.get<String>("framework");
|
||||
int backendId = parser.get<int>("backend");
|
||||
int targetId = parser.get<int>("target");
|
||||
|
||||
// Parse mean values.
|
||||
Scalar mean;
|
||||
@@ -77,22 +81,24 @@ int main(int argc, char** argv)
|
||||
}
|
||||
}
|
||||
|
||||
// Load a model.
|
||||
CV_Assert(parser.has("model"));
|
||||
Net net = readNet(parser.get<String>("model"), parser.get<String>("config"), parser.get<String>("framework"));
|
||||
net.setPreferableBackend(parser.get<int>("backend"));
|
||||
net.setPreferableTarget(parser.get<int>("target"));
|
||||
//! [Read and initialize network]
|
||||
Net net = readNet(model, config, framework);
|
||||
net.setPreferableBackend(backendId);
|
||||
net.setPreferableTarget(targetId);
|
||||
//! [Read and initialize network]
|
||||
|
||||
// Create a window
|
||||
static const std::string kWinName = "Deep learning image classification in OpenCV";
|
||||
namedWindow(kWinName, WINDOW_NORMAL);
|
||||
|
||||
// Open a video file or an image file or a camera stream.
|
||||
//! [Open a video file or an image file or a camera stream]
|
||||
VideoCapture cap;
|
||||
if (parser.has("input"))
|
||||
cap.open(parser.get<String>("input"));
|
||||
else
|
||||
cap.open(0);
|
||||
//! [Open a video file or an image file or a camera stream]
|
||||
|
||||
// Process frames.
|
||||
Mat frame, blob;
|
||||
@@ -105,24 +111,29 @@ int main(int argc, char** argv)
|
||||
break;
|
||||
}
|
||||
|
||||
// Create a 4D blob from a frame.
|
||||
//! [Create a 4D blob from a frame]
|
||||
blobFromImage(frame, blob, scale, Size(inpWidth, inpHeight), mean, swapRB, false);
|
||||
//! [Create a 4D blob from a frame]
|
||||
|
||||
// Run a model.
|
||||
//! [Set input blob]
|
||||
net.setInput(blob);
|
||||
Mat out = net.forward();
|
||||
out = out.reshape(1, 1);
|
||||
//! [Set input blob]
|
||||
//! [Make forward pass]
|
||||
Mat prob = net.forward();
|
||||
//! [Make forward pass]
|
||||
|
||||
// Get a class with a highest score.
|
||||
//! [Get a class with a highest score]
|
||||
Point classIdPoint;
|
||||
double confidence;
|
||||
minMaxLoc(out, 0, &confidence, 0, &classIdPoint);
|
||||
minMaxLoc(prob.reshape(1, 1), 0, &confidence, 0, &classIdPoint);
|
||||
int classId = classIdPoint.x;
|
||||
//! [Get a class with a highest score]
|
||||
|
||||
// Put efficiency information.
|
||||
std::vector<double> layersTimes;
|
||||
double t = net.getPerfProfile(layersTimes);
|
||||
std::string label = format("Inference time: %.2f", t * 1000 / getTickFrequency());
|
||||
double freq = getTickFrequency() / 1000;
|
||||
double t = net.getPerfProfile(layersTimes) / freq;
|
||||
std::string label = format("Inference time: %.2f ms", t);
|
||||
putText(frame, label, Point(0, 15), FONT_HERSHEY_SIMPLEX, 0.5, Scalar(0, 255, 0));
|
||||
|
||||
// Print predicted class.
|
||||
@@ -135,19 +146,3 @@ int main(int argc, char** argv)
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
Net readNet(const std::string& model, const std::string& config, const std::string& framework)
|
||||
{
|
||||
std::string modelExt = model.substr(model.rfind('.'));
|
||||
if (framework == "caffe" || modelExt == ".caffemodel")
|
||||
return readNetFromCaffe(config, model);
|
||||
else if (framework == "tensorflow" || modelExt == ".pb")
|
||||
return readNetFromTensorflow(model, config);
|
||||
else if (framework == "torch" || modelExt == ".t7" || modelExt == ".net")
|
||||
return readNetFromTorch(model);
|
||||
else if (framework == "darknet" || modelExt == ".weights")
|
||||
return readNetFromDarknet(config, model);
|
||||
else
|
||||
CV_Error(Error::StsError, "Cannot determine an origin framework of model from file " + model);
|
||||
return Net();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user