ml/python: fix digits samples(3.4)
This commit is contained in:
@@ -70,13 +70,8 @@ def deskew(img):
|
||||
img = cv.warpAffine(img, M, (SZ, SZ), flags=cv.WARP_INVERSE_MAP | cv.INTER_LINEAR)
|
||||
return img
|
||||
|
||||
class StatModel(object):
|
||||
def load(self, fn):
|
||||
self.model.load(fn) # Known bug: https://github.com/opencv/opencv/issues/4969
|
||||
def save(self, fn):
|
||||
self.model.save(fn)
|
||||
|
||||
class KNearest(StatModel):
|
||||
class KNearest(object):
|
||||
def __init__(self, k = 3):
|
||||
self.k = k
|
||||
self.model = cv.ml.KNearest_create()
|
||||
@@ -88,7 +83,13 @@ class KNearest(StatModel):
|
||||
_retval, results, _neigh_resp, _dists = self.model.findNearest(samples, self.k)
|
||||
return results.ravel()
|
||||
|
||||
class SVM(StatModel):
|
||||
def load(self, fn):
|
||||
self.model = cv.ml.KNearest_load(fn)
|
||||
|
||||
def save(self, fn):
|
||||
self.model.save(fn)
|
||||
|
||||
class SVM(object):
|
||||
def __init__(self, C = 1, gamma = 0.5):
|
||||
self.model = cv.ml.SVM_create()
|
||||
self.model.setGamma(gamma)
|
||||
@@ -102,6 +103,11 @@ class SVM(StatModel):
|
||||
def predict(self, samples):
|
||||
return self.model.predict(samples)[1].ravel()
|
||||
|
||||
def load(self, fn):
|
||||
self.model = cv.ml.SVM_load(fn)
|
||||
|
||||
def save(self, fn):
|
||||
self.model.save(fn)
|
||||
|
||||
def evaluate_model(model, digits, samples, labels):
|
||||
resp = model.predict(samples)
|
||||
|
||||
Reference in New Issue
Block a user