ml: add checks of empty train data
This commit is contained in:
@@ -920,6 +920,7 @@ public:
|
||||
|
||||
bool train( const Ptr<TrainData>& trainData, int flags ) CV_OVERRIDE
|
||||
{
|
||||
CV_Assert(!trainData.empty());
|
||||
const int MAX_ITER = 1000;
|
||||
const double DEFAULT_EPSILON = FLT_EPSILON;
|
||||
|
||||
@@ -955,6 +956,7 @@ public:
|
||||
}
|
||||
int train_anneal(const Ptr<TrainData>& trainData)
|
||||
{
|
||||
CV_Assert(!trainData.empty());
|
||||
SimulatedAnnealingANN_MLP s(*this, trainData);
|
||||
trained = true; // Enable call to CalcError
|
||||
int iter = simulatedAnnealingSolver(s, params.initialT, params.finalT, params.coolingRatio, params.itePerStep, NULL, params.rEnergy);
|
||||
|
||||
@@ -88,6 +88,7 @@ public:
|
||||
|
||||
void startTraining( const Ptr<TrainData>& trainData, int flags ) CV_OVERRIDE
|
||||
{
|
||||
CV_Assert(!trainData.empty());
|
||||
DTreesImpl::startTraining(trainData, flags);
|
||||
sumResult.assign(w->sidx.size(), 0.);
|
||||
|
||||
@@ -184,6 +185,7 @@ public:
|
||||
|
||||
bool train( const Ptr<TrainData>& trainData, int flags ) CV_OVERRIDE
|
||||
{
|
||||
CV_Assert(!trainData.empty());
|
||||
startTraining(trainData, flags);
|
||||
int treeidx, ntrees = bparams.weakCount >= 0 ? bparams.weakCount : 10000;
|
||||
vector<int> sidx = w->sidx;
|
||||
@@ -482,6 +484,7 @@ public:
|
||||
|
||||
bool train( const Ptr<TrainData>& trainData, int flags ) CV_OVERRIDE
|
||||
{
|
||||
CV_Assert(!trainData.empty());
|
||||
return impl.train(trainData, flags);
|
||||
}
|
||||
|
||||
|
||||
@@ -112,6 +112,7 @@ public:
|
||||
|
||||
bool train(const Ptr<TrainData>& data, int) CV_OVERRIDE
|
||||
{
|
||||
CV_Assert(!data.empty());
|
||||
Mat samples = data->getTrainSamples(), labels;
|
||||
return trainEM(samples, labels, noArray(), noArray());
|
||||
}
|
||||
|
||||
@@ -59,9 +59,10 @@ bool StatModel::empty() const { return !isTrained(); }
|
||||
|
||||
int StatModel::getVarCount() const { return 0; }
|
||||
|
||||
bool StatModel::train( const Ptr<TrainData>&, int )
|
||||
bool StatModel::train(const Ptr<TrainData>& trainData, int )
|
||||
{
|
||||
CV_TRACE_FUNCTION();
|
||||
CV_Assert(!trainData.empty());
|
||||
CV_Error(CV_StsNotImplemented, "");
|
||||
return false;
|
||||
}
|
||||
@@ -69,6 +70,7 @@ bool StatModel::train( const Ptr<TrainData>&, int )
|
||||
bool StatModel::train( InputArray samples, int layout, InputArray responses )
|
||||
{
|
||||
CV_TRACE_FUNCTION();
|
||||
CV_Assert(!samples.empty());
|
||||
return train(TrainData::create(samples, layout, responses));
|
||||
}
|
||||
|
||||
@@ -134,6 +136,7 @@ public:
|
||||
float StatModel::calcError(const Ptr<TrainData>& data, bool testerr, OutputArray _resp) const
|
||||
{
|
||||
CV_TRACE_FUNCTION_SKIP_NESTED();
|
||||
CV_Assert(!data.empty());
|
||||
Mat samples = data->getSamples();
|
||||
Mat sidx = testerr ? data->getTestSampleIdx() : data->getTrainSampleIdx();
|
||||
Mat weights = testerr ? data->getTestSampleWeights() : data->getTrainSampleWeights();
|
||||
|
||||
@@ -73,6 +73,7 @@ public:
|
||||
|
||||
bool train( const Ptr<TrainData>& data, int flags )
|
||||
{
|
||||
CV_Assert(!data.empty());
|
||||
Mat new_samples = data->getTrainSamples(ROW_SAMPLE);
|
||||
Mat new_responses;
|
||||
data->getTrainResponses().convertTo(new_responses, CV_32F);
|
||||
@@ -494,6 +495,7 @@ public:
|
||||
|
||||
bool train( const Ptr<TrainData>& data, int flags ) CV_OVERRIDE
|
||||
{
|
||||
CV_Assert(!data.empty());
|
||||
return impl->train(data, flags);
|
||||
}
|
||||
|
||||
|
||||
@@ -142,12 +142,10 @@ Ptr<LogisticRegression> LogisticRegression::load(const String& filepath, const S
|
||||
bool LogisticRegressionImpl::train(const Ptr<TrainData>& trainData, int)
|
||||
{
|
||||
CV_TRACE_FUNCTION_SKIP_NESTED();
|
||||
CV_Assert(!trainData.empty());
|
||||
|
||||
// return value
|
||||
bool ok = false;
|
||||
|
||||
if (trainData.empty()) {
|
||||
return false;
|
||||
}
|
||||
clear();
|
||||
Mat _data_i = trainData->getSamples();
|
||||
Mat _labels_i = trainData->getResponses();
|
||||
|
||||
@@ -54,6 +54,7 @@ public:
|
||||
|
||||
bool train( const Ptr<TrainData>& trainData, int flags ) CV_OVERRIDE
|
||||
{
|
||||
CV_Assert(!trainData.empty());
|
||||
const float min_variation = FLT_EPSILON;
|
||||
Mat responses = trainData->getNormCatResponses();
|
||||
Mat __cls_labels = trainData->getClassLabels();
|
||||
|
||||
@@ -111,6 +111,7 @@ public:
|
||||
void startTraining( const Ptr<TrainData>& trainData, int flags ) CV_OVERRIDE
|
||||
{
|
||||
CV_TRACE_FUNCTION();
|
||||
CV_Assert(!trainData.empty());
|
||||
DTreesImpl::startTraining(trainData, flags);
|
||||
int nvars = w->data->getNVars();
|
||||
int i, m = rparams.nactiveVars > 0 ? rparams.nactiveVars : cvRound(std::sqrt((double)nvars));
|
||||
@@ -133,6 +134,7 @@ public:
|
||||
bool train( const Ptr<TrainData>& trainData, int flags ) CV_OVERRIDE
|
||||
{
|
||||
CV_TRACE_FUNCTION();
|
||||
CV_Assert(!trainData.empty());
|
||||
startTraining(trainData, flags);
|
||||
int treeidx, ntrees = (rparams.termCrit.type & TermCriteria::COUNT) != 0 ?
|
||||
rparams.termCrit.maxCount : 10000;
|
||||
@@ -463,6 +465,7 @@ public:
|
||||
bool train( const Ptr<TrainData>& trainData, int flags ) CV_OVERRIDE
|
||||
{
|
||||
CV_TRACE_FUNCTION();
|
||||
CV_Assert(!trainData.empty());
|
||||
if (impl.getCVFolds() != 0)
|
||||
CV_Error(Error::StsBadArg, "Cross validation for RTrees is not implemented");
|
||||
return impl.train(trainData, flags);
|
||||
|
||||
@@ -1613,6 +1613,7 @@ public:
|
||||
|
||||
bool train( const Ptr<TrainData>& data, int ) CV_OVERRIDE
|
||||
{
|
||||
CV_Assert(!data.empty());
|
||||
clear();
|
||||
|
||||
checkParams();
|
||||
@@ -1739,6 +1740,7 @@ public:
|
||||
ParamGrid nu_grid, ParamGrid coef_grid, ParamGrid degree_grid,
|
||||
bool balanced ) CV_OVERRIDE
|
||||
{
|
||||
CV_Assert(!data.empty());
|
||||
checkParams();
|
||||
|
||||
int svmType = params.svmType;
|
||||
|
||||
@@ -230,6 +230,7 @@ float SVMSGDImpl::calcShift(InputArray _samples, InputArray _responses) const
|
||||
|
||||
bool SVMSGDImpl::train(const Ptr<TrainData>& data, int)
|
||||
{
|
||||
CV_Assert(!data.empty());
|
||||
clear();
|
||||
CV_Assert( isClassifier() ); //toDo: consider
|
||||
|
||||
|
||||
@@ -98,6 +98,7 @@ DTrees::Split::Split()
|
||||
|
||||
DTreesImpl::WorkData::WorkData(const Ptr<TrainData>& _data)
|
||||
{
|
||||
CV_Assert(!_data.empty());
|
||||
data = _data;
|
||||
vector<int> subsampleIdx;
|
||||
Mat sidx0 = _data->getTrainSampleIdx();
|
||||
@@ -136,6 +137,7 @@ void DTreesImpl::clear()
|
||||
|
||||
void DTreesImpl::startTraining( const Ptr<TrainData>& data, int )
|
||||
{
|
||||
CV_Assert(!data.empty());
|
||||
clear();
|
||||
w = makePtr<WorkData>(data);
|
||||
|
||||
@@ -223,6 +225,7 @@ void DTreesImpl::endTraining()
|
||||
|
||||
bool DTreesImpl::train( const Ptr<TrainData>& trainData, int flags )
|
||||
{
|
||||
CV_Assert(!trainData.empty());
|
||||
startTraining(trainData, flags);
|
||||
bool ok = addTree( w->sidx ) >= 0;
|
||||
w.release();
|
||||
|
||||
Reference in New Issue
Block a user