Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug fix #724

Merged
merged 4 commits into from
Aug 24, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions blackjax/adaptation/mclmc_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,6 @@ def step(iteration_state, weight_and_key):
expectation=jnp.array([x, jnp.square(x)]),
incremental_val=streaming_avg,
weight=(1 - mask) * success * params.step_size,
zero_prevention=mask,
)

return (state, params, adaptive_state, streaming_avg), None
Expand Down Expand Up @@ -243,7 +242,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key):
L = params.L
# determine L
sqrt_diag_cov = params.sqrt_diag_cov
if num_steps2 != 0.0:
if num_steps2 > 1:
x_average, x_squared_average = average[0], average[1]
variances = x_squared_average - jnp.square(x_average)
L = jnp.sqrt(jnp.sum(variances))
Expand All @@ -260,6 +259,9 @@ def L_step_size_adaptation(state, params, num_steps, rng_key):
xs=(jnp.ones(steps), keys), state=state, params=params
)

jax.debug.print(
reubenharry marked this conversation as resolved.
Show resolved Hide resolved
"params {x}", x=MCLMCAdaptationState(L, params.step_size, sqrt_diag_cov)
)
return state, MCLMCAdaptationState(L, params.step_size, sqrt_diag_cov)

return L_step_size_adaptation
Expand Down Expand Up @@ -304,7 +306,8 @@ def handle_nans(previous_state, next_state, step_size, step_size_max, kinetic_ch

reduced_step_size = 0.8
p, unravel_fn = ravel_pytree(next_state.position)
nonans = jnp.all(jnp.isfinite(p))
q, unravel_fn = ravel_pytree(next_state.momentum)
nonans = jnp.logical_and(jnp.all(jnp.isfinite(p)), jnp.all(jnp.isfinite(q)))
state, step_size, kinetic_change = jax.tree_util.tree_map(
lambda new, old: jax.lax.select(nonans, jnp.nan_to_num(new), old),
(next_state, step_size_max, kinetic_change),
Expand Down
9 changes: 7 additions & 2 deletions blackjax/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,10 @@ def transform(state_and_incremental_val, info):
return SamplingAlgorithm(init_fn, update_fn), transform


def safediv(x, y):
junpenglao marked this conversation as resolved.
Show resolved Hide resolved
return jnp.where(x == 0.0, 0.0, x / y)


def incremental_value_update(
expectation, incremental_val, weight=1.0, zero_prevention=0.0
):
Expand All @@ -302,8 +306,9 @@ def incremental_value_update(

total, average = incremental_val
average = tree_map(
lambda exp, av: (total * av + weight * exp)
/ (total + weight + zero_prevention),
lambda exp, av: safediv(
total * av + weight * exp, (total + weight + zero_prevention)
),
expectation,
average,
)
Expand Down
Loading