diff --git a/modules/ml/test/test_mltests2.cpp b/modules/ml/test/test_mltests2.cpp index b8c909001e..616a527bfe 100644 --- a/modules/ml/test/test_mltests2.cpp +++ b/modules/ml/test/test_mltests2.cpp @@ -721,5 +721,68 @@ void CV_MLBaseTest::load( const char* filename ) CV_Error( CV_StsNotImplemented, "invalid stat model name"); } + + +TEST(TrainDataGet, layout_ROW_SAMPLE) // Details: #12236 +{ + cv::Mat test = cv::Mat::ones(150, 30, CV_32FC1) * 2; + test.col(3) += Scalar::all(3); + cv::Mat labels = cv::Mat::ones(150, 3, CV_32SC1) * 5; + labels.col(1) += 1; + cv::Ptr train_data = cv::ml::TrainData::create(test, cv::ml::ROW_SAMPLE, labels); + train_data->setTrainTestSplitRatio(0.9); + + Mat tidx = train_data->getTestSampleIdx(); + EXPECT_EQ((size_t)15, tidx.total()); + + Mat tresp = train_data->getTestResponses(); + EXPECT_EQ(15, tresp.rows); + EXPECT_EQ(labels.cols, tresp.cols); + EXPECT_EQ(5, tresp.at(0, 0)) << tresp; + EXPECT_EQ(6, tresp.at(0, 1)) << tresp; + EXPECT_EQ(6, tresp.at(14, 1)) << tresp; + EXPECT_EQ(5, tresp.at(14, 2)) << tresp; + + Mat tsamples = train_data->getTestSamples(); + EXPECT_EQ(15, tsamples.rows); + EXPECT_EQ(test.cols, tsamples.cols); + EXPECT_EQ(2, tsamples.at(0, 0)) << tsamples; + EXPECT_EQ(5, tsamples.at(0, 3)) << tsamples; + EXPECT_EQ(2, tsamples.at(14, test.cols - 1)) << tsamples; + EXPECT_EQ(5, tsamples.at(14, 3)) << tsamples; +} + +TEST(TrainDataGet, layout_COL_SAMPLE) // Details: #12236 +{ + cv::Mat test = cv::Mat::ones(30, 150, CV_32FC1) * 3; + test.row(3) += Scalar::all(3); + cv::Mat labels = cv::Mat::ones(3, 150, CV_32SC1) * 5; + labels.row(1) += 1; + cv::Ptr train_data = cv::ml::TrainData::create(test, cv::ml::COL_SAMPLE, labels); + train_data->setTrainTestSplitRatio(0.9); + + Mat tidx = train_data->getTestSampleIdx(); + EXPECT_EQ((size_t)15, tidx.total()); + + Mat tresp = train_data->getTestResponses(); // always row-based, transposed + EXPECT_EQ(15, tresp.rows); + EXPECT_EQ(labels.rows, tresp.cols); + EXPECT_EQ(5, tresp.at(0, 0)) << tresp; + EXPECT_EQ(6, tresp.at(0, 1)) << tresp; + EXPECT_EQ(6, tresp.at(14, 1)) << tresp; + EXPECT_EQ(5, tresp.at(14, 2)) << tresp; + + + Mat tsamples = train_data->getTestSamples(); + EXPECT_EQ(15, tsamples.cols); + EXPECT_EQ(test.rows, tsamples.rows); + EXPECT_EQ(3, tsamples.at(0, 0)) << tsamples; + EXPECT_EQ(6, tsamples.at(3, 0)) << tsamples; + EXPECT_EQ(6, tsamples.at(3, 14)) << tsamples; + EXPECT_EQ(3, tsamples.at(test.rows - 1, 14)) << tsamples; +} + + + } // namespace /* End of file. */