Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Riemannian Manifold HMC #538

Merged
merged 9 commits into from
Dec 14, 2023
Merged

Adding Riemannian Manifold HMC #538

merged 9 commits into from
Dec 14, 2023

Conversation

dfm
Copy link
Contributor

@dfm dfm commented May 28, 2023

Close #283.

As suggested by @junpenglao over in dfm/rmhmc#1, I've started the skeleton of a RMHMC implementation.

It doesn't check all the boxes yet, and I'll need to keep working on it, but I wanted to open the draft since I'm about to go on leave and might not get to it soon :D

There are various open interface question so it's probably not worth reviewing in too much detail as yet!

Refs:


A few important guidelines and requirements before we can merge your PR:

  • ish If I add a new sampler, there is an issue discussing it already;
  • We should be able to understand what the PR does from its title only;
  • There is a high-level description of the changes;
  • There are links to all the relevant issues, discussions and PRs;
  • The branch is rebased on the latest main commit;
  • Commit messages follow these guidelines;
  • The code respects the current naming conventions;
  • Docstrings follow the numpy style guide
  • pre-commit is installed and configured on your machine, and you ran it before opening the PR;
  • There are tests covering the changes;
  • The doc is up-to-date;
  • If I add a new sampler* I added/updated related examples

Consider opening a Draft PR if your work is still in progress but you would like some feedback from other contributors.

@junpenglao
Copy link
Member

Thanks @dfm!!! Super excited for this!

@codecov
Copy link

codecov bot commented May 29, 2023

Codecov Report

Attention: 5 lines in your changes are missing coverage. Please review.

Comparison is base (4058971) 99.05% compared to head (62f9d24) 98.90%.

Files Patch % Lines
blackjax/mcmc/metrics.py 89.79% 5 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #538      +/-   ##
==========================================
- Coverage   99.05%   98.90%   -0.15%     
==========================================
  Files          58       59       +1     
  Lines        2632     2738     +106     
==========================================
+ Hits         2607     2708     +101     
- Misses         25       30       +5     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@krzysztofrusek
Copy link

Nice work!
Will this support implicitly defined manifolds?

I ported the algorithm from A Family of MCMC Methods on Implicitly Defined Manifolds to python using Jax and blackjax api with intention to make it a PR. Now I can see that I can build on top of this PR.

@junpenglao is this something of interest for blackjax devs?

@junpenglao
Copy link
Member

@junpenglao is this something of interest for blackjax devs?

Yes! Especially nice that you already see a path of building it on top of this PR!

@dfm
Copy link
Contributor Author

dfm commented Jun 1, 2023

I think this is probably more or less ready to review now. Some open questions/to-dos that could be implemented here or in future PRs:

  1. Since the fixed point solver in the integrator can fail, it would be good to track that metadata (reference TODO here) as well as the number of iterations that were required, and the error on the solution. This would require some refactoring of the trajectory building code, and some thought about how to combine the integrator info for multiple integration steps per proposal.
  2. I've implemented everything here such that the Riemannian metric should also work with the NUTS sampler, but the turning criterion needs access to the positions at the left and right end of the trajectory, as well as the momenta. This is a slightly larger refactoring, so I'd be inclined to skip it here, so I've left the turning criterion unimplemented for the new metric.
  3. Another stretch goal: It should also be fairly straightforward to implement general metrics for MALA, which seems to currently only support the identity metric.

But either way, I'd love to hear any feedback about this implementation at this point, including from @krzysztofrusek to get a sense for if there would be any interface changes that would make your future PR more straightforward!

@dfm dfm marked this pull request as ready for review June 1, 2023 15:50
@dfm dfm changed the title [WIP]: Adding Riemannian Manifold HMC Adding Riemannian Manifold HMC Jun 1, 2023
@krzysztofrusek
Copy link

But either way, I'd love to hear any feedback about this implementation at this point, including from @krzysztofrusek to get a sense for if there would be any interface changes that would make your future PR more straightforward!

Great abstraction, it will be super simple for me to use blackjax parts.
I started with a direct port of Matlab code from the paper and now I am splitting it into parts like integrator, and proposal.

I think I can implement it within your interface.
All I need is to pass the Jacobian of the constraint function to a few parts. Metric abstraction is nice and will serve me as well.

The additional constrained function is needed in

  • the momentum generator, so it will stay in the tangent space
  • Logdet to account for projection
  • the Hamiltonian - this is tricky because it's neither kinetic nor potential energy. I could add constrain as a part of potential but it would lose its physical interpretation so I need to think about how to handle this case.

@dfm
Copy link
Contributor Author

dfm commented Jun 2, 2023

the Hamiltonian - this is tricky because it's neither kinetic nor potential energy. I could add constrain as a part of potential but it would lose its physical interpretation so I need to think about how to handle this case.

Good point! In general, I think it might actually be better to abstract all the way to Hamiltonian rather than just the metric because it many cases where you might want to use RMHMC it would actually be more efficient to evaluate the log density and mass matrix together! Let me poke at that a little bit more.

dfm added 6 commits June 7, 2023 15:36
moving RMHMC to a separate submodule

fixing parallel tests and improving kinetic energy interface

Moving explicit leapfrog step to end of implicit midpoint

lint

fix explicit update; include logdet in kinetic energy

lint
@krzysztofrusek
Copy link

the Hamiltonian - this is tricky because it's neither kinetic nor potential energy. I could add constrain as a part of potential but it would lose its physical interpretation so I need to think about how to handle this case.

Good point! In general, I think it might actually be better to abstract all the way to Hamiltonian rather than just the metric because it many cases where you might want to use RMHMC it would actually be more efficient to evaluate the log density and mass matrix together! Let me poke at that a little bit more.

While waiting for this PR to be reviewed, I drafter CHMC in a separate repo.
The api is similar to the one from blackjax but slightly simplified, and closer to the original MATLAB implementation by @mbrubake .

@junpenglao
Copy link
Member

Hi @dfm, what is the status of the PR currently?

@dfm
Copy link
Contributor Author

dfm commented Sep 18, 2023

@junpenglao — I've been on parental leave and I'm just catching up, so I don't have much to add here at the moment. I'm happy to revisit this implementation in the coming months or happy to hand it off to someone who is keen to take it over!

@junpenglao
Copy link
Member

Thanks Dan! Let me take a look and get back to you.

@junpenglao
Copy link
Member

@dfm I updated the PR to the main branch, and it is good to go now. Do you have other things you would like to add?

cc @krzysztofrusek

@junpenglao junpenglao merged commit f12fc38 into blackjax-devs:main Dec 14, 2023
5 of 7 checks passed
@junpenglao
Copy link
Member

Alright merging this as it is good to go - will follow up for feature improvement in Discussion and Issue.

@junpenglao
Copy link
Member

Thank you for the contribution @dfm !!!

@dfm
Copy link
Contributor Author

dfm commented Dec 14, 2023

Thanks @junpenglao! Sorry for disappearing!!

junpenglao added a commit that referenced this pull request Mar 12, 2024
* Adding initial implementation of RMHMC

moving RMHMC to a separate submodule

fixing parallel tests and improving kinetic energy interface

Moving explicit leapfrog step to end of implicit midpoint

lint

fix explicit update; include logdet in kinetic energy

lint

* implementing untested rmhmc turning criterion

* implementing Metric type

* adding test for integrating non-separable potential

* add energy check in non-separable test

* add test for riemannian metric

* Fix typing

* fix test

---------

Co-authored-by: Junpeng Lao <junpenglao@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add Riemannian HMC
4 participants