diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h index d9d6151c06bf..ab53e7733066 100644 --- a/src/operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -45,14 +45,12 @@ namespace mshadow_op { __constant__ const float PI = 3.14159265358979323846; __constant__ const float SELU_ALPHA = 1.6732632423543772848170429916717; __constant__ const float SELU_LAMBDA = 1.0507009873554804934193349852946; -__constant__ const float GELU_CUBIC_CONSTANT = 0.044715; -__constant__ const float GELU_ROOT_2_OVER_PI = 0.7978845608028654; +__constant__ const float SQRT_2 = 1.4142135623730950488016887242096; #else const float PI = 3.14159265358979323846; const float SELU_ALPHA = 1.6732632423543772848170429916717; const float SELU_LAMBDA = 1.0507009873554804934193349852946; -const float GELU_CUBIC_CONSTANT = 0.044715; -const float GELU_ROOT_2_OVER_PI = 0.7978845608028654; +const float SQRT_2 = 1.4142135623730950488016887242096; using std::isnan; #endif using std::enable_if; @@ -131,21 +129,6 @@ MXNET_UNARY_MATH_OP(softsign, a / (1.0f + math::fabs(a))); MXNET_UNARY_MATH_OP(softsign_grad, 1.0f / math::sqr(1.0f + math::fabs(a))); -#define MXNET_GELU_GX(a) \ - a * (DType(1.0f) + DType(GELU_CUBIC_CONSTANT) * a * a) - -#define MXNET_GELU_GX_GRAD(a) \ - (DType(1.0f) + DType(3.0f * GELU_CUBIC_CONSTANT) * a * a) - -#define MXNET_GELU_TANH(a) \ - math::tanh(DType(GELU_ROOT_2_OVER_PI) * MXNET_GELU_GX(a)) - -MXNET_UNARY_MATH_OP(gelu, DType(0.5f) * a * (DType(1.0f) + MXNET_GELU_TANH(a))); - -MXNET_BINARY_MATH_OP_NC(gelu_grad, - b / a + b * (DType(1.0f) - MXNET_GELU_TANH(a)) * - DType(GELU_ROOT_2_OVER_PI) * MXNET_GELU_GX_GRAD(a)); - MXNET_UNARY_MATH_OP_NC(selu, DType(SELU_LAMBDA) * (a > DType(0) ? a : DType(math::id(SELU_ALPHA) * math::expm1(a)))); @@ -191,6 +174,13 @@ MXNET_UNARY_MATH_OP(erf_grad, 2.0 / math::sqrt(PI) * math::exp(-(a * a))); MXNET_SIMPLE_UNARY_MATH_OP(erf); +MXNET_UNARY_MATH_OP(gelu, + DType(0.5f * static_cast(a) * (1.0f + math::erf(static_cast(a) / SQRT_2)))); + +MXNET_BINARY_MATH_OP_NC(gelu_grad, + DType(0.5f * (1.0f + math::erf(static_cast(a) / SQRT_2) + + static_cast(a) * erf_grad::Map(static_cast(a) / SQRT_2) / SQRT_2))); + MXNET_SIMPLE_UNARY_MATH_OP(exp); MXNET_SIMPLE_UNARY_MATH_OP(expm1); @@ -355,7 +345,6 @@ MXNET_BINARY_MATH_OP(logical_xor, (a || b) && !(a && b) ? DType(1) : DType(0)); MXNET_UNARY_MATH_OP(square_root, math::sqrt(a)); MXNET_UNARY_MATH_OP(square_root_grad, 0.5f / math::id(a)); - MXNET_UNARY_MATH_OP(reciprocal_square_root, 1.0f / math::sqrt(a)); MXNET_UNARY_MATH_OP(reciprocal_square_root_grad, -0.5f / (math::sqrt(a) * math::id(a))); diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index eb10b3b2751e..ddcc881939ad 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -886,7 +886,7 @@ def fgelu_grad(grad, x, y): y = mx.sym.LeakyReLU(data=x, act_type="gelu") for dtype in [np.float16, np.float32, np.float64]: xa = np.random.uniform(low=-0.1,high=0.1,size=shape).astype(dtype) - eps, rtol, atol = (7.5e-4, 1e-1, 1e-2) if dtype is np.float16 else (1e-4, 1e-2, 1e-4) + eps, rtol, atol = (7.5e-4, 2e-2, 1e-3) if dtype is np.float16 else (1e-4, 1e-3, 1e-5) if dtype is np.float16: xa /= 10.0 xa[abs(xa) < eps] = 0.01