Skip to content

Commit

Permalink
Refactor adaptation to return parameters and extra info
Browse files Browse the repository at this point in the history
We currently only return the last state, the values of the parameter and
the adapted kernel. However, the full chain and intermediate adaptation
states can be useful when debugging inference.

In addition, adaptation currently returns a `kernel` where the
parameters have already been specified. This is however a bit to high
level for Blackjax and can make vmap-ing adaptation difficult.

Finally, MEADS is currently only implemented as an adaptation scheme for
GHMC, we change its name to reflect this.

In this PR we make `window_adaptation`, `meads_adaptation` and
`pathfinder_adaptation` return extra information, and do not return the
kernel directly anymore.
  • Loading branch information
rlouf committed Jan 17, 2023
1 parent 888d273 commit c0f9687
Show file tree
Hide file tree
Showing 19 changed files with 97 additions and 87 deletions.
4 changes: 2 additions & 2 deletions blackjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
hmc,
irmh,
mala,
meads,
meads_adaptation,
meanfield_vi,
mgrad_gaussian,
nuts,
Expand Down Expand Up @@ -38,11 +38,11 @@
"irmh",
"elliptical_slice",
"ghmc",
"meads",
"sgld", # stochastic gradient mcmc
"sghmc",
"csgld",
"window_adaptation", # mcmc adaptation
"meads_adaptation",
"pathfinder_adaptation",
"adaptive_tempered_smc", # smc
"tempered_smc",
Expand Down
4 changes: 2 additions & 2 deletions blackjax/adaptation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from . import meads, pathfinder_adaptation, window_adaptation
from . import meads_adaptation, pathfinder_adaptation, window_adaptation

__all__ = ["meads", "window_adaptation", "pathfinder_adaptation"]
__all__ = ["meads_adaptation", "window_adaptation", "pathfinder_adaptation"]
File renamed without changes.
102 changes: 49 additions & 53 deletions blackjax/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,10 +680,15 @@ def step_fn(

class AdaptationResults(NamedTuple):
state: PyTree
kernel: Callable
parameters: dict


class AdaptationInfo(NamedTuple):
state: NamedTuple
info: NamedTuple
adaptation_state: NamedTuple


def window_adaptation(
algorithm: Union[hmc, nuts],
logdensity_fn: Callable,
Expand Down Expand Up @@ -731,9 +736,9 @@ def window_adaptation(
"""

step_fn = algorithm.kernel()
mcmc_step = algorithm.kernel()

init, update, final = adaptation.window_adaptation.base(
adapt_init, adapt_step, adapt_final = adaptation.window_adaptation.base(
is_mass_matrix_diagonal,
target_acceptance_rate=target_acceptance_rate,
)
Expand All @@ -742,15 +747,15 @@ def one_step(carry, xs):
_, rng_key, adaptation_stage = xs
state, adaptation_state = carry

new_state, info = step_fn(
new_state, info = mcmc_step(
rng_key,
state,
logdensity_fn,
adaptation_state.step_size,
adaptation_state.inverse_mass_matrix,
**extra_parameters,
)
new_adaptation_state = update(
new_adaptation_state = adapt_step(
adaptation_state,
adaptation_stage,
new_state.position,
Expand All @@ -759,13 +764,13 @@ def one_step(carry, xs):

return (
(new_state, new_adaptation_state),
(new_state, info, new_adaptation_state),
AdaptationInfo(new_state, info, new_adaptation_state),
)

def run(rng_key: PRNGKey, position: PyTree, num_steps: int = 1000):

init_state = algorithm.init(position, logdensity_fn)
init_adaptation_state = init(position, initial_step_size)
init_adaptation_state = adapt_init(position, initial_step_size)

if progress_bar:
print("Running window adaptation")
Expand All @@ -775,29 +780,32 @@ def run(rng_key: PRNGKey, position: PyTree, num_steps: int = 1000):

keys = jax.random.split(rng_key, num_steps)
schedule = adaptation.window_adaptation.schedule(num_steps)
last_state, adaptation_chain = jax.lax.scan(
last_state, info = jax.lax.scan(
one_step_,
(init_state, init_adaptation_state),
(jnp.arange(num_steps), keys, schedule),
)
last_chain_state, last_warmup_state, *_ = last_state

step_size, inverse_mass_matrix = final(last_warmup_state)
step_size, inverse_mass_matrix = adapt_final(last_warmup_state)
parameters = {
"step_size": step_size,
"inverse_mass_matrix": inverse_mass_matrix,
**extra_parameters,
}

def kernel(rng_key, state):
return step_fn(rng_key, state, logdensity_fn, **parameters)

return AdaptationResults(last_chain_state, kernel, parameters)
return (
AdaptationResults(
last_chain_state,
parameters,
),
info,
)

return AdaptationAlgorithm(run)


def meads(
def meads_adaptation(
logdensity_fn: Callable,
num_chains: int,
) -> AdaptationAlgorithm:
Expand Down Expand Up @@ -837,33 +845,32 @@ def meads(
"""

step_fn = ghmc.kernel()
ghmc_step = ghmc.kernel()

init, update = adaptation.meads.base()
adapt_init, adapt_update = adaptation.meads_adaptation.base()

batch_init = jax.vmap(lambda r, p: ghmc.init(r, p, logdensity_fn))

def one_step(carry, rng_key):
states, adaptation_state = carry

def kernel(rng_key, state):
return step_fn(
rng_key,
state,
logdensity_fn,
adaptation_state.step_size,
adaptation_state.position_sigma,
adaptation_state.alpha,
adaptation_state.delta,
)

keys = jax.random.split(rng_key, num_chains)
new_states, info = jax.vmap(kernel)(keys, states)
new_adaptation_state = update(
new_states, info = jax.vmap(
ghmc_step, in_axes=(0, 0, None, None, None, None, None)
)(
keys,
states,
logdensity_fn,
adaptation_state.step_size,
adaptation_state.position_sigma,
adaptation_state.alpha,
adaptation_state.delta,
)
new_adaptation_state = adapt_update(
adaptation_state, new_states.position, new_states.logdensity_grad
)

return (new_states, new_adaptation_state), (
return (new_states, new_adaptation_state), AdaptationInfo(
new_states,
info,
new_adaptation_state,
Expand All @@ -875,10 +882,10 @@ def run(rng_key: PRNGKey, positions: PyTree, num_steps: int = 1000):

rng_keys = jax.random.split(key_init, num_chains)
init_states = batch_init(rng_keys, positions)
init_adaptation_state = init(positions, init_states.logdensity_grad)
init_adaptation_state = adapt_init(positions, init_states.logdensity_grad)

keys = jax.random.split(key_adapt, num_steps)
(last_states, last_adaptation_state), _ = jax.lax.scan(
(last_states, last_adaptation_state), info = jax.lax.scan(
one_step, (init_states, init_adaptation_state), keys
)

Expand All @@ -889,15 +896,7 @@ def run(rng_key: PRNGKey, positions: PyTree, num_steps: int = 1000):
"delta": last_adaptation_state.delta,
}

def kernel(rng_key, state):
return step_fn(
rng_key,
state,
logdensity_fn,
**parameters,
)

return AdaptationResults(last_states, kernel, parameters)
return AdaptationResults(last_states, parameters), info

return AdaptationAlgorithm(run) # type: ignore[arg-type]

Expand Down Expand Up @@ -1326,28 +1325,28 @@ def pathfinder_adaptation(
"""

step_fn = algorithm.kernel()
mcmc_step = algorithm.kernel()

init, update, final = adaptation.pathfinder_adaptation.base(
adapt_init, adapt_update, adapt_final = adaptation.pathfinder_adaptation.base(
target_acceptance_rate,
)

def one_step(carry, rng_key):
state, adaptation_state = carry
new_state, info = step_fn(
new_state, info = mcmc_step(
rng_key,
state,
logdensity_fn,
adaptation_state.step_size,
adaptation_state.inverse_mass_matrix,
**extra_parameters,
)
new_adaptation_state = update(
new_adaptation_state = adapt_update(
adaptation_state, new_state.position, info.acceptance_rate
)
return (
(new_state, new_adaptation_state),
(new_state, info, new_adaptation_state.ss_state),
AdaptationInfo(new_state, info, new_adaptation_state),
)

def run(rng_key: PRNGKey, position: PyTree, num_steps: int = 400):
Expand All @@ -1357,7 +1356,7 @@ def run(rng_key: PRNGKey, position: PyTree, num_steps: int = 400):
pathfinder_state, _ = vi.pathfinder.approximate(
init_key, logdensity_fn, position
)
init_warmup_state = init(
init_warmup_state = adapt_init(
pathfinder_state.alpha,
pathfinder_state.beta,
pathfinder_state.gamma,
Expand All @@ -1368,24 +1367,21 @@ def run(rng_key: PRNGKey, position: PyTree, num_steps: int = 400):
init_state = algorithm.init(init_position, logdensity_fn)

keys = jax.random.split(rng_key, num_steps)
last_state, warmup_chain = jax.lax.scan(
last_state, info = jax.lax.scan(
one_step,
(init_state, init_warmup_state),
keys,
)
last_chain_state, last_warmup_state = last_state

step_size, inverse_mass_matrix = final(last_warmup_state)
step_size, inverse_mass_matrix = adapt_final(last_warmup_state)
parameters = {
"step_size": step_size,
"inverse_mass_matrix": inverse_mass_matrix,
**extra_parameters,
}

def kernel(rng_key, state):
return step_fn(rng_key, state, logdensity_fn, **parameters)

return AdaptationResults(last_chain_state, kernel, parameters)
return AdaptationResults(last_chain_state, parameters), info

return AdaptationAlgorithm(run)

Expand Down
2 changes: 1 addition & 1 deletion docs/adaptation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ MEADS

.. currentmodule:: blackjax

.. autoclass:: blackjax.meads
.. autoclass:: blackjax.meads_adaptation
3 changes: 2 additions & 1 deletion docs/examples/GP_EllipticalSliceSampler.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,8 @@ n_iter = 2000
logdensity_fn = lambda f: loglikelihood_fn(f) - 0.5 * jnp.dot(f @ invSigma, f)
warmup = window_adaptation(nuts, logdensity_fn, n_warm, target_acceptance_rate=0.8)
key_warm, key_sample = jrnd.split(jrnd.PRNGKey(0))
state, kernel, _ = warmup.run(key_warm, f)
(state, params), _ = warmup.run(key_warm, f)
kernel = nuts(logdensity_fn, **parameters).step
states, _ = inference_loop(key_sample, state, kernel, n_iter)
```

Expand Down
3 changes: 2 additions & 1 deletion docs/examples/Introduction.md
Original file line number Diff line number Diff line change
Expand Up @@ -166,14 +166,15 @@ The adaptation algorithm takes a function that returns a transition kernel given
%%time
warmup = blackjax.window_adaptation(blackjax.nuts, logdensity)
state, kernel, _ = warmup.run(rng_key, initial_position, num_steps=1000)
(state, parameters), _ = warmup.run(rng_key, initial_position, num_steps=1000)
```

We can use the obtained parameters to define a new kernel. Note that we do not have to use the same kernel that was used for the adaptation:

```{code-cell} python
%%time
kernel = blackjax.nuts(logdensity_fn, **parameters)
states = inference_loop(rng_key, kernel, state, 1_000)
loc_samples = states.position["loc"].block_until_ready()
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/Pathfinder.md
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ This scheme is implemented in `blackjax.kernel.pathfinder_adaptation` function:

```{code-cell} python
adapt = blackjax.kernels.pathfinder_adaptation(blackjax.nuts, logdensity_fn)
state, kernel, info = adapt.run(rng_key, w0, 400)
(state, parameters), info = adapt.run(rng_key, w0, 400)
```

## Some Caveats
Expand Down
5 changes: 3 additions & 2 deletions docs/examples/RegimeSwitchingModel.md
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,9 @@ dist.initialize_model(kinit, n_chain)
```{code-cell} python
tic1 = pd.Timestamp.now()
k_warm, k_sample = jrnd.split(ksam)
warmup = blackjax.meads(dist.logdensity_fn, n_chain)
init_state, kernel, _ = warmup.run(k_warm, dist.init_params, n_warm)
warmup = blackjax.meads_adaptation(dist.logdensity_fn, n_chain)
(init_state, parameters), _ = warmup.run(k_warm, dist.init_params, n_warm)
kernel = blackjax.ghmc(dist.logdensity_fn, **parameters).step
def one_chain(k_sam, init_state):
state, info = inference_loop(k_sam, init_state, kernel, n_iter)
Expand Down
6 changes: 3 additions & 3 deletions docs/examples/SparseLogisticRegression.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,10 @@ dist.initialize_model(kinit, n_chain)
tic1 = pd.Timestamp.now()
k_warm, k_sample = jrnd.split(ksam)
warmup = blackjax.meads(dist.logdensity_fn, n_chain)
adaptation_results = warmup.run(k_warm, dist.init_params, n_warm)
warmup = blackjax.meads_adaptation(dist.logdensity_fn, n_chain)
adaptation_results, _ = warmup.run(k_warm, dist.init_params, n_warm)
init_state = adaptation_results.state
kernel = adaptation_results.kernel
kernel = blackjax.ghmc(dist.logdensity_fn, **adaptation_results.parameters).step
def one_chain(k_sam, init_state):
Expand Down
6 changes: 3 additions & 3 deletions docs/examples/change_of_variable_hmc.md
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ init_params = jax.vmap(init_param_fn)(keys)
@jax.vmap
def call_warmup(seed, param):
initial_states, _, tuned_params = warmup.run(seed, param, 1000)
(initial_states, tuned_params), _ = warmup.run(seed, param, 1000)
return initial_states, tuned_params
initial_states, tuned_params = jax.jit(call_warmup)(keys, init_params)
Expand Down Expand Up @@ -468,7 +468,7 @@ init_params = jax.vmap(init_param_fn)(keys)
@jax.vmap
def call_warmup(seed, param):
initial_states, _, tuned_params = warmup.run(seed, param, 1000)
(initial_states, tuned_params), _ = warmup.run(seed, param, 1000)
return initial_states, tuned_params
initial_states, tuned_params = call_warmup(keys, init_params)
Expand Down Expand Up @@ -565,7 +565,7 @@ keys = jax.random.split(warmup_key, n_chains)
@jax.vmap
def call_warmup(seed, param):
initial_states, _, tuned_params = warmup.run(seed, param, 1000)
(initial_states, tuned_params), _ = warmup.run(seed, param, 1000)
return initial_states, tuned_params
initial_states, tuned_params = call_warmup(keys, init_params)
Expand Down
3 changes: 2 additions & 1 deletion docs/examples/howto_use_aesara.md
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,8 @@ n_adapt = 3000
n_samples = 1000
adapt = blackjax.window_adaptation(blackjax.nuts, logdensity_fn)
state, kernel, _ = adapt.run(rng_key, init_position, n_adapt)
(state, parameters), _ = adapt.run(rng_key, init_position, n_adapt)
kernel = blackjax.nuts(logdensity_fn, **parameters).step
states, infos = inference_loop(
rng_key, kernel, state, n_samples
Expand Down
3 changes: 2 additions & 1 deletion docs/examples/howto_use_numpyro.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ num_warmup = 2000
adapt = blackjax.window_adaptation(
blackjax.nuts, logdensity_fn, target_acceptance_rate=0.8
)
last_state, kernel, _ = adapt.run(rng_key, initial_position, num_warmup)
(last_state, parameters), _ = adapt.run(rng_key, initial_position, num_warmup)
kernel = blackjax.nuts(logdensity_fn, **parameters).step
```

Let us now perform inference with the tuned kernel:
Expand Down
Loading

0 comments on commit c0f9687

Please sign in to comment.