Skip to content

Commit

Permalink
Improve float_samplers implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Wright committed Nov 24, 2022
1 parent 3fbd250 commit 6097b11
Showing 1 changed file with 166 additions and 79 deletions.
245 changes: 166 additions & 79 deletions proptest/src/num/float_samplers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,15 @@
// option. This file may not be copied, modified, or distributed
// except according to those terms.

//! Alternative uniform float samplers because the ones provided by the rand crate are prone
//! to overflow. The samplers work by uniformly selecting from a set of equally spaced values in
//! the interval and the included bounds. Selection is slightly biased towards the bounds.
//! Alternative uniform float samplers.
//! These samplers are used over the ones from `rand` because the ones provided by the
//! rand crate are prone to overflow. In addition, these are 'high precision' samplers
//! that are more appropriate for test data.
//! The samplers work by splitting the range into equally sized intervals and selecting
//! an iterval at random. That interval is then itself split and a new interval is
//! selected at random. The process repeats until the interval only contains two
//! floating point values at the bounds. At that stage, one is selected at random and
//! returned.
pub(crate) use self::f32::F32U;
pub(crate) use self::f64::F64U;
Expand All @@ -19,7 +25,7 @@ macro_rules! float_sampler {
pub mod $typ {
use rand::prelude::*;
use rand::distributions::uniform::{
SampleBorrow, SampleUniform, Uniform, UniformSampler,
SampleBorrow, SampleUniform, UniformSampler,
};

#[must_use]
Expand Down Expand Up @@ -61,8 +67,10 @@ macro_rules! float_sampler {

#[derive(Clone, Copy, Debug)]
pub(crate) struct FloatUniform {
uniform: Uniform<$int_typ>,
values: SampleValueCollection,
low: $typ,
high: $typ,
intervals: IntervalCollection,
inclusive: bool,
}

impl UniformSampler for FloatUniform {
Expand All @@ -76,12 +84,11 @@ macro_rules! float_sampler {
{
let low = low.borrow().0;
let high = high.borrow().0;

let values = SampleValueCollection::new_inclusive(low, next_down(high));

FloatUniform {
uniform: Uniform::new(0, values.count),
values,
low,
high,
intervals: split_interval([low, high]),
inclusive: false,
}
}

Expand All @@ -93,51 +100,71 @@ macro_rules! float_sampler {
let low = low.borrow().0;
let high = high.borrow().0;

let values = SampleValueCollection::new_inclusive(low, high);

FloatUniform {
uniform: Uniform::new(0, values.count),
values,
low,
high,
intervals: split_interval([low, high]),
inclusive: true,
}
}

fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X {
$wrapper(self.values.get(self.uniform.sample(rng)))
let mut intervals = self.intervals;
while intervals.count > 1 {
let new_interval = intervals.get(rng.gen_range(0..intervals.count));
intervals = split_interval(new_interval);
}
let last = intervals.get(0);
let result = *last.choose(rng).expect("Slice is not empty");

// These results could happen because the first split might
// overshoot one of the bounds. We could resample in this
// case but for testing data this is not a problem.
let clamped_result = if result < self.low {
debug_assert!(self.low - result < self.intervals.step);
self.low
} else if result > self.high{
debug_assert!(result - self.high < self.intervals.step);
self.high
} else {
result
};

if !self.inclusive && clamped_result == self.high {
return $wrapper(next_down(self.high));
};

$wrapper(clamped_result)
}
}

impl SampleUniform for $wrapper {
type Sampler = FloatUniform;
}

// Divides the range [low, high] into intervals of size epsilon * max(abs(low, high));
// Note that the one interval may extend out of the range.
#[derive(Clone, Copy, Debug)]
struct SampleValueCollection {
struct IntervalCollection {
start: $typ,
end: $typ,
step: $typ,
count: $int_typ,
}

// Values greater than MAX_PRECISE_INT may be rounded when converted to float.
const MAX_PRECISE_INT: $int_typ =
(2 as $int_typ).pow($typ::MANTISSA_DIGITS);

// The collection of sample values that may be generated by UniformF32U.
impl SampleValueCollection {
fn new_inclusive(low: $typ, high: $typ) -> Self {
fn split_interval([low, high]: [$typ; 2]) -> IntervalCollection {
assert!(low.is_finite(), "low finite");
assert!(high.is_finite(), "high finite");
assert!(high - low >= 0., "invalid range");
assert!(high - low > 0., "invalid range");

let min_abs = $typ::min(low.abs(), high.abs());
let max_abs = $typ::max(low.abs(), high.abs());

let gap = ulp(max_abs);

let (start, end, step) = if low.abs() < high.abs() {
(high, low, -gap)
let (start, step) = if low.abs() < high.abs() {
(high, -gap)
} else {
(low, high, gap)
(low, gap)
};

let min_gaps = min_abs / gap;
Expand All @@ -150,37 +177,54 @@ macro_rules! float_sampler {
let count = if low.signum() == high.signum() {
max_gaps as $int_typ - min_gaps.floor() as $int_typ
} else {
max_gaps as $int_typ + min_gaps.ceil() as $int_typ
} + 1;
// `step` is a power of two so `min_gaps` won't be rounded
// except possibly to 0.
if min_gaps == 0. && min_abs > 0. {
max_gaps as $int_typ + 1
} else {
max_gaps as $int_typ + min_gaps.ceil() as $int_typ
}
};

debug_assert!(count - 1 <= 2 * MAX_PRECISE_INT);

Self {
IntervalCollection {
start,
end,
step,
count,
}
}
}

fn get(&self, index: $int_typ) -> $typ {
assert!(index < self.count, "index out of bounds");

if index == self.count - 1 {
return self.end;
}
impl IntervalCollection {
fn get(&self, index: $int_typ) -> [$typ; 2] {
assert!(index < self.count, "index out of bounds");

// `index` might be greater that `MAX_PERCISE_INT` which means
// `index as $typ` could round to a different integer and
// `index as $typ + self.start` would have a rounding error.
// Fortunately, `index` will never be larger than `2 * MAX_PRECISE_INT`
// (as asserted above) so the expression below will be free of rounding.
((index / 2) as $typ).mul_add(
// `index` might be greater that `MAX_PERCISE_INT`
// which means `MAX_PRECIST_INT as $typ` would round
// to a different number. Fortunately, `index` will
// never be larger than `2 * MAX_PRECISE_INT` (as
// asserted above).
let x = ((index / 2) as $typ).mul_add(
2. * self.step,
(index % 2) as $typ * self.step + self.start,
)
);

let y = x + self.step;

if self.step > 0. {
[x, y]
} else {
[y, x]
}
}
}


// Values greater than MAX_PRECISE_INT may be rounded when converted to float.
const MAX_PRECISE_INT: $int_typ =
(2 as $int_typ).pow($typ::MANTISSA_DIGITS);

#[cfg(test)]
mod test {

Expand All @@ -204,7 +248,9 @@ macro_rules! float_sampler {
}

fn bounds() -> impl Strategy<Value = ($typ, $typ)> {
(finite(), finite()).prop_map(sort)
(finite(), finite())
.prop_filter("Bounds can't be equal", |(a, b)| a != b)
.prop_map(sort)
}

#[test]
Expand Down Expand Up @@ -291,13 +337,6 @@ macro_rules! float_sampler {
assert!(samples.any(|x| x != 0.));
}

#[test]
// We treat [-0., 0.] as [0., 0.] since the distance between -0. and 0. is 0.
fn zero_sample_values() {
let values = SampleValueCollection::new_inclusive(-0., 0.);
assert_eq!((values.count, values.get(0)), (1, 0.));
}

#[test]
fn max_precise_int_plus_one_is_rounded_down() {
assert_eq!(((MAX_PRECISE_INT + 1) as $typ) as $int_typ, MAX_PRECISE_INT);
Expand All @@ -324,45 +363,93 @@ macro_rules! float_sampler {
}

#[test]
fn single_value_interval(value: $typ) {
let values = SampleValueCollection::new_inclusive(value, value);
prop_assert_eq!((values.count, values.get(0)), (1, value));
fn indivisible_intervals_are_split_to_self(val in finite()) {
prop_assume!(val > $typ::MIN);
let prev = next_down(val);
let intervals = split_interval([prev, val]);
prop_assert_eq!(intervals.count, 1);
}

#[test]
fn incl_low_and_high_are_start_and_end((low, high) in bounds()) {
let values = SampleValueCollection::new_inclusive(low, high);

let count = values.count;

let bounds = (values.get(0), values.get(count - 1));
prop_assert_eq!(sort(bounds), (low, high));
}
fn split_intervals_are_the_same_size(
(low, high) in bounds(),
indices: [prop::sample::Index; 32]) {

#[test]
fn values_excluding_end_are_equally_spaced(
(low, high) in bounds(), indices: [prop::sample::Index; 32]) {
let values = SampleValueCollection::new_inclusive(low, high);
let intervals = split_interval([low, high]);

let size = (values.count - 1) as usize;
let size = (intervals.count - 1) as usize;
prop_assume!(size > 0);

let all_equal = indices.iter()
let mut it = indices.iter()
.map(|i| i.index(size) as $int_typ)
.map(|i| values.get(i + 1) - values.get(i))
.all(|g| g == values.step);
.map(|i| intervals.get(i))
.map(|[low, high]| high - low);

let interval_size = it.next().unwrap();
let all_equal = it.all(|g| g == interval_size);
prop_assert!(all_equal);
}

#[test]
fn end_gap_smaller_but_positive((low, high) in bounds()) {
let values = SampleValueCollection::new_inclusive(low, high);
fn split_intervals_are_consecutive(
(low, high) in bounds(),
indices: [prop::sample::Index; 32]) {

let intervals = split_interval([low, high]);

let size = (intervals.count - 1) as usize;
prop_assume!(size > 1);

let mut it = indices.iter()
.map(|i| i.index(size - 1) as $int_typ)
.map(|i| (intervals.get(i), intervals.get(i + 1)));

let n = values.count;
prop_assume!(n > 1);
let ascending = it.all(|([_, h1], [l2, _])| h1 == l2);
let descending = it.all(|([l1, _], [_, h2])| l1 == h2);

let gap = (values.get(n - 1) - values.get(n - 2)).abs();
prop_assert!(0. < gap && gap <= values.step.abs());
prop_assert!(ascending || descending);
}

#[test]
fn first_split_might_slightly_overshoot_one_bound((low, high) in bounds()) {
let intervals = split_interval([low, high]);
let start = intervals.get(0);
let end = intervals.get(intervals.count - 1);
let (low_interval, high_interval) = if start[0] < end[0] {
(start, end)
} else {
(end, start)
};

prop_assert!(
low == low_interval[0] && high_interval[0] < high && high <= high_interval[1] ||
low_interval[0] <= low && low < low_interval[1] && high == high_interval[1]);
}

#[test]
fn subsequent_splits_always_match_bounds(
(low, high) in bounds(),
index: prop::sample::Index) {
// This property is true because the distances of split intervals of
// are powers of two so the smaller one always divides the larger.

let intervals = split_interval([low, high]);
let size = (intervals.count - 1) as usize;

let interval = intervals.get(index.index(size) as $int_typ);
let small_intervals = split_interval(interval);

let start = small_intervals.get(0);
let end = small_intervals.get(small_intervals.count - 1);
let (low_interval, high_interval) = if start[0] < end[0] {
(start, end)
} else {
(end, start)
};

prop_assert!(
interval[0] == low_interval[0] &&
interval[1] == high_interval[1]);
}
}
}
Expand Down

0 comments on commit 6097b11

Please sign in to comment.