-
Notifications
You must be signed in to change notification settings - Fork 107
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor the SMC kernels #279
Conversation
A few thoughts:
Tentative design for SMC with MCMC stepsimport jax
import blackjax
logprob_fn: Callable
mcmc_init: Callable
mcmc_step: Callable
num_mcmc_steps: int
mcmc_parameters: Dict
def update_particle(rng_key, position):
def one_step(state, rng_key):
# This can contain *anything* not
# just a MCMC kernel
#
# We can even do adaptation here if we
# plug-in the states correctly between here and SMC
# particles would need to be `state`
# and we'd need to pass info
state, _ = mcmc_step(rng_key, state, logprob_fn, **mcmc_parameters)
return state, state
keys = jax.random.split(rng_key, num_mcmc_steps)
state = mcmc_init(position, logprob_fn)
last_state, states = jax.lax.scan(one_step, state, keys)
# can be a while loop 🔁 for adaptive schemes
return last_state.position
# Waste-free version 🚯
# return states.position
# Can be nicely combined with progressive HMC sampling
# This could also be a SMC step 🙃
update = jax.vmap(update_particle)
weigh = jax.vmap(logprob_fn)
resample = blackjax.smc.resampling.stratified
new_particles, info = smc.step(
rng_key,
particles,
update,
weigh,
resample,
) And How general is this design?
|
Codecov Report
@@ Coverage Diff @@
## main #279 +/- ##
==========================================
- Coverage 99.16% 99.16% -0.01%
==========================================
Files 48 48
Lines 1923 1919 -4
==========================================
- Hits 1907 1903 -4
Misses 16 16
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. |
242102b
to
79c4730
Compare
@AdrienCorenflos could you please take a look? (especially Left in this PR:
|
a46d762
to
412661b
Compare
I am about done with this. The new interface allows to work with the Waste-Free version of SMC, one just has to design the appropriate "update" function and pass the desired number of samples extracted during resampling. I added a test that demonstrates this. I am waiting for #441 to decide whether we move all the content of |
This is ready for review. We'll do the repo re-organisation in a separate PR so it is easily reversible. |
The number of particles we resample in the SMC step is currently equal to the number of weights passed to the resampling function. In this PR we allow the caller to ask for a different number of particles. This allows to build Waste-Free SMC kernels by asking to resample M < N particles and build the update function so that it returns N particles.
ac9f75d
to
c7e1f5e
Compare
There are a few things that I find unsatisfactory with the SMC base kernel:
kernel_factory
, which prevents the users from passing different parameters at different iterations. It forces adaptation to happen inside the SMC kernel, which is something we don't want. We should stick to the sample, then update the parameters paradigm sketched in Refactor the adaptation kernels #276I think this decomposition is what was conceptually missing to properly integrate #117
vmap
, but it should be possible to usepmap
as welljnum_mcmc_iterations
,mcmc_iter
, etc. naming is inconsistent.The implementation of the SMC base kernel should be based on the formalism exposed in this book.