diff --git a/modules/dnn/src/cuda4dnn/csl/cudnn/activation.hpp b/modules/dnn/src/cuda4dnn/csl/cudnn/activation.hpp new file mode 100644 index 0000000000..f73c1d40e8 --- /dev/null +++ b/modules/dnn/src/cuda4dnn/csl/cudnn/activation.hpp @@ -0,0 +1,80 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_CUDA4DNN_CSL_CUDNN_ACTIVATION_HPP +#define OPENCV_DNN_CUDA4DNN_CSL_CUDNN_ACTIVATION_HPP + +#include + +namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace cudnn { + + class ActivationDescriptor { + public: + enum class ActivationType { + IDENTITY, + RELU, + CLIPPED_RELU, + TANH, + SIGMOID, + ELU + }; + + ActivationDescriptor() noexcept : descriptor{ nullptr } { } + ActivationDescriptor(const ActivationDescriptor&) = delete; + ActivationDescriptor(ActivationDescriptor&& other) noexcept + : descriptor{ other.descriptor } { + other.descriptor = nullptr; + } + + /* `relu_ceiling_or_elu_alpha`: + * - `alpha` coefficient in ELU activation + * - `ceiling` for CLIPPED_RELU activation + */ + ActivationDescriptor(ActivationType type, double relu_ceiling_or_elu_alpha = 0.0) { + CUDA4DNN_CHECK_CUDNN(cudnnCreateActivationDescriptor(&descriptor)); + try { + const auto mode = [type] { + switch(type) { + case ActivationType::IDENTITY: return CUDNN_ACTIVATION_IDENTITY; + case ActivationType::RELU: return CUDNN_ACTIVATION_RELU; + case ActivationType::CLIPPED_RELU: return CUDNN_ACTIVATION_CLIPPED_RELU; + case ActivationType::SIGMOID: return CUDNN_ACTIVATION_SIGMOID; + case ActivationType::TANH: return CUDNN_ACTIVATION_TANH; + case ActivationType::ELU: return CUDNN_ACTIVATION_ELU; + } + CV_Assert(0); + return CUDNN_ACTIVATION_IDENTITY; + } (); + + CUDA4DNN_CHECK_CUDNN(cudnnSetActivationDescriptor(descriptor, mode, CUDNN_NOT_PROPAGATE_NAN, relu_ceiling_or_elu_alpha)); + } catch(...) { + /* cudnnDestroyActivationDescriptor will not fail for a valid descriptor object */ + CUDA4DNN_CHECK_CUDNN(cudnnDestroyActivationDescriptor(descriptor)); + throw; + } + } + + ~ActivationDescriptor() noexcept { + if (descriptor != nullptr) { + /* cudnnDestroyActivationDescriptor will not fail */ + CUDA4DNN_CHECK_CUDNN(cudnnDestroyActivationDescriptor(descriptor)); + } + } + + ActivationDescriptor& operator=(const ActivationDescriptor&) = delete; + ActivationDescriptor& operator=(ActivationDescriptor&& other) noexcept { + descriptor = other.descriptor; + other.descriptor = nullptr; + return *this; + }; + + cudnnActivationDescriptor_t get() const noexcept { return descriptor; } + + private: + cudnnActivationDescriptor_t descriptor; + }; + +}}}}} /* namespace cv::dnn::cuda4dnn::csl::cudnn */ + +#endif /* OPENCV_DNN_CUDA4DNN_CSL_CUDNN_ACTIVATION_HPP */ diff --git a/modules/dnn/src/cuda4dnn/csl/cudnn/convolution.hpp b/modules/dnn/src/cuda4dnn/csl/cudnn/convolution.hpp index d8ff498185..46463b6538 100644 --- a/modules/dnn/src/cuda4dnn/csl/cudnn/convolution.hpp +++ b/modules/dnn/src/cuda4dnn/csl/cudnn/convolution.hpp @@ -6,6 +6,7 @@ #define OPENCV_DNN_CUDA4DNN_CSL_CUDNN_CONVOLUTION_HPP #include "cudnn.hpp" +#include "activation.hpp" #include "../pointer.hpp" #include "../workspace.hpp" @@ -405,6 +406,93 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace cu ); } + /** @brief performs convolution, bias addition and activation simultaneously + * + * dstValue = act(alpha * conv(input) + bias) + * + * @tparam T convolution element type (must be `half` or `float`) + * + * @param handle valid cuDNN Handle + * @param convDesc convolution description + * @param convAlgo algorithm to use for convolution + * @param workspace workspace memory which meets the requirements of \p convAlgo + * @param filterDesc filter descriptor + * @param[in] filterPtr pointer to device memory containing the filters + * @param alpha convolution scale factor + * @param inputDesc tensor descriptor describing the input + * @param[in] inputPtr pointer to input tensor in device memory + * @param biasDesc tensor descriptor describing the bias + * @param[in] biasPtr pointer to bias tensor in device memory + * @param actDesc activation descriptor + * @param outputDesc tensor descriptor describing the output + * @param[out] outputPtr pointer to output tensor in device memory + * + * Exception Guarantee: Basic + */ + template + void convolve_with_bias_activation( + const Handle& handle, + T alpha, + const ConvolutionDescriptor& convDesc, + const ConvolutionAlgorithm& convAlgo, + WorkspaceInstance workspace, + const FilterDescriptor& filterDesc, + DevicePtr filterPtr, + const TensorDescriptor& inputDesc, + DevicePtr inputPtr, + const TensorDescriptor& biasDesc, + DevicePtr biasPtr, + const ActivationDescriptor& actDesc, + const TensorDescriptor& outputDesc, + DevicePtr outputPtr) + { + CV_Assert(handle); + + T alpha2 = 0.0; + CUDA4DNN_CHECK_CUDNN(cudnnConvolutionBiasActivationForward( + handle.get(), + &alpha, inputDesc.get(), inputPtr.get(), + filterDesc.get(), filterPtr.get(), + convDesc.get(), convAlgo.get(), + static_cast(workspace.get()), workspace.size_in_bytes(), + &alpha2, outputDesc.get(), outputPtr.get(), + biasDesc.get(), biasPtr.get(), + actDesc.get(), + outputDesc.get(), outputPtr.get())); + } + + template <> inline + void convolve_with_bias_activation( + const Handle& handle, + half alpha, + const ConvolutionDescriptor& convDesc, + const ConvolutionAlgorithm& convAlgo, + WorkspaceInstance workspace, + const FilterDescriptor& filterDesc, + DevicePtr filterPtr, + const TensorDescriptor& inputDesc, + DevicePtr inputPtr, + const TensorDescriptor& biasDesc, + DevicePtr biasPtr, + const ActivationDescriptor& actDesc, + const TensorDescriptor& outputDesc, + DevicePtr outputPtr) + { + CV_Assert(handle); + + float alpha_ = alpha, alpha2 = 0.0; + CUDA4DNN_CHECK_CUDNN(cudnnConvolutionBiasActivationForward( + handle.get(), + &alpha_, inputDesc.get(), inputPtr.get(), + filterDesc.get(), filterPtr.get(), + convDesc.get(), convAlgo.get(), + static_cast(workspace.get()), workspace.size_in_bytes(), + &alpha2, outputDesc.get(), outputPtr.get(), + biasDesc.get(), biasPtr.get(), + actDesc.get(), + outputDesc.get(), outputPtr.get())); + } + }}}}} /* namespace cv::dnn::cuda4dnn::csl::cudnn */ #endif /* OPENCV_DNN_CUDA4DNN_CSL_CUDNN_CONVOLUTION_HPP */ diff --git a/modules/dnn/src/cuda4dnn/csl/tensor_ops.hpp b/modules/dnn/src/cuda4dnn/csl/tensor_ops.hpp index fc29d9b121..efea967650 100644 --- a/modules/dnn/src/cuda4dnn/csl/tensor_ops.hpp +++ b/modules/dnn/src/cuda4dnn/csl/tensor_ops.hpp @@ -135,17 +135,23 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { using FilterDescriptor = cudnn::FilterDescriptor; using ConvolutionDescriptor = cudnn::ConvolutionDescriptor; using ConvolutionAlgorithm = cudnn::ConvolutionAlgorithm; + using ActivationDescriptor = cudnn::ActivationDescriptor; public: + using ActivationType = ActivationDescriptor::ActivationType; + struct params_type { + /* convolution */ std::vector input_shape; std::vector filter_shape; - std::vector padding; std::vector stride; std::vector dilation; - std::size_t groups; + + /* bias and activation (only RELU supported) */ + std::vector bias_shape; + ActivationType activation_type; /* MUST BE identity if there is no bias and ReLU if there is bias */ }; Convolution() = default; @@ -158,6 +164,14 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { filterDesc = FilterDescriptor(params.filter_shape); convDesc = ConvolutionDescriptor(params.padding, params.stride, params.dilation, params.groups); + if (!params.bias_shape.empty()) { + CV_Assert(params.activation_type == ActivationType::RELU); + biasTensorDesc = TensorDescriptor(params.bias_shape); + activationDesc = ActivationDescriptor(params.activation_type, 0.0); + } else { + CV_Assert(params.activation_type == ActivationType::IDENTITY); + } + std::vector output_dims; getConvolutionForwardOutputDim(convDesc, filterDesc, inputTensorDesc, output_dims); outputTensorDesc = TensorDescriptor(output_dims); @@ -182,12 +196,26 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { ); } + void convolve_with_bias_activation(TensorSpan output, TensorView input, TensorView filters, TensorView bias, WorkspaceInstance scratchpad) { + cudnn::convolve_with_bias_activation( + cudnnHandle, + 1.0, convDesc, algo, scratchpad, + filterDesc, filters.get(), + inputTensorDesc, input.get(), + biasTensorDesc, bias.get(), + activationDesc, + outputTensorDesc, output.get() + ); + } + private: cudnn::Handle cudnnHandle; TensorDescriptor inputTensorDesc, outputTensorDesc; FilterDescriptor filterDesc; ConvolutionDescriptor convDesc; ConvolutionAlgorithm algo; + TensorDescriptor biasTensorDesc; + ActivationDescriptor activationDesc; }; template diff --git a/modules/dnn/src/cuda4dnn/primitives/convolution.hpp b/modules/dnn/src/cuda4dnn/primitives/convolution.hpp index 72a84deed7..0a0050bd85 100644 --- a/modules/dnn/src/cuda4dnn/primitives/convolution.hpp +++ b/modules/dnn/src/cuda4dnn/primitives/convolution.hpp @@ -204,6 +204,21 @@ namespace cv { namespace dnn { namespace cuda4dnn { params.dilation = dilations; params.groups = config.groups; + /* check if we can perform fused convolution using cudnn */ + params.activation_type = csl::Convolution::ActivationType::IDENTITY; + fusion_location = InternalFusionLocation::NATIVE; + if (!biasTensor.empty() && + biasTensor.size() == output_feature_maps && /* cuDNN requirement */ + config.activation_type == ConvolutionConfiguration::ActivationType::RELU && + config.relu_negative_slope == 0.0) + { + fusion_location = InternalFusionLocation::CUDNN; + auto bias_shape = std::vector(rank, 1); + bias_shape[1] = output_feature_maps; + params.bias_shape = bias_shape; + params.activation_type = csl::Convolution::ActivationType::RELU; + } + convoluter = csl::Convolution(cudnnHandle, params); activation = config.activation_type; @@ -216,7 +231,8 @@ namespace cv { namespace dnn { namespace cuda4dnn { activation = ConvolutionConfiguration::ActivationType::IDENTITY; csl::WorkspaceBuilder builder; - if (!transformed_shape.empty()) { + if (!transformed_shape.empty()) + { auto& shape = transformed_shape; auto sz = std::accumulate(std::begin(shape), std::end(shape), 1, std::multiplies()); builder.require(sz); @@ -248,65 +264,72 @@ namespace cv { namespace dnn { namespace cuda4dnn { auto output_wrapper = outputs[0].dynamicCast(); auto output = output_wrapper->getSpan(); - convoluter.convolve(output, input, filtersTensor, allocator.get_instance()); - if (!biasTensor.empty()) + if (fusion_location == InternalFusionLocation::CUDNN) { - std::size_t inner_size = output.size_range(2, output.rank()); - switch(activation) - { - case ConvolutionConfiguration::ActivationType::IDENTITY: - kernels::biasN(stream, output, output, inner_size, biasTensor); - break; - case ConvolutionConfiguration::ActivationType::RELU: - kernels::biasN_relu_inplace(stream, output, inner_size, biasTensor, relu_negative_slope); - break; - case ConvolutionConfiguration::ActivationType::CLIPPED_RELU: - kernels::biasN_clipped_relu_inplace(stream, output, inner_size, biasTensor, crelu_floor, crelu_ceil); - break; - case ConvolutionConfiguration::ActivationType::POWER: - kernels::biasN_power_inplace(stream, output, inner_size, biasTensor, power_exp); - break; - case ConvolutionConfiguration::ActivationType::TANH: - kernels::biasN_tanh_inplace(stream, output, inner_size, biasTensor); - break; - case ConvolutionConfiguration::ActivationType::SIGMOID: - kernels::biasN_sigmoid_inplace(stream, output, inner_size, biasTensor); - break; - case ConvolutionConfiguration::ActivationType::SWISH: - kernels::biasN_swish_inplace(stream, output, inner_size, biasTensor); - break; - case ConvolutionConfiguration::ActivationType::MISH: - kernels::biasN_mish_inplace(stream, output, inner_size, biasTensor); - break; - } + convoluter.convolve_with_bias_activation(output, input, filtersTensor, biasTensor, allocator.get_instance()); } else { - switch(activation) + convoluter.convolve(output, input, filtersTensor, allocator.get_instance()); + if (!biasTensor.empty()) { - case ConvolutionConfiguration::ActivationType::IDENTITY: - break; - case ConvolutionConfiguration::ActivationType::RELU: - kernels::relu(stream, output, output, relu_negative_slope); - break; - case ConvolutionConfiguration::ActivationType::CLIPPED_RELU: - kernels::clipped_relu(stream, output, output, crelu_floor, crelu_ceil); - break; - case ConvolutionConfiguration::ActivationType::POWER: - kernels::power(stream, output, output, power_exp, 1.0, 0.0); - break; - case ConvolutionConfiguration::ActivationType::TANH: - kernels::tanh(stream, output, output); - break; - case ConvolutionConfiguration::ActivationType::SIGMOID: - kernels::sigmoid(stream, output, output); - break; - case ConvolutionConfiguration::ActivationType::SWISH: - kernels::swish(stream, output, output); - break; - case ConvolutionConfiguration::ActivationType::MISH: - kernels::mish(stream, output, output); - break; + std::size_t inner_size = output.size_range(2, output.rank()); + switch(activation) + { + case ConvolutionConfiguration::ActivationType::IDENTITY: + kernels::biasN(stream, output, output, inner_size, biasTensor); + break; + case ConvolutionConfiguration::ActivationType::RELU: + kernels::biasN_relu_inplace(stream, output, inner_size, biasTensor, relu_negative_slope); + break; + case ConvolutionConfiguration::ActivationType::CLIPPED_RELU: + kernels::biasN_clipped_relu_inplace(stream, output, inner_size, biasTensor, crelu_floor, crelu_ceil); + break; + case ConvolutionConfiguration::ActivationType::POWER: + kernels::biasN_power_inplace(stream, output, inner_size, biasTensor, power_exp); + break; + case ConvolutionConfiguration::ActivationType::TANH: + kernels::biasN_tanh_inplace(stream, output, inner_size, biasTensor); + break; + case ConvolutionConfiguration::ActivationType::SIGMOID: + kernels::biasN_sigmoid_inplace(stream, output, inner_size, biasTensor); + break; + case ConvolutionConfiguration::ActivationType::SWISH: + kernels::biasN_swish_inplace(stream, output, inner_size, biasTensor); + break; + case ConvolutionConfiguration::ActivationType::MISH: + kernels::biasN_mish_inplace(stream, output, inner_size, biasTensor); + break; + } + } + else + { + switch(activation) + { + case ConvolutionConfiguration::ActivationType::IDENTITY: + break; + case ConvolutionConfiguration::ActivationType::RELU: + kernels::relu(stream, output, output, relu_negative_slope); + break; + case ConvolutionConfiguration::ActivationType::CLIPPED_RELU: + kernels::clipped_relu(stream, output, output, crelu_floor, crelu_ceil); + break; + case ConvolutionConfiguration::ActivationType::POWER: + kernels::power(stream, output, output, power_exp, 1.0, 0.0); + break; + case ConvolutionConfiguration::ActivationType::TANH: + kernels::tanh(stream, output, output); + break; + case ConvolutionConfiguration::ActivationType::SIGMOID: + kernels::sigmoid(stream, output, output); + break; + case ConvolutionConfiguration::ActivationType::SWISH: + kernels::swish(stream, output, output); + break; + case ConvolutionConfiguration::ActivationType::MISH: + kernels::mish(stream, output, output); + break; + } } } } @@ -326,6 +349,11 @@ namespace cv { namespace dnn { namespace cuda4dnn { ConvolutionConfiguration::ActivationType activation; float relu_negative_slope, crelu_floor, crelu_ceil, power_exp; + + enum class InternalFusionLocation { + CUDNN, + NATIVE + } fusion_location; }; }}} /* namespace cv::dnn::cuda4dnn */