Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Call Eigen logistic function for inv_logit instead of custom implementation #3155

Open
wants to merge 7 commits into
base: develop
Choose a base branch
from
53 changes: 40 additions & 13 deletions stan/math/prim/fun/inv_logit.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,7 @@ namespace math {
* @return Inverse logit of argument.
*/
inline double inv_logit(double a) {
using std::exp;
if (a < 0) {
double exp_a = exp(a);
if (a < LOG_EPSILON) {
return exp_a;
}
return exp_a / (1 + exp_a);
}
return inv(1 + exp(-a));
return Eigen::internal::scalar_logistic_op<double>()(a);
}

/**
Expand All @@ -75,21 +67,56 @@ struct inv_logit_fun {
};

/**
* Vectorized version of inv_logit().
* Vectorized version of inv_logit() for matrices.
*
* @tparam T type of container
* @param x container
* @return Inverse logit applied to each value in x.
*/
template <
typename T, require_not_var_matrix_t<T>* = nullptr,
typename T, require_eigen_t<T>* = nullptr,
require_not_var_matrix_t<T>* = nullptr,
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<T>* = nullptr>
inline auto inv_logit(const T& x) {
return x.array().logistic();
}

/**
* Vectorized version of inv_logit() for std::vector.
*
* @tparam T type of container
* @param x container
* @return Inverse logit applied to each value in x.
*/
template <typename T, require_std_vector_t<T>* = nullptr,
require_not_eigen_t<T>* = nullptr>
inline auto inv_logit(const T& x) {
return apply_scalar_unary<inv_logit_fun, T>::apply(x);
}

// TODO(Tadej): Eigen is introducing their implementation logistic() of this
// in 3.4. Use that once we switch to Eigen 3.4
/**
* Vectorized version of inv_logit() for Eigen types with arithmetic value type.
*
* @tparam T type of Eigen expression
* @param x Eigen expression
* @return Inverse logit applied to each value in x.
*/
template <typename T, require_eigen_vt<std::is_arithmetic, T>* = nullptr>
inline auto inv_logit(T&& x) {
return std::forward<T>(x).array().logistic();
}

/**
* Vectorized version of inv_logit() for std::vector.
*
* @tparam T type of std::vector
* @param x std::vector
* @return Inverse logit applied to each value in x.
*/
template <typename T, require_std_vector_t<T>* = nullptr>
inline auto inv_logit(T&& x) {
return apply_scalar_unary<inv_logit_fun, T>::apply(std::forward<T>(x));
}

} // namespace math
} // namespace stan
Expand Down
22 changes: 22 additions & 0 deletions stan/math/rev/fun/inv_logit.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,28 @@ inline auto inv_logit(const var_value<T>& a) {
});
}

/**
* The inverse logit function for Eigen expressions with var value type.
*
* See inv_logit() for the double-based version.
*
* The derivative of inverse logit is
*
* \f$\frac{d}{dx} \mbox{logit}^{-1}(x) = \mbox{logit}^{-1}(x) (1 -
* \mbox{logit}^{-1}(x))\f$.
*
* @tparam T type of Eigen expression
* @param x Eigen expression
* @return Inverse logit of argument.
*/
template <typename T, require_eigen_vt<is_var, T>* = nullptr>
inline auto inv_logit(T&& x) {
return make_callback_var(inv_logit(value_of(x)), [x](auto& vi) mutable {
value_of(x).adj().array()
+= vi.adj().array() * vi.val().array() * (1.0 - vi.val().array());
});
}

} // namespace math
} // namespace stan
#endif