finished for one sample
Finished with several samples support, need regression testing Gave a more relevant name to function (getVotes) Finished implicit implementation Removed printf, finished regresion testing Fixed conversion warning Finished test for Rtrees Fixed documentation Initialized variable Added doxygen documentation Added parameter name
This commit is contained in:
@@ -349,6 +349,60 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
void getVotes( InputArray input, OutputArray output, int flags ) const
|
||||
{
|
||||
CV_Assert( !roots.empty() );
|
||||
int nclasses = (int)classLabels.size(), ntrees = (int)roots.size();
|
||||
Mat samples = input.getMat(), results;
|
||||
int i, j, nsamples = samples.rows;
|
||||
|
||||
int predictType = flags & PREDICT_MASK;
|
||||
if( predictType == PREDICT_AUTO )
|
||||
{
|
||||
predictType = !_isClassifier || (classLabels.size() == 2 && (flags & RAW_OUTPUT) != 0) ?
|
||||
PREDICT_SUM : PREDICT_MAX_VOTE;
|
||||
}
|
||||
|
||||
if( predictType == PREDICT_SUM )
|
||||
{
|
||||
output.create(nsamples, ntrees, CV_32F);
|
||||
results = output.getMat();
|
||||
for( i = 0; i < nsamples; i++ )
|
||||
{
|
||||
for( j = 0; j < ntrees; j++ )
|
||||
{
|
||||
float val = predictTrees( Range(j, j+1), samples.row(i), flags);
|
||||
results.at<float> (i, j) = val;
|
||||
}
|
||||
}
|
||||
} else
|
||||
{
|
||||
vector<int> votes;
|
||||
output.create(nsamples+1, nclasses, CV_32S);
|
||||
results = output.getMat();
|
||||
|
||||
for ( j = 0; j < nclasses; j++)
|
||||
{
|
||||
results.at<int> (0, j) = classLabels[j];
|
||||
}
|
||||
|
||||
for( i = 0; i < nsamples; i++ )
|
||||
{
|
||||
votes.clear();
|
||||
for( j = 0; j < ntrees; j++ )
|
||||
{
|
||||
int val = (int)predictTrees( Range(j, j+1), samples.row(i), flags);
|
||||
votes.push_back(val);
|
||||
}
|
||||
|
||||
for ( j = 0; j < nclasses; j++)
|
||||
{
|
||||
results.at<int> (i+1, j) = (int)std::count(votes.begin(), votes.end(), classLabels[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
RTreeParams rparams;
|
||||
double oobError;
|
||||
vector<float> varImportance;
|
||||
@@ -401,6 +455,11 @@ public:
|
||||
impl.read(fn);
|
||||
}
|
||||
|
||||
void getVotes_( InputArray samples, OutputArray results, int flags ) const
|
||||
{
|
||||
impl.getVotes(samples, results, flags);
|
||||
}
|
||||
|
||||
Mat getVarImportance() const { return Mat_<float>(impl.varImportance, true); }
|
||||
int getVarCount() const { return impl.getVarCount(); }
|
||||
|
||||
@@ -427,6 +486,14 @@ Ptr<RTrees> RTrees::load(const String& filepath, const String& nodeName)
|
||||
return Algorithm::load<RTrees>(filepath, nodeName);
|
||||
}
|
||||
|
||||
void RTrees::getVotes(InputArray input, OutputArray output, int flags) const
|
||||
{
|
||||
const RTreesImpl* this_ = dynamic_cast<const RTreesImpl*>(this);
|
||||
if(!this_)
|
||||
CV_Error(Error::StsNotImplemented, "the class is not RTreesImpl");
|
||||
return this_->getVotes_(input, output, flags);
|
||||
}
|
||||
|
||||
}}
|
||||
|
||||
// End of file.
|
||||
|
||||
Reference in New Issue
Block a user