diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h index 8194be6a1204..ab53e7733066 100644 --- a/src/operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -45,10 +45,12 @@ namespace mshadow_op { __constant__ const float PI = 3.14159265358979323846; __constant__ const float SELU_ALPHA = 1.6732632423543772848170429916717; __constant__ const float SELU_LAMBDA = 1.0507009873554804934193349852946; +__constant__ const float SQRT_2 = 1.4142135623730950488016887242096; #else const float PI = 3.14159265358979323846; const float SELU_ALPHA = 1.6732632423543772848170429916717; const float SELU_LAMBDA = 1.0507009873554804934193349852946; +const float SQRT_2 = 1.4142135623730950488016887242096; using std::isnan; #endif using std::enable_if; @@ -173,11 +175,11 @@ MXNET_UNARY_MATH_OP(erf_grad, 2.0 / math::sqrt(PI) * math::exp(-(a * a))); MXNET_SIMPLE_UNARY_MATH_OP(erf); MXNET_UNARY_MATH_OP(gelu, - DType(0.5f * float(a) * (1.0f + math::erf(float(a) / math::sqrt(2.0f))))); + DType(0.5f * static_cast(a) * (1.0f + math::erf(static_cast(a) / SQRT_2)))); MXNET_BINARY_MATH_OP_NC(gelu_grad, - DType(float(b) / float(a) + - 0.5f * float(a) * erf_grad::Map(float(a) / math::sqrt(2.0f)) / math::sqrt(2.0f))); + DType(0.5f * (1.0f + math::erf(static_cast(a) / SQRT_2) + + static_cast(a) * erf_grad::Map(static_cast(a) / SQRT_2) / SQRT_2))); MXNET_SIMPLE_UNARY_MATH_OP(exp);