From 6097b11e8bcf4b3ebb16e856a3e2d163d716d649 Mon Sep 17 00:00:00 2001 From: Michael Wright Date: Thu, 24 Nov 2022 07:51:19 +0200 Subject: [PATCH] Improve float_samplers implementation --- proptest/src/num/float_samplers.rs | 245 +++++++++++++++++++---------- 1 file changed, 166 insertions(+), 79 deletions(-) diff --git a/proptest/src/num/float_samplers.rs b/proptest/src/num/float_samplers.rs index 39d59403..3b6180c7 100644 --- a/proptest/src/num/float_samplers.rs +++ b/proptest/src/num/float_samplers.rs @@ -7,9 +7,15 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -//! Alternative uniform float samplers because the ones provided by the rand crate are prone -//! to overflow. The samplers work by uniformly selecting from a set of equally spaced values in -//! the interval and the included bounds. Selection is slightly biased towards the bounds. +//! Alternative uniform float samplers. +//! These samplers are used over the ones from `rand` because the ones provided by the +//! rand crate are prone to overflow. In addition, these are 'high precision' samplers +//! that are more appropriate for test data. +//! The samplers work by splitting the range into equally sized intervals and selecting +//! an iterval at random. That interval is then itself split and a new interval is +//! selected at random. The process repeats until the interval only contains two +//! floating point values at the bounds. At that stage, one is selected at random and +//! returned. pub(crate) use self::f32::F32U; pub(crate) use self::f64::F64U; @@ -19,7 +25,7 @@ macro_rules! float_sampler { pub mod $typ { use rand::prelude::*; use rand::distributions::uniform::{ - SampleBorrow, SampleUniform, Uniform, UniformSampler, + SampleBorrow, SampleUniform, UniformSampler, }; #[must_use] @@ -61,8 +67,10 @@ macro_rules! float_sampler { #[derive(Clone, Copy, Debug)] pub(crate) struct FloatUniform { - uniform: Uniform<$int_typ>, - values: SampleValueCollection, + low: $typ, + high: $typ, + intervals: IntervalCollection, + inclusive: bool, } impl UniformSampler for FloatUniform { @@ -76,12 +84,11 @@ macro_rules! float_sampler { { let low = low.borrow().0; let high = high.borrow().0; - - let values = SampleValueCollection::new_inclusive(low, next_down(high)); - FloatUniform { - uniform: Uniform::new(0, values.count), - values, + low, + high, + intervals: split_interval([low, high]), + inclusive: false, } } @@ -93,16 +100,41 @@ macro_rules! float_sampler { let low = low.borrow().0; let high = high.borrow().0; - let values = SampleValueCollection::new_inclusive(low, high); - FloatUniform { - uniform: Uniform::new(0, values.count), - values, + low, + high, + intervals: split_interval([low, high]), + inclusive: true, } } fn sample(&self, rng: &mut R) -> Self::X { - $wrapper(self.values.get(self.uniform.sample(rng))) + let mut intervals = self.intervals; + while intervals.count > 1 { + let new_interval = intervals.get(rng.gen_range(0..intervals.count)); + intervals = split_interval(new_interval); + } + let last = intervals.get(0); + let result = *last.choose(rng).expect("Slice is not empty"); + + // These results could happen because the first split might + // overshoot one of the bounds. We could resample in this + // case but for testing data this is not a problem. + let clamped_result = if result < self.low { + debug_assert!(self.low - result < self.intervals.step); + self.low + } else if result > self.high{ + debug_assert!(result - self.high < self.intervals.step); + self.high + } else { + result + }; + + if !self.inclusive && clamped_result == self.high { + return $wrapper(next_down(self.high)); + }; + + $wrapper(clamped_result) } } @@ -110,34 +142,29 @@ macro_rules! float_sampler { type Sampler = FloatUniform; } + // Divides the range [low, high] into intervals of size epsilon * max(abs(low, high)); + // Note that the one interval may extend out of the range. #[derive(Clone, Copy, Debug)] - struct SampleValueCollection { + struct IntervalCollection { start: $typ, - end: $typ, step: $typ, count: $int_typ, } - // Values greater than MAX_PRECISE_INT may be rounded when converted to float. - const MAX_PRECISE_INT: $int_typ = - (2 as $int_typ).pow($typ::MANTISSA_DIGITS); - - // The collection of sample values that may be generated by UniformF32U. - impl SampleValueCollection { - fn new_inclusive(low: $typ, high: $typ) -> Self { + fn split_interval([low, high]: [$typ; 2]) -> IntervalCollection { assert!(low.is_finite(), "low finite"); assert!(high.is_finite(), "high finite"); - assert!(high - low >= 0., "invalid range"); + assert!(high - low > 0., "invalid range"); let min_abs = $typ::min(low.abs(), high.abs()); let max_abs = $typ::max(low.abs(), high.abs()); let gap = ulp(max_abs); - let (start, end, step) = if low.abs() < high.abs() { - (high, low, -gap) + let (start, step) = if low.abs() < high.abs() { + (high, -gap) } else { - (low, high, gap) + (low, gap) }; let min_gaps = min_abs / gap; @@ -150,37 +177,54 @@ macro_rules! float_sampler { let count = if low.signum() == high.signum() { max_gaps as $int_typ - min_gaps.floor() as $int_typ } else { - max_gaps as $int_typ + min_gaps.ceil() as $int_typ - } + 1; + // `step` is a power of two so `min_gaps` won't be rounded + // except possibly to 0. + if min_gaps == 0. && min_abs > 0. { + max_gaps as $int_typ + 1 + } else { + max_gaps as $int_typ + min_gaps.ceil() as $int_typ + } + }; + debug_assert!(count - 1 <= 2 * MAX_PRECISE_INT); - Self { + IntervalCollection { start, - end, step, count, } - } + } - fn get(&self, index: $int_typ) -> $typ { - assert!(index < self.count, "index out of bounds"); - if index == self.count - 1 { - return self.end; - } + impl IntervalCollection { + fn get(&self, index: $int_typ) -> [$typ; 2] { + assert!(index < self.count, "index out of bounds"); - // `index` might be greater that `MAX_PERCISE_INT` which means - // `index as $typ` could round to a different integer and - // `index as $typ + self.start` would have a rounding error. - // Fortunately, `index` will never be larger than `2 * MAX_PRECISE_INT` - // (as asserted above) so the expression below will be free of rounding. - ((index / 2) as $typ).mul_add( + // `index` might be greater that `MAX_PERCISE_INT` + // which means `MAX_PRECIST_INT as $typ` would round + // to a different number. Fortunately, `index` will + // never be larger than `2 * MAX_PRECISE_INT` (as + // asserted above). + let x = ((index / 2) as $typ).mul_add( 2. * self.step, (index % 2) as $typ * self.step + self.start, - ) + ); + + let y = x + self.step; + + if self.step > 0. { + [x, y] + } else { + [y, x] + } } } + + // Values greater than MAX_PRECISE_INT may be rounded when converted to float. + const MAX_PRECISE_INT: $int_typ = + (2 as $int_typ).pow($typ::MANTISSA_DIGITS); + #[cfg(test)] mod test { @@ -204,7 +248,9 @@ macro_rules! float_sampler { } fn bounds() -> impl Strategy { - (finite(), finite()).prop_map(sort) + (finite(), finite()) + .prop_filter("Bounds can't be equal", |(a, b)| a != b) + .prop_map(sort) } #[test] @@ -291,13 +337,6 @@ macro_rules! float_sampler { assert!(samples.any(|x| x != 0.)); } - #[test] - // We treat [-0., 0.] as [0., 0.] since the distance between -0. and 0. is 0. - fn zero_sample_values() { - let values = SampleValueCollection::new_inclusive(-0., 0.); - assert_eq!((values.count, values.get(0)), (1, 0.)); - } - #[test] fn max_precise_int_plus_one_is_rounded_down() { assert_eq!(((MAX_PRECISE_INT + 1) as $typ) as $int_typ, MAX_PRECISE_INT); @@ -324,45 +363,93 @@ macro_rules! float_sampler { } #[test] - fn single_value_interval(value: $typ) { - let values = SampleValueCollection::new_inclusive(value, value); - prop_assert_eq!((values.count, values.get(0)), (1, value)); + fn indivisible_intervals_are_split_to_self(val in finite()) { + prop_assume!(val > $typ::MIN); + let prev = next_down(val); + let intervals = split_interval([prev, val]); + prop_assert_eq!(intervals.count, 1); } #[test] - fn incl_low_and_high_are_start_and_end((low, high) in bounds()) { - let values = SampleValueCollection::new_inclusive(low, high); - - let count = values.count; - - let bounds = (values.get(0), values.get(count - 1)); - prop_assert_eq!(sort(bounds), (low, high)); - } + fn split_intervals_are_the_same_size( + (low, high) in bounds(), + indices: [prop::sample::Index; 32]) { - #[test] - fn values_excluding_end_are_equally_spaced( - (low, high) in bounds(), indices: [prop::sample::Index; 32]) { - let values = SampleValueCollection::new_inclusive(low, high); + let intervals = split_interval([low, high]); - let size = (values.count - 1) as usize; + let size = (intervals.count - 1) as usize; prop_assume!(size > 0); - let all_equal = indices.iter() + let mut it = indices.iter() .map(|i| i.index(size) as $int_typ) - .map(|i| values.get(i + 1) - values.get(i)) - .all(|g| g == values.step); + .map(|i| intervals.get(i)) + .map(|[low, high]| high - low); + + let interval_size = it.next().unwrap(); + let all_equal = it.all(|g| g == interval_size); prop_assert!(all_equal); } #[test] - fn end_gap_smaller_but_positive((low, high) in bounds()) { - let values = SampleValueCollection::new_inclusive(low, high); + fn split_intervals_are_consecutive( + (low, high) in bounds(), + indices: [prop::sample::Index; 32]) { + + let intervals = split_interval([low, high]); + + let size = (intervals.count - 1) as usize; + prop_assume!(size > 1); + + let mut it = indices.iter() + .map(|i| i.index(size - 1) as $int_typ) + .map(|i| (intervals.get(i), intervals.get(i + 1))); - let n = values.count; - prop_assume!(n > 1); + let ascending = it.all(|([_, h1], [l2, _])| h1 == l2); + let descending = it.all(|([l1, _], [_, h2])| l1 == h2); - let gap = (values.get(n - 1) - values.get(n - 2)).abs(); - prop_assert!(0. < gap && gap <= values.step.abs()); + prop_assert!(ascending || descending); + } + + #[test] + fn first_split_might_slightly_overshoot_one_bound((low, high) in bounds()) { + let intervals = split_interval([low, high]); + let start = intervals.get(0); + let end = intervals.get(intervals.count - 1); + let (low_interval, high_interval) = if start[0] < end[0] { + (start, end) + } else { + (end, start) + }; + + prop_assert!( + low == low_interval[0] && high_interval[0] < high && high <= high_interval[1] || + low_interval[0] <= low && low < low_interval[1] && high == high_interval[1]); + } + + #[test] + fn subsequent_splits_always_match_bounds( + (low, high) in bounds(), + index: prop::sample::Index) { + // This property is true because the distances of split intervals of + // are powers of two so the smaller one always divides the larger. + + let intervals = split_interval([low, high]); + let size = (intervals.count - 1) as usize; + + let interval = intervals.get(index.index(size) as $int_typ); + let small_intervals = split_interval(interval); + + let start = small_intervals.get(0); + let end = small_intervals.get(small_intervals.count - 1); + let (low_interval, high_interval) = if start[0] < end[0] { + (start, end) + } else { + (end, start) + }; + + prop_assert!( + interval[0] == low_interval[0] && + interval[1] == high_interval[1]); } } }