diff --git a/modules/ml/include/opencv2/ml.hpp b/modules/ml/include/opencv2/ml.hpp index adc437e808..669e2d004e 100644 --- a/modules/ml/include/opencv2/ml.hpp +++ b/modules/ml/include/opencv2/ml.hpp @@ -1206,6 +1206,17 @@ public: */ CV_WRAP virtual Mat getVarImportance() const = 0; + /** Returns the result of each individual tree in the forest. + In case the model is a regression problem, the method will return each of the trees' + results for each of the sample cases. If the model is a classifier, it will return + a Mat with samples + 1 rows, where the first row gives the class number and the + following rows return the votes each class had for each sample. + @param samples Array containg the samples for which votes will be calculated. + @param results Array where the result of the calculation will be written. + @param flags Flags for defining the type of RTrees. + */ + CV_WRAP void getVotes(InputArray samples, OutputArray results, int flags) const; + /** Creates the empty model. Use StatModel::train to train the model, StatModel::train to create and train the model, Algorithm::load to load the pre-trained model. diff --git a/modules/ml/src/rtrees.cpp b/modules/ml/src/rtrees.cpp index 65fe6827a7..fa2a23950f 100644 --- a/modules/ml/src/rtrees.cpp +++ b/modules/ml/src/rtrees.cpp @@ -349,6 +349,60 @@ public: } } + void getVotes( InputArray input, OutputArray output, int flags ) const + { + CV_Assert( !roots.empty() ); + int nclasses = (int)classLabels.size(), ntrees = (int)roots.size(); + Mat samples = input.getMat(), results; + int i, j, nsamples = samples.rows; + + int predictType = flags & PREDICT_MASK; + if( predictType == PREDICT_AUTO ) + { + predictType = !_isClassifier || (classLabels.size() == 2 && (flags & RAW_OUTPUT) != 0) ? + PREDICT_SUM : PREDICT_MAX_VOTE; + } + + if( predictType == PREDICT_SUM ) + { + output.create(nsamples, ntrees, CV_32F); + results = output.getMat(); + for( i = 0; i < nsamples; i++ ) + { + for( j = 0; j < ntrees; j++ ) + { + float val = predictTrees( Range(j, j+1), samples.row(i), flags); + results.at (i, j) = val; + } + } + } else + { + vector votes; + output.create(nsamples+1, nclasses, CV_32S); + results = output.getMat(); + + for ( j = 0; j < nclasses; j++) + { + results.at (0, j) = classLabels[j]; + } + + for( i = 0; i < nsamples; i++ ) + { + votes.clear(); + for( j = 0; j < ntrees; j++ ) + { + int val = (int)predictTrees( Range(j, j+1), samples.row(i), flags); + votes.push_back(val); + } + + for ( j = 0; j < nclasses; j++) + { + results.at (i+1, j) = (int)std::count(votes.begin(), votes.end(), classLabels[j]); + } + } + } + } + RTreeParams rparams; double oobError; vector varImportance; @@ -401,6 +455,11 @@ public: impl.read(fn); } + void getVotes_( InputArray samples, OutputArray results, int flags ) const + { + impl.getVotes(samples, results, flags); + } + Mat getVarImportance() const { return Mat_(impl.varImportance, true); } int getVarCount() const { return impl.getVarCount(); } @@ -427,6 +486,14 @@ Ptr RTrees::load(const String& filepath, const String& nodeName) return Algorithm::load(filepath, nodeName); } +void RTrees::getVotes(InputArray input, OutputArray output, int flags) const +{ + const RTreesImpl* this_ = dynamic_cast(this); + if(!this_) + CV_Error(Error::StsNotImplemented, "the class is not RTreesImpl"); + return this_->getVotes_(input, output, flags); +} + }} // End of file. diff --git a/modules/ml/test/test_mltests.cpp b/modules/ml/test/test_mltests.cpp index 70cc0f7ecb..719333140c 100644 --- a/modules/ml/test/test_mltests.cpp +++ b/modules/ml/test/test_mltests.cpp @@ -172,4 +172,49 @@ TEST(ML_NBAYES, regression_5911) EXPECT_EQ(sum(P1 == P3)[0], 255 * P3.total()); } +TEST(ML_RTrees, getVotes) +{ + int n = 12; + int count, i; + int label_size = 3; + int predicted_class = 0; + int max_votes = -1; + int val; + // RTrees for classification + Ptr rt = cv::ml::RTrees::create(); + + //data + Mat data(n, 4, CV_32F); + randu(data, 0, 10); + + //labels + Mat labels = (Mat_(n,1) << 0,0,0,0, 1,1,1,1, 2,2,2,2); + + rt->train(data, ml::ROW_SAMPLE, labels); + + //run function + Mat test(1, 4, CV_32F); + Mat result; + randu(test, 0, 10); + rt->getVotes(test, result, 0); + + //count vote amount and find highest vote + count = 0; + const int* result_row = result.ptr(1); + for( i = 0; i < label_size; i++ ) + { + val = result_row[i]; + //predicted_class = max_votes < val? i; + if( max_votes < val ) + { + max_votes = val; + predicted_class = i; + } + count += val; + } + + EXPECT_EQ(count, (int)rt->getRoots().size()); + EXPECT_EQ(result.at(0, predicted_class), rt->predict(test)); +} + /* End of file. */