Skip to content

dirmeier/sbijax

Repository files navigation

sbijax

active ci codecov documentation version

Simulation-based inference in JAX

About

Sbijax is a Python library for neural simulation-based inference and approximate Bayesian computation using JAX. It implements recent methods, such as Simulated-annealing ABC, Surjective Neural Likelihood Estimation, Neural Approximate Sufficient Statistics or Consistency model posterior estimation, as well as methods to compute model diagnostics and for visualizing posterior distributions.

Caution

⚠️ As per the LICENSE file, there is no warranty whatsoever for this free software tool. If you discover bugs, please report them.

Examples

Sbijax implements a slim object-oriented API with functional elements stemming from JAX. All a user needs to define is a prior model, a simulator function and an inferential algorithm. For example, you can define a neural likelihood estimation method and generate posterior samples like this:

from jax import numpy as jnp, random as jr
from sbijax import NLE
from sbijax.nn import make_maf
from tensorflow_probability.substrates.jax import distributions as tfd

def prior_fn():
    prior = tfd.JointDistributionNamed(dict(
        theta=tfd.Normal(jnp.zeros(2), jnp.ones(2))
    ), batch_ndims=0)
    return prior

def simulator_fn(seed, theta):
    p = tfd.Normal(jnp.zeros_like(theta["theta"]), 0.1)
    y = theta["theta"] + p.sample(seed=seed)
    return y


fns = prior_fn, simulator_fn
model = NLE(fns, make_maf(2))

y_observed = jnp.array([-1.0, 1.0])
data, _ = model.simulate_data(jr.PRNGKey(1))
params, _ = model.fit(jr.PRNGKey(2), data=data)
posterior, _ = model.sample_posterior(jr.PRNGKey(3), params, y_observed)

More self-contained examples can be found in examples.

Documentation

Documentation can be found here.

Installation

Make sure to have a working JAX installation. Depending whether you want to use CPU/GPU/TPU, please follow these instructions.

To install from PyPI, just call the following on the command line:

pip install sbijax

To install the latest GitHub , use:

pip install git+/~https://github.com/dirmeier/sbijax@<RELEASE>

Contributing

Contributions in the form of pull requests are more than welcome. A good way to start is to check out issues labelled good first issue.

In order to contribute:

  1. Clone sbijax and install hatch via pip install hatch,
  2. create a new branch locally git checkout -b feature/my-new-feature or git checkout -b issue/fixes-bug,
  3. implement your contribution and ideally a test case,
  4. test it by calling make tests, make lints and make format on the (Unix) command line,
  5. submit a PR 🙂

Acknowledgements

Note

📝 The API of the package is heavily inspired by the excellent Pytorch-based sbi package.

Author

Simon Dirmeier sfyrbnd @ pm me