Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement metric scaling #733

Merged
merged 12 commits into from
Sep 16, 2024
Merged

Conversation

AdrienCorenflos
Copy link
Contributor

@AdrienCorenflos AdrienCorenflos commented Sep 6, 2024

Thank you for opening a PR!

This is a change that relates to the PR #731.
The idea is to allow metrics to scale a vector (possibly momentum, possibly otherwise) by the cholesky of the mass matrix or its inverse.

DRAFT PR so far, taking comments on interface and tentative implementation. Nothing is tested yet.

A comment: for efficiency reasons (in the context of Riemannian metrics), I have allowed the scale function to take a sequence of elements to rescale, so as to avoid recomputing the mass matrix all the time. There may be a cleaner way to do this but I can't see it.

@AdrienCorenflos
Copy link
Contributor Author

@ismael-mendoza this is what you need right?

blackjax/mcmc/metrics.py Outdated Show resolved Hide resolved
@junpenglao
Copy link
Member

A comment: for efficiency reasons (in the context of Riemannian metrics), I have allowed the scale function to take a sequence of elements to rescale, so as to avoid recomputing the mass matrix all the time. There may be a cleaner way to do this but I can't see it.

I see... I dont love it but I also dont hate it LOL

@ismael-mendoza
Copy link
Contributor

Thanks @AdrienCorenflos I think this will work great, happy to incoporate this into #731 once it's merged!

@AdrienCorenflos AdrienCorenflos marked this pull request as ready for review September 7, 2024 12:55
@AdrienCorenflos
Copy link
Contributor Author

I am a bit unhappy with one design choice: for the function scale to mean the same thing for all metrics I had to treat inputs differently: this comes from the fact that euclidean metrics take an inverse mass_matrix, while Riemannian take the mass matrix... I'm not sure what the rationale was behind this, but we will need to rethink it down the line. Maybe Riemannian should also take the inverse mass matrix? Either way, not a problem for this PR.

@junpenglao
Copy link
Member

A likely over-engineered solution here is to create a new covariance matrix class that has both original and inverse form, kind of like what TFP does for Gaussian Process /~https://github.com/tensorflow/probability/blob/main/tensorflow_probability/python/math/psd_kernels/positive_semidefinite_kernel.py

@AdrienCorenflos
Copy link
Contributor Author

AdrienCorenflos commented Sep 9, 2024

I changed the api a bit -> you now only give scale(position, vector, inv) rather than something complicated. This is because I remembered we are using JAX and I shouldn't reinvent the wheel: if people want several scalings they should use vmap and that's it. This will still only call mass_matrix (and cholesky etc) only once.

@AdrienCorenflos
Copy link
Contributor Author

@junpenglao any comments or go?

@junpenglao
Copy link
Member

Let me take another look tomorrow

blackjax/mcmc/metrics.py Outdated Show resolved Hide resolved
blackjax/mcmc/metrics.py Outdated Show resolved Hide resolved
blackjax/mcmc/metrics.py Show resolved Hide resolved
blackjax/mcmc/metrics.py Show resolved Hide resolved
@junpenglao junpenglao self-assigned this Sep 12, 2024
AdrienCorenflos and others added 4 commits September 16, 2024 20:02
Co-authored-by: Junpeng Lao <junpenglao@gmail.com>
Co-authored-by: Junpeng Lao <junpenglao@gmail.com>
@AdrienCorenflos
Copy link
Contributor Author

That should be good now @junpenglao
Sorry about the hold up @ismael-mendoza but I think this is for a good reason :D

@junpenglao
Copy link
Member

Great stuff, thank you!

@junpenglao junpenglao merged commit e1d816a into blackjax-devs:main Sep 16, 2024
5 checks passed
aphc14 pushed a commit to aphc14/blackjax that referenced this pull request Sep 17, 2024
* Plotting BlackJAX with BlackJAX

* Plotting BlackJAX with BlackJAX

* Proposed implementation for metric scaling

* Add tests and fix some small typing issues raised by pre-commit.

* Fix remaining failing tests

* pre-commit run

* The original implementation was using upper cholesky, I was using lower.

* Fixing a bunch of tests

* Update blackjax/mcmc/metrics.py

Co-authored-by: Junpeng Lao <junpenglao@gmail.com>

* Update blackjax/mcmc/metrics.py

Co-authored-by: Junpeng Lao <junpenglao@gmail.com>

* Merged comments from Junpeng

* Merged comments from Junpeng

---------

Co-authored-by: Junpeng Lao <junpenglao@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants