Skip to content

Commit

Permalink
Auto merge of #68358 - matthewjasper:spec-fix, r=nikomatsakis
Browse files Browse the repository at this point in the history
Remove some unsound specializations

This removes the unsound and exploitable specializations in the standard library

* The `PartialEq` and `Hash` implementations for  `RangeInclusive` are changed to avoid specialization.
* The `PartialOrd` specialization for slices now specializes on a limited set of concrete types.
* Added some tests for the soundness problems.
  • Loading branch information
bors committed Feb 8, 2020
2 parents 8498c5f + a81c59f commit 6cad754
Show file tree
Hide file tree
Showing 7 changed files with 168 additions and 57 deletions.
4 changes: 4 additions & 0 deletions src/libcore/iter/range.rs
Original file line number Diff line number Diff line change
Expand Up @@ -385,12 +385,14 @@ impl<A: Step> Iterator for ops::RangeInclusive<A> {
}
Some(Equal) => {
self.is_empty = Some(true);
self.start = plus_n.clone();
return Some(plus_n);
}
_ => {}
}
}

self.start = self.end.clone();
self.is_empty = Some(true);
None
}
Expand Down Expand Up @@ -477,12 +479,14 @@ impl<A: Step> DoubleEndedIterator for ops::RangeInclusive<A> {
}
Some(Equal) => {
self.is_empty = Some(true);
self.end = minus_n.clone();
return Some(minus_n);
}
_ => {}
}
}

self.end = self.start.clone();
self.is_empty = Some(true);
None
}
Expand Down
38 changes: 15 additions & 23 deletions src/libcore/ops/range.rs
Original file line number Diff line number Diff line change
Expand Up @@ -343,38 +343,21 @@ pub struct RangeInclusive<Idx> {
pub(crate) is_empty: Option<bool>,
// This field is:
// - `None` when next() or next_back() was never called
// - `Some(false)` when `start <= end` assuming no overflow
// - `Some(true)` otherwise
// - `Some(false)` when `start < end`
// - `Some(true)` when `end < start`
// - `Some(false)` when `start == end` and the range hasn't yet completed iteration
// - `Some(true)` when `start == end` and the range has completed iteration
// The field cannot be a simple `bool` because the `..=` constructor can
// accept non-PartialOrd types, also we want the constructor to be const.
}

trait RangeInclusiveEquality: Sized {
fn canonicalized_is_empty(range: &RangeInclusive<Self>) -> bool;
}

impl<T> RangeInclusiveEquality for T {
#[inline]
default fn canonicalized_is_empty(range: &RangeInclusive<Self>) -> bool {
range.is_empty.unwrap_or_default()
}
}

impl<T: PartialOrd> RangeInclusiveEquality for T {
#[inline]
fn canonicalized_is_empty(range: &RangeInclusive<Self>) -> bool {
range.is_empty()
}
}

#[stable(feature = "inclusive_range", since = "1.26.0")]
impl<Idx: PartialEq> PartialEq for RangeInclusive<Idx> {
#[inline]
fn eq(&self, other: &Self) -> bool {
self.start == other.start
&& self.end == other.end
&& RangeInclusiveEquality::canonicalized_is_empty(self)
== RangeInclusiveEquality::canonicalized_is_empty(other)
&& self.is_exhausted() == other.is_exhausted()
}
}

Expand All @@ -386,7 +369,8 @@ impl<Idx: Hash> Hash for RangeInclusive<Idx> {
fn hash<H: Hasher>(&self, state: &mut H) {
self.start.hash(state);
self.end.hash(state);
RangeInclusiveEquality::canonicalized_is_empty(self).hash(state);
// Ideally we would hash `is_exhausted` here as well, but there's no
// way for us to call it.
}
}

Expand Down Expand Up @@ -485,6 +469,14 @@ impl<Idx: fmt::Debug> fmt::Debug for RangeInclusive<Idx> {
}
}

impl<Idx: PartialEq<Idx>> RangeInclusive<Idx> {
// Returns true if this is a range that started non-empty, and was iterated
// to exhaustion.
fn is_exhausted(&self) -> bool {
Some(true) == self.is_empty && self.start == self.end
}
}

impl<Idx: PartialOrd<Idx>> RangeInclusive<Idx> {
/// Returns `true` if `item` is contained in the range.
///
Expand Down
80 changes: 51 additions & 29 deletions src/libcore/slice/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5584,21 +5584,18 @@ where

#[doc(hidden)]
// intermediate trait for specialization of slice's PartialOrd
trait SlicePartialOrd<B> {
fn partial_compare(&self, other: &[B]) -> Option<Ordering>;
trait SlicePartialOrd: Sized {
fn partial_compare(left: &[Self], right: &[Self]) -> Option<Ordering>;
}

impl<A> SlicePartialOrd<A> for [A]
where
A: PartialOrd,
{
default fn partial_compare(&self, other: &[A]) -> Option<Ordering> {
let l = cmp::min(self.len(), other.len());
impl<A: PartialOrd> SlicePartialOrd for A {
default fn partial_compare(left: &[A], right: &[A]) -> Option<Ordering> {
let l = cmp::min(left.len(), right.len());

// Slice to the loop iteration range to enable bound check
// elimination in the compiler
let lhs = &self[..l];
let rhs = &other[..l];
let lhs = &left[..l];
let rhs = &right[..l];

for i in 0..l {
match lhs[i].partial_cmp(&rhs[i]) {
Expand All @@ -5607,36 +5604,61 @@ where
}
}

self.len().partial_cmp(&other.len())
left.len().partial_cmp(&right.len())
}
}

impl<A> SlicePartialOrd<A> for [A]
// This is the impl that we would like to have. Unfortunately it's not sound.
// See `partial_ord_slice.rs`.
/*
impl<A> SlicePartialOrd for A
where
A: Ord,
{
default fn partial_compare(&self, other: &[A]) -> Option<Ordering> {
Some(SliceOrd::compare(self, other))
default fn partial_compare(left: &[A], right: &[A]) -> Option<Ordering> {
Some(SliceOrd::compare(left, right))
}
}
*/

impl<A: AlwaysApplicableOrd> SlicePartialOrd for A {
fn partial_compare(left: &[A], right: &[A]) -> Option<Ordering> {
Some(SliceOrd::compare(left, right))
}
}

trait AlwaysApplicableOrd: SliceOrd + Ord {}

macro_rules! always_applicable_ord {
($([$($p:tt)*] $t:ty,)*) => {
$(impl<$($p)*> AlwaysApplicableOrd for $t {})*
}
}

always_applicable_ord! {
[] u8, [] u16, [] u32, [] u64, [] u128, [] usize,
[] i8, [] i16, [] i32, [] i64, [] i128, [] isize,
[] bool, [] char,
[T: ?Sized] *const T, [T: ?Sized] *mut T,
[T: AlwaysApplicableOrd] &T,
[T: AlwaysApplicableOrd] &mut T,
[T: AlwaysApplicableOrd] Option<T>,
}

#[doc(hidden)]
// intermediate trait for specialization of slice's Ord
trait SliceOrd<B> {
fn compare(&self, other: &[B]) -> Ordering;
trait SliceOrd: Sized {
fn compare(left: &[Self], right: &[Self]) -> Ordering;
}

impl<A> SliceOrd<A> for [A]
where
A: Ord,
{
default fn compare(&self, other: &[A]) -> Ordering {
let l = cmp::min(self.len(), other.len());
impl<A: Ord> SliceOrd for A {
default fn compare(left: &[Self], right: &[Self]) -> Ordering {
let l = cmp::min(left.len(), right.len());

// Slice to the loop iteration range to enable bound check
// elimination in the compiler
let lhs = &self[..l];
let rhs = &other[..l];
let lhs = &left[..l];
let rhs = &right[..l];

for i in 0..l {
match lhs[i].cmp(&rhs[i]) {
Expand All @@ -5645,19 +5667,19 @@ where
}
}

self.len().cmp(&other.len())
left.len().cmp(&right.len())
}
}

// memcmp compares a sequence of unsigned bytes lexicographically.
// this matches the order we want for [u8], but no others (not even [i8]).
impl SliceOrd<u8> for [u8] {
impl SliceOrd for u8 {
#[inline]
fn compare(&self, other: &[u8]) -> Ordering {
fn compare(left: &[Self], right: &[Self]) -> Ordering {
let order =
unsafe { memcmp(self.as_ptr(), other.as_ptr(), cmp::min(self.len(), other.len())) };
unsafe { memcmp(left.as_ptr(), right.as_ptr(), cmp::min(left.len(), right.len())) };
if order == 0 {
self.len().cmp(&other.len())
left.len().cmp(&right.len())
} else if order < 0 {
Less
} else {
Expand Down
10 changes: 5 additions & 5 deletions src/libcore/str/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use self::pattern::{DoubleEndedSearcher, ReverseSearcher, SearchStep, Searcher};
use crate::char;
use crate::fmt::{self, Write};
use crate::iter::{Chain, FlatMap, Flatten};
use crate::iter::{Cloned, Filter, FusedIterator, Map, TrustedLen, TrustedRandomAccess};
use crate::iter::{Copied, Filter, FusedIterator, Map, TrustedLen, TrustedRandomAccess};
use crate::mem;
use crate::ops::Try;
use crate::option;
Expand Down Expand Up @@ -750,7 +750,7 @@ impl<'a> CharIndices<'a> {
/// [`str`]: ../../std/primitive.str.html
#[stable(feature = "rust1", since = "1.0.0")]
#[derive(Clone, Debug)]
pub struct Bytes<'a>(Cloned<slice::Iter<'a, u8>>);
pub struct Bytes<'a>(Copied<slice::Iter<'a, u8>>);

#[stable(feature = "rust1", since = "1.0.0")]
impl Iterator for Bytes<'_> {
Expand Down Expand Up @@ -2778,7 +2778,7 @@ impl str {
#[stable(feature = "rust1", since = "1.0.0")]
#[inline]
pub fn bytes(&self) -> Bytes<'_> {
Bytes(self.as_bytes().iter().cloned())
Bytes(self.as_bytes().iter().copied())
}

/// Splits a string slice by whitespace.
Expand Down Expand Up @@ -3895,7 +3895,7 @@ impl str {
debug_assert_eq!(
start, 0,
"The first search step from Searcher \
must include the first character"
must include the first character"
);
// SAFETY: `Searcher` is known to return valid indices.
unsafe { Some(self.get_unchecked(len..)) }
Expand Down Expand Up @@ -3934,7 +3934,7 @@ impl str {
end,
self.len(),
"The first search step from ReverseSearcher \
must include the last character"
must include the last character"
);
// SAFETY: `Searcher` is known to return valid indices.
unsafe { Some(self.get_unchecked(..start)) }
Expand Down
16 changes: 16 additions & 0 deletions src/libcore/tests/iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1956,11 +1956,19 @@ fn test_range_inclusive_exhaustion() {
assert_eq!(r.next(), None);
assert_eq!(r.next(), None);

assert_eq!(*r.start(), 10);
assert_eq!(*r.end(), 10);
assert_ne!(r, 10..=10);

let mut r = 10..=10;
assert_eq!(r.next_back(), Some(10));
assert!(r.is_empty());
assert_eq!(r.next_back(), None);

assert_eq!(*r.start(), 10);
assert_eq!(*r.end(), 10);
assert_ne!(r, 10..=10);

let mut r = 10..=12;
assert_eq!(r.next(), Some(10));
assert_eq!(r.next(), Some(11));
Expand Down Expand Up @@ -2078,6 +2086,9 @@ fn test_range_inclusive_nth() {
assert_eq!((10..=15).nth(5), Some(15));
assert_eq!((10..=15).nth(6), None);

let mut exhausted_via_next = 10_u8..=20;
while exhausted_via_next.next().is_some() {}

let mut r = 10_u8..=20;
assert_eq!(r.nth(2), Some(12));
assert_eq!(r, 13..=20);
Expand All @@ -2087,6 +2098,7 @@ fn test_range_inclusive_nth() {
assert_eq!(ExactSizeIterator::is_empty(&r), false);
assert_eq!(r.nth(10), None);
assert_eq!(r.is_empty(), true);
assert_eq!(r, exhausted_via_next);
assert_eq!(ExactSizeIterator::is_empty(&r), true);
}

Expand All @@ -2098,6 +2110,9 @@ fn test_range_inclusive_nth_back() {
assert_eq!((10..=15).nth_back(6), None);
assert_eq!((-120..=80_i8).nth_back(200), Some(-120));

let mut exhausted_via_next_back = 10_u8..=20;
while exhausted_via_next_back.next_back().is_some() {}

let mut r = 10_u8..=20;
assert_eq!(r.nth_back(2), Some(18));
assert_eq!(r, 10..=17);
Expand All @@ -2107,6 +2122,7 @@ fn test_range_inclusive_nth_back() {
assert_eq!(ExactSizeIterator::is_empty(&r), false);
assert_eq!(r.nth_back(10), None);
assert_eq!(r.is_empty(), true);
assert_eq!(r, exhausted_via_next_back);
assert_eq!(ExactSizeIterator::is_empty(&r), true);
}

Expand Down
35 changes: 35 additions & 0 deletions src/test/ui/specialization/soundness/partial_eq_range_inclusive.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// run-pass

use std::cell::RefCell;
use std::cmp::Ordering;

struct Evil<'a, 'b> {
values: RefCell<Vec<&'a str>>,
to_insert: &'b String,
}

impl<'a, 'b> PartialEq for Evil<'a, 'b> {
fn eq(&self, _other: &Self) -> bool {
true
}
}

impl<'a> PartialOrd for Evil<'a, 'a> {
fn partial_cmp(&self, _other: &Self) -> Option<Ordering> {
self.values.borrow_mut().push(self.to_insert);
None
}
}

fn main() {
let e;
let values;
{
let to_insert = String::from("Hello, world!");
e = Evil { values: RefCell::new(Vec::new()), to_insert: &to_insert };
let range = &e..=&e;
let _ = range == range;
values = e.values;
}
assert_eq!(*values.borrow(), Vec::<&str>::new());
}
Loading

0 comments on commit 6cad754

Please sign in to comment.