Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Call Eigen logistic function for inv_logit instead of custom implementation #3155

Open
wants to merge 7 commits into
base: develop
Choose a base branch
from

Conversation

jachymb
Copy link

@jachymb jachymb commented Feb 27, 2025

My first time trying to contribute. I tried to resolve an old TODO - see #3154

Summary

Call Eigen logistic function for inv_logit instead of custom implementation.

I did not touch opencl/kernels/device_functions/inv_logit.hpp - not sure if it's necessary or how opencl works here.

Tests

Original tests pass. No new tests added.

Side Effects

I am not aware of any. There may perhaps be some numerical stability/precision differences, I didn't look into that.

Release notes

Call Eigen logistic function for inv_logit instead of custom implementation.

Checklist

  • Copyright holder: Me.

    The copyright holder is typically you or your assignee, such as a university or company. By submitting this pull request, the copyright holder is agreeing to the license the submitted work under the following licenses:
    - Code: BSD 3-clause (https://opensource.org/licenses/BSD-3-Clause)
    - Documentation: CC-BY 4.0 (https://creativecommons.org/licenses/by/4.0/)

  • the basic tests are passing

    • unit tests pass (to run, use: ./runTests.py test/unit)
    • header checks pass, (make test-headers)
    • dependencies checks pass, (make test-math-dependencies)
    • docs build, (make doxygen)
    • code passes the built in C++ standards checks (make cpplint)
  • the code is written in idiomatic C++ and changes are documented in the doxygen

  • the new changes are tested

@WardBrian
Copy link
Member

WardBrian commented Feb 27, 2025

Hi @jachymb — I believe this TODO was suggesting we use the .logistic method on matrices in a specific overload for them, not for the implementation for double.

Right now, the prim version of this function just has two overloads, one for doubles and one for all containers that uses apply_scalar (which is, essentially, a for loop)

To complete the TODO, there would be 3 overloads:

double

prim matrices (using .logistic)

std::vectors (using apply_scalar like it currently would)

The hardest part of doing this is generally writing the template bounds so that the correct version is always selected. If you’re running into trouble with that, we can try to help!

@jachymb
Copy link
Author

jachymb commented Feb 27, 2025

Thanks for the commentary!

How about something like this replacing the current vectorized version then?

template <typename T, require_eigen_t<T>* = nullptr>
inline auto inv_logit(const T& x) {
  return x.array().logistic();
}

template <typename T, require_std_vector_t<T>* = nullptr, require_not_eigen_t<T>* = nullptr>
inline auto inv_logit(const T& x) {
  return apply_scalar_unary<inv_logit_fun, T>::apply(x);
}

it seems to work when I add tests for std::vector<double> and stan::math::matrix_d

does that look reasonable?

@WardBrian
Copy link
Member

That looks more-or-less correct. Because we want the autodiff overloads (which live in a different file) to also not intersect with these, I think you need to add a require_not_var_matrix_t to the eigen overload.

I think you might also still need to keep the require_all_not_nonscalar_prim_or_rev_kernel_expression_t on one of those, but I'm also not super familiar with the openCL details, so it's probably best we ask @SteveBronder on that one

Comment on lines 84 to 79
template <
typename T, require_not_var_matrix_t<T>* = nullptr,
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<T>* = nullptr>
template <typename T, require_eigen_t<T>* = nullptr>
inline auto inv_logit(const T& x) {
return apply_scalar_unary<inv_logit_fun, T>::apply(x);
return x.array().logistic();
}
Copy link
Collaborator

@SteveBronder SteveBronder Feb 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So right now this function works over all containers that are not var<Matrix> or var<MatrixCL> types. If we want a version for just doubles I would add three signatures. Two in this file and one in stan/math/rev/fun/inv_logit.hpp

// Version that works on Eigen types which have an arithmetic value type
template <typename T, require_eigen_vt<std::is_arithmetic, T> * = nullptr>
inline auto inv_logit(T&& x);

// Version that works on std::vector types and does apply scalar unary
template <typename T, require_std_vector_t<T> * = nullptr>
inline auto inv_logit(T&& x);

Then in stan/math/rev/fun/inv_logit.hpp to add the signature

// Version that works with Eigen types with inner `var` value types
template <typename T, require_eigen_vt<is_var, T> * = nullptr>
inline auto inv_logit(T&& x);

You can look at the other function in rev/fun/inv_logit.hpp to see how to write functions with our autodiff scheme. We also have a guide for adding new functions at the link below

https://mc-stan.org/math/md_doxygen_2contributor__help__pages_2getting__started.html

@andrjohns
Copy link
Collaborator

We already have several functions using Eigen's vectorised calls that could be a good guide on approaching this

For example:

@jachymb
Copy link
Author

jachymb commented Feb 28, 2025

Added the signatures suggested by SteveBronder and WardBrian, but to be honest, at this point I don't feel very confident I actually know what I'm doing :( I think I may need to backtrack a few steps to study the code & theroy a bit more. If this isn't it, feel free to close the PR or reuse any of my code in some other way.

@SteveBronder
Copy link
Collaborator

This code is 90% of the way there! I'll look at this on Monday

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants