Skip to content

Commit

Permalink
training run with parameters as in lib.rs
Browse files Browse the repository at this point in the history
  • Loading branch information
krukah committed Dec 27, 2024
1 parent 73ec8d2 commit fb03ef5
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 17 deletions.
6 changes: 2 additions & 4 deletions src/clustering/lookup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,18 +51,16 @@ impl Save for Lookup {
std::fs::metadata(format!("{}{}", street, Self::name())).is_ok()
}
fn make(street: Street) -> Self {
let n = street.n_isomorphisms();
let progress = crate::progress(n);
// abstractions for River are calculated once via obs.equity
// abstractions for Preflop are cequivalent to just enumerating isomorphisms
match street {
Street::Rive => IsomorphismIterator::from(Street::Rive)
.map(|iso| (iso, Abstraction::from(iso.0.equity())))
.inspect(|_| progress.inc(1))
.collect::<BTreeMap<_, _>>()
.into(),
Street::Pref => IsomorphismIterator::from(Street::Pref)
.enumerate()
.map(|(k, iso)| (iso, Abstraction::from((Street::Pref, k))))
.inspect(|_| progress.inc(1))
.collect::<BTreeMap<_, _>>()
.into(),
_ => panic!("lookup must be learned via layer for {street}"),
Expand Down
18 changes: 9 additions & 9 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,19 @@ const N_RAISE: usize = 3;

/// sinkhorn optimal transport parameters
const SINKHORN_TEMPERATURE: Entropy = 0.125;
const SINKHORN_ITERATIONS: usize = 16;
const SINKHORN_TOLERANCE: Energy = 0.001;
const SINKHORN_ITERATIONS: usize = 32;
const SINKHORN_TOLERANCE: Energy = 0.005;

// kmeans clustering parameters
const KMEANS_FLOP_TRAINING_ITERATIONS: usize = 32;
const KMEANS_TURN_TRAINING_ITERATIONS: usize = 32;
const KMEANS_FLOP_CLUSTER_COUNT: usize = 24;
const KMEANS_TURN_CLUSTER_COUNT: usize = 16;
const KMEANS_EQTY_CLUSTER_COUNT: usize = 64;
const KMEANS_FLOP_TRAINING_ITERATIONS: usize = KMEANS_TURN_TRAINING_ITERATIONS;
const KMEANS_TURN_TRAINING_ITERATIONS: usize = KMEANS_TURN_CLUSTER_COUNT;
const KMEANS_FLOP_CLUSTER_COUNT: usize = 128;
const KMEANS_TURN_CLUSTER_COUNT: usize = 144;
const KMEANS_EQTY_CLUSTER_COUNT: usize = 101;

// mccfr parameters
const CFR_BATCH_SIZE: usize = 16;
const CFR_TREE_COUNT: usize = 1024; // WARNING THIS WILL NOT SOLVE ANYTHING
const CFR_BATCH_SIZE: usize = 256;
const CFR_TREE_COUNT: usize = 1_048_576;
const CFR_ITERATIONS: usize = CFR_TREE_COUNT / CFR_BATCH_SIZE;
const CFR_PRUNNING_PHASE: usize = 100_000_000 / CFR_BATCH_SIZE;
const CFR_DISCOUNT_PHASE: usize = 100_000 / CFR_BATCH_SIZE;
Expand Down
4 changes: 0 additions & 4 deletions src/mccfr/profile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -583,12 +583,8 @@ mod tests {
/// arguments to the save function to write to a temporary name
/// and delete the file
fn persistence() {
let name = "test";
let file = format!("{}.profile.pgcopy", name);
let save = Profile::random();
save.save();
let load = Profile::load(Street::random());
std::fs::remove_file(file).unwrap();
assert!(std::iter::empty()
.chain(save.strategies.iter().zip(load.strategies.iter()))
.chain(load.strategies.iter().zip(save.strategies.iter()))
Expand Down

0 comments on commit fb03ef5

Please sign in to comment.