diff --git a/docs/api/python/ndarray/ndarray.md b/docs/api/python/ndarray/ndarray.md index 6419c4ed4067..2df18c286ba7 100644 --- a/docs/api/python/ndarray/ndarray.md +++ b/docs/api/python/ndarray/ndarray.md @@ -659,6 +659,7 @@ The `ndarray` package provides several classes: relu sigmoid erf + erfinv ``` ### More diff --git a/docs/api/python/symbol/symbol.md b/docs/api/python/symbol/symbol.md index 9eba2618065b..0fc2aa7c6cf2 100644 --- a/docs/api/python/symbol/symbol.md +++ b/docs/api/python/symbol/symbol.md @@ -659,6 +659,7 @@ Composite multiple symbols into a new one by an operator. relu sigmoid erf + erfinv ``` ### More diff --git a/src/operator/contrib/erfinv-inl.h b/src/operator/contrib/erfinv-inl.h new file mode 100644 index 000000000000..8d718ade6562 --- /dev/null +++ b/src/operator/contrib/erfinv-inl.h @@ -0,0 +1,105 @@ +/* + * Copyright (c) 2014 Indiana University + * All rights reserved. + * Written by Prof. Gary L. Pavlis, Dept. of Geol. Sci., + * Indiana University, Bloomington, IN + * This software is licensed under the New BSD license: + * Redistribution and use in source and binary forms, + * with or without modification, are permitted provided + * that the following conditions are met: + * Redistributions of source code must retain the above + * copyright notice, this list of conditions and the + * following disclaimer. + * Redistributions in binary form must reproduce the + * above copyright notice, this list of conditions and + * the following disclaimer in the documentation and/or + * other materials provided with the distribution. + * Neither the name of Indiana University nor + * the names of its contributors may be used to endorse + * or promote products derived from this software without + * specific prior written permission. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND + * CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED + * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A + * PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL + * THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF + * USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER + * IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE + * USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ +/* + * The next function is taken from + * /~https://github.com/antelopeusersgroup/antelope_contrib/blob/master/lib/location/libgenloc/erfinv.c. + * Output was modified to be inf or -inf when input is 1 or -1. + */ +#ifndef MXNET_OPERATOR_CONTRIB_ERFINV_INL_H_ +#define MXNET_OPERATOR_CONTRIB_ERFINV_INL_H_ + +#define _USE_MATH_DEFINES + +#include +#include +#include "math.h" + +namespace mxnet { +namespace op { +namespace mshadow_op { + +/*! \brief inverse gauss error function */ +struct erfinv : public mxnet_op::tunable { + template + MSHADOW_XINLINE static DType Map(DType v) { + /* Function to calculate inverse error function. Rational approximation + is used to generate an initial approximation, which is then improved to + full accuracy by two steps of Newton's method. Code is a direct + translation of the erfinv m file in matlab version 2.0. + Author: Gary L. Pavlis, Indiana University + Date: February 1996 + */ + const double central_range = 0.7; + double y = static_cast(v); + double y_fab = std::fabs(y); + /*working variables */ + double x = 0.0; + double z, num, dem; + /* coefficients in rational expansion */ + double a[4]={ 0.886226899, -1.645349621, 0.914624893, -0.140543331}; + double b[4]={-2.118377725, 1.442710462, -0.329097515, 0.012229801}; + double c[4]={-1.970840454, -1.624906493, 3.429567803, 1.641345311}; + double d[2]={ 3.543889200, 1.637067800}; + if (y_fab > 1.0) { + /* This needs IEEE constant*/ + return DType(std::numeric_limits::quiet_NaN()); + } else if (y_fab == 1.0) { + return DType((std::copysign(1.0, y))*std::numeric_limits::infinity()); + } else if (y_fab <= central_range) { + z = y*y; + num = (((a[3]*z + a[2])*z + a[1])*z + a[0]); + dem = ((((b[3]*z + b[2])*z + b[1])*z +b[0])*z + 1.0); + x = y*num/dem; + } else { + z = std::sqrt(-std::log((1.0-y_fab)/2.0)); + num = ((c[3]*z + c[2])*z + c[1])*z + c[0]; + dem = (d[1]*z + d[0])*z + 1.0; + x = (std::copysign(1.0, y))*num/dem; + } + /* Two steps of Newton-Raphson correction */ + x = x - (std::erf(x) - y)/((2.0/std::sqrt(M_PI))*std::exp(-x*x)); + x = x - (std::erf(x) - y)/((2.0/std::sqrt(M_PI))*std::exp(-x*x)); + + return DType(x); + } +}; + +} // namespace mshadow_op +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_CONTRIB_ERFINV_INL_H_ diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h index 0b20a02634c3..f56436b8fa0c 100644 --- a/src/operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -31,6 +31,7 @@ #include "math_functions-inl.h" #include "special_functions-inl.h" #include "./operator_tune.h" +#include "./contrib/erfinv-inl.h" #ifdef __CUDACC__ #include @@ -169,6 +170,8 @@ struct softrelu : public mxnet_op::tunable { MXNET_UNARY_MATH_OP(softrelu_grad, -math::expm1(-a)); +MXNET_UNARY_MATH_OP(erfinv_grad, 0.5 * math::sqrt(PI) * math::exp(math::sqr(erfinv::Map(a)))); + MXNET_UNARY_MATH_OP(erf_grad, 2.0 / math::sqrt(PI) * math::exp(-(a * a))); MXNET_SIMPLE_UNARY_MATH_OP(erf); diff --git a/src/operator/operator_tune.cc b/src/operator/operator_tune.cc index 2018e80cb48b..56d35b23b369 100644 --- a/src/operator/operator_tune.cc +++ b/src/operator/operator_tune.cc @@ -234,9 +234,11 @@ IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::log2); // NOLINT() IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::log2_grad); // NOLINT() IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::log10); // NOLINT() IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::log10_grad); // NOLINT() -IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::sin); // NOLINT() IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::erf); // NOLINT() IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::erf_grad); // NOLINT() +IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::erfinv); // NOLINT() +IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::erfinv_grad); // NOLINT() +IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::sin); // NOLINT() IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sin_grad); // NOLINT() IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::sinh); // NOLINT() IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sinh_grad); // NOLINT() diff --git a/src/operator/tensor/elemwise_unary_op_basic.cc b/src/operator/tensor/elemwise_unary_op_basic.cc index c0d420f9599b..d0079b545dd8 100644 --- a/src/operator/tensor/elemwise_unary_op_basic.cc +++ b/src/operator/tensor/elemwise_unary_op_basic.cc @@ -916,6 +916,22 @@ MXNET_OPERATOR_REGISTER_BINARY(_backward_erf) .set_attr("FCompute", ElemwiseBinaryOp::Compute>); +// erfinv +MXNET_OPERATOR_REGISTER_UNARY(erfinv) +.describe(R"code(Returns element-wise inverse gauss error function of the input. + +Example:: + + erfinv([0, 0.5., -1.]) = [0., 0.4769, -inf] + +)code" ADD_FILELINE) +.set_attr("FCompute", UnaryOp::Compute) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_erfinv"}); + +MXNET_OPERATOR_REGISTER_BINARY(_backward_erfinv) +.set_attr("FCompute", + ElemwiseBinaryOp::Compute>); + // rcbrt MXNET_OPERATOR_REGISTER_UNARY(rcbrt) .describe(R"code(Returns element-wise inverse cube-root value of the input. diff --git a/src/operator/tensor/elemwise_unary_op_basic.cu b/src/operator/tensor/elemwise_unary_op_basic.cu index 14f2be02ab1a..642cb0e6e48b 100644 --- a/src/operator/tensor/elemwise_unary_op_basic.cu +++ b/src/operator/tensor/elemwise_unary_op_basic.cu @@ -62,6 +62,14 @@ NNVM_REGISTER_OP(_backward_erf) .set_attr("FCompute", ElemwiseBinaryOp::Compute>); +// erfinv +NNVM_REGISTER_OP(erfinv) +.set_attr("FCompute", UnaryOp::Compute); + +NNVM_REGISTER_OP(_backward_erfinv) +.set_attr("FCompute", + ElemwiseBinaryOp::Compute>); + // copy NNVM_REGISTER_OP(_copy) .set_attr("FCompute", UnaryOp::IdentityCompute) diff --git a/tests/nightly/apache_rat_license_check/rat-excludes b/tests/nightly/apache_rat_license_check/rat-excludes index 5969f01a3225..782ef40b7e35 100755 --- a/tests/nightly/apache_rat_license_check/rat-excludes +++ b/tests/nightly/apache_rat_license_check/rat-excludes @@ -35,6 +35,7 @@ _mask.pyx coco.py base.pyi special_functions-inl.h +erfinv-inl.h im2col.cuh im2col.h pool.h @@ -49,4 +50,4 @@ deformable_im2col.h REQUIRE include/* .*.iml -.*.json.ref \ No newline at end of file +.*.json.ref diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index cb19fd869d30..cda801c25bbb 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -3500,7 +3500,11 @@ def test_special_functions_using_scipy(): # erf mathematical_core("erf", lambda x: mx.sym.erf(x), lambda x: scipy_special.erf(x), - lambda x: 2.0 / math.sqrt(math.pi) * math.exp(-(x ** 2)), 0.5, 0.5) + lambda x: 2.0 / math.sqrt(math.pi) * np.exp(-(x ** 2)), 0.5, 0.5) + + # erfinv + mathematical_core("erfinv", lambda x: mx.sym.erfinv(x), lambda x: scipy_special.erfinv(x), + lambda x: 0.5 * math.sqrt(math.pi) * np.exp(scipy_special.erfinv(x) ** 2), 0.5, 0.5) def rounding(name, forward_mxnet_call, forward_numpy_call, data_init=5., grad_init=2.): diff --git a/tools/license_header.py b/tools/license_header.py index 199d56c7ee35..11cc92839993 100755 --- a/tools/license_header.py +++ b/tools/license_header.py @@ -84,6 +84,7 @@ 'src/operator/nn/im2col.cuh', # Licenses in headers + 'src/operator/contrib/erfinv-inl.h', 'docs/_static/searchtools_custom.js', 'docs/_static/js/clipboard.js', 'docs/_static/js/clipboard.min.js',