Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

valtree performance tuning #136593

Merged
merged 2 commits into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 21 additions & 27 deletions compiler/rustc_const_eval/src/const_eval/valtrees.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use rustc_abi::{BackendRepr, VariantIdx};
use rustc_data_structures::stack::ensure_sufficient_stack;
use rustc_middle::mir::interpret::{EvalToValTreeResult, GlobalId, ReportedErrorInfo};
use rustc_middle::ty::layout::{LayoutCx, LayoutOf, TyAndLayout};
use rustc_middle::ty::{self, ScalarInt, Ty, TyCtxt};
use rustc_middle::ty::{self, Ty, TyCtxt};
use rustc_middle::{bug, mir};
use rustc_span::DUMMY_SP;
use tracing::{debug, instrument, trace};
Expand All @@ -21,38 +21,36 @@ use crate::interpret::{
fn branches<'tcx>(
ecx: &CompileTimeInterpCx<'tcx>,
place: &MPlaceTy<'tcx>,
n: usize,
field_count: usize,
variant: Option<VariantIdx>,
num_nodes: &mut usize,
) -> ValTreeCreationResult<'tcx> {
let place = match variant {
Some(variant) => ecx.project_downcast(place, variant).unwrap(),
None => place.clone(),
};
let variant = variant.map(|variant| Some(ty::ValTree::Leaf(ScalarInt::from(variant.as_u32()))));
debug!(?place, ?variant);
debug!(?place);

let mut fields = Vec::with_capacity(n);
for i in 0..n {
let field = ecx.project_field(&place, i).unwrap();
let valtree = const_to_valtree_inner(ecx, &field, num_nodes)?;
fields.push(Some(valtree));
}
let mut branches = Vec::with_capacity(field_count + variant.is_some() as usize);

// For enums, we prepend their variant index before the variant's fields so we can figure out
// the variant again when just seeing a valtree.
let branches = variant
.into_iter()
.chain(fields.into_iter())
.collect::<Option<Vec<_>>>()
.expect("should have already checked for errors in ValTree creation");
if let Some(variant) = variant {
branches.push(ty::ValTree::from_scalar_int(*ecx.tcx, variant.as_u32().into()));
}

for i in 0..field_count {
let field = ecx.project_field(&place, i).unwrap();
let valtree = const_to_valtree_inner(ecx, &field, num_nodes)?;
branches.push(valtree);
}

// Have to account for ZSTs here
if branches.len() == 0 {
*num_nodes += 1;
}

Ok(ty::ValTree::Branch(ecx.tcx.arena.alloc_from_iter(branches)))
Ok(ty::ValTree::from_branches(*ecx.tcx, branches))
}

#[instrument(skip(ecx), level = "debug")]
Expand All @@ -70,7 +68,7 @@ fn slice_branches<'tcx>(
elems.push(valtree);
}

Ok(ty::ValTree::Branch(ecx.tcx.arena.alloc_from_iter(elems)))
Ok(ty::ValTree::from_branches(*ecx.tcx, elems))
}

#[instrument(skip(ecx), level = "debug")]
Expand All @@ -79,6 +77,7 @@ fn const_to_valtree_inner<'tcx>(
place: &MPlaceTy<'tcx>,
num_nodes: &mut usize,
) -> ValTreeCreationResult<'tcx> {
let tcx = *ecx.tcx;
let ty = place.layout.ty;
debug!("ty kind: {:?}", ty.kind());

Expand All @@ -89,14 +88,14 @@ fn const_to_valtree_inner<'tcx>(
match ty.kind() {
ty::FnDef(..) => {
*num_nodes += 1;
Ok(ty::ValTree::zst())
Ok(ty::ValTree::zst(tcx))
}
ty::Bool | ty::Int(_) | ty::Uint(_) | ty::Float(_) | ty::Char => {
let val = ecx.read_immediate(place).unwrap();
let val = val.to_scalar_int().unwrap();
*num_nodes += 1;

Ok(ty::ValTree::Leaf(val))
Ok(ty::ValTree::from_scalar_int(tcx, val))
}

ty::Pat(base, ..) => {
Expand Down Expand Up @@ -127,7 +126,7 @@ fn const_to_valtree_inner<'tcx>(
return Err(ValTreeCreationError::NonSupportedType(ty));
};
// It's just a ScalarInt!
Ok(ty::ValTree::Leaf(val))
Ok(ty::ValTree::from_scalar_int(tcx, val))
}

// Technically we could allow function pointers (represented as `ty::Instance`), but this is not guaranteed to
Expand Down Expand Up @@ -287,16 +286,11 @@ pub fn valtree_to_const_value<'tcx>(
// FIXME: Does this need an example?
match *cv.ty.kind() {
ty::FnDef(..) => {
assert!(cv.valtree.unwrap_branch().is_empty());
assert!(cv.valtree.is_zst());
mir::ConstValue::ZeroSized
}
ty::Bool | ty::Int(_) | ty::Uint(_) | ty::Float(_) | ty::Char | ty::RawPtr(_, _) => {
match cv.valtree {
ty::ValTree::Leaf(scalar_int) => mir::ConstValue::Scalar(Scalar::Int(scalar_int)),
ty::ValTree::Branch(_) => bug!(
"ValTrees for Bool, Int, Uint, Float, Char or RawPtr should have the form ValTree::Leaf"
),
}
mir::ConstValue::Scalar(Scalar::Int(cv.valtree.unwrap_leaf()))
}
ty::Pat(ty, _) => {
let cv = ty::Value { valtree: cv.valtree, ty };
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_hir_analysis/src/hir_ty_lowering/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2161,7 +2161,7 @@ impl<'tcx> dyn HirTyLowerer<'tcx> + '_ {
did,
path.segments.last().unwrap(),
);
ty::Const::new_value(tcx, ty::ValTree::zst(), Ty::new_fn_def(tcx, did, args))
ty::Const::zero_sized(tcx, Ty::new_fn_def(tcx, did, args))
}

// Exhaustive match to be clear about what exactly we're considering to be
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_middle/src/arena.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ macro_rules! arena_types {
[] autodiff_item: rustc_ast::expand::autodiff_attrs::AutoDiffItem,
[] ordered_name_set: rustc_data_structures::fx::FxIndexSet<rustc_span::Symbol>,
[] pats: rustc_middle::ty::PatternKind<'tcx>,
[] valtree: rustc_middle::ty::ValTreeKind<'tcx>,

// Note that this deliberately duplicates items in the `rustc_hir::arena`,
// since we need to allocate this type on both the `rustc_hir` arena
Expand Down
15 changes: 9 additions & 6 deletions compiler/rustc_middle/src/ty/codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,12 @@ impl<'tcx, E: TyEncoder<I = TyCtxt<'tcx>>> Encodable<E> for ty::Pattern<'tcx> {
}
}

impl<'tcx, E: TyEncoder<I = TyCtxt<'tcx>>> Encodable<E> for ty::ValTree<'tcx> {
fn encode(&self, e: &mut E) {
self.0.0.encode(e);
}
}

impl<'tcx, E: TyEncoder<I = TyCtxt<'tcx>>> Encodable<E> for ConstAllocation<'tcx> {
fn encode(&self, e: &mut E) {
self.inner().encode(e)
Expand Down Expand Up @@ -355,12 +361,9 @@ impl<'tcx, D: TyDecoder<I = TyCtxt<'tcx>>> Decodable<D> for ty::Pattern<'tcx> {
}
}

impl<'tcx, D: TyDecoder<I = TyCtxt<'tcx>>> RefDecodable<'tcx, D> for [ty::ValTree<'tcx>] {
fn decode(decoder: &mut D) -> &'tcx Self {
decoder
.interner()
.arena
.alloc_from_iter((0..decoder.read_usize()).map(|_| Decodable::decode(decoder)))
impl<'tcx, D: TyDecoder<I = TyCtxt<'tcx>>> Decodable<D> for ty::ValTree<'tcx> {
fn decode(decoder: &mut D) -> Self {
decoder.interner().intern_valtree(Decodable::decode(decoder))
}
}

Expand Down
6 changes: 3 additions & 3 deletions compiler/rustc_middle/src/ty/consts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ pub type ConstKind<'tcx> = ir::ConstKind<TyCtxt<'tcx>>;
pub type UnevaluatedConst<'tcx> = ir::UnevaluatedConst<TyCtxt<'tcx>>;

#[cfg(target_pointer_width = "64")]
rustc_data_structures::static_assert_size!(ConstKind<'_>, 32);
rustc_data_structures::static_assert_size!(ConstKind<'_>, 24);

#[derive(Copy, Clone, PartialEq, Eq, Hash, HashStable)]
#[rustc_pass_by_value]
Expand Down Expand Up @@ -190,15 +190,15 @@ impl<'tcx> Const<'tcx> {
.size;
ty::Const::new_value(
tcx,
ty::ValTree::from_scalar_int(ScalarInt::try_from_uint(bits, size).unwrap()),
ty::ValTree::from_scalar_int(tcx, ScalarInt::try_from_uint(bits, size).unwrap()),
ty,
)
}

#[inline]
/// Creates an interned zst constant.
pub fn zero_sized(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> Self {
ty::Const::new_value(tcx, ty::ValTree::zst(), ty)
ty::Const::new_value(tcx, ty::ValTree::zst(tcx), ty)
}

#[inline]
Expand Down
98 changes: 71 additions & 27 deletions compiler/rustc_middle/src/ty/consts/valtree.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
use rustc_macros::{HashStable, TyDecodable, TyEncodable, TypeFoldable, TypeVisitable};
use std::fmt;
use std::ops::Deref;

use rustc_data_structures::intern::Interned;
use rustc_macros::{HashStable, Lift, TyDecodable, TyEncodable, TypeFoldable, TypeVisitable};

use super::ScalarInt;
use crate::mir::interpret::Scalar;
Expand All @@ -16,9 +20,9 @@ use crate::ty::{self, Ty, TyCtxt};
///
/// `ValTree` does not have this problem with representation, as it only contains integers or
/// lists of (nested) `ValTree`.
#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
#[derive(Clone, Debug, Hash, Eq, PartialEq)]
#[derive(HashStable, TyEncodable, TyDecodable)]
pub enum ValTree<'tcx> {
pub enum ValTreeKind<'tcx> {
/// integers, `bool`, `char` are represented as scalars.
/// See the `ScalarInt` documentation for how `ScalarInt` guarantees that equal values
/// of these types have the same representation.
Expand All @@ -33,58 +37,98 @@ pub enum ValTree<'tcx> {
/// the fields of the variant.
///
/// ZST types are represented as an empty slice.
Branch(&'tcx [ValTree<'tcx>]),
Branch(Box<[ValTree<'tcx>]>),
}

impl<'tcx> ValTree<'tcx> {
pub fn zst() -> Self {
Self::Branch(&[])
}

impl<'tcx> ValTreeKind<'tcx> {
#[inline]
pub fn unwrap_leaf(self) -> ScalarInt {
pub fn unwrap_leaf(&self) -> ScalarInt {
match self {
Self::Leaf(s) => s,
Self::Leaf(s) => *s,
_ => bug!("expected leaf, got {:?}", self),
}
}

#[inline]
pub fn unwrap_branch(self) -> &'tcx [Self] {
pub fn unwrap_branch(&self) -> &[ValTree<'tcx>] {
match self {
Self::Branch(branch) => branch,
Self::Branch(branch) => &**branch,
_ => bug!("expected branch, got {:?}", self),
}
}

pub fn from_raw_bytes<'a>(tcx: TyCtxt<'tcx>, bytes: &'a [u8]) -> Self {
let branches = bytes.iter().map(|b| Self::Leaf(ScalarInt::from(*b)));
let interned = tcx.arena.alloc_from_iter(branches);
pub fn try_to_scalar(&self) -> Option<Scalar> {
self.try_to_scalar_int().map(Scalar::Int)
}

Self::Branch(interned)
pub fn try_to_scalar_int(&self) -> Option<ScalarInt> {
match self {
Self::Leaf(s) => Some(*s),
Self::Branch(_) => None,
}
}

pub fn from_scalar_int(i: ScalarInt) -> Self {
Self::Leaf(i)
pub fn try_to_branch(&self) -> Option<&[ValTree<'tcx>]> {
match self {
Self::Branch(branch) => Some(&**branch),
Self::Leaf(_) => None,
}
}
}

pub fn try_to_scalar(self) -> Option<Scalar> {
self.try_to_scalar_int().map(Scalar::Int)
/// An interned valtree. Use this rather than `ValTreeKind`, whenever possible.
///
/// See the docs of [`ValTreeKind`] or the [dev guide] for an explanation of this type.
///
/// [dev guide]: https://rustc-dev-guide.rust-lang.org/mir/index.html#valtrees
#[derive(Copy, Clone, Hash, Eq, PartialEq)]
#[derive(HashStable)]
pub struct ValTree<'tcx>(pub(crate) Interned<'tcx, ValTreeKind<'tcx>>);

impl<'tcx> ValTree<'tcx> {
/// Returns the zero-sized valtree: `Branch([])`.
pub fn zst(tcx: TyCtxt<'tcx>) -> Self {
tcx.consts.valtree_zst
}

pub fn try_to_scalar_int(self) -> Option<ScalarInt> {
match self {
Self::Leaf(s) => Some(s),
Self::Branch(_) => None,
}
pub fn is_zst(self) -> bool {
matches!(*self, ValTreeKind::Branch(box []))
}

pub fn from_raw_bytes(tcx: TyCtxt<'tcx>, bytes: &[u8]) -> Self {
let branches = bytes.iter().map(|&b| Self::from_scalar_int(tcx, b.into()));
Self::from_branches(tcx, branches)
}

pub fn from_branches(tcx: TyCtxt<'tcx>, branches: impl IntoIterator<Item = Self>) -> Self {
tcx.intern_valtree(ValTreeKind::Branch(branches.into_iter().collect()))
}

pub fn from_scalar_int(tcx: TyCtxt<'tcx>, i: ScalarInt) -> Self {
tcx.intern_valtree(ValTreeKind::Leaf(i))
}
}

impl<'tcx> Deref for ValTree<'tcx> {
type Target = &'tcx ValTreeKind<'tcx>;

#[inline]
fn deref(&self) -> &&'tcx ValTreeKind<'tcx> {
&self.0.0
}
}

impl fmt::Debug for ValTree<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
(**self).fmt(f)
}
}

/// A type-level constant value.
///
/// Represents a typed, fully evaluated constant.
#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
#[derive(HashStable, TyEncodable, TyDecodable, TypeFoldable, TypeVisitable)]
#[derive(HashStable, TyEncodable, TyDecodable, TypeFoldable, TypeVisitable, Lift)]
pub struct Value<'tcx> {
pub ty: Ty<'tcx>,
pub valtree: ValTree<'tcx>,
Expand Down
Loading
Loading