From 00444bab2602ca487be08d5e2eaa6179833333b8 Mon Sep 17 00:00:00 2001 From: Urgau Date: Sat, 26 Oct 2024 18:44:05 +0200 Subject: [PATCH 1/2] Round negative signed integer towards zero in `iN::midpoint` Instead of towards negative infinity as is currently the case. This done so that the obvious expectations of `midpoint(a, b) == midpoint(b, a)` and `midpoint(-a, -b) == -midpoint(a, b)` are true, which makes the even more obvious implementation `(a + b) / 2` true. /~https://github.com/rust-lang/rust/issues/110840#issuecomment-2336753931 --- library/core/src/num/int_macros.rs | 38 ---------------- library/core/src/num/mod.rs | 65 ++++++++++++++++++++++++++++ library/core/tests/num/int_macros.rs | 4 +- 3 files changed, 67 insertions(+), 40 deletions(-) diff --git a/library/core/src/num/int_macros.rs b/library/core/src/num/int_macros.rs index 1d640ea74c4a8..01ecaf2710ff6 100644 --- a/library/core/src/num/int_macros.rs +++ b/library/core/src/num/int_macros.rs @@ -3181,44 +3181,6 @@ macro_rules! int_impl { } } - /// Calculates the middle point of `self` and `rhs`. - /// - /// `midpoint(a, b)` is `(a + b) >> 1` as if it were performed in a - /// sufficiently-large signed integral type. This implies that the result is - /// always rounded towards negative infinity and that no overflow will ever occur. - /// - /// # Examples - /// - /// ``` - /// #![feature(num_midpoint)] - #[doc = concat!("assert_eq!(0", stringify!($SelfT), ".midpoint(4), 2);")] - #[doc = concat!("assert_eq!(0", stringify!($SelfT), ".midpoint(-1), -1);")] - #[doc = concat!("assert_eq!((-1", stringify!($SelfT), ").midpoint(0), -1);")] - /// ``` - #[unstable(feature = "num_midpoint", issue = "110840")] - #[rustc_const_unstable(feature = "const_num_midpoint", issue = "110840")] - #[rustc_allow_const_fn_unstable(const_num_midpoint)] - #[must_use = "this returns the result of the operation, \ - without modifying the original"] - #[inline] - pub const fn midpoint(self, rhs: Self) -> Self { - const U: $UnsignedT = <$SelfT>::MIN.unsigned_abs(); - - // Map an $SelfT to an $UnsignedT - // ex: i8 [-128; 127] to [0; 255] - const fn map(a: $SelfT) -> $UnsignedT { - (a as $UnsignedT) ^ U - } - - // Map an $UnsignedT to an $SelfT - // ex: u8 [0; 255] to [-128; 127] - const fn demap(a: $UnsignedT) -> $SelfT { - (a ^ U) as $SelfT - } - - demap(<$UnsignedT>::midpoint(map(self), map(rhs))) - } - /// Returns the logarithm of the number with respect to an arbitrary base, /// rounded down. /// diff --git a/library/core/src/num/mod.rs b/library/core/src/num/mod.rs index f95cfd33ae5d2..9a5e211dd6087 100644 --- a/library/core/src/num/mod.rs +++ b/library/core/src/num/mod.rs @@ -124,6 +124,37 @@ macro_rules! midpoint_impl { ((self ^ rhs) >> 1) + (self & rhs) } }; + ($SelfT:ty, signed) => { + /// Calculates the middle point of `self` and `rhs`. + /// + /// `midpoint(a, b)` is `(a + b) / 2` as if it were performed in a + /// sufficiently-large signed integral type. This implies that the result is + /// always rounded towards zero and that no overflow will ever occur. + /// + /// # Examples + /// + /// ``` + /// #![feature(num_midpoint)] + #[doc = concat!("assert_eq!(0", stringify!($SelfT), ".midpoint(4), 2);")] + #[doc = concat!("assert_eq!((-1", stringify!($SelfT), ").midpoint(2), 0);")] + #[doc = concat!("assert_eq!((-7", stringify!($SelfT), ").midpoint(0), -3);")] + #[doc = concat!("assert_eq!(0", stringify!($SelfT), ".midpoint(-7), -3);")] + #[doc = concat!("assert_eq!(0", stringify!($SelfT), ".midpoint(7), 3);")] + /// ``` + #[unstable(feature = "num_midpoint", issue = "110840")] + #[rustc_const_unstable(feature = "const_num_midpoint", issue = "110840")] + #[must_use = "this returns the result of the operation, \ + without modifying the original"] + #[inline] + pub const fn midpoint(self, rhs: Self) -> Self { + // Use the well known branchless algorithm from Hacker's Delight to compute + // `(a + b) / 2` without overflowing: `((a ^ b) >> 1) + (a & b)`. + let t = ((self ^ rhs) >> 1) + (self & rhs); + // Except that it fails for integers whose sum is an odd negative number as + // their floor is one less than their average. So we adjust the result. + t + (if t < 0 { 1 } else { 0 } & (self ^ rhs)) + } + }; ($SelfT:ty, $WideT:ty, unsigned) => { /// Calculates the middle point of `self` and `rhs`. /// @@ -147,6 +178,32 @@ macro_rules! midpoint_impl { ((self as $WideT + rhs as $WideT) / 2) as $SelfT } }; + ($SelfT:ty, $WideT:ty, signed) => { + /// Calculates the middle point of `self` and `rhs`. + /// + /// `midpoint(a, b)` is `(a + b) / 2` as if it were performed in a + /// sufficiently-large signed integral type. This implies that the result is + /// always rounded towards zero and that no overflow will ever occur. + /// + /// # Examples + /// + /// ``` + /// #![feature(num_midpoint)] + #[doc = concat!("assert_eq!(0", stringify!($SelfT), ".midpoint(4), 2);")] + #[doc = concat!("assert_eq!((-1", stringify!($SelfT), ").midpoint(2), 0);")] + #[doc = concat!("assert_eq!((-7", stringify!($SelfT), ").midpoint(0), -3);")] + #[doc = concat!("assert_eq!(0", stringify!($SelfT), ".midpoint(-7), -3);")] + #[doc = concat!("assert_eq!(0", stringify!($SelfT), ".midpoint(7), 3);")] + /// ``` + #[unstable(feature = "num_midpoint", issue = "110840")] + #[rustc_const_unstable(feature = "const_num_midpoint", issue = "110840")] + #[must_use = "this returns the result of the operation, \ + without modifying the original"] + #[inline] + pub const fn midpoint(self, rhs: $SelfT) -> $SelfT { + ((self as $WideT + rhs as $WideT) / 2) as $SelfT + } + }; } macro_rules! widening_impl { @@ -300,6 +357,7 @@ impl i8 { from_xe_bytes_doc = "", bound_condition = "", } + midpoint_impl! { i8, i16, signed } } impl i16 { @@ -323,6 +381,7 @@ impl i16 { from_xe_bytes_doc = "", bound_condition = "", } + midpoint_impl! { i16, i32, signed } } impl i32 { @@ -346,6 +405,7 @@ impl i32 { from_xe_bytes_doc = "", bound_condition = "", } + midpoint_impl! { i32, i64, signed } } impl i64 { @@ -369,6 +429,7 @@ impl i64 { from_xe_bytes_doc = "", bound_condition = "", } + midpoint_impl! { i64, i128, signed } } impl i128 { @@ -394,6 +455,7 @@ impl i128 { from_xe_bytes_doc = "", bound_condition = "", } + midpoint_impl! { i128, signed } } #[cfg(target_pointer_width = "16")] @@ -418,6 +480,7 @@ impl isize { from_xe_bytes_doc = usize_isize_from_xe_bytes_doc!(), bound_condition = " on 16-bit targets", } + midpoint_impl! { isize, i32, signed } } #[cfg(target_pointer_width = "32")] @@ -442,6 +505,7 @@ impl isize { from_xe_bytes_doc = usize_isize_from_xe_bytes_doc!(), bound_condition = " on 32-bit targets", } + midpoint_impl! { isize, i64, signed } } #[cfg(target_pointer_width = "64")] @@ -466,6 +530,7 @@ impl isize { from_xe_bytes_doc = usize_isize_from_xe_bytes_doc!(), bound_condition = " on 64-bit targets", } + midpoint_impl! { isize, i128, signed } } /// If the 6th bit is set ascii is lower case. diff --git a/library/core/tests/num/int_macros.rs b/library/core/tests/num/int_macros.rs index 1608080d6b605..474d57049ab65 100644 --- a/library/core/tests/num/int_macros.rs +++ b/library/core/tests/num/int_macros.rs @@ -369,8 +369,8 @@ macro_rules! int_module { assert_eq_const_safe!(<$T>::midpoint(3, 4), 3); assert_eq_const_safe!(<$T>::midpoint(4, 3), 3); - assert_eq_const_safe!(<$T>::midpoint(<$T>::MIN, <$T>::MAX), -1); - assert_eq_const_safe!(<$T>::midpoint(<$T>::MAX, <$T>::MIN), -1); + assert_eq_const_safe!(<$T>::midpoint(<$T>::MIN, <$T>::MAX), 0); + assert_eq_const_safe!(<$T>::midpoint(<$T>::MAX, <$T>::MIN), 0); assert_eq_const_safe!(<$T>::midpoint(<$T>::MIN, <$T>::MIN), <$T>::MIN); assert_eq_const_safe!(<$T>::midpoint(<$T>::MAX, <$T>::MAX), <$T>::MAX); From 74b9de4af2c1200a82bfa9193423cc7889ddc924 Mon Sep 17 00:00:00 2001 From: Urgau Date: Sat, 26 Oct 2024 22:08:34 +0200 Subject: [PATCH 2/2] Add test for all midpoint expectations --- library/core/tests/num/midpoint.rs | 54 ++++++++++++++++++++++++++++++ library/core/tests/num/mod.rs | 1 + 2 files changed, 55 insertions(+) create mode 100644 library/core/tests/num/midpoint.rs diff --git a/library/core/tests/num/midpoint.rs b/library/core/tests/num/midpoint.rs new file mode 100644 index 0000000000000..71e980067842a --- /dev/null +++ b/library/core/tests/num/midpoint.rs @@ -0,0 +1,54 @@ +//! Test the following expectations: +//! - midpoint(a, b) == (a + b) / 2 +//! - midpoint(a, b) == midpoint(b, a) +//! - midpoint(-a, -b) == -midpoint(a, b) + +#[test] +#[cfg(not(miri))] +fn midpoint_obvious_impl_i8() { + for a in i8::MIN..=i8::MAX { + for b in i8::MIN..=i8::MAX { + assert_eq!(i8::midpoint(a, b), ((a as i16 + b as i16) / 2) as i8); + } + } +} + +#[test] +#[cfg(not(miri))] +fn midpoint_obvious_impl_u8() { + for a in u8::MIN..=u8::MAX { + for b in u8::MIN..=u8::MAX { + assert_eq!(u8::midpoint(a, b), ((a as u16 + b as u16) / 2) as u8); + } + } +} + +#[test] +#[cfg(not(miri))] +fn midpoint_order_expectation_i8() { + for a in i8::MIN..=i8::MAX { + for b in i8::MIN..=i8::MAX { + assert_eq!(i8::midpoint(a, b), i8::midpoint(b, a)); + } + } +} + +#[test] +#[cfg(not(miri))] +fn midpoint_order_expectation_u8() { + for a in u8::MIN..=u8::MAX { + for b in u8::MIN..=u8::MAX { + assert_eq!(u8::midpoint(a, b), u8::midpoint(b, a)); + } + } +} + +#[test] +#[cfg(not(miri))] +fn midpoint_negative_expectation() { + for a in 0..=i8::MAX { + for b in 0..=i8::MAX { + assert_eq!(i8::midpoint(-a, -b), -i8::midpoint(a, b)); + } + } +} diff --git a/library/core/tests/num/mod.rs b/library/core/tests/num/mod.rs index 6da9b9a13293a..0add9a01e682d 100644 --- a/library/core/tests/num/mod.rs +++ b/library/core/tests/num/mod.rs @@ -28,6 +28,7 @@ mod dec2flt; mod flt2dec; mod int_log; mod int_sqrt; +mod midpoint; mod ops; mod wrapping;