Skip to content

Commit

Permalink
parametrize batch size (#285)
Browse files Browse the repository at this point in the history
  • Loading branch information
nerdai authored Jan 22, 2025
1 parent 75eaf24 commit 72f2fe6
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 4 deletions.
1 change: 1 addition & 0 deletions src/bin/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ static EXAMPLE_REGISTRY: LazyLock<HashMap<&'static str, Box<dyn Example>>> = Laz
m.insert("E.03", Box::new(examples::apdx_e::EG03));
m.insert("E.04", Box::new(examples::apdx_e::EG04));
m.insert("E.05", Box::new(examples::apdx_e::EG05));
m.insert("E.06", Box::new(examples::apdx_e::EG06));
m
});

Expand Down
149 changes: 146 additions & 3 deletions src/examples/apdx_e.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ impl Example for EG01 {
fn main(&self) -> Result<()> {
use crate::listings::apdx_e::create_candle_dataloaders;

let (train_loader, val_loader, test_loader) = create_candle_dataloaders()?;
let batch_size = 8_usize;
let (train_loader, val_loader, test_loader) = create_candle_dataloaders(batch_size)?;

// print last batch of train loader
let (input_batch, target_batch) = train_loader.batcher().last().unwrap()?;
Expand Down Expand Up @@ -120,7 +121,8 @@ impl Example for EG02 {
modify_out_head_for_classification(&mut model, cfg, num_classes, &varmap, vb.pp("model"))?;

// calc classification accuracy
let (train_loader, val_loader, test_loader) = create_candle_dataloaders()?;
let batch_size = 8_usize;
let (train_loader, val_loader, test_loader) = create_candle_dataloaders(batch_size)?;

// compute accuracies
let num_batches = Some(10_usize);
Expand Down Expand Up @@ -331,7 +333,8 @@ impl Example for EG05 {
let model = GPTModelWithLoRA::from_gpt_model(model, rank, alpha, vb.pp("model"))?;

// calc classification accuracy
let (train_loader, val_loader, test_loader) = create_candle_dataloaders()?;
let batch_size = 8_usize;
let (train_loader, val_loader, test_loader) = create_candle_dataloaders(batch_size)?;

// compute accuracies
let num_batches = Some(10_usize);
Expand All @@ -349,3 +352,143 @@ impl Example for EG05 {
Ok(())
}
}

/// # Fine-tuning a model with LoRA layers
///
/// NOTE: technically this Listing 7.1 in the book, but we felt it was better
/// as an Example.
///
/// #### Id
/// E.06
///
/// #### Page
/// This example starts on page 334
///
/// #### CLI command
/// ```sh
/// # without cuda
/// cargo run example E.06
///
/// # with cuda
/// cargo run --features cuda example E.06
/// ```
pub struct EG06;

impl Example for EG06 {
fn description(&self) -> String {
"Fine-tuning a model with LoRA layers.".to_string()
}

fn page_source(&self) -> usize {
334_usize
}

fn main(&self) -> Result<()> {
use crate::listings::{
apdx_e::{create_candle_dataloaders, train_classifier_simple, GPTModelWithLoRA},
ch04::Config,
ch06::{
download_and_load_gpt2, modify_out_head_for_classification, plot_values,
HF_GPT2_MODEL_ID,
},
};
use candle_core::{DType, Device, Var};
use candle_nn::{AdamW, Optimizer, ParamsAdamW, VarBuilder, VarMap};
use ndarray::linspace;
use std::path::Path;

// get gpt model with classification head
let mut cfg = Config::gpt2_124m();
cfg.qkv_bias = true;
let varmap = VarMap::new();
let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?);
let mut model = download_and_load_gpt2(&varmap, vb.pp("model"), cfg, HF_GPT2_MODEL_ID)?;

// modify to use classification head
let num_classes = 2_usize;
modify_out_head_for_classification(&mut model, cfg, num_classes, &varmap, vb.pp("model"))?;

// convert to LoRA model
let rank = 16_usize;
let alpha = 16_f64;
let model = GPTModelWithLoRA::from_gpt_model(model, rank, alpha, vb.pp("model"))?;

// data loaders
let batch_size = 2_usize; // Get OOM on my Tesla P100 (12GB) with 8_usize
let (train_loader, val_loader, _test_loader) = create_candle_dataloaders(batch_size)?;

// extract only LoRA weights as trainable params
let mut training_vars: Vec<Var> = vec![];
let tensor_data = varmap.data().lock().unwrap();
let var_names: Vec<&String> = tensor_data
.keys()
.filter(|k| k.contains("A") || k.contains("B"))
.collect();

println!("Training variables: {:?}\n", var_names);

for var_name in var_names.into_iter() {
let var = tensor_data.get(var_name).unwrap();
training_vars.push(var.clone());
}
drop(tensor_data);

// train model
let optimizer = AdamW::new(
training_vars,
ParamsAdamW {
lr: 5e-5,
weight_decay: 0.1,
..Default::default()
},
)?;

let (eval_freq, eval_iter, num_epochs) = (50_usize, 5_usize, 5_usize);
let (train_loss, val_loss, train_accs, val_accs, num_examples) = train_classifier_simple(
&model,
&train_loader,
&val_loader,
optimizer,
vb.device(),
num_epochs,
eval_freq,
eval_iter,
None,
)?;

// save model
println!("Saving weights to `./clf.gptwithlora.checkpoint.safetensors`");
varmap.save("clf.gptwithlora.checkpoint.safetensors")?;

// prepare and save plots
let epochs_seen = Vec::from_iter(linspace(0_f32, num_epochs as f32, train_loss.len()));
let examples_seen = Vec::from_iter(linspace(0_f32, num_examples as f32, train_loss.len()));
let label = "loss";
let save_path = Path::new(format!("plot_classification_gptwithlora_{label}.html").as_str())
.to_path_buf();
plot_values(
epochs_seen,
examples_seen,
train_loss,
val_loss,
label,
save_path,
)?;

let epochs_seen = Vec::from_iter(linspace(0_f32, num_epochs as f32, train_accs.len()));
let examples_seen = Vec::from_iter(linspace(0_f32, num_examples as f32, train_accs.len()));
let label = "accuracy";
let save_path = Path::new(format!("plot_classification_gptwithlora_{label}.html").as_str())
.to_path_buf();
plot_values(
epochs_seen,
examples_seen,
train_accs,
val_accs,
label,
save_path,
)?;

Ok(())
}
}
2 changes: 1 addition & 1 deletion src/listings/apdx_e.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,11 @@ pub fn create_candle_datasets() -> anyhow::Result<(SpamDataset, SpamDataset, Spa
///
/// NOTE: This is merely EG 06.06
pub fn create_candle_dataloaders(
batch_size: usize,
) -> anyhow::Result<(SpamDataLoader, SpamDataLoader, SpamDataLoader)> {
let (train_dataset, val_dataset, test_dataset) = create_candle_datasets()?;

// create loaders
let batch_size = 8_usize;
let train_loader = SpamDataLoader::new(train_dataset, batch_size, true, true);
let val_loader = SpamDataLoader::new(val_dataset, batch_size, false, false);
let test_loader = SpamDataLoader::new(test_dataset, batch_size, false, false);
Expand Down

0 comments on commit 72f2fe6

Please sign in to comment.