Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#[cold] on match arms #120193

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock
Original file line number Diff line number Diff line change
Expand Up @@ -4220,6 +4220,7 @@ dependencies = [
"rustc_target",
"rustc_trait_selection",
"smallvec",
"thin-vec",
"tracing",
]

Expand Down
29 changes: 25 additions & 4 deletions compiler/rustc_codegen_llvm/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,30 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
cond: &'ll Value,
then_llbb: &'ll BasicBlock,
else_llbb: &'ll BasicBlock,
cold_br: Option<bool>,
) {
unsafe {
llvm::LLVMBuildCondBr(self.llbuilder, cond, then_llbb, else_llbb);
// emit the branch instruction
let n = unsafe { llvm::LLVMBuildCondBr(self.llbuilder, cond, then_llbb, else_llbb) };

// if one of the branches is cold, emit metadata with branch weights
if let Some(cold_br) = cold_br {
unsafe {
let s = "branch_weights";
let v = [
llvm::LLVMMDStringInContext(
self.cx.llcx,
s.as_ptr() as *const c_char,
s.len() as c_uint,
),
self.cx.const_u32(if cold_br { 1 } else { 2000 }), // 'then' branch weight
self.cx.const_u32(if cold_br { 2000 } else { 1 }), // 'else' branch weight
];
llvm::LLVMSetMetadata(
n,
llvm::MD_prof as c_uint,
llvm::LLVMMDNodeInContext(self.cx.llcx, v.as_ptr(), v.len() as c_uint),
);
}
}
}

Expand Down Expand Up @@ -605,7 +626,7 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
let i = header_bx.phi(self.val_ty(zero), &[zero], &[self.llbb()]);

let keep_going = header_bx.icmp(IntPredicate::IntULT, i, count);
header_bx.cond_br(keep_going, body_bb, next_bb);
header_bx.cond_br(keep_going, body_bb, next_bb, None);

let mut body_bx = Self::build(self.cx, body_bb);
let dest_elem = dest.project_index(&mut body_bx, i);
Expand Down Expand Up @@ -1556,7 +1577,7 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
let cond = self.type_test(llfn, typeid_metadata);
let bb_pass = self.append_sibling_block("type_test.pass");
let bb_fail = self.append_sibling_block("type_test.fail");
self.cond_br(cond, bb_pass, bb_fail);
self.cond_br(cond, bb_pass, bb_fail, None);

self.switch_to_block(bb_fail);
self.abort();
Expand Down
6 changes: 3 additions & 3 deletions compiler/rustc_codegen_llvm/src/va_arg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ fn emit_aapcs_va_arg<'ll, 'tcx>(
// if the offset >= 0 then the value will be on the stack
let mut reg_off_v = bx.load(bx.type_i32(), reg_off, offset_align);
let use_stack = bx.icmp(IntPredicate::IntSGE, reg_off_v, zero);
bx.cond_br(use_stack, on_stack, maybe_reg);
bx.cond_br(use_stack, on_stack, maybe_reg, None);

// The value at this point might be in a register, but there is a chance that
// it could be on the stack so we have to update the offset and then check
Expand All @@ -137,7 +137,7 @@ fn emit_aapcs_va_arg<'ll, 'tcx>(
// Check to see if we have overflowed the registers as a result of this.
// If we have then we need to use the stack for this value
let use_stack = bx.icmp(IntPredicate::IntSGT, new_reg_off_v, zero);
bx.cond_br(use_stack, on_stack, in_reg);
bx.cond_br(use_stack, on_stack, in_reg, None);

bx.switch_to_block(in_reg);
let top_type = bx.type_ptr();
Expand Down Expand Up @@ -203,7 +203,7 @@ fn emit_s390x_va_arg<'ll, 'tcx>(
);
let reg_count_v = bx.load(bx.type_i64(), reg_count, Align::from_bytes(8).unwrap());
let use_regs = bx.icmp(IntPredicate::IntULT, reg_count_v, bx.const_u64(max_regs));
bx.cond_br(use_regs, in_reg, in_mem);
bx.cond_br(use_regs, in_reg, in_mem, None);

// Emit code to load the value if it was passed in a register.
bx.switch_to_block(in_reg);
Expand Down
18 changes: 12 additions & 6 deletions compiler/rustc_codegen_ssa/src/mir/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -327,18 +327,24 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
let (test_value, target) = target_iter.next().unwrap();
let lltrue = helper.llbb_with_cleanup(self, target);
let llfalse = helper.llbb_with_cleanup(self, targets.otherwise());
let cold_br =
targets.cold_target().and_then(|t| if t == 0 { Some(true) } else { Some(false) });

if switch_ty == bx.tcx().types.bool {
// Don't generate trivial icmps when switching on bool.
match test_value {
0 => bx.cond_br(discr.immediate(), llfalse, lltrue),
1 => bx.cond_br(discr.immediate(), lltrue, llfalse),
0 => {
let cold_br = cold_br.and_then(|t| Some(!t));
bx.cond_br(discr.immediate(), llfalse, lltrue, cold_br);
}
1 => bx.cond_br(discr.immediate(), lltrue, llfalse, cold_br),
_ => bug!(),
}
} else {
let switch_llty = bx.immediate_backend_type(bx.layout_of(switch_ty));
let llval = bx.const_uint_big(switch_llty, test_value);
let cmp = bx.icmp(IntPredicate::IntEQ, discr.immediate(), llval);
bx.cond_br(cmp, lltrue, llfalse);
bx.cond_br(cmp, lltrue, llfalse, cold_br);
}
} else if self.cx.sess().opts.optimize == OptLevel::No
&& target_iter.len() == 2
Expand All @@ -363,7 +369,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
let switch_llty = bx.immediate_backend_type(bx.layout_of(switch_ty));
let llval = bx.const_uint_big(switch_llty, test_value1);
let cmp = bx.icmp(IntPredicate::IntEQ, discr.immediate(), llval);
bx.cond_br(cmp, ll1, ll2);
bx.cond_br(cmp, ll1, ll2, None);
} else {
bx.switch(
discr.immediate(),
Expand Down Expand Up @@ -595,9 +601,9 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
let lltarget = helper.llbb_with_cleanup(self, target);
let panic_block = bx.append_sibling_block("panic");
if expected {
bx.cond_br(cond, lltarget, panic_block);
bx.cond_br(cond, lltarget, panic_block, None);
} else {
bx.cond_br(cond, panic_block, lltarget);
bx.cond_br(cond, panic_block, lltarget, None);
}

// After this point, bx is the block for the call to panic.
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_codegen_ssa/src/traits/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ pub trait BuilderMethods<'a, 'tcx>:
cond: Self::Value,
then_llbb: Self::BasicBlock,
else_llbb: Self::BasicBlock,
cold_br: Option<bool>,
);
fn switch(
&mut self,
Expand Down
11 changes: 11 additions & 0 deletions compiler/rustc_data_structures/src/stable_hasher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use std::fmt;
use std::hash::{BuildHasher, Hash, Hasher};
use std::marker::PhantomData;
use std::mem;
use thin_vec::ThinVec;

#[cfg(test)]
mod tests;
Expand Down Expand Up @@ -499,6 +500,16 @@ where
}
}

impl<T, CTX> HashStable<CTX> for ThinVec<T>
where
T: HashStable<CTX>,
{
#[inline]
fn hash_stable(&self, ctx: &mut CTX, hasher: &mut StableHasher) {
self[..].hash_stable(ctx, hasher);
}
}

impl<T: ?Sized + HashStable<CTX>, CTX> HashStable<CTX> for Box<T> {
#[inline]
fn hash_stable(&self, ctx: &mut CTX, hasher: &mut StableHasher) {
Expand Down
7 changes: 7 additions & 0 deletions compiler/rustc_middle/src/mir/syntax.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use rustc_span::symbol::Symbol;
use rustc_span::Span;
use rustc_target::asm::InlineAsmRegOrRegClass;
use smallvec::SmallVec;
use thin_vec::ThinVec;

/// Represents the "flavors" of MIR.
///
Expand Down Expand Up @@ -844,6 +845,12 @@ pub struct SwitchTargets {
// However we’ve decided to keep this as-is until we figure a case
// where some other approach seems to be strictly better than other.
pub(super) targets: SmallVec<[BasicBlock; 2]>,

// Targets that are marked 'cold', if any.
// This vector contains indices into `targets`.
// It can also contain 'targets.len()' to indicate that the otherwise
// branch is cold.
pub(super) cold_targets: ThinVec<usize>,
}

/// Action to be taken when a stack unwind happens.
Expand Down
43 changes: 38 additions & 5 deletions compiler/rustc_middle/src/mir/terminator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use smallvec::SmallVec;
use super::TerminatorKind;
use rustc_macros::HashStable;
use std::slice;
use thin_vec::{thin_vec, ThinVec};

use super::*;

Expand All @@ -16,13 +17,26 @@ impl SwitchTargets {
pub fn new(targets: impl Iterator<Item = (u128, BasicBlock)>, otherwise: BasicBlock) -> Self {
let (values, mut targets): (SmallVec<_>, SmallVec<_>) = targets.unzip();
targets.push(otherwise);
Self { values, targets }
Self { values, targets, cold_targets: ThinVec::new() }
}

/// Builds a switch targets definition that jumps to `then` if the tested value equals `value`,
/// and to `else_` if not.
pub fn static_if(value: u128, then: BasicBlock, else_: BasicBlock) -> Self {
Self { values: smallvec![value], targets: smallvec![then, else_] }
/// If cold_br is some bool value, the given outcome is considered cold (i.e., unlikely).
pub fn static_if(
value: u128,
then: BasicBlock,
else_: BasicBlock,
cold_br: Option<bool>,
) -> Self {
Self {
values: smallvec![value],
targets: smallvec![then, else_],
cold_targets: match cold_br {
Some(br) => thin_vec![if br { 0 } else { 1 }],
None => ThinVec::new(),
},
}
}

/// Inverse of `SwitchTargets::static_if`.
Expand All @@ -37,6 +51,11 @@ impl SwitchTargets {
}
}

// If this switch has exactly one cold target, returns it.
pub fn cold_target(&self) -> Option<usize> {
if self.cold_targets.len() == 1 { Some(self.cold_targets[0]) } else { None }
}

/// Returns the fallback target that is jumped to when none of the values match the operand.
#[inline]
pub fn otherwise(&self) -> BasicBlock {
Expand Down Expand Up @@ -361,8 +380,22 @@ impl<'tcx> Terminator<'tcx> {

impl<'tcx> TerminatorKind<'tcx> {
#[inline]
pub fn if_(cond: Operand<'tcx>, t: BasicBlock, f: BasicBlock) -> TerminatorKind<'tcx> {
TerminatorKind::SwitchInt { discr: cond, targets: SwitchTargets::static_if(0, f, t) }
pub fn if_(
cond: Operand<'tcx>,
t: BasicBlock,
f: BasicBlock,
cold_branch: Option<bool>,
) -> TerminatorKind<'tcx> {
TerminatorKind::SwitchInt {
discr: cond,
targets: SwitchTargets::static_if(
0,
f,
t,
// we compare to zero, so have to invert the branch
cold_branch.and_then(|b| Some(!b)),
),
}
}

#[inline]
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_middle/src/thir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,7 @@ pub struct Arm<'tcx> {
pub lint_level: LintLevel,
pub scope: region::Scope,
pub span: Span,
pub is_cold: bool,
}

#[derive(Copy, Clone, Debug, HashStable)]
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_mir_build/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,6 @@ rustc_span = { path = "../rustc_span" }
rustc_target = { path = "../rustc_target" }
rustc_trait_selection = { path = "../rustc_trait_selection" }
smallvec = { version = "1.8.1", features = ["union", "may_dangle"] }
thin-vec = "0.2.12"
tracing = "0.1"
# tidy-alphabetical-end
2 changes: 2 additions & 0 deletions compiler/rustc_mir_build/src/build/expr/into.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
condition_scope,
source_info,
true,
None,
));

this.expr_into_dest(destination, then_blk, then)
Expand Down Expand Up @@ -176,6 +177,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
condition_scope,
source_info,
true,
None,
)
});
let (short_circuit, continuation, constant) = match op {
Expand Down
Loading