Merge pull request #18126 from danielenricocahall:add-oob-error-sample-weighting

Account for sample weights in calculating OOB Error

* account for sample weights in oob error calculation

* redefine oob error functions

* fix ABI compatibility
This commit is contained in:
Danny
2020-09-05 14:52:10 -04:00
committed by GitHub
parent 3835ab394e
commit c31164bf1e
3 changed files with 79 additions and 2 deletions
+46
View File
@@ -51,4 +51,50 @@ TEST(ML_RTrees, getVotes)
EXPECT_EQ(result.at<float>(0, predicted_class), rt->predict(test));
}
TEST(ML_RTrees, 11142_sample_weights_regression)
{
int n = 3;
// RTrees for regression
Ptr<ml::RTrees> rt = cv::ml::RTrees::create();
//simple regression problem of x -> 2x
Mat data = (Mat_<float>(n,1) << 1, 2, 3);
Mat values = (Mat_<float>(n,1) << 2, 4, 6);
Mat weights = (Mat_<float>(n, 1) << 10, 10, 10);
Ptr<TrainData> trainData = TrainData::create(data, ml::ROW_SAMPLE, values);
rt->train(trainData);
double error_without_weights = round(rt->getOOBError());
rt->clear();
Ptr<TrainData> trainDataWithWeights = TrainData::create(data, ml::ROW_SAMPLE, values, Mat(), Mat(), weights );
rt->train(trainDataWithWeights);
double error_with_weights = round(rt->getOOBError());
// error with weights should be larger than error without weights
EXPECT_GE(error_with_weights, error_without_weights);
}
TEST(ML_RTrees, 11142_sample_weights_classification)
{
int n = 12;
// RTrees for classification
Ptr<ml::RTrees> rt = cv::ml::RTrees::create();
Mat data(n, 4, CV_32F);
randu(data, 0, 10);
Mat labels = (Mat_<int>(n,1) << 0,0,0,0, 1,1,1,1, 2,2,2,2);
Mat weights = (Mat_<float>(n, 1) << 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10);
rt->train(data, ml::ROW_SAMPLE, labels);
rt->clear();
double error_without_weights = round(rt->getOOBError());
Ptr<TrainData> trainDataWithWeights = TrainData::create(data, ml::ROW_SAMPLE, labels, Mat(), Mat(), weights );
rt->train(data, ml::ROW_SAMPLE, labels);
double error_with_weights = round(rt->getOOBError());
std::cout << error_without_weights << std::endl;
std::cout << error_with_weights << std::endl;
// error with weights should be larger than error without weights
EXPECT_GE(error_with_weights, error_without_weights);
}
}} // namespace