Skip to content

Commit

Permalink
Documentation clean up (#572)
Browse files Browse the repository at this point in the history
* documentation clean up

* formatting

* remove tensorflow-cpu
  • Loading branch information
junpenglao authored Sep 27, 2023
1 parent 8beec4d commit 51cf08a
Show file tree
Hide file tree
Showing 13 changed files with 312 additions and 245 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ import blackjax

observed = np.random.normal(10, 20, size=1_000)
def logdensity_fn(x):
logpdf = stats.norm.logpdf(observed, x["loc"], x["scale"])
return jnp.sum(logpdf)
logpdf = stats.norm.logpdf(observed, x["loc"], x["scale"])
return jnp.sum(logpdf)

# Build the kernel
step_size = 1e-3
Expand Down Expand Up @@ -136,11 +136,11 @@ To cite this repository:

```
@software{blackjax2020github,
author = {Lao, Junpeng and Louf, R\'emi},
author = {Cabezas, Alberto, Lao, Junpeng, and Louf, R\'emi},
title = {{B}lackjax: A sampling library for {JAX}},
url = {http://github.com/blackjax-devs/blackjax},
version = {<insert current release tag>},
year = {2020},
year = {2023},
}
```
In the above bibtex entry, names are in alphabetical order, the version number should be the last tag on the `main` branch.
Expand Down
36 changes: 22 additions & 14 deletions docs/examples/howto_custom_gradients.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.14.0
jupytext_version: 1.15.2
kernelspec:
display_name: Python 3 (ipykernel)
language: python
Expand Down Expand Up @@ -37,8 +37,16 @@ $$

And define the function $f$ as $f(x) = -min_y g(x, y)$ which we can be implemented as:

```{code-cell} python
```{code-cell} ipython3
:tags: [remove-output]
import jax
from datetime import date
rng_key = jax.random.key(int(date.today().strftime("%Y%m%d")))
```

```{code-cell} ipython3
import jax.numpy as jnp
from jax.scipy.optimize import minimize
Expand All @@ -56,15 +64,14 @@ def f(x, p):
return -res.fun, res.x[0]
```


Note the we also return the value of $y$ where the minimum of $g$ is achieved (this will be useful later).


### Trying to differentate the function with `jax.grad`

The gradient of the function $f$ is undefined for JAX, which cannot differentiate through `while` loops, and trying to compute it directly raises an error:

```{code-cell} python
```{code-cell} ipython3
# We only want the gradient with respect to `x`
try:
jax.grad(f, has_aux=True)(0.5, 3)
Expand Down Expand Up @@ -105,7 +112,7 @@ i.e. the value of the derivative at $x$ is the value $y(x)$ at which the minimum

We can thus now tell JAX to compute the derivative of the function using the argmin using `jax.custom_vjp`

```{code-cell} python
```{code-cell} ipython3
from functools import partial
Expand All @@ -128,7 +135,7 @@ def f_jac_vec_prod(p, primals, tangents):

Which now outputs a value:

```{code-cell} python
```{code-cell} ipython3
jax.grad(f_with_gradient)(0.31415, 3)
```

Expand All @@ -145,7 +152,7 @@ $$

Which is obviously differentiable. We implement it:

```{code-cell} python
```{code-cell} ipython3
def true_f(x, p):
q = 1 / (1 - 1 / p)
out = jnp.abs(x) ** q
Expand All @@ -156,8 +163,9 @@ print(jax.grad(true_f)(0.31415, 3))

And compare the gradient of this function with the custom gradient defined above:

```{code-cell} python
```{code-cell} ipython3
:tags: [hide-input]
print(f"Gradient of closed-form f: {jax.grad(true_f)(0.31415, 3)}")
print(f"Custom gradient based on argmin: {jax.grad(f_with_gradient)(0.31415, 3)}")
```
Expand All @@ -171,7 +179,7 @@ They give close enough values! In other words, it suffices to know that the valu

Let us now demonstrate that we can use `f_with_gradients` with Blackjax. We define a toy log-density function and use a gradient-based sampler:

```{code-cell} python
```{code-cell} ipython3
import blackjax
Expand All @@ -181,13 +189,13 @@ def logdensity_fn(y):
logdensity += jax.scipy.stats.norm.logpdf(x)
return logdensity
hmc = blackjax.hmc(logdensity_fn,1e-3, jnp.ones(1), 10)
hmc = blackjax.hmc(logdensity_fn,1e-2, jnp.ones(1), 20)
state = hmc.init(1.)
rng_key = jax.random.key(0)
new_state, info = hmc.step(rng_key, state)
rng_key, step_key = jax.random.split(rng_key)
new_state, info = hmc.step(step_key, state)
```

```{code-cell} python
new_state
```{code-cell} ipython3
state, new_state
```
31 changes: 15 additions & 16 deletions docs/examples/howto_metropolis_within_gibbs.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.14.7
jupytext_version: 1.15.2
kernelspec:
display_name: Python 3 (ipykernel)
language: python
Expand All @@ -16,14 +16,16 @@ kernelspec:
Gibbs sampling is an MCMC technique where sampling from a joint probability distribution $\newcommand{\xx}{\boldsymbol{x}}\newcommand{\yy}{\boldsymbol{y}}p(\xx, \yy)$ is achieved by alternately sampling from $\xx \sim p(\xx \mid \yy)$ and $\yy \sim p(\yy \mid \xx)$. Ideally these conditional distributions can be sampled from analytically. In general however they must each be updated using any MCMC kernel appropriate to the conditional distribution at hand. This technique is referred to as Metropolis-within-Gibbs (MWG) sampling. The idea can be applied to an arbitrary number of blocks of variables $p(\xx_1, \ldots, \xx_n)$. For simplicity in this notebook we focus on a two-block example.

```{code-cell} ipython3
:tags: [remove-output]
import jax
import jax.numpy as jnp
import jax.scipy as jsp
import blackjax
import pandas as pd
import seaborn as sns
from datetime import date
rng_key = jax.random.key(int(date.today().strftime("%Y%m%d")))
```

## The Model
Expand Down Expand Up @@ -200,19 +202,17 @@ def sampling_loop(rng_key, initial_state, parameters, num_samples):

```{code-cell} ipython3
%%time
rng_key = jax.random.key(0)
positions = sampling_loop(rng_key, initial_state, parameters, 10_000)
rng_key, sample_key = jax.random.split(rng_key)
positions = sampling_loop(sample_key, initial_state, parameters, 10_000)
```

```{code-cell} ipython3
plt_data = pd.DataFrame({
"x1": positions["x"][:,0],
"x2": positions["x"][:,1],
"y1": positions["y"][:,0],
"y2": positions["y"][:,1]
})
sns.pairplot(plt_data, kind="hist")
import matplotlib.pyplot as plt
import arviz as az
idata = az.from_dict(posterior={k: v[None, ...] for k, v in positions.items()})
az.plot_pair(idata, kind='hexbin', marginals=True)
plt.tight_layout();
```

## General MWG Kernel
Expand Down Expand Up @@ -305,9 +305,8 @@ def sampling_loop_general(rng_key, initial_state, logdensity_fn, step_fn, init,

```{code-cell} ipython3
%%time
rng_key = jax.random.key(0)
positions_general = sampling_loop_general(
rng_key=rng_key,
rng_key=sample_key, # reuse PRNG key from above
initial_state=initial_state,
logdensity_fn=logdensity,
step_fn={
Expand All @@ -326,7 +325,7 @@ positions_general = sampling_loop_general(
### Check Result

```{code-cell} ipython3
{k: jnp.max(jnp.abs(positions[k] - positions_general[k])) for k in initial_state.keys()}
jax.tree_map(lambda x, y: jnp.max(jnp.abs(x-y)), positions, positions_general)
```

## Developer Notes
Expand Down
37 changes: 22 additions & 15 deletions docs/examples/howto_other_frameworks.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.13.1
jupytext_version: 1.15.2
kernelspec:
display_name: Python 3 (ipykernel)
language: python
Expand All @@ -20,11 +20,19 @@ Nevertheless, you may have a good reason to use a function that is incompatible

In this example we will show you how this can be done using JAX's experimental `host_callback` API, and hint at a faster solution.

```{code-cell} ipython3
:tags: [remove-output]
import jax
from datetime import date
rng_key = jax.random.key(int(date.today().strftime("%Y%m%d")))
```

## Aesara model compiled to Numba

The following example builds a logdensity function with [Aesara](/~https://github.com/aesara-devs/aesara), compiles it with [Numba](https://numba.pydata.org/) and uses Blackjax to sample from the posterior distribution of the model.

```{code-cell} python
```{code-cell} ipython3
import aesara.tensor as at
import numpy as np
Expand All @@ -41,7 +49,7 @@ Y_rv = N_rv[I_rv]

We can sample from the prior predictive distribution to make sure the model is correctly implemented:

```{code-cell} python
```{code-cell} ipython3
import aesara
sampling_fn = aesara.function((), Y_rv)
Expand All @@ -51,7 +59,7 @@ print(sampling_fn())

We do not care about the posterior distribution of the indicator variable `I_rv` so we marginalize it out, and subsequently build the logdensity's graph:

```{code-cell} python
```{code-cell} ipython3
from aeppl import joint_logprob
y_vv = Y_rv.clone()
Expand All @@ -69,14 +77,14 @@ total_logdensity = at.logsumexp(at.log(weights) + logdensity)

We are now ready to compile the logdensity to Numba:

```{code-cell} python
```{code-cell} ipython3
logdensity_fn = aesara.function((y_vv,), total_logdensity, mode="NUMBA")
logdensity_fn(1.)
```

As is we cannot use these functions within jit-compiled functions written with JAX, or apply `jax.grad` to get the function's gradients:

```{code-cell} python
```{code-cell} ipython3
try:
jax.jit(logdensity_fn)(1.)
except Exception:
Expand All @@ -90,7 +98,7 @@ except Exception:

Indeed, a function written with Numba is incompatible with JAX's primitives. Luckily Aesara can build the model's gradient graph and compile it to Numba as well:

```{code-cell} python
```{code-cell} ipython3
total_logdensity_grad = at.grad(total_logdensity, y_vv)
logdensity_grad_fn = aesara.function((y_vv,), total_logdensity_grad, mode="NUMBA")
logdensity_grad_fn(1.)
Expand All @@ -100,8 +108,7 @@ logdensity_grad_fn(1.)

In order to be able to call `logdensity_fn` within JAX, we need to define a function that will call it via JAX's `host_callback`. Yet, this wrapper function is not differentiable with JAX, and so we will also need to define this functions' `custom_vjp`, and use `host_callback` to call the gradient-computing function as well:

```{code-cell} python
import jax
```{code-cell} ipython3
import jax.experimental.host_callback as hcb
@jax.custom_vjp
Expand All @@ -122,24 +129,24 @@ numba_logpdf.defvjp(vjp_fwd, vjp_bwd)

And we can now call the function from a jitted function and apply `jax.grad` without JAX complaining:

```{code-cell} python
```{code-cell} ipython3
:tags: [remove-stderr]
jax.jit(numba_logpdf)(1.), jax.grad(numba_logpdf)(1.)
```

And use Blackjax's NUTS sampler to sample from the model's posterior distribution:

```{code-cell} python
```{code-cell} ipython3
import blackjax
inverse_mass_matrix = np.ones(1)
step_size=1e-3
nuts = blackjax.nuts(numba_logpdf, step_size, inverse_mass_matrix)
init = nuts.init(0.)
rng_key = jax.random.key(0)
state, info = nuts.step(rng_key, init)
rng_key, init_key = jax.random.split(rng_key)
state, info = nuts.step(init_key, init)
for _ in range(10):
rng_key, nuts_key = jax.random.split(rng_key)
Expand All @@ -150,15 +157,15 @@ print(state)

If you run this on your machine you will notice that this runs quite slowly compared to a pure-JAX equivalent, that's because `host_callback` implied a lot of back-and-forth with Python. To see this let's compare execution times between *pure Numba on the one hand*:

```{code-cell} python
```{code-cell} ipython3
%%time
for _ in range(100_000):
logdensity_fn(100)
```

And *JAX on the other hand, with 100 times less iterations*:

```{code-cell} python
```{code-cell} ipython3
%%time
for _ in range(1_000):
numba_logpdf(100.)
Expand Down
Loading

1 comment on commit 51cf08a

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'Python Benchmark with pytest-benchmark'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: 51cf08a Previous: 655c36b Ratio
tests/test_benchmarks.py::test_regression_hmc 0.03595568031505422 iter/sec (stddev: 0.4001022870366988) 0.0793071786328184 iter/sec (stddev: 0.4758808541120901) 2.21

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.