diff --git a/modules/ml/src/kdtree.cpp b/modules/ml/src/kdtree.cpp index 1ab8400936..a80e12964a 100644 --- a/modules/ml/src/kdtree.cpp +++ b/modules/ml/src/kdtree.cpp @@ -101,7 +101,7 @@ medianPartition( size_t* ofs, int a, int b, const float* vals ) int i0 = a, i1 = (a+b)/2, i2 = b; float v0 = vals[ofs[i0]], v1 = vals[ofs[i1]], v2 = vals[ofs[i2]]; int ip = v0 < v1 ? (v1 < v2 ? i1 : v0 < v2 ? i2 : i0) : - v0 < v2 ? i0 : (v1 < v2 ? i2 : i1); + v0 < v2 ? (v1 == v0 ? i2 : i0): (v1 < v2 ? i2 : i1); float pivot = vals[ofs[ip]]; std::swap(ofs[ip], ofs[i2]); @@ -131,7 +131,6 @@ medianPartition( size_t* ofs, int a, int b, const float* vals ) CV_Assert(vals[ofs[k]] >= pivot); more += vals[ofs[k]] > pivot; } - CV_Assert(std::abs(more - less) <= 1); return vals[ofs[middle]]; } diff --git a/modules/ml/src/knearest.cpp b/modules/ml/src/knearest.cpp index ca23d0f4d6..3d8f9b5d2e 100644 --- a/modules/ml/src/knearest.cpp +++ b/modules/ml/src/knearest.cpp @@ -381,36 +381,23 @@ public: Mat res, nr, d; if( _results.needed() ) { - _results.create(testcount, 1, CV_32F); res = _results.getMat(); } if( _neighborResponses.needed() ) { - _neighborResponses.create(testcount, k, CV_32F); nr = _neighborResponses.getMat(); } if( _dists.needed() ) { - _dists.create(testcount, k, CV_32F); d = _dists.getMat(); } for (int i=0; ii) - { - _res = res.row(i); - } - if (nr.rows>i) - { - _nr = nr.row(i); - } - if (d.rows>i) - { - _d = d.row(i); - } tr.findNearest(test_samples.row(i), k, Emax, _res, _nr, _d, noArray()); + res.push_back(_res.t()); + _results.assign(res); } return result; // currently always 0 diff --git a/modules/ml/test/test_knearest.cpp b/modules/ml/test/test_knearest.cpp index 49e6b0d12a..80baed9626 100644 --- a/modules/ml/test/test_knearest.cpp +++ b/modules/ml/test/test_knearest.cpp @@ -37,18 +37,31 @@ TEST(ML_KNearest, accuracy) EXPECT_LE(err, 0.01f); } { - // TODO: broken -#if 0 SCOPED_TRACE("KDTree"); - Mat bestLabels; + Mat neighborIndexes; float err = 1000; Ptr knn = KNearest::create(); knn->setAlgorithmType(KNearest::KDTREE); knn->train(trainData, ml::ROW_SAMPLE, trainLabels); - knn->findNearest(testData, 4, bestLabels); + knn->findNearest(testData, 4, neighborIndexes); + Mat bestLabels; + // The output of the KDTree are the neighbor indexes, not actual class labels + // so we need to do some extra work to get actual predictions + for(int row_num = 0; row_num < neighborIndexes.rows; ++row_num){ + vector labels; + for(int index = 0; index < neighborIndexes.row(row_num).cols; ++index) { + labels.push_back(trainLabels.at(neighborIndexes.row(row_num).at(0, index) , 0)); + } + // computing the mode of the output class predictions to determine overall prediction + std::vector histogram(3,0); + for( int i=0; i<3; ++i ) + ++histogram[ static_cast(labels[i]) ]; + int bestLabel = static_cast(std::max_element( histogram.begin(), histogram.end() ) - histogram.begin()); + bestLabels.push_back(bestLabel); + } + bestLabels.convertTo(bestLabels, testLabels.type()); EXPECT_TRUE(calcErr( bestLabels, testLabels, sizes, err, true )); EXPECT_LE(err, 0.01f); -#endif } } @@ -74,4 +87,26 @@ TEST(ML_KNearest, regression_12347) EXPECT_EQ(2, zBestLabels.at(1,0)); } +TEST(ML_KNearest, bug_11877) +{ + Mat trainData = (Mat_(5,2) << 3, 3, 3, 3, 4, 4, 4, 4, 4, 4); + Mat trainLabels = (Mat_(5,1) << 0, 0, 1, 1, 1); + + Ptr knnKdt = KNearest::create(); + knnKdt->setAlgorithmType(KNearest::KDTREE); + knnKdt->setIsClassifier(true); + + knnKdt->train(trainData, ml::ROW_SAMPLE, trainLabels); + + Mat testData = (Mat_(2,2) << 3.1, 3.1, 4, 4.1); + Mat testLabels = (Mat_(2,1) << 0, 1); + Mat result; + + knnKdt->findNearest(testData, 1, result); + + EXPECT_EQ(1, int(result.at(0, 0))); + EXPECT_EQ(2, int(result.at(1, 0))); + EXPECT_EQ(0, trainLabels.at(result.at(0, 0), 0)); +} + }} // namespace