Merge pull request #18061 from danielenricocahall:fix-kd-tree
Fix KD Tree kNN Implementation * Make KDTree mode in kNN functional remove docs and revert change Make KDTree mode in kNN functional spacing Make KDTree mode in kNN functional fix window compilations warnings Make KDTree mode in kNN functional fix window compilations warnings Make KDTree mode in kNN functional casting Make KDTree mode in kNN functional formatting Make KDTree mode in kNN functional * test coding style
This commit is contained in:
parent
fa11b98800
commit
20b23da8e2
@ -101,7 +101,7 @@ medianPartition( size_t* ofs, int a, int b, const float* vals )
|
||||
int i0 = a, i1 = (a+b)/2, i2 = b;
|
||||
float v0 = vals[ofs[i0]], v1 = vals[ofs[i1]], v2 = vals[ofs[i2]];
|
||||
int ip = v0 < v1 ? (v1 < v2 ? i1 : v0 < v2 ? i2 : i0) :
|
||||
v0 < v2 ? i0 : (v1 < v2 ? i2 : i1);
|
||||
v0 < v2 ? (v1 == v0 ? i2 : i0): (v1 < v2 ? i2 : i1);
|
||||
float pivot = vals[ofs[ip]];
|
||||
std::swap(ofs[ip], ofs[i2]);
|
||||
|
||||
@ -131,7 +131,6 @@ medianPartition( size_t* ofs, int a, int b, const float* vals )
|
||||
CV_Assert(vals[ofs[k]] >= pivot);
|
||||
more += vals[ofs[k]] > pivot;
|
||||
}
|
||||
CV_Assert(std::abs(more - less) <= 1);
|
||||
|
||||
return vals[ofs[middle]];
|
||||
}
|
||||
|
||||
@ -381,36 +381,23 @@ public:
|
||||
Mat res, nr, d;
|
||||
if( _results.needed() )
|
||||
{
|
||||
_results.create(testcount, 1, CV_32F);
|
||||
res = _results.getMat();
|
||||
}
|
||||
if( _neighborResponses.needed() )
|
||||
{
|
||||
_neighborResponses.create(testcount, k, CV_32F);
|
||||
nr = _neighborResponses.getMat();
|
||||
}
|
||||
if( _dists.needed() )
|
||||
{
|
||||
_dists.create(testcount, k, CV_32F);
|
||||
d = _dists.getMat();
|
||||
}
|
||||
|
||||
for (int i=0; i<test_samples.rows; ++i)
|
||||
{
|
||||
Mat _res, _nr, _d;
|
||||
if (res.rows>i)
|
||||
{
|
||||
_res = res.row(i);
|
||||
}
|
||||
if (nr.rows>i)
|
||||
{
|
||||
_nr = nr.row(i);
|
||||
}
|
||||
if (d.rows>i)
|
||||
{
|
||||
_d = d.row(i);
|
||||
}
|
||||
tr.findNearest(test_samples.row(i), k, Emax, _res, _nr, _d, noArray());
|
||||
res.push_back(_res.t());
|
||||
_results.assign(res);
|
||||
}
|
||||
|
||||
return result; // currently always 0
|
||||
|
||||
@ -37,18 +37,31 @@ TEST(ML_KNearest, accuracy)
|
||||
EXPECT_LE(err, 0.01f);
|
||||
}
|
||||
{
|
||||
// TODO: broken
|
||||
#if 0
|
||||
SCOPED_TRACE("KDTree");
|
||||
Mat bestLabels;
|
||||
Mat neighborIndexes;
|
||||
float err = 1000;
|
||||
Ptr<KNearest> knn = KNearest::create();
|
||||
knn->setAlgorithmType(KNearest::KDTREE);
|
||||
knn->train(trainData, ml::ROW_SAMPLE, trainLabels);
|
||||
knn->findNearest(testData, 4, bestLabels);
|
||||
knn->findNearest(testData, 4, neighborIndexes);
|
||||
Mat bestLabels;
|
||||
// The output of the KDTree are the neighbor indexes, not actual class labels
|
||||
// so we need to do some extra work to get actual predictions
|
||||
for(int row_num = 0; row_num < neighborIndexes.rows; ++row_num){
|
||||
vector<float> labels;
|
||||
for(int index = 0; index < neighborIndexes.row(row_num).cols; ++index) {
|
||||
labels.push_back(trainLabels.at<float>(neighborIndexes.row(row_num).at<int>(0, index) , 0));
|
||||
}
|
||||
// computing the mode of the output class predictions to determine overall prediction
|
||||
std::vector<int> histogram(3,0);
|
||||
for( int i=0; i<3; ++i )
|
||||
++histogram[ static_cast<int>(labels[i]) ];
|
||||
int bestLabel = static_cast<int>(std::max_element( histogram.begin(), histogram.end() ) - histogram.begin());
|
||||
bestLabels.push_back(bestLabel);
|
||||
}
|
||||
bestLabels.convertTo(bestLabels, testLabels.type());
|
||||
EXPECT_TRUE(calcErr( bestLabels, testLabels, sizes, err, true ));
|
||||
EXPECT_LE(err, 0.01f);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
@ -74,4 +87,26 @@ TEST(ML_KNearest, regression_12347)
|
||||
EXPECT_EQ(2, zBestLabels.at<float>(1,0));
|
||||
}
|
||||
|
||||
TEST(ML_KNearest, bug_11877)
|
||||
{
|
||||
Mat trainData = (Mat_<float>(5,2) << 3, 3, 3, 3, 4, 4, 4, 4, 4, 4);
|
||||
Mat trainLabels = (Mat_<float>(5,1) << 0, 0, 1, 1, 1);
|
||||
|
||||
Ptr<KNearest> knnKdt = KNearest::create();
|
||||
knnKdt->setAlgorithmType(KNearest::KDTREE);
|
||||
knnKdt->setIsClassifier(true);
|
||||
|
||||
knnKdt->train(trainData, ml::ROW_SAMPLE, trainLabels);
|
||||
|
||||
Mat testData = (Mat_<float>(2,2) << 3.1, 3.1, 4, 4.1);
|
||||
Mat testLabels = (Mat_<int>(2,1) << 0, 1);
|
||||
Mat result;
|
||||
|
||||
knnKdt->findNearest(testData, 1, result);
|
||||
|
||||
EXPECT_EQ(1, int(result.at<int>(0, 0)));
|
||||
EXPECT_EQ(2, int(result.at<int>(1, 0)));
|
||||
EXPECT_EQ(0, trainLabels.at<int>(result.at<int>(0, 0), 0));
|
||||
}
|
||||
|
||||
}} // namespace
|
||||
|
||||
Loading…
Reference in New Issue
Block a user