From 6e84abc746757c6de75e7c5a303b7c68f721575f Mon Sep 17 00:00:00 2001 From: Alexander Alekhin Date: Fri, 17 Aug 2018 16:45:31 +0300 Subject: [PATCH] ml: don't use "getSubVector()" with 2D matrix It is designed for 1D vectors only --- modules/ml/include/opencv2/ml.hpp | 13 +++- modules/ml/src/data.cpp | 112 +++++++++++++++--------------- 2 files changed, 67 insertions(+), 58 deletions(-) diff --git a/modules/ml/include/opencv2/ml.hpp b/modules/ml/include/opencv2/ml.hpp index 31e7427735..f2ca78fe91 100644 --- a/modules/ml/include/opencv2/ml.hpp +++ b/modules/ml/include/opencv2/ml.hpp @@ -239,7 +239,18 @@ public: /** @brief Returns vector of symbolic names captured in loadFromCSV() */ CV_WRAP void getNames(std::vector& names) const; - CV_WRAP static Mat getSubVector(const Mat& vec, const Mat& idx); + /** @brief Extract from 1D vector elements specified by passed indexes. + @param vec input vector (supported types: CV_32S, CV_32F, CV_64F) + @param idx 1D index vector + */ + static CV_WRAP Mat getSubVector(const Mat& vec, const Mat& idx); + + /** @brief Extract from matrix rows/cols specified by passed indexes. + @param matrix input matrix (supported types: CV_32S, CV_32F, CV_64F) + @param idx 1D index vector + @param layout specifies to extract rows (cv::ml::ROW_SAMPLES) or to extract columns (cv::ml::COL_SAMPLES) + */ + static CV_WRAP Mat getSubMatrix(const Mat& matrix, const Mat& idx, int layout); /** @brief Reads the dataset from a .csv file and returns the ready-to-use training data. diff --git a/modules/ml/src/data.cpp b/modules/ml/src/data.cpp index 852d5c6a5b..a5dd101f1d 100644 --- a/modules/ml/src/data.cpp +++ b/modules/ml/src/data.cpp @@ -43,6 +43,8 @@ #include #include +#include + namespace cv { namespace ml { static const float MISSED_VAL = TrainData::missingValue(); @@ -54,69 +56,65 @@ Mat TrainData::getTestSamples() const { Mat idx = getTestSampleIdx(); Mat samples = getSamples(); - return idx.empty() ? Mat() : getSubVector(samples, idx); + return idx.empty() ? Mat() : getSubMatrix(samples, idx, getLayout()); } Mat TrainData::getSubVector(const Mat& vec, const Mat& idx) { - if( idx.empty() ) - return vec; - int i, j, n = idx.checkVector(1, CV_32S); - int type = vec.type(); - CV_Assert( type == CV_32S || type == CV_32F || type == CV_64F ); - int dims = 1, m; + if (!(vec.cols == 1 || vec.rows == 1)) + CV_LOG_WARNING(NULL, "'getSubVector(const Mat& vec, const Mat& idx)' call with non-1D input is deprecated. It is not designed to work with 2D matrixes (especially with 'cv::ml::COL_SAMPLE' layout)."); + return getSubMatrix(vec, idx, vec.rows == 1 ? cv::ml::COL_SAMPLE : cv::ml::ROW_SAMPLE); +} - if( vec.cols == 1 || vec.rows == 1 ) +template +Mat getSubMatrixImpl(const Mat& m, const Mat& idx, int layout) +{ + int nidx = idx.checkVector(1, CV_32S); + int dims = m.cols, nsamples = m.rows; + + Mat subm; + if (layout == COL_SAMPLE) { - dims = 1; - m = vec.cols + vec.rows - 1; + std::swap(dims, nsamples); + subm.create(dims, nidx, m.type()); } else { - dims = vec.cols; - m = vec.rows; + subm.create(nidx, dims, m.type()); } - Mat subvec; + for (int i = 0; i < nidx; i++) + { + int k = idx.at(i); CV_CheckGE(k, 0, "Bad idx"); CV_CheckLT(k, nsamples, "Bad idx or layout"); + if (dims == 1) + { + subm.at(i) = m.at(k); // at() has "transparent" access for 1D col-based / row-based vectors. + } + else if (layout == COL_SAMPLE) + { + for (int j = 0; j < dims; j++) + subm.at(j, i) = m.at(j, k); + } + else + { + for (int j = 0; j < dims; j++) + subm.at(i, j) = m.at(k, j); + } + } + return subm; +} - if( vec.cols == m ) - subvec.create(dims, n, type); - else - subvec.create(n, dims, type); - if( type == CV_32S ) - for( i = 0; i < n; i++ ) - { - int k = idx.at(i); - CV_Assert( 0 <= k && k < m ); - if( dims == 1 ) - subvec.at(i) = vec.at(k); - else - for( j = 0; j < dims; j++ ) - subvec.at(i, j) = vec.at(k, j); - } - else if( type == CV_32F ) - for( i = 0; i < n; i++ ) - { - int k = idx.at(i); - CV_Assert( 0 <= k && k < m ); - if( dims == 1 ) - subvec.at(i) = vec.at(k); - else - for( j = 0; j < dims; j++ ) - subvec.at(i, j) = vec.at(k, j); - } - else - for( i = 0; i < n; i++ ) - { - int k = idx.at(i); - CV_Assert( 0 <= k && k < m ); - if( dims == 1 ) - subvec.at(i) = vec.at(k); - else - for( j = 0; j < dims; j++ ) - subvec.at(i, j) = vec.at(k, j); - } - return subvec; +Mat TrainData::getSubMatrix(const Mat& m, const Mat& idx, int layout) +{ + if (idx.empty()) + return m; + int type = m.type(); + CV_CheckType(type, type == CV_32S || type == CV_32F || type == CV_64F, ""); + if (type == CV_32S || type == CV_32F) // 32-bit + return getSubMatrixImpl(m, idx, layout); + if (type == CV_64F) // 64-bit + return getSubMatrixImpl(m, idx, layout); + CV_Error(Error::StsInternal, ""); } class TrainDataImpl CV_FINAL : public TrainData @@ -172,30 +170,30 @@ public: } Mat getTrainSampleWeights() const CV_OVERRIDE { - return getSubVector(sampleWeights, getTrainSampleIdx()); + return getSubVector(sampleWeights, getTrainSampleIdx()); // 1D-vector } Mat getTestSampleWeights() const CV_OVERRIDE { Mat idx = getTestSampleIdx(); - return idx.empty() ? Mat() : getSubVector(sampleWeights, idx); + return idx.empty() ? Mat() : getSubVector(sampleWeights, idx); // 1D-vector } Mat getTrainResponses() const CV_OVERRIDE { - return getSubVector(responses, getTrainSampleIdx()); + return getSubMatrix(responses, getTrainSampleIdx(), cv::ml::ROW_SAMPLE); // col-based responses are transposed in setData() } Mat getTrainNormCatResponses() const CV_OVERRIDE { - return getSubVector(normCatResponses, getTrainSampleIdx()); + return getSubMatrix(normCatResponses, getTrainSampleIdx(), cv::ml::ROW_SAMPLE); // like 'responses' } Mat getTestResponses() const CV_OVERRIDE { Mat idx = getTestSampleIdx(); - return idx.empty() ? Mat() : getSubVector(responses, idx); + return idx.empty() ? Mat() : getSubMatrix(responses, idx, cv::ml::ROW_SAMPLE); // col-based responses are transposed in setData() } Mat getTestNormCatResponses() const CV_OVERRIDE { Mat idx = getTestSampleIdx(); - return idx.empty() ? Mat() : getSubVector(normCatResponses, idx); + return idx.empty() ? Mat() : getSubMatrix(normCatResponses, idx, cv::ml::ROW_SAMPLE); // like 'responses' } Mat getNormCatResponses() const CV_OVERRIDE { return normCatResponses; } Mat getClassLabels() const CV_OVERRIDE { return classLabels; }