Merge pull request #8116 from mrquorr:master
This commit is contained in:
@@ -349,6 +349,60 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
void getVotes( InputArray input, OutputArray output, int flags ) const
|
||||
{
|
||||
CV_Assert( !roots.empty() );
|
||||
int nclasses = (int)classLabels.size(), ntrees = (int)roots.size();
|
||||
Mat samples = input.getMat(), results;
|
||||
int i, j, nsamples = samples.rows;
|
||||
|
||||
int predictType = flags & PREDICT_MASK;
|
||||
if( predictType == PREDICT_AUTO )
|
||||
{
|
||||
predictType = !_isClassifier || (classLabels.size() == 2 && (flags & RAW_OUTPUT) != 0) ?
|
||||
PREDICT_SUM : PREDICT_MAX_VOTE;
|
||||
}
|
||||
|
||||
if( predictType == PREDICT_SUM )
|
||||
{
|
||||
output.create(nsamples, ntrees, CV_32F);
|
||||
results = output.getMat();
|
||||
for( i = 0; i < nsamples; i++ )
|
||||
{
|
||||
for( j = 0; j < ntrees; j++ )
|
||||
{
|
||||
float val = predictTrees( Range(j, j+1), samples.row(i), flags);
|
||||
results.at<float> (i, j) = val;
|
||||
}
|
||||
}
|
||||
} else
|
||||
{
|
||||
vector<int> votes;
|
||||
output.create(nsamples+1, nclasses, CV_32S);
|
||||
results = output.getMat();
|
||||
|
||||
for ( j = 0; j < nclasses; j++)
|
||||
{
|
||||
results.at<int> (0, j) = classLabels[j];
|
||||
}
|
||||
|
||||
for( i = 0; i < nsamples; i++ )
|
||||
{
|
||||
votes.clear();
|
||||
for( j = 0; j < ntrees; j++ )
|
||||
{
|
||||
int val = (int)predictTrees( Range(j, j+1), samples.row(i), flags);
|
||||
votes.push_back(val);
|
||||
}
|
||||
|
||||
for ( j = 0; j < nclasses; j++)
|
||||
{
|
||||
results.at<int> (i+1, j) = (int)std::count(votes.begin(), votes.end(), classLabels[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
RTreeParams rparams;
|
||||
double oobError;
|
||||
vector<float> varImportance;
|
||||
@@ -401,6 +455,11 @@ public:
|
||||
impl.read(fn);
|
||||
}
|
||||
|
||||
void getVotes_( InputArray samples, OutputArray results, int flags ) const
|
||||
{
|
||||
impl.getVotes(samples, results, flags);
|
||||
}
|
||||
|
||||
Mat getVarImportance() const { return Mat_<float>(impl.varImportance, true); }
|
||||
int getVarCount() const { return impl.getVarCount(); }
|
||||
|
||||
@@ -427,6 +486,14 @@ Ptr<RTrees> RTrees::load(const String& filepath, const String& nodeName)
|
||||
return Algorithm::load<RTrees>(filepath, nodeName);
|
||||
}
|
||||
|
||||
void RTrees::getVotes(InputArray input, OutputArray output, int flags) const
|
||||
{
|
||||
const RTreesImpl* this_ = dynamic_cast<const RTreesImpl*>(this);
|
||||
if(!this_)
|
||||
CV_Error(Error::StsNotImplemented, "the class is not RTreesImpl");
|
||||
return this_->getVotes_(input, output, flags);
|
||||
}
|
||||
|
||||
}}
|
||||
|
||||
// End of file.
|
||||
|
||||
Reference in New Issue
Block a user