diff --git a/compiler/rustc_codegen_llvm/src/builder.rs b/compiler/rustc_codegen_llvm/src/builder.rs index b5bb7630ca6c9..977a8e21c776b 100644 --- a/compiler/rustc_codegen_llvm/src/builder.rs +++ b/compiler/rustc_codegen_llvm/src/builder.rs @@ -2,7 +2,7 @@ use std::borrow::Cow; use std::ops::Deref; use std::{iter, ptr}; -use libc::{c_char, c_uint}; +use libc::{c_char, c_uint, size_t}; use rustc_abi as abi; use rustc_abi::{Align, Size, WrappingRange}; use rustc_codegen_ssa::MemFlags; @@ -30,7 +30,7 @@ use crate::abi::FnAbiLlvmExt; use crate::attributes; use crate::common::Funclet; use crate::context::CodegenCx; -use crate::llvm::{self, AtomicOrdering, AtomicRmwBinOp, BasicBlock, False, True}; +use crate::llvm::{self, AtomicOrdering, AtomicRmwBinOp, BasicBlock, False, Metadata, True}; use crate::type_::Type; use crate::type_of::LayoutLlvmExt; use crate::value::Value; @@ -223,6 +223,50 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { } } + fn switch_with_weights( + &mut self, + v: Self::Value, + else_llbb: Self::BasicBlock, + else_is_cold: bool, + cases: impl ExactSizeIterator, + ) { + if self.cx.sess().opts.optimize == rustc_session::config::OptLevel::No { + self.switch(v, else_llbb, cases.map(|(val, dest, _)| (val, dest))); + return; + } + + let id_str = "branch_weights"; + let id = unsafe { + llvm::LLVMMDStringInContext2( + self.cx.llcx, + id_str.as_ptr() as *const c_char, + id_str.len() as size_t, + ) + }; + + let cold_weight = unsafe { llvm::LLVMValueAsMetadata(self.cx.const_u32(1)) }; + let hot_weight = unsafe { llvm::LLVMValueAsMetadata(self.cx.const_u32(2000)) }; + let weight = + |is_cold: bool| -> &Metadata { if is_cold { cold_weight } else { hot_weight } }; + + let mut md: SmallVec<[&Metadata; 16]> = SmallVec::with_capacity(cases.len() + 2); + md.push(id); + md.push(weight(else_is_cold)); + + let switch = + unsafe { llvm::LLVMBuildSwitch(self.llbuilder, v, else_llbb, cases.len() as c_uint) }; + for (on_val, dest, is_cold) in cases { + let on_val = self.const_uint_big(self.val_ty(v), on_val); + unsafe { llvm::LLVMAddCase(switch, on_val, dest) } + md.push(weight(is_cold)); + } + + unsafe { + let md_node = llvm::LLVMMDNodeInContext2(self.cx.llcx, md.as_ptr(), md.len() as size_t); + self.cx.set_metadata(switch, llvm::MD_prof, md_node); + } + } + fn invoke( &mut self, llty: &'ll Type, diff --git a/compiler/rustc_codegen_ssa/src/mir/block.rs b/compiler/rustc_codegen_ssa/src/mir/block.rs index e3ed12b5ce61c..38a9a31b428d6 100644 --- a/compiler/rustc_codegen_ssa/src/mir/block.rs +++ b/compiler/rustc_codegen_ssa/src/mir/block.rs @@ -429,11 +429,34 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { let cmp = bx.icmp(IntPredicate::IntEQ, discr_value, llval); bx.cond_br(cmp, ll1, ll2); } else { - bx.switch( - discr_value, - helper.llbb_with_cleanup(self, targets.otherwise()), - target_iter.map(|(value, target)| (value, helper.llbb_with_cleanup(self, target))), - ); + let otherwise = targets.otherwise(); + let otherwise_cold = self.cold_blocks[otherwise]; + let otherwise_unreachable = self.mir[otherwise].is_empty_unreachable(); + let cold_count = targets.iter().filter(|(_, target)| self.cold_blocks[*target]).count(); + let none_cold = cold_count == 0; + let all_cold = cold_count == targets.iter().len(); + if (none_cold && (!otherwise_cold || otherwise_unreachable)) + || (all_cold && (otherwise_cold || otherwise_unreachable)) + { + // All targets have the same weight, + // or `otherwise` is unreachable and it's the only target with a different weight. + bx.switch( + discr_value, + helper.llbb_with_cleanup(self, targets.otherwise()), + target_iter + .map(|(value, target)| (value, helper.llbb_with_cleanup(self, target))), + ); + } else { + // Targets have different weights + bx.switch_with_weights( + discr_value, + helper.llbb_with_cleanup(self, targets.otherwise()), + otherwise_cold, + target_iter.map(|(value, target)| { + (value, helper.llbb_with_cleanup(self, target), self.cold_blocks[target]) + }), + ); + } } } diff --git a/compiler/rustc_codegen_ssa/src/mir/mod.rs b/compiler/rustc_codegen_ssa/src/mir/mod.rs index 0cbc5c45736e8..6637aa1fa6105 100644 --- a/compiler/rustc_codegen_ssa/src/mir/mod.rs +++ b/compiler/rustc_codegen_ssa/src/mir/mod.rs @@ -496,14 +496,25 @@ fn find_cold_blocks<'tcx>( for (bb, bb_data) in traversal::postorder(mir) { let terminator = bb_data.terminator(); - // If a BB ends with a call to a cold function, mark it as cold. - if let mir::TerminatorKind::Call { ref func, .. } = terminator.kind - && let ty::FnDef(def_id, ..) = *func.ty(local_decls, tcx).kind() - && let attrs = tcx.codegen_fn_attrs(def_id) - && attrs.flags.contains(CodegenFnAttrFlags::COLD) - { - cold_blocks[bb] = true; - continue; + match terminator.kind { + // If a BB ends with a call to a cold function, mark it as cold. + mir::TerminatorKind::Call { ref func, .. } + | mir::TerminatorKind::TailCall { ref func, .. } + if let ty::FnDef(def_id, ..) = *func.ty(local_decls, tcx).kind() + && let attrs = tcx.codegen_fn_attrs(def_id) + && attrs.flags.contains(CodegenFnAttrFlags::COLD) => + { + cold_blocks[bb] = true; + continue; + } + + // If a BB ends with an `unreachable`, also mark it as cold. + mir::TerminatorKind::Unreachable => { + cold_blocks[bb] = true; + continue; + } + + _ => {} } // If all successors of a BB are cold and there's at least one of them, mark this BB as cold diff --git a/compiler/rustc_codegen_ssa/src/traits/builder.rs b/compiler/rustc_codegen_ssa/src/traits/builder.rs index b0138ac8bfed6..243ec7ac20486 100644 --- a/compiler/rustc_codegen_ssa/src/traits/builder.rs +++ b/compiler/rustc_codegen_ssa/src/traits/builder.rs @@ -110,6 +110,20 @@ pub trait BuilderMethods<'a, 'tcx>: else_llbb: Self::BasicBlock, cases: impl ExactSizeIterator, ); + + // This is like `switch()`, but every case has a bool flag indicating whether it's cold. + // + // Default implementation throws away the cold flags and calls `switch()`. + fn switch_with_weights( + &mut self, + v: Self::Value, + else_llbb: Self::BasicBlock, + _else_is_cold: bool, + cases: impl ExactSizeIterator, + ) { + self.switch(v, else_llbb, cases.map(|(val, bb, _)| (val, bb))) + } + fn invoke( &mut self, llty: Self::Type, diff --git a/tests/codegen/intrinsics/cold_path2.rs b/tests/codegen/intrinsics/cold_path2.rs new file mode 100644 index 0000000000000..a535f2c99ca1a --- /dev/null +++ b/tests/codegen/intrinsics/cold_path2.rs @@ -0,0 +1,35 @@ +//@ compile-flags: -O +#![crate_type = "lib"] +#![feature(core_intrinsics)] + +use std::intrinsics::cold_path; + +#[inline(never)] +#[no_mangle] +pub fn path_a() { + println!("path a"); +} + +#[inline(never)] +#[no_mangle] +pub fn path_b() { + println!("path b"); +} + +#[no_mangle] +pub fn test(x: Option) { + if let Some(_) = x { + path_a(); + } else { + cold_path(); + path_b(); + } +} + +// CHECK-LABEL: @test( +// CHECK: br i1 %1, label %bb2, label %bb1, !prof ![[NUM:[0-9]+]] +// CHECK: bb1: +// CHECK: path_a +// CHECK: bb2: +// CHECK: path_b +// CHECK: ![[NUM]] = !{!"branch_weights", {{(!"expected", )?}}i32 1, i32 2000} diff --git a/tests/codegen/intrinsics/cold_path3.rs b/tests/codegen/intrinsics/cold_path3.rs new file mode 100644 index 0000000000000..1c90e77f4f479 --- /dev/null +++ b/tests/codegen/intrinsics/cold_path3.rs @@ -0,0 +1,87 @@ +//@ compile-flags: -O +#![crate_type = "lib"] +#![feature(core_intrinsics)] + +use std::intrinsics::cold_path; + +#[inline(never)] +#[no_mangle] +pub fn path_a() { + println!("path a"); +} + +#[inline(never)] +#[no_mangle] +pub fn path_b() { + println!("path b"); +} + +#[inline(never)] +#[no_mangle] +pub fn path_c() { + println!("path c"); +} + +#[inline(never)] +#[no_mangle] +pub fn path_d() { + println!("path d"); +} + +#[no_mangle] +pub fn test(x: Option) { + match x { + Some(0) => path_a(), + Some(1) => { + cold_path(); + path_b() + } + Some(2) => path_c(), + Some(3) => { + cold_path(); + path_d() + } + _ => path_a(), + } +} + +#[no_mangle] +pub fn test2(x: Option) { + match x { + Some(10) => path_a(), + Some(11) => { + cold_path(); + path_b() + } + Some(12) => { + unsafe { core::intrinsics::unreachable() }; + path_c() + } + Some(13) => { + cold_path(); + path_d() + } + _ => { + cold_path(); + path_a() + } + } +} + +// CHECK-LABEL: @test( +// CHECK: switch i32 %1, label %bb1 [ +// CHECK: i32 0, label %bb6 +// CHECK: i32 1, label %bb5 +// CHECK: i32 2, label %bb4 +// CHECK: i32 3, label %bb3 +// CHECK: ], !prof !3 + +// CHECK-LABEL: @test2( +// CHECK: switch i32 %1, label %bb1 [ +// CHECK: i32 10, label %bb5 +// CHECK: i32 11, label %bb4 +// CHECK: i32 13, label %bb3 +// CHECK: ], !prof !5 + +// CHECK: !3 = !{!"branch_weights", i32 2000, i32 2000, i32 1, i32 2000, i32 1} +// CHECK: !5 = !{!"branch_weights", i32 1, i32 2000, i32 1, i32 1}