Merge pull request #17349 from YashasSamaga:cuda4dnn-general-fixes
This commit is contained in:
commit
6b0fff72d9
@ -52,22 +52,30 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace cu
|
||||
}
|
||||
}
|
||||
|
||||
/** noncopyable cuBLAS smart handle
|
||||
/** non-copyable cuBLAS smart handle
|
||||
*
|
||||
* UniqueHandle is a smart non-sharable wrapper for cuBLAS handle which ensures that the handle
|
||||
* is destroyed after use. The handle can be associated with a CUDA stream by specifying the
|
||||
* stream during construction. By default, the handle is associated with the default stream.
|
||||
* is destroyed after use. The handle must always be associated with a non-default stream. The stream
|
||||
* must be specified during construction.
|
||||
*
|
||||
* Refer to stream API for more information for the choice of forcing non-default streams.
|
||||
*/
|
||||
class UniqueHandle {
|
||||
public:
|
||||
UniqueHandle() { CUDA4DNN_CHECK_CUBLAS(cublasCreate(&handle)); }
|
||||
UniqueHandle() noexcept : handle{ nullptr } { }
|
||||
UniqueHandle(UniqueHandle&) = delete;
|
||||
UniqueHandle(UniqueHandle&& other) noexcept
|
||||
: stream(std::move(other.stream)), handle{ other.handle } {
|
||||
UniqueHandle(UniqueHandle&& other) noexcept {
|
||||
stream = std::move(other.stream);
|
||||
handle = other.handle;
|
||||
other.handle = nullptr;
|
||||
}
|
||||
|
||||
/** creates a cuBLAS handle and associates it with the stream specified
|
||||
*
|
||||
* Exception Guarantee: Basic
|
||||
*/
|
||||
UniqueHandle(Stream strm) : stream(std::move(strm)) {
|
||||
CV_Assert(stream);
|
||||
CUDA4DNN_CHECK_CUBLAS(cublasCreate(&handle));
|
||||
try {
|
||||
CUDA4DNN_CHECK_CUBLAS(cublasSetStream(handle, stream.get()));
|
||||
@ -79,7 +87,7 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace cu
|
||||
}
|
||||
|
||||
~UniqueHandle() noexcept {
|
||||
if (handle != nullptr) {
|
||||
if (handle) {
|
||||
/* cublasDestroy won't throw if a valid handle is passed */
|
||||
CUDA4DNN_CHECK_CUBLAS(cublasDestroy(handle));
|
||||
}
|
||||
@ -87,14 +95,24 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace cu
|
||||
|
||||
UniqueHandle& operator=(const UniqueHandle&) = delete;
|
||||
UniqueHandle& operator=(UniqueHandle&& other) noexcept {
|
||||
stream = std::move(other.stream);
|
||||
handle = other.handle;
|
||||
other.handle = nullptr;
|
||||
CV_Assert(other);
|
||||
if (&other != this) {
|
||||
UniqueHandle(std::move(*this)); /* destroy current handle */
|
||||
stream = std::move(other.stream);
|
||||
handle = other.handle;
|
||||
other.handle = nullptr;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
/** @brief returns the raw cuBLAS handle */
|
||||
cublasHandle_t get() const noexcept { return handle; }
|
||||
/** returns the raw cuBLAS handle */
|
||||
cublasHandle_t get() const noexcept {
|
||||
CV_Assert(handle);
|
||||
return handle;
|
||||
}
|
||||
|
||||
/** returns true if the handle is valid */
|
||||
explicit operator bool() const noexcept { return static_cast<bool>(handle); }
|
||||
|
||||
private:
|
||||
Stream stream;
|
||||
@ -104,17 +122,21 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace cu
|
||||
/** @brief sharable cuBLAS smart handle
|
||||
*
|
||||
* Handle is a smart sharable wrapper for cuBLAS handle which ensures that the handle
|
||||
* is destroyed after all references to the handle are destroyed. The handle can be
|
||||
* associated with a CUDA stream by specifying the stream during construction. By default,
|
||||
* the handle is associated with the default stream.
|
||||
* is destroyed after all references to the handle are destroyed. The handle must always
|
||||
* be associated with a non-default stream. The stream must be specified during construction.
|
||||
*
|
||||
* @note Moving a Handle object to another invalidates the former
|
||||
*/
|
||||
class Handle {
|
||||
public:
|
||||
Handle() : handle(std::make_shared<UniqueHandle>()) { }
|
||||
Handle() = default;
|
||||
Handle(const Handle&) = default;
|
||||
Handle(Handle&&) = default;
|
||||
|
||||
/** creates a cuBLAS handle and associates it with the stream specified
|
||||
*
|
||||
* Exception Guarantee: Basic
|
||||
*/
|
||||
Handle(Stream strm) : handle(std::make_shared<UniqueHandle>(std::move(strm))) { }
|
||||
|
||||
Handle& operator=(const Handle&) = default;
|
||||
@ -123,6 +145,7 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace cu
|
||||
/** returns true if the handle is valid */
|
||||
explicit operator bool() const noexcept { return static_cast<bool>(handle); }
|
||||
|
||||
/** returns the raw cuBLAS handle */
|
||||
cublasHandle_t get() const noexcept {
|
||||
CV_Assert(handle);
|
||||
return handle->get();
|
||||
|
||||
@ -58,15 +58,11 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace cu
|
||||
*/
|
||||
class UniqueHandle {
|
||||
public:
|
||||
/** creates a cuDNN handle which executes in the default stream
|
||||
*
|
||||
* Exception Guarantee: Basic
|
||||
*/
|
||||
UniqueHandle() { CUDA4DNN_CHECK_CUDNN(cudnnCreate(&handle)); }
|
||||
|
||||
UniqueHandle() noexcept : handle{ nullptr } { }
|
||||
UniqueHandle(UniqueHandle&) = delete;
|
||||
UniqueHandle(UniqueHandle&& other) noexcept
|
||||
: stream(std::move(other.stream)), handle{ other.handle } {
|
||||
UniqueHandle(UniqueHandle&& other) noexcept {
|
||||
stream = std::move(other.stream);
|
||||
handle = other.handle;
|
||||
other.handle = nullptr;
|
||||
}
|
||||
|
||||
@ -75,6 +71,7 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace cu
|
||||
* Exception Guarantee: Basic
|
||||
*/
|
||||
UniqueHandle(Stream strm) : stream(std::move(strm)) {
|
||||
CV_Assert(stream);
|
||||
CUDA4DNN_CHECK_CUDNN(cudnnCreate(&handle));
|
||||
try {
|
||||
CUDA4DNN_CHECK_CUDNN(cudnnSetStream(handle, stream.get()));
|
||||
@ -94,14 +91,24 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace cu
|
||||
|
||||
UniqueHandle& operator=(const UniqueHandle&) = delete;
|
||||
UniqueHandle& operator=(UniqueHandle&& other) noexcept {
|
||||
stream = std::move(other.stream);
|
||||
handle = other.handle;
|
||||
other.handle = nullptr;
|
||||
CV_Assert(other);
|
||||
if (&other != this) {
|
||||
UniqueHandle(std::move(*this)); /* destroy current handle */
|
||||
stream = std::move(other.stream);
|
||||
handle = other.handle;
|
||||
other.handle = nullptr;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
/** returns the raw cuDNN handle */
|
||||
cudnnHandle_t get() const noexcept { return handle; }
|
||||
cudnnHandle_t get() const noexcept {
|
||||
CV_Assert(handle);
|
||||
return handle;
|
||||
}
|
||||
|
||||
/** returns true if the handle is valid */
|
||||
explicit operator bool() const noexcept { return static_cast<bool>(handle); }
|
||||
|
||||
private:
|
||||
Stream stream;
|
||||
@ -111,18 +118,14 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace cu
|
||||
/** @brief sharable cuDNN smart handle
|
||||
*
|
||||
* Handle is a smart sharable wrapper for cuDNN handle which ensures that the handle
|
||||
* is destroyed after all references to the handle are destroyed.
|
||||
* is destroyed after all references to the handle are destroyed. The handle must always
|
||||
* be associated with a non-default stream. The stream must be specified during construction.
|
||||
*
|
||||
* @note Moving a Handle object to another invalidates the former
|
||||
*/
|
||||
class Handle {
|
||||
public:
|
||||
/** creates a cuDNN handle which executes in the default stream
|
||||
*
|
||||
* Exception Guarantee: Basic
|
||||
*/
|
||||
Handle() : handle(std::make_shared<UniqueHandle>()) { }
|
||||
|
||||
Handle() = default;
|
||||
Handle(const Handle&) = default;
|
||||
Handle(Handle&&) = default;
|
||||
|
||||
@ -138,6 +141,7 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace cu
|
||||
/** returns true if the handle is valid */
|
||||
explicit operator bool() const noexcept { return static_cast<bool>(handle); }
|
||||
|
||||
/** returns the raw cuDNN handle */
|
||||
cudnnHandle_t get() const noexcept {
|
||||
CV_Assert(handle);
|
||||
return handle->get();
|
||||
|
||||
@ -18,11 +18,20 @@
|
||||
|
||||
namespace cv { namespace dnn { namespace cuda4dnn { namespace csl {
|
||||
|
||||
/** @brief noncopyable smart CUDA stream
|
||||
/** \file stream.hpp
|
||||
*
|
||||
* Default streams are not supported as they limit flexiblity. All operations are always
|
||||
* carried out in non-default streams in the CUDA backend. The stream classes sacrifice
|
||||
* the ability to support default streams in exchange for better error detection. That is,
|
||||
* a default constructed stream represents no stream and any attempt to use it will throw an
|
||||
* exception.
|
||||
*/
|
||||
|
||||
/** @brief non-copyable smart CUDA stream
|
||||
*
|
||||
* UniqueStream is a smart non-sharable wrapper for CUDA stream handle which ensures that
|
||||
* the handle is destroyed after use. Unless explicitly specified by a constructor argument,
|
||||
* the stream object represents the default stream.
|
||||
* the stream object does not represent any stream by default.
|
||||
*/
|
||||
class UniqueStream {
|
||||
public:
|
||||
@ -33,14 +42,19 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl {
|
||||
other.stream = 0;
|
||||
}
|
||||
|
||||
/** creates a non-default stream if `create` is true; otherwise, no stream is created */
|
||||
UniqueStream(bool create) : stream{ 0 } {
|
||||
if (create) {
|
||||
/* we create non-blocking streams to avoid inrerruptions from users using the default stream */
|
||||
CUDA4DNN_CHECK_CUDA(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
|
||||
}
|
||||
}
|
||||
|
||||
~UniqueStream() {
|
||||
try {
|
||||
/* cudaStreamDestroy does not throw if a valid stream is passed unless a previous
|
||||
* asynchronous operation errored.
|
||||
*/
|
||||
if (stream != 0)
|
||||
CUDA4DNN_CHECK_CUDA(cudaStreamDestroy(stream));
|
||||
} catch (const CUDAException& ex) {
|
||||
@ -54,16 +68,31 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl {
|
||||
|
||||
UniqueStream& operator=(const UniqueStream&) = delete;
|
||||
UniqueStream& operator=(UniqueStream&& other) noexcept {
|
||||
stream = other.stream;
|
||||
other.stream = 0;
|
||||
CV_Assert(other);
|
||||
if (&other != this) {
|
||||
UniqueStream(std::move(*this)); /* destroy current stream */
|
||||
stream = other.stream;
|
||||
other.stream = 0;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
/** returns the raw CUDA stream handle */
|
||||
cudaStream_t get() const noexcept { return stream; }
|
||||
cudaStream_t get() const noexcept {
|
||||
CV_Assert(stream);
|
||||
return stream;
|
||||
}
|
||||
|
||||
void synchronize() const { CUDA4DNN_CHECK_CUDA(cudaStreamSynchronize(stream)); }
|
||||
/** blocks the calling thread until all pending operations in the stream finish */
|
||||
void synchronize() const {
|
||||
CV_Assert(stream);
|
||||
CUDA4DNN_CHECK_CUDA(cudaStreamSynchronize(stream));
|
||||
}
|
||||
|
||||
/** returns true if there are pending operations in the stream */
|
||||
bool busy() const {
|
||||
CV_Assert(stream);
|
||||
|
||||
auto status = cudaStreamQuery(stream);
|
||||
if (status == cudaErrorNotReady)
|
||||
return true;
|
||||
@ -71,6 +100,9 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl {
|
||||
return false;
|
||||
}
|
||||
|
||||
/** returns true if the stream is valid */
|
||||
explicit operator bool() const noexcept { return static_cast<bool>(stream); }
|
||||
|
||||
private:
|
||||
cudaStream_t stream;
|
||||
};
|
||||
@ -78,31 +110,42 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl {
|
||||
/** @brief sharable smart CUDA stream
|
||||
*
|
||||
* Stream is a smart sharable wrapper for CUDA stream handle which ensures that
|
||||
* the handle is destroyed after use. Unless explicitly specified by a constructor argument,
|
||||
* the stream object represents the default stream.
|
||||
*
|
||||
* @note Moving a Stream object to another invalidates the former
|
||||
* the handle is destroyed after use. Unless explicitly specified in the constructor,
|
||||
* the stream object represents no stream.
|
||||
*/
|
||||
class Stream {
|
||||
public:
|
||||
Stream() : stream(std::make_shared<UniqueStream>()) { }
|
||||
Stream() { }
|
||||
Stream(const Stream&) = default;
|
||||
Stream(Stream&&) = default;
|
||||
|
||||
/** if \p create is `true`, a new stream will be created instead of the otherwise default stream */
|
||||
Stream(bool create) : stream(std::make_shared<UniqueStream>(create)) { }
|
||||
/** if \p create is `true`, a new stream will be created; otherwise, no stream is created */
|
||||
Stream(bool create) {
|
||||
if (create)
|
||||
stream = std::make_shared<UniqueStream>(create);
|
||||
}
|
||||
|
||||
Stream& operator=(const Stream&) = default;
|
||||
Stream& operator=(Stream&&) = default;
|
||||
|
||||
/** blocks the caller thread until all operations in the stream are complete */
|
||||
void synchronize() const { stream->synchronize(); }
|
||||
void synchronize() const {
|
||||
CV_Assert(stream);
|
||||
stream->synchronize();
|
||||
}
|
||||
|
||||
/** returns true if there are operations pending in the stream */
|
||||
bool busy() const { return stream->busy(); }
|
||||
bool busy() const {
|
||||
CV_Assert(stream);
|
||||
return stream->busy();
|
||||
}
|
||||
|
||||
/** returns true if the stream is valid */
|
||||
explicit operator bool() const noexcept { return static_cast<bool>(stream); }
|
||||
/** returns true if the object points has a valid stream */
|
||||
explicit operator bool() const noexcept {
|
||||
if (!stream)
|
||||
return false;
|
||||
return stream->operator bool();
|
||||
}
|
||||
|
||||
cudaStream_t get() const noexcept {
|
||||
CV_Assert(stream);
|
||||
|
||||
Loading…
Reference in New Issue
Block a user