Skip to content

Commit

Permalink
sample-replace empty drifting centroid abstractions
Browse files Browse the repository at this point in the history
  • Loading branch information
krukah committed Oct 16, 2024
1 parent 09b24c2 commit 6f83b8b
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 45 deletions.
11 changes: 8 additions & 3 deletions src/clustering/abstraction.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::cards::hole::Hole;
use crate::Probability;
use std::hash::Hash;
use std::u64;
Expand All @@ -9,8 +10,9 @@ use std::u64;
/// - Other Streets: we use a u64 to represent the hash signature of the centroid Histogram over lower layers of abstraction.
#[derive(Copy, Clone, Hash, Eq, PartialEq, Debug, PartialOrd, Ord)]
pub enum Abstraction {
Random(u64),
Equity(i8),
Equity(i8), // river
Random(u64), // flop, turn
Pocket(Hole), // preflop
}

impl Abstraction {
Expand Down Expand Up @@ -54,6 +56,7 @@ impl From<Abstraction> for Probability {
match abstraction {
Abstraction::Equity(n) => Abstraction::floatize(n),
Abstraction::Random(_) => unreachable!("no cluster into probability"),
Abstraction::Pocket(_) => unreachable!("no preflop into probability"),
}
}
}
Expand All @@ -66,6 +69,7 @@ impl From<Abstraction> for u64 {
match a {
Abstraction::Random(n) => n,
Abstraction::Equity(_) => unreachable!("no equity into u64"),
Abstraction::Pocket(_) => unreachable!("no preflop into u64"),
}
}
}
Expand Down Expand Up @@ -93,7 +97,8 @@ impl std::fmt::Display for Abstraction {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Random(n) => write!(f, "{:016x}", n),
Self::Equity(_) => unreachable!("don't log me"),
Self::Equity(n) => write!(f, "unreachable ? Equity({})", n),
Self::Pocket(h) => write!(f, "unreachable ? Pocket({})", h),
}
}
}
Expand Down
6 changes: 3 additions & 3 deletions src/clustering/abstractor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ impl Abstractor {
pub fn projection(&self, inner: &Isomorphism) -> Histogram {
let inner = Observation::from(*inner); // isomorphism translation
match inner.street() {
Street::Turn => inner.clone().into(),
Street::Turn => inner.clone().into(), // Histogram::from<Observation>
_ => inner
.children()
.map(|outer| Isomorphism::from(outer)) // isomorphism translation
.map(|ref outer| self.abstraction(outer))
.map(|outer| self.abstraction(&outer))
.collect::<Vec<Abstraction>>()
.into(),
.into(), // Histogram::from<Vec<Abstraction>>
}
}
/// lookup the pre-computed abstraction for the outer observation
Expand Down
18 changes: 13 additions & 5 deletions src/clustering/centroid.rs
Original file line number Diff line number Diff line change
@@ -1,31 +1,39 @@
use crate::clustering::histogram::Histogram;

/// TODO this is now a full shallow wrapper around a Histogram
/// originaly i thought we shoud separate the last and next Histograms
/// but then the mutation loop changed such that it's not necessary
///
/// `Centroid` is a wrapper around two histograms.
/// We use it to swap the current and next histograms
/// after each iteration of kmeans clustering.
pub struct Centroid {
last: Histogram,
next: Histogram,
// next: Histogram,
}

impl Centroid {
pub fn reset(&mut self) {
self.last.destroy();
std::mem::swap(&mut self.last, &mut self.next);
// std::mem::swap(&mut self.last, &mut self.next);
}
pub fn absorb(&mut self, h: &Histogram) {
self.next.absorb(h);
self.last.absorb(h);
// self.next.absorb(h);
}
pub fn reveal(&self) -> &Histogram {
pub fn histogram(&self) -> &Histogram {
&self.last
}
pub fn is_empty(&self) -> bool {
self.last.is_empty()
}
}

impl From<Histogram> for Centroid {
fn from(h: Histogram) -> Self {
Self {
last: h,
next: Histogram::default(),
// next: Histogram::default(),
}
}
}
15 changes: 15 additions & 0 deletions src/clustering/datasets.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,19 @@ impl AbstractionSpace {
.expect("abstraction generated during initialization")
.absorb(histogram);
}

pub fn orphans(&self) -> Vec<Abstraction> {
self.0
.iter()
.filter(|(_, c)| c.is_empty())
.map(|(a, _)| a)
.cloned()
.collect::<Vec<Abstraction>>()
}

pub fn clear(&mut self) {
for (_, centroid) in self.0.iter_mut() {
centroid.reset();
}
}
}
24 changes: 12 additions & 12 deletions src/clustering/histogram.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ impl Histogram {
self.weights.keys().collect()
}

/// useful only for k-means edge case of centroid drift
pub fn is_empty(&self) -> bool {
self.weights.is_empty()
// self.norm == 0
}

/// insert the Abstraction into our support,
/// incrementing its local weight,
/// incrementing our global norm.
Expand Down Expand Up @@ -70,10 +76,7 @@ impl Histogram {
/// Abstraction variants, so we expose this method to
/// infer the type of Abstraction contained by this Histogram.
pub fn peek(&self) -> &Abstraction {
self.weights
.keys()
.next()
.expect("non empty histogram, consistent abstraction variant")
self.weights.keys().next().expect("non empty histogram")
}

/// exhaustive calculation of all
Expand All @@ -83,11 +86,8 @@ impl Histogram {
/// ONLY WORKS FOR STREET::TURN
/// ONLY WORKS FOR STREET::TURN
pub fn equity(&self) -> Equity {
assert!(matches!(
self.weights.keys().next(),
Some(Abstraction::Equity(_))
));
self.posterior().iter().map(|(x, y)| x * y).sum()
assert!(matches!(self.peek(), Abstraction::Equity(_)));
self.distribution().iter().map(|(x, y)| x * y).sum()
}

/// this yields the posterior equity distribution
Expand All @@ -100,7 +100,7 @@ impl Histogram {
///
/// ONLY WORKS FOR STREET::TURN
/// ONLY WORKS FOR STREET::TURN
pub fn posterior(&self) -> Vec<(Equity, Probability)> {
pub fn distribution(&self) -> Vec<(Equity, Probability)> {
assert!(matches!(self.peek(), Abstraction::Equity(_)));
self.weights
.iter()
Expand All @@ -114,7 +114,7 @@ impl From<Observation> for Histogram {
fn from(ref turn: Observation) -> Self {
assert!(turn.street() == crate::cards::street::Street::Turn);
Self::from(
turn.children() //? iso
turn.children()
.map(|river| Abstraction::from(river.equity()))
.collect::<Vec<Abstraction>>(),
)
Expand All @@ -132,7 +132,7 @@ impl std::fmt::Display for Histogram {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// 1. interpret each key of the Histogram as probability
// 2. they should already be sorted bc BTreeMap
let ref distribution = self.posterior();
let ref distribution = self.distribution();
// 3. Create 32 bins for the x-axis
let n_x_bins = 32;
let ref mut bins = vec![0.0; n_x_bins];
Expand Down
61 changes: 39 additions & 22 deletions src/clustering/layer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@ use std::collections::BTreeMap;
/// - CPU: O(N) for kmeans clustering
/// - RAM: O(N^2) for learned metric
/// - RAM: O(N) for learned centroids
const N_KMEANS_CENTROIDS: usize = 64;
const N_KMEANS_CENTROIDS: usize = 8;

/// number of kmeans iterations.
/// this controls the precision of the abstraction space.
///
/// - CPU: O(N) for kmeans clustering
const N_KMEANS_ITERATION: usize = 10;
const N_KMEANS_ITERATION: usize = 8;

/// Hierarchical K Means Learner.
/// this is decomposed into the necessary data structures
Expand Down Expand Up @@ -87,14 +87,18 @@ impl Layer {
/// - `points`: not used for inward projection. only used for clustering. and no clustering on River.
fn outer() -> Self {
Self {
street: Street::Rive,
metric: Metric::default(),
lookup: Abstractor::default(),
kmeans: AbstractionSpace::default(),
points: ObservationSpace::default(),
metric: Metric::default(),
street: Street::Rive,
}
}
/// hierarchically, recursively generate the inner layer
/// 0. initialize empty lookup table and kmeans centroids
/// 1. generate Street, Metric, and Points as a pure function of the outer Layer
/// 2. initialize kmeans centroids with weighted random Observation sampling (kmeans++ for faster convergence)
/// 3. cluster kmeans centroids
fn inner(&self) -> Self {
let mut layer = Self {
lookup: Abstractor::default(), // assigned during clustering
Expand Down Expand Up @@ -134,8 +138,8 @@ impl Layer {
for b in self.kmeans.0.keys() {
if a > b {
let index = Pair::from((a, b));
let x = self.kmeans.0.get(a).expect("pre-computed").reveal();
let y = self.kmeans.0.get(b).expect("pre-computed").reveal();
let x = self.kmeans.0.get(a).expect("pre-computed").histogram();
let y = self.kmeans.0.get(b).expect("pre-computed").histogram();
let distance = self.metric.emd(x, y) + self.metric.emd(y, x);
let distance = distance / 2.0;
metric.insert(index, distance);
Expand Down Expand Up @@ -167,6 +171,7 @@ impl Layer {
.inspect(|_| progress.inc(1))
.collect::<BTreeMap<Isomorphism, Histogram>>();
progress.finish();
log::info!("completed point projections {}", projection.len());
ObservationSpace(projection)
}

Expand All @@ -177,41 +182,53 @@ impl Layer {
fn initial_kmeans(&mut self) {
log::info!("initializing kmeans {}", self.street);
let progress = Self::progress(N_KMEANS_CENTROIDS - 1);
let ref mut rng = rand::rngs::StdRng::seed_from_u64(self.street as u64);
let histogram = self.sample_uniform(rng);
self.kmeans.expand(histogram);
let ref mut rng = rand::rngs::StdRng::seed_from_u64(self.street as u64 + 0xBAD);
let sample = self.sample_uniform(rng);
self.kmeans.expand(sample);
while self.kmeans.0.len() < N_KMEANS_CENTROIDS {
let histogram = self.sample_outlier(rng);
self.kmeans.expand(histogram);
let sample = self.sample_outlier(rng);
self.kmeans.expand(sample);
progress.inc(1);
}
progress.finish();
log::info!("completed kmeans initialization {}", self.kmeans.0.len());
}
/// for however many iterations we want,
/// 1. assign each `Observation` to the nearest `Centroid`
/// 2. update each `Centroid` by averaging the `Observation`s assigned to it
fn cluster_kmeans(&mut self) {
log::info!("clustering kmeans {}", self.street);
let ref mut rng = rand::rngs::StdRng::seed_from_u64(self.street as u64 + 0xADD);
let progress = Self::progress(N_KMEANS_ITERATION * self.points.0.len());
for _ in 0..N_KMEANS_ITERATION {
let abstractions = self
// calculate nearest neighbor Abstractions for each Observation
// each nearest neighbor calculation is O(k^2)
// there are k of them for each N observations
let neighbors = self
.points
.0
.par_iter()
.map(|(_, h)| self.nearest_neighbor(h))
.inspect(|_| progress.inc(1))
.collect::<Vec<Abstraction>>();
for ((observation, histogram), abstraction) in
self.points.0.iter_mut().zip(abstractions.iter())
{
self.lookup.assign(abstraction, observation);
self.kmeans.absorb(abstraction, histogram);
// clear centroids before absorbtion
// assign new neighbor Abstractions to each Observation
// absorb Histograms into each Centroid
self.kmeans.clear();
for ((o, h), a) in std::iter::zip(self.points.0.iter_mut(), neighbors.iter()) {
self.lookup.assign(a, o);
self.kmeans.absorb(a, h);
}
for (_, centroid) in self.kmeans.0.iter_mut() {
centroid.reset();
// centroid drift may make it such that some centroids are empty
// reinitialize empty centroids with random Observations if necessary
for ref a in self.kmeans.orphans() {
log::info!("reassinging drifting empty centroid {}", a);
let ref sample = self.sample_uniform(rng);
self.kmeans.absorb(a, sample);
}
}
progress.finish();
log::info!("completed kmeans clustering {}", self.kmeans.0.len());
}

/// the first Centroid is uniformly random across all `Observation` `Histogram`s
Expand Down Expand Up @@ -249,7 +266,7 @@ impl Layer {
self.kmeans
.0
.par_iter()
.map(|(_, centroid)| centroid.reveal())
.map(|(_, centroid)| centroid.histogram())
.map(|centroid| self.metric.emd(histogram, centroid))
.map(|min| min * min)
.min_by(|dx, dy| dx.partial_cmp(dy).unwrap())
Expand All @@ -260,7 +277,7 @@ impl Layer {
self.kmeans
.0
.par_iter()
.map(|(abs, centroid)| (abs, centroid.reveal()))
.map(|(abs, centroid)| (abs, centroid.histogram()))
.map(|(abs, centroid)| (abs, self.metric.emd(histogram, centroid)))
.min_by(|(_, dx), (_, dy)| dx.partial_cmp(dy).unwrap())
.expect("find nearest neighbor")
Expand All @@ -277,7 +294,7 @@ impl Layer {

fn progress(n: usize) -> indicatif::ProgressBar {
let tick = std::time::Duration::from_secs(1);
let style = "[{elapsed}] {spinner:.green} {wide_bar:.green} ETA {eta}";
let style = "[{elapsed}] {spinner} {wide_bar} ETA {eta}";
let style = indicatif::ProgressStyle::with_template(style).unwrap();
let progress = indicatif::ProgressBar::new(n as u64);
progress.set_style(style);
Expand Down
1 change: 1 addition & 0 deletions src/clustering/metric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ impl Metric {
match target.peek() {
Abstraction::Equity(_) => Self::difference(source, target),
Abstraction::Random(_) => self.wasserstein(source, target),
Abstraction::Pocket(_) => unreachable!("no preflop emd"),
}
}

Expand Down

0 comments on commit 6f83b8b

Please sign in to comment.