finished for one sample

Finished with several samples support, need regression testing

Gave a more relevant name to function (getVotes)

Finished implicit implementation

Removed printf, finished regresion testing

Fixed conversion warning

Finished test for Rtrees

Fixed documentation

Initialized variable

Added doxygen documentation

Added parameter name
This commit is contained in:
mrquorr
2017-01-31 23:31:10 -06:00
parent ec47a0a6de
commit d8425d8881
3 changed files with 123 additions and 0 deletions
+67
View File
@@ -349,6 +349,60 @@ public:
}
}
void getVotes( InputArray input, OutputArray output, int flags ) const
{
CV_Assert( !roots.empty() );
int nclasses = (int)classLabels.size(), ntrees = (int)roots.size();
Mat samples = input.getMat(), results;
int i, j, nsamples = samples.rows;
int predictType = flags & PREDICT_MASK;
if( predictType == PREDICT_AUTO )
{
predictType = !_isClassifier || (classLabels.size() == 2 && (flags & RAW_OUTPUT) != 0) ?
PREDICT_SUM : PREDICT_MAX_VOTE;
}
if( predictType == PREDICT_SUM )
{
output.create(nsamples, ntrees, CV_32F);
results = output.getMat();
for( i = 0; i < nsamples; i++ )
{
for( j = 0; j < ntrees; j++ )
{
float val = predictTrees( Range(j, j+1), samples.row(i), flags);
results.at<float> (i, j) = val;
}
}
} else
{
vector<int> votes;
output.create(nsamples+1, nclasses, CV_32S);
results = output.getMat();
for ( j = 0; j < nclasses; j++)
{
results.at<int> (0, j) = classLabels[j];
}
for( i = 0; i < nsamples; i++ )
{
votes.clear();
for( j = 0; j < ntrees; j++ )
{
int val = (int)predictTrees( Range(j, j+1), samples.row(i), flags);
votes.push_back(val);
}
for ( j = 0; j < nclasses; j++)
{
results.at<int> (i+1, j) = (int)std::count(votes.begin(), votes.end(), classLabels[j]);
}
}
}
}
RTreeParams rparams;
double oobError;
vector<float> varImportance;
@@ -401,6 +455,11 @@ public:
impl.read(fn);
}
void getVotes_( InputArray samples, OutputArray results, int flags ) const
{
impl.getVotes(samples, results, flags);
}
Mat getVarImportance() const { return Mat_<float>(impl.varImportance, true); }
int getVarCount() const { return impl.getVarCount(); }
@@ -427,6 +486,14 @@ Ptr<RTrees> RTrees::load(const String& filepath, const String& nodeName)
return Algorithm::load<RTrees>(filepath, nodeName);
}
void RTrees::getVotes(InputArray input, OutputArray output, int flags) const
{
const RTreesImpl* this_ = dynamic_cast<const RTreesImpl*>(this);
if(!this_)
CV_Error(Error::StsNotImplemented, "the class is not RTreesImpl");
return this_->getVotes_(input, output, flags);
}
}}
// End of file.