From a8ff4a469cdcf14b7478168695df2f5fcd6dc5b5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Fri, 7 Oct 2022 17:54:09 +0200 Subject: [PATCH] Refactor the base SMC kernel Base SMC is neatly divided in 3 steps: - particle update - particle weighting - resampling --- blackjax/smc/base.py | 37 ++++++++++++++++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/blackjax/smc/base.py b/blackjax/smc/base.py index 2e9289369..d649a5442 100644 --- a/blackjax/smc/base.py +++ b/blackjax/smc/base.py @@ -16,7 +16,7 @@ import jax import jax.numpy as jnp -from blackjax.types import PyTree +from blackjax.types import PRNGKey, PyTree class SMCInfo(NamedTuple): @@ -40,6 +40,41 @@ class SMCInfo(NamedTuple): log_likelihood_increment: float +def base( + rng_key: PRNGKey, + particles: PyTree, + update: Callable, + weigh: Callable, + resample: Callable, +): + """General SMC sampling step. + + rng_key + particles + update + weigh + resample + """ + + updating_key, resampling_key = jax.random.split(rng_key, 2) + + num_particles = jax.tree_util.tree_flatten(particles)[0][0].shape[0] + keys = jax.random.split(updating_key, num_particles) + particles, update_info = update(keys, particles) + + weights = weigh(particles) + weights, logp_increments = normalize(weights) + # Here normalize the weights and compute log_increments + + resampling_idx = resample(weights, resampling_key) + particles = jax.tree_map(lambda x: x[resampling_idx], particles) + + # class TemperedSMCInfo(NamedTuple): + # lambda: float + # smc_info: SMCInfo + return particles, SMCInfo(weights, resampling_idx, logp_increments, update_info) + + def kernel( mcmc_step_fn: Callable, mcmc_init_fn: Callable,