ml: add checks of empty train data

This commit is contained in:
Alexander Alekhin
2019-09-22 11:11:08 +00:00
parent eabbe38001
commit fef7fc343e
13 changed files with 33 additions and 11 deletions
+2
View File
@@ -920,6 +920,7 @@ public:
bool train( const Ptr<TrainData>& trainData, int flags ) CV_OVERRIDE
{
CV_Assert(!trainData.empty());
const int MAX_ITER = 1000;
const double DEFAULT_EPSILON = FLT_EPSILON;
@@ -955,6 +956,7 @@ public:
}
int train_anneal(const Ptr<TrainData>& trainData)
{
CV_Assert(!trainData.empty());
SimulatedAnnealingANN_MLP s(*this, trainData);
trained = true; // Enable call to CalcError
int iter = simulatedAnnealingSolver(s, params.initialT, params.finalT, params.coolingRatio, params.itePerStep, NULL, params.rEnergy);
+3
View File
@@ -88,6 +88,7 @@ public:
void startTraining( const Ptr<TrainData>& trainData, int flags ) CV_OVERRIDE
{
CV_Assert(!trainData.empty());
DTreesImpl::startTraining(trainData, flags);
sumResult.assign(w->sidx.size(), 0.);
@@ -184,6 +185,7 @@ public:
bool train( const Ptr<TrainData>& trainData, int flags ) CV_OVERRIDE
{
CV_Assert(!trainData.empty());
startTraining(trainData, flags);
int treeidx, ntrees = bparams.weakCount >= 0 ? bparams.weakCount : 10000;
vector<int> sidx = w->sidx;
@@ -482,6 +484,7 @@ public:
bool train( const Ptr<TrainData>& trainData, int flags ) CV_OVERRIDE
{
CV_Assert(!trainData.empty());
return impl.train(trainData, flags);
}
+1
View File
@@ -112,6 +112,7 @@ public:
bool train(const Ptr<TrainData>& data, int) CV_OVERRIDE
{
CV_Assert(!data.empty());
Mat samples = data->getTrainSamples(), labels;
return trainEM(samples, labels, noArray(), noArray());
}
+4 -1
View File
@@ -59,9 +59,10 @@ bool StatModel::empty() const { return !isTrained(); }
int StatModel::getVarCount() const { return 0; }
bool StatModel::train( const Ptr<TrainData>&, int )
bool StatModel::train(const Ptr<TrainData>& trainData, int )
{
CV_TRACE_FUNCTION();
CV_Assert(!trainData.empty());
CV_Error(CV_StsNotImplemented, "");
return false;
}
@@ -69,6 +70,7 @@ bool StatModel::train( const Ptr<TrainData>&, int )
bool StatModel::train( InputArray samples, int layout, InputArray responses )
{
CV_TRACE_FUNCTION();
CV_Assert(!samples.empty());
return train(TrainData::create(samples, layout, responses));
}
@@ -134,6 +136,7 @@ public:
float StatModel::calcError(const Ptr<TrainData>& data, bool testerr, OutputArray _resp) const
{
CV_TRACE_FUNCTION_SKIP_NESTED();
CV_Assert(!data.empty());
Mat samples = data->getSamples();
Mat sidx = testerr ? data->getTestSampleIdx() : data->getTrainSampleIdx();
Mat weights = testerr ? data->getTestSampleWeights() : data->getTrainSampleWeights();
+2
View File
@@ -73,6 +73,7 @@ public:
bool train( const Ptr<TrainData>& data, int flags )
{
CV_Assert(!data.empty());
Mat new_samples = data->getTrainSamples(ROW_SAMPLE);
Mat new_responses;
data->getTrainResponses().convertTo(new_responses, CV_32F);
@@ -494,6 +495,7 @@ public:
bool train( const Ptr<TrainData>& data, int flags ) CV_OVERRIDE
{
CV_Assert(!data.empty());
return impl->train(data, flags);
}
+2 -4
View File
@@ -142,12 +142,10 @@ Ptr<LogisticRegression> LogisticRegression::load(const String& filepath, const S
bool LogisticRegressionImpl::train(const Ptr<TrainData>& trainData, int)
{
CV_TRACE_FUNCTION_SKIP_NESTED();
CV_Assert(!trainData.empty());
// return value
bool ok = false;
if (trainData.empty()) {
return false;
}
clear();
Mat _data_i = trainData->getSamples();
Mat _labels_i = trainData->getResponses();
+1
View File
@@ -54,6 +54,7 @@ public:
bool train( const Ptr<TrainData>& trainData, int flags ) CV_OVERRIDE
{
CV_Assert(!trainData.empty());
const float min_variation = FLT_EPSILON;
Mat responses = trainData->getNormCatResponses();
Mat __cls_labels = trainData->getClassLabels();
+3
View File
@@ -111,6 +111,7 @@ public:
void startTraining( const Ptr<TrainData>& trainData, int flags ) CV_OVERRIDE
{
CV_TRACE_FUNCTION();
CV_Assert(!trainData.empty());
DTreesImpl::startTraining(trainData, flags);
int nvars = w->data->getNVars();
int i, m = rparams.nactiveVars > 0 ? rparams.nactiveVars : cvRound(std::sqrt((double)nvars));
@@ -133,6 +134,7 @@ public:
bool train( const Ptr<TrainData>& trainData, int flags ) CV_OVERRIDE
{
CV_TRACE_FUNCTION();
CV_Assert(!trainData.empty());
startTraining(trainData, flags);
int treeidx, ntrees = (rparams.termCrit.type & TermCriteria::COUNT) != 0 ?
rparams.termCrit.maxCount : 10000;
@@ -463,6 +465,7 @@ public:
bool train( const Ptr<TrainData>& trainData, int flags ) CV_OVERRIDE
{
CV_TRACE_FUNCTION();
CV_Assert(!trainData.empty());
if (impl.getCVFolds() != 0)
CV_Error(Error::StsBadArg, "Cross validation for RTrees is not implemented");
return impl.train(trainData, flags);
+2
View File
@@ -1613,6 +1613,7 @@ public:
bool train( const Ptr<TrainData>& data, int ) CV_OVERRIDE
{
CV_Assert(!data.empty());
clear();
checkParams();
@@ -1739,6 +1740,7 @@ public:
ParamGrid nu_grid, ParamGrid coef_grid, ParamGrid degree_grid,
bool balanced ) CV_OVERRIDE
{
CV_Assert(!data.empty());
checkParams();
int svmType = params.svmType;
+1
View File
@@ -230,6 +230,7 @@ float SVMSGDImpl::calcShift(InputArray _samples, InputArray _responses) const
bool SVMSGDImpl::train(const Ptr<TrainData>& data, int)
{
CV_Assert(!data.empty());
clear();
CV_Assert( isClassifier() ); //toDo: consider
+3
View File
@@ -98,6 +98,7 @@ DTrees::Split::Split()
DTreesImpl::WorkData::WorkData(const Ptr<TrainData>& _data)
{
CV_Assert(!_data.empty());
data = _data;
vector<int> subsampleIdx;
Mat sidx0 = _data->getTrainSampleIdx();
@@ -136,6 +137,7 @@ void DTreesImpl::clear()
void DTreesImpl::startTraining( const Ptr<TrainData>& data, int )
{
CV_Assert(!data.empty());
clear();
w = makePtr<WorkData>(data);
@@ -223,6 +225,7 @@ void DTreesImpl::endTraining()
bool DTreesImpl::train( const Ptr<TrainData>& trainData, int flags )
{
CV_Assert(!trainData.empty());
startTraining(trainData, flags);
bool ok = addTree( w->sidx ) >= 0;
w.release();