use fp32 mish for fp16 mish

This commit is contained in:
YashasSamaga 2020-06-22 19:09:36 +05:30
parent ee5ff71d55
commit 6573b9ace0

View File

@ -57,11 +57,19 @@ struct mish_functor<float> {
auto n = e * e + 2 * e;
if (value <= -0.6f)
return value * fast_divide(n, n + 2);
return value - 2 * fast_divide(value, n + 2);
}
};
#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
template <>
struct mish_functor<__half> {
__device__ __half operator()(__half value) {
return mish_functor<float>()(value);
}
};
#endif
template <class T>
struct sigmoid_functor {
__device__ T operator()(T value) {