Merge pull request #8116 from mrquorr:master

This commit is contained in:
Vadim Pisarevsky
2017-03-02 11:07:23 +00:00
3 changed files with 123 additions and 0 deletions
+67
View File
@@ -349,6 +349,60 @@ public:
}
}
void getVotes( InputArray input, OutputArray output, int flags ) const
{
CV_Assert( !roots.empty() );
int nclasses = (int)classLabels.size(), ntrees = (int)roots.size();
Mat samples = input.getMat(), results;
int i, j, nsamples = samples.rows;
int predictType = flags & PREDICT_MASK;
if( predictType == PREDICT_AUTO )
{
predictType = !_isClassifier || (classLabels.size() == 2 && (flags & RAW_OUTPUT) != 0) ?
PREDICT_SUM : PREDICT_MAX_VOTE;
}
if( predictType == PREDICT_SUM )
{
output.create(nsamples, ntrees, CV_32F);
results = output.getMat();
for( i = 0; i < nsamples; i++ )
{
for( j = 0; j < ntrees; j++ )
{
float val = predictTrees( Range(j, j+1), samples.row(i), flags);
results.at<float> (i, j) = val;
}
}
} else
{
vector<int> votes;
output.create(nsamples+1, nclasses, CV_32S);
results = output.getMat();
for ( j = 0; j < nclasses; j++)
{
results.at<int> (0, j) = classLabels[j];
}
for( i = 0; i < nsamples; i++ )
{
votes.clear();
for( j = 0; j < ntrees; j++ )
{
int val = (int)predictTrees( Range(j, j+1), samples.row(i), flags);
votes.push_back(val);
}
for ( j = 0; j < nclasses; j++)
{
results.at<int> (i+1, j) = (int)std::count(votes.begin(), votes.end(), classLabels[j]);
}
}
}
}
RTreeParams rparams;
double oobError;
vector<float> varImportance;
@@ -401,6 +455,11 @@ public:
impl.read(fn);
}
void getVotes_( InputArray samples, OutputArray results, int flags ) const
{
impl.getVotes(samples, results, flags);
}
Mat getVarImportance() const { return Mat_<float>(impl.varImportance, true); }
int getVarCount() const { return impl.getVarCount(); }
@@ -427,6 +486,14 @@ Ptr<RTrees> RTrees::load(const String& filepath, const String& nodeName)
return Algorithm::load<RTrees>(filepath, nodeName);
}
void RTrees::getVotes(InputArray input, OutputArray output, int flags) const
{
const RTreesImpl* this_ = dynamic_cast<const RTreesImpl*>(this);
if(!this_)
CV_Error(Error::StsNotImplemented, "the class is not RTreesImpl");
return this_->getVotes_(input, output, flags);
}
}}
// End of file.