Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add "integer square root" method to integer primitive types #116176

Merged
merged 10 commits into from
Sep 29, 2023
1 change: 1 addition & 0 deletions library/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@
#![feature(ip)]
#![feature(ip_bits)]
#![feature(is_ascii_octdigit)]
#![feature(isqrt)]
#![feature(maybe_uninit_uninit_array)]
#![feature(ptr_alignment_type)]
#![feature(ptr_metadata)]
Expand Down
54 changes: 54 additions & 0 deletions library/core/src/num/int_macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -898,6 +898,30 @@ macro_rules! int_impl {
acc.checked_mul(base)
}

/// Returns the square root of the number, rounded down.
///
/// Returns `None` if `self` is negative.
///
/// # Examples
///
/// Basic usage:
/// ```
/// #![feature(isqrt)]
#[doc = concat!("assert_eq!(10", stringify!($SelfT), ".checked_isqrt(), Some(3));")]
/// ```
#[unstable(feature = "isqrt", issue = "116226")]
#[rustc_const_unstable(feature = "isqrt", issue = "116226")]
#[must_use = "this returns the result of the operation, \
without modifying the original"]
#[inline]
pub const fn checked_isqrt(self) -> Option<Self> {
if self < 0 {
None
} else {
Some((self as $UnsignedT).isqrt() as Self)
}
}

/// Saturating integer addition. Computes `self + rhs`, saturating at the numeric
/// bounds instead of overflowing.
///
Expand Down Expand Up @@ -2061,6 +2085,36 @@ macro_rules! int_impl {
acc * base
}

/// Returns the square root of the number, rounded down.
///
/// # Panics
///
/// This function will panic if `self` is negative.
///
/// # Examples
///
/// Basic usage:
/// ```
/// #![feature(isqrt)]
#[doc = concat!("assert_eq!(10", stringify!($SelfT), ".isqrt(), 3);")]
/// ```
#[unstable(feature = "isqrt", issue = "116226")]
#[rustc_const_unstable(feature = "isqrt", issue = "116226")]
#[must_use = "this returns the result of the operation, \
without modifying the original"]
#[inline]
pub const fn isqrt(self) -> Self {
// I would like to implement it as
// ```
// self.checked_isqrt().expect("argument of integer square root must be non-negative")
// ```
// but `expect` is not yet stable as a `const fn`.
match self.checked_isqrt() {
Some(sqrt) => sqrt,
None => panic!("argument of integer square root must be non-negative"),
}
}

/// Calculates the quotient of Euclidean division of `self` by `rhs`.
///
/// This computes the integer `q` such that `self = q * rhs + r`, with
Expand Down
48 changes: 48 additions & 0 deletions library/core/src/num/uint_macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1979,6 +1979,54 @@ macro_rules! uint_impl {
acc * base
}

/// Returns the square root of the number, rounded down.
///
/// # Examples
///
/// Basic usage:
/// ```
/// #![feature(isqrt)]
#[doc = concat!("assert_eq!(10", stringify!($SelfT), ".isqrt(), 3);")]
/// ```
#[unstable(feature = "isqrt", issue = "116226")]
#[rustc_const_unstable(feature = "isqrt", issue = "116226")]
#[must_use = "this returns the result of the operation, \
without modifying the original"]
#[inline]
pub const fn isqrt(self) -> Self {
dtolnay marked this conversation as resolved.
Show resolved Hide resolved
if self < 2 {
FedericoStra marked this conversation as resolved.
Show resolved Hide resolved
return self;
}

// The algorithm is based on the one presented in
// <https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Binary_numeral_system_(base_2)>
// which cites as source the following C code:
// <https://web.archive.org/web/20120306040058/http://medialab.freaknet.org/martin/src/sqrt/sqrt.c>.

let mut op = self;
let mut res = 0;
let mut one = 1 << (self.ilog2() & !1);

while one != 0 {
if op >= res + one {
op -= res + one;
res = (res >> 1) + one;
} else {
res >>= 1;
}
one >>= 2;
}

// SAFETY: the result is positive and fits in an integer with half as many bits.
// Inform the optimizer about it.
unsafe {
intrinsics::assume(0 < res);
intrinsics::assume(res < 1 << (Self::BITS / 2));
}

res
}

/// Performs Euclidean division.
///
/// Since, for the positive integers, all common
Expand Down
1 change: 1 addition & 0 deletions library/core/tests/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
#![feature(min_specialization)]
#![feature(numfmt)]
#![feature(num_midpoint)]
#![feature(isqrt)]
#![feature(step_trait)]
#![feature(str_internals)]
#![feature(std_internals)]
Expand Down
32 changes: 32 additions & 0 deletions library/core/tests/num/int_macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,38 @@ macro_rules! int_module {
assert_eq!(r.saturating_pow(0), 1 as $T);
}

#[test]
fn test_isqrt() {
assert_eq!($T::MIN.checked_isqrt(), None);
assert_eq!((-1 as $T).checked_isqrt(), None);
assert_eq!((0 as $T).isqrt(), 0 as $T);
assert_eq!((1 as $T).isqrt(), 1 as $T);
assert_eq!((2 as $T).isqrt(), 1 as $T);
assert_eq!((99 as $T).isqrt(), 9 as $T);
assert_eq!((100 as $T).isqrt(), 10 as $T);
}

#[cfg(not(miri))] // Miri is too slow
#[test]
fn test_lots_of_isqrt() {
let n_max: $T = (1024 * 1024).min($T::MAX as u128) as $T;
for n in 0..=n_max {
let isqrt: $T = n.isqrt();

assert!(isqrt.pow(2) <= n);
let (square, overflow) = (isqrt + 1).overflowing_pow(2);
assert!(overflow || square > n);
}

for n in ($T::MAX - 127)..=$T::MAX {
let isqrt: $T = n.isqrt();

assert!(isqrt.pow(2) <= n);
let (square, overflow) = (isqrt + 1).overflowing_pow(2);
assert!(overflow || square > n);
}
}

#[test]
fn test_div_floor() {
let a: $T = 8;
Expand Down
29 changes: 29 additions & 0 deletions library/core/tests/num/uint_macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,35 @@ macro_rules! uint_module {
assert_eq!(r.saturating_pow(2), MAX);
}

#[test]
fn test_isqrt() {
assert_eq!((0 as $T).isqrt(), 0 as $T);
assert_eq!((1 as $T).isqrt(), 1 as $T);
assert_eq!((2 as $T).isqrt(), 1 as $T);
assert_eq!((99 as $T).isqrt(), 9 as $T);
assert_eq!((100 as $T).isqrt(), 10 as $T);
assert_eq!($T::MAX.isqrt(), (1 << ($T::BITS / 2)) - 1);
}

#[cfg(not(miri))] // Miri is too slow
#[test]
fn test_lots_of_isqrt() {
let n_max: $T = (1024 * 1024).min($T::MAX as u128) as $T;
for n in 0..=n_max {
let isqrt: $T = n.isqrt();

assert!(isqrt.pow(2) <= n);
assert!(isqrt + 1 == (1 as $T) << ($T::BITS / 2) || (isqrt + 1).pow(2) > n);
}

for n in ($T::MAX - 255)..=$T::MAX {
let isqrt: $T = n.isqrt();

assert!(isqrt.pow(2) <= n);
assert!(isqrt + 1 == (1 as $T) << ($T::BITS / 2) || (isqrt + 1).pow(2) > n);
}
}

#[test]
fn test_div_floor() {
assert_eq!((8 as $T).div_floor(3), 2);
Expand Down