Skip to content

Commit

Permalink
consistent progress bar usage
Browse files Browse the repository at this point in the history
  • Loading branch information
krukah committed Oct 16, 2024
1 parent 5d12edd commit 020d83c
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 35 deletions.
7 changes: 1 addition & 6 deletions src/cards/observation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,7 @@ impl From<(Hand, Hand)> for Observation {
impl From<Street> for Observation {
fn from(street: Street) -> Self {
let mut deck = Deck::new();
let n = match street {
Street::Pref => 0,
Street::Flop => 3,
Street::Turn => 4,
Street::Rive => 5,
};
let n = street.n_observed();
let public = (0..n)
.map(|_| deck.draw())
.map(u64::from)
Expand Down
2 changes: 1 addition & 1 deletion src/clustering/centroid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ pub struct Centroid {
}

impl Centroid {
pub fn rotate(&mut self) {
pub fn reset(&mut self) {
self.last.destroy();
std::mem::swap(&mut self.last, &mut self.next);
}
Expand Down
79 changes: 51 additions & 28 deletions src/clustering/layer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ impl Layer {
// add the abstraction-less PreFlop Observations
// or include a Abstraction::PreFlop(Hole) variant
// to make sure we cover the full set of Observations
// this might be better off to do in Explorer::load() perhaps
// or we could add one more .inner().save() call with
// special Preflop logic to not actually do any clustering
// i.e. k = 169, t = 0
}

/// start with the River layer. everything is empty because we
Expand Down Expand Up @@ -125,16 +129,16 @@ impl Layer {
self.street,
self.street.prev()
);
// pretty progress bar
let n = self.street.prev().n_isomorphisms() as u64;
let tick = std::time::Duration::from_secs(5);
// pretty progress
let n = self.street.prev().n_isomorphisms();
let tick = std::time::Duration::from_secs(1);
let style = "[{elapsed}] {spinner:.green} {wide_bar:.green} ETA {eta}";
let style = ProgressStyle::with_template(style).unwrap();
let progress = ProgressBar::new(n);
let progress = ProgressBar::new(n as u64);
progress.set_style(style);
progress.enable_steady_tick(tick);
//
ObservationSpace(
let points = ObservationSpace(
Observation::exhaust(self.street.prev())
.filter(|o| Isomorphism::is_canonical(o))
.map(|o| Isomorphism::from(o)) // isomorphism translation
Expand All @@ -143,53 +147,73 @@ impl Layer {
.map(|inner| (inner, self.lookup.projection(&inner)))
.inspect(|_| progress.inc(1))
.collect::<BTreeMap<Isomorphism, Histogram>>(),
)
);
//
progress.finish();
points
}

/// initializes the centroids for k-means clustering using the k-means++ algorithm
/// 1. choose 1st centroid randomly from the dataset
/// 2. choose nth centroid with probability proportional to squared distance of nearest neighbors
/// 3. collect histograms and label with arbitrary (random) `Abstraction`s
///
/// if this becomes a bottleneck with contention,
/// consider partitioning dataset or using lock-free data structures.
fn initial(&mut self) {
log::info!("initializing kmeans {}", self.street);
// pretty progress
let n = self.k() - 1;
let tick = std::time::Duration::from_secs(1);
let style = "[{elapsed}] {spinner:.green} {wide_bar:.green} ETA {eta}";
let style = ProgressStyle::with_template(style).unwrap();
let progress = ProgressBar::new(n as u64);
progress.set_style(style);
progress.enable_steady_tick(tick);
//
let ref mut rng = rand::rngs::StdRng::seed_from_u64(self.street as u64);
let histogram = self.sample_uniform(rng);
self.kmeans.expand(histogram);
while self.k() > self.l() {
log::info!("add initial {}", self.l());
let histogram = self.sample_outlier(rng);
self.kmeans.expand(histogram);
progress.inc(1);
}
//
progress.finish();
}
/// for however many iterations we want,
/// 1. assign each `Observation` to the nearest `Centroid`
/// 2. update each `Centroid` by averaging the `Observation`s assigned to it
///
/// if this becomes a bottleneck with contention,
/// consider partitioning dataset or using lock-free data structures.
fn cluster(&mut self) {
log::info!("clustering kmeans {}", self.street);
for i in 0..self.t() {
log::info!("computing abstractions {} {}", self.street, i);
// pretty progress
assert!(self.points.0.len() == self.street.n_isomorphisms());
let n = self.t() * self.points.0.len();
let tick = std::time::Duration::from_secs(1);
let style = "[{elapsed}] {spinner:.green} {wide_bar:.green} ETA {eta}";
let style = ProgressStyle::with_template(style).unwrap();
let progress = ProgressBar::new(n as u64);
progress.set_style(style);
progress.enable_steady_tick(tick);
//
for _ in 0..self.t() {
let abstractions = self
.points
.0
.par_iter()
.map(|(_, h)| self.nearest_neighbor(h))
.inspect(|_| progress.inc(1))
.collect::<Vec<Abstraction>>();
log::info!("assigning abstractions {} {}", self.street, i);
for ((o, h), a) in self.points.0.iter_mut().zip(abstractions.iter()) {
self.lookup.assign(a, o);
self.kmeans.absorb(a, h);
for ((observation, histogram), abstraction) in
self.points.0.iter_mut().zip(abstractions.iter())
{
self.lookup.assign(abstraction, observation);
self.kmeans.absorb(abstraction, histogram);
}
log::info!("resetting abstractions {} {}", self.street, i);
for (_, centroid) in self.kmeans.0.iter_mut() {
centroid.rotate();
centroid.reset();
}
}
//
progress.finish();
}

/// the first Centroid is uniformly random across all `Observation` `Histogram`s
Expand Down Expand Up @@ -249,15 +273,15 @@ impl Layer {
/// hyperparameter: how many centroids to learn
fn k(&self) -> usize {
match self.street {
Street::Turn => 128,
Street::Flop => 128,
_ => unreachable!("how did you get here"),
Street::Turn => 64,
Street::Flop => 64,
_ => unreachable!("no other abstractable streets"),
}
}
/// hyperparameter: how many iterations to run kmeans
fn t(&self) -> usize {
match self.street {
_ => 100,
_ => 10,
}
}
/// length of current kmeans centroids
Expand All @@ -267,9 +291,8 @@ impl Layer {

/// save the current layer's `Metric` and `Abstractor` to disk
fn save(self) -> Self {
let path = format!("{}.abstraction.pgcopy", self.street);
self.metric.save(path.clone());
self.lookup.save(path.clone());
self.metric.save(format!("{}", self.street));
self.lookup.save(format!("{}", self.street));
self
}
}

0 comments on commit 020d83c

Please sign in to comment.