Code for the paper Bayesian Low-Rank Adaptation for Large Language Models.
See the explanatory blog post and documentation for more information.
pip install bayesian-lora
We provide a comprehensive example in examples/example_usage.py
, running
through the main methods using Phi-2 on ARC-E.
Note that running this requires a local installation with a few extra dependencies. Run:
git clone /~https://github.com/MaximeRobeyns/bayesian_lora
cd bayesian_lora
pip install -e ".[examples]"
and then
python ./examples/example_usage.py
The main functions this library provides are for calculating Kronecker factors, the marginal likelihood, and the posterior predictive distribution. We show how to use these in the examples below.
First, wrap your model call in a function that takes a batch from your data loader, and returns the relevant logits. For a CausalLM from HuggingFace:
def fwd_call(model: nn.Module, batch_prompts: Any) -> t.Tensor:
inputs = tokenizer(batch_prompts).to(device)
outputs = model(**inputs)
logits = outputs.logits[:, -1] # Get the last token logits
return logits
You can now call our calculate_kronecker_factors
function:
from bayesian_lora import calculate_kronecker_factors
factors = calculate_kronecker_factors(
model, # Your model (not necessarily PEFT)
fwd_call, # Model call wrapper, defined above
train_loader, # Your training data loader
cfg.n_kfac, # (Optional) rank to use
cfg.lr_threshold, # (Optional) threshold for low-rank approximation
["lora"], # (Optional) modules to target; defaults to all modules
use_tqdm=True, # (Optional) use tqdm for progress bar
)
In the above, the ["lora"]
argument contains a case-insensitive list of
keywords to identify modules to target. Since we're working with a LoRA model,
we choose "lora"
to target LoRA modules, for instance
layers.0.q_proj.lora_A
.
The factors
are a dictionary with keys being the full name of the targetted
modules, and a tuple of two tensors as the values: the first being the
(possibly low-rank) Kronecker factor corresponding to the input activations,
and the second being the (possibly low-rank) factor corresponding to the output
gradients.
See the K-FAC docs for more detail.
We provide a function called model_evidence
which returns the evidence /
marginal likelihood.
from bayesian_lora import model_evidence
evidence = model_evidence(
model, # Your model
log_likelihood, # A Tensor with model's log likelihood on some eval dataset
factors, # Kronecker factors, as calculated above
n_lora, # rank used in the LoRA adapters
n_kfac, # rank used in the Kronecker factors
prior_var, # prior variance hyperparameter, as a tensor
)
You can then use evidence
as the loss in a normal training loop, presuming
your parameters (e.g. prior_var
have gradients).
To get the parameters of the Gaussian over the logits, use
the jacobian_mean
and variance
functions.
with t.no_grad():
for batch in validation_loader
prompts, classes = batch
batch_inputs = tokenizer(prompts)
# Predict the output logit locations
# target_ids is a tensor containing the indices of the target tokens
# e.g. [354, 355, 356].
jacobian, f_mu = jacobian_mean(
model, batch_inputs, target_ids
)
# Predict the output logit variances
f_var = variance(
batch_inputs, # inputs
jacobian, # the Jacobian dictionary, obtained above
factors, # Kronecker factors, as calculated above
prior_var, # prior variance hyperparameter, as a tensor
classes.size(-1), # number of classes to predict
n_lora, # rank of the LoRA adapters
n_kfac, # rank of the Kronecker factors
device, # device to use
)
# Now use the parameters to e.g. sample logits from the Gaussian
# predictive, parametrised by f_mu, f_var
L = t.linalg.cholesky(f_var)
samples = 100_000
f_mu = f_mu.expand(samples, *f_mu.shape)
L = L.expand(samples, *L.shape)
eps = t.randn_like(f_mu)
logits = (f_mu + L @ eps).squeeze(-1).mean(0)
The above is a minimal example; see this section of the documentation for more detail.
This library is intentionally very small and hackable. It has two main files,
and three dependencies (torch
, tqdm
and jaxtyping
.)
main.py
contains methods specific to the paper,kfac.py
contains relatively portable K-FAC methods
Feel free to directly copy the code into your projects and hack on it.