diff --git a/modules/dnn/include/opencv2/dnn/dnn.hpp b/modules/dnn/include/opencv2/dnn/dnn.hpp index 89b4770e8d..8a2ae2337e 100644 --- a/modules/dnn/include/opencv2/dnn/dnn.hpp +++ b/modules/dnn/include/opencv2/dnn/dnn.hpp @@ -726,13 +726,17 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN * @param src Path to origin model from Caffe framework contains single * precision floating point weights (usually has `.caffemodel` extension). * @param dst Path to destination model with updated weights. + * @param layersTypes Set of layers types which parameters will be converted. + * By default, converts only Convolutional and Fully-Connected layers' + * weights. * * @note Shrinked model has no origin float32 weights so it can't be used * in origin Caffe framework anymore. However the structure of data * is taken from NVidia's Caffe fork: https://github.com/NVIDIA/caffe. * So the resulting model may be used there. */ - CV_EXPORTS_W void shrinkCaffeModel(const String& src, const String& dst); + CV_EXPORTS_W void shrinkCaffeModel(const String& src, const String& dst, + const std::vector& layersTypes = std::vector()); /** @brief Performs non maximum suppression given boxes and corresponding scores. diff --git a/modules/dnn/src/caffe/caffe_shrinker.cpp b/modules/dnn/src/caffe/caffe_shrinker.cpp index f9c50dbafd..98df108c0e 100644 --- a/modules/dnn/src/caffe/caffe_shrinker.cpp +++ b/modules/dnn/src/caffe/caffe_shrinker.cpp @@ -17,16 +17,27 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN #ifdef HAVE_PROTOBUF -void shrinkCaffeModel(const String& src, const String& dst) +void shrinkCaffeModel(const String& src, const String& dst, const std::vector& layersTypes) { CV_TRACE_FUNCTION(); + std::vector types(layersTypes); + if (types.empty()) + { + types.push_back("Convolution"); + types.push_back("InnerProduct"); + } + caffe::NetParameter net; ReadNetParamsFromBinaryFileOrDie(src.c_str(), &net); for (int i = 0; i < net.layer_size(); ++i) { caffe::LayerParameter* lp = net.mutable_layer(i); + if (std::find(types.begin(), types.end(), lp->type()) == types.end()) + { + continue; + } for (int j = 0; j < lp->blobs_size(); ++j) { caffe::BlobProto* blob = lp->mutable_blobs(j); @@ -54,7 +65,7 @@ void shrinkCaffeModel(const String& src, const String& dst) #else -void shrinkCaffeModel(const String& src, const String& dst) +void shrinkCaffeModel(const String& src, const String& dst, const std::vector& types) { CV_Error(cv::Error::StsNotImplemented, "libprotobuf required to import data from Caffe models"); }