Add tests for stateful kernel functionality
This commit is contained in:
parent
a3d6994afa
commit
bf54a370e5
@ -432,7 +432,7 @@ try:
|
|||||||
with self.assertRaises(Exception): create_op([cv.GMat, int], [cv.GMat]).on(cv.GMat())
|
with self.assertRaises(Exception): create_op([cv.GMat, int], [cv.GMat]).on(cv.GMat())
|
||||||
|
|
||||||
|
|
||||||
def test_stateful_kernel(self):
|
def test_state_in_class(self):
|
||||||
@cv.gapi.op('custom.sum', in_types=[cv.GArray.Int], out_types=[cv.GOpaque.Int])
|
@cv.gapi.op('custom.sum', in_types=[cv.GArray.Int], out_types=[cv.GOpaque.Int])
|
||||||
class GSum:
|
class GSum:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
96
modules/gapi/misc/python/test/test_gapi_stateful_kernel.py
Normal file
96
modules/gapi/misc/python/test/test_gapi_stateful_kernel.py
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import cv2 as cv
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from tests_common import NewOpenCVTests
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
|
||||||
|
if sys.version_info[:2] < (3, 0):
|
||||||
|
raise unittest.SkipTest('Python 2.x is not supported')
|
||||||
|
|
||||||
|
|
||||||
|
class CounterState:
|
||||||
|
def __init__(self):
|
||||||
|
self.counter = 0
|
||||||
|
|
||||||
|
|
||||||
|
@cv.gapi.op('stateful_counter',
|
||||||
|
in_types=[cv.GOpaque.Int],
|
||||||
|
out_types=[cv.GOpaque.Int])
|
||||||
|
class GStatefulCounter:
|
||||||
|
"""Accumulate state counter on every call"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def outMeta(desc):
|
||||||
|
return cv.empty_gopaque_desc()
|
||||||
|
|
||||||
|
|
||||||
|
@cv.gapi.kernel(GStatefulCounter)
|
||||||
|
class GStatefulCounterImpl:
|
||||||
|
"""Implementation for GStatefulCounter operation."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def setup(desc):
|
||||||
|
return CounterState()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def run(value, state):
|
||||||
|
state.counter += value
|
||||||
|
return state.counter
|
||||||
|
|
||||||
|
|
||||||
|
class gapi_sample_pipelines(NewOpenCVTests):
|
||||||
|
def test_stateful_kernel_single_instance(self):
|
||||||
|
g_in = cv.GOpaque.Int()
|
||||||
|
g_out = GStatefulCounter.on(g_in)
|
||||||
|
comp = cv.GComputation(cv.GIn(g_in), cv.GOut(g_out))
|
||||||
|
pkg = cv.gapi.kernels(GStatefulCounterImpl)
|
||||||
|
|
||||||
|
nums = [i for i in range(10)]
|
||||||
|
acc = 0
|
||||||
|
for v in nums:
|
||||||
|
acc = comp.apply(cv.gin(v), args=cv.gapi.compile_args(pkg))
|
||||||
|
|
||||||
|
self.assertEqual(sum(nums), acc)
|
||||||
|
|
||||||
|
|
||||||
|
def test_stateful_kernel_multiple_instances(self):
|
||||||
|
# NB: Every counter has his own independent state.
|
||||||
|
g_in = cv.GOpaque.Int()
|
||||||
|
g_out0 = GStatefulCounter.on(g_in)
|
||||||
|
g_out1 = GStatefulCounter.on(g_in)
|
||||||
|
comp = cv.GComputation(cv.GIn(g_in), cv.GOut(g_out0, g_out1))
|
||||||
|
pkg = cv.gapi.kernels(GStatefulCounterImpl)
|
||||||
|
|
||||||
|
nums = [i for i in range(10)]
|
||||||
|
acc0 = acc1 = 0
|
||||||
|
for v in nums:
|
||||||
|
acc0, acc1 = comp.apply(cv.gin(v), args=cv.gapi.compile_args(pkg))
|
||||||
|
|
||||||
|
ref = sum(nums)
|
||||||
|
self.assertEqual(ref, acc0)
|
||||||
|
self.assertEqual(ref, acc1)
|
||||||
|
|
||||||
|
|
||||||
|
except unittest.SkipTest as e:
|
||||||
|
|
||||||
|
message = str(e)
|
||||||
|
|
||||||
|
class TestSkip(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.skipTest('Skip tests: ' + message)
|
||||||
|
|
||||||
|
def test_skip():
|
||||||
|
pass
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
NewOpenCVTests.bootstrap()
|
||||||
Loading…
Reference in New Issue
Block a user