-
Notifications
You must be signed in to change notification settings - Fork 107
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement metric scaling #733
Conversation
@ismael-mendoza this is what you need right? |
I see... I dont love it but I also dont hate it LOL |
Thanks @AdrienCorenflos I think this will work great, happy to incoporate this into #731 once it's merged! |
I am a bit unhappy with one design choice: for the function scale to mean the same thing for all metrics I had to treat inputs differently: this comes from the fact that euclidean metrics take an inverse mass_matrix, while Riemannian take the mass matrix... I'm not sure what the rationale was behind this, but we will need to rethink it down the line. Maybe Riemannian should also take the inverse mass matrix? Either way, not a problem for this PR. |
A likely over-engineered solution here is to create a new covariance matrix class that has both original and inverse form, kind of like what TFP does for Gaussian Process /~https://github.com/tensorflow/probability/blob/main/tensorflow_probability/python/math/psd_kernels/positive_semidefinite_kernel.py |
I changed the api a bit -> you now only give scale(position, vector, inv) rather than something complicated. This is because I remembered we are using JAX and I shouldn't reinvent the wheel: if people want several scalings they should use vmap and that's it. This will still only call mass_matrix (and cholesky etc) only once. |
@junpenglao any comments or go? |
Let me take another look tomorrow |
Co-authored-by: Junpeng Lao <junpenglao@gmail.com>
Co-authored-by: Junpeng Lao <junpenglao@gmail.com>
That should be good now @junpenglao |
Great stuff, thank you! |
* Plotting BlackJAX with BlackJAX * Plotting BlackJAX with BlackJAX * Proposed implementation for metric scaling * Add tests and fix some small typing issues raised by pre-commit. * Fix remaining failing tests * pre-commit run * The original implementation was using upper cholesky, I was using lower. * Fixing a bunch of tests * Update blackjax/mcmc/metrics.py Co-authored-by: Junpeng Lao <junpenglao@gmail.com> * Update blackjax/mcmc/metrics.py Co-authored-by: Junpeng Lao <junpenglao@gmail.com> * Merged comments from Junpeng * Merged comments from Junpeng --------- Co-authored-by: Junpeng Lao <junpenglao@gmail.com>
Thank you for opening a PR!
This is a change that relates to the PR #731.
The idea is to allow metrics to scale a vector (possibly momentum, possibly otherwise) by the cholesky of the mass matrix or its inverse.
DRAFT PR so far, taking comments on interface and tentative implementation. Nothing is tested yet.
A comment: for efficiency reasons (in the context of Riemannian metrics), I have allowed the scale function to take a sequence of elements to rescale, so as to avoid recomputing the mass matrix all the time. There may be a cleaner way to do this but I can't see it.