From 36840272e3fde2b6b049f81e4a5c96a5582580f8 Mon Sep 17 00:00:00 2001 From: Trevor Gross Date: Thu, 13 Jun 2024 08:11:16 -0500 Subject: [PATCH] Enable const evaluation for `f16` and `f128` This excludes casting, which needs more tests. --- .../rustc_const_eval/src/interpret/cast.rs | 27 ++++++++++++------- .../src/interpret/operator.rs | 10 +++++-- .../rustc_middle/src/mir/interpret/value.rs | 14 ++++++++++ 3 files changed, 39 insertions(+), 12 deletions(-) diff --git a/compiler/rustc_const_eval/src/interpret/cast.rs b/compiler/rustc_const_eval/src/interpret/cast.rs index 6961e13c2399e..c83fe1413087f 100644 --- a/compiler/rustc_const_eval/src/interpret/cast.rs +++ b/compiler/rustc_const_eval/src/interpret/cast.rs @@ -1,6 +1,6 @@ use std::assert_matches::assert_matches; -use rustc_apfloat::ieee::{Double, Single}; +use rustc_apfloat::ieee::{Double, Half, Quad, Single}; use rustc_apfloat::{Float, FloatConvert}; use rustc_middle::mir::interpret::{InterpResult, PointerArithmetic, Scalar}; use rustc_middle::mir::CastKind; @@ -189,10 +189,10 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> { bug!("FloatToFloat/FloatToInt cast: source type {} is not a float type", src.layout.ty) }; let val = match fty { - FloatTy::F16 => unimplemented!("f16_f128"), + FloatTy::F16 => self.cast_from_float(src.to_scalar().to_f16()?, cast_to.ty), FloatTy::F32 => self.cast_from_float(src.to_scalar().to_f32()?, cast_to.ty), FloatTy::F64 => self.cast_from_float(src.to_scalar().to_f64()?, cast_to.ty), - FloatTy::F128 => unimplemented!("f16_f128"), + FloatTy::F128 => self.cast_from_float(src.to_scalar().to_f128()?, cast_to.ty), }; Ok(ImmTy::from_scalar(val, cast_to)) } @@ -298,18 +298,18 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> { Float(fty) if signed => { let v = v as i128; match fty { - FloatTy::F16 => unimplemented!("f16_f128"), + FloatTy::F16 => Scalar::from_f16(Half::from_i128(v).value), FloatTy::F32 => Scalar::from_f32(Single::from_i128(v).value), FloatTy::F64 => Scalar::from_f64(Double::from_i128(v).value), - FloatTy::F128 => unimplemented!("f16_f128"), + FloatTy::F128 => Scalar::from_f128(Quad::from_i128(v).value), } } // unsigned int -> float Float(fty) => match fty { - FloatTy::F16 => unimplemented!("f16_f128"), + FloatTy::F16 => Scalar::from_f16(Half::from_u128(v).value), FloatTy::F32 => Scalar::from_f32(Single::from_u128(v).value), FloatTy::F64 => Scalar::from_f64(Double::from_u128(v).value), - FloatTy::F128 => unimplemented!("f16_f128"), + FloatTy::F128 => Scalar::from_f128(Quad::from_u128(v).value), }, // u8 -> char @@ -323,7 +323,12 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> { /// Low-level cast helper function. Converts an apfloat `f` into int or float types. fn cast_from_float(&self, f: F, dest_ty: Ty<'tcx>) -> Scalar where - F: Float + Into> + FloatConvert + FloatConvert, + F: Float + + Into> + + FloatConvert + + FloatConvert + + FloatConvert + + FloatConvert, { use rustc_type_ir::TyKind::*; @@ -360,10 +365,12 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> { } // float -> float Float(fty) => match fty { - FloatTy::F16 => unimplemented!("f16_f128"), + FloatTy::F16 => Scalar::from_f16(adjust_nan(self, f, f.convert(&mut false).value)), FloatTy::F32 => Scalar::from_f32(adjust_nan(self, f, f.convert(&mut false).value)), FloatTy::F64 => Scalar::from_f64(adjust_nan(self, f, f.convert(&mut false).value)), - FloatTy::F128 => unimplemented!("f16_f128"), + FloatTy::F128 => { + Scalar::from_f128(adjust_nan(self, f, f.convert(&mut false).value)) + } }, // That's it. _ => span_bug!(self.cur_span(), "invalid float to {} cast", dest_ty), diff --git a/compiler/rustc_const_eval/src/interpret/operator.rs b/compiler/rustc_const_eval/src/interpret/operator.rs index a6eef9f5662ca..a6924371632b2 100644 --- a/compiler/rustc_const_eval/src/interpret/operator.rs +++ b/compiler/rustc_const_eval/src/interpret/operator.rs @@ -362,14 +362,18 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> { let left = left.to_scalar(); let right = right.to_scalar(); Ok(match fty { - FloatTy::F16 => unimplemented!("f16_f128"), + FloatTy::F16 => { + self.binary_float_op(bin_op, layout, left.to_f16()?, right.to_f16()?) + } FloatTy::F32 => { self.binary_float_op(bin_op, layout, left.to_f32()?, right.to_f32()?) } FloatTy::F64 => { self.binary_float_op(bin_op, layout, left.to_f64()?, right.to_f64()?) } - FloatTy::F128 => unimplemented!("f16_f128"), + FloatTy::F128 => { + self.binary_float_op(bin_op, layout, left.to_f128()?, right.to_f128()?) + } }) } _ if left.layout.ty.is_integral() => { @@ -431,8 +435,10 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> { let val = val.to_scalar(); // No NaN adjustment here, `-` is a bitwise operation! let res = match (un_op, fty) { + (Neg, FloatTy::F16) => Scalar::from_f16(-val.to_f16()?), (Neg, FloatTy::F32) => Scalar::from_f32(-val.to_f32()?), (Neg, FloatTy::F64) => Scalar::from_f64(-val.to_f64()?), + (Neg, FloatTy::F128) => Scalar::from_f128(-val.to_f128()?), _ => span_bug!(self.cur_span(), "Invalid float op {:?}", un_op), }; Ok(ImmTy::from_scalar(res, layout)) diff --git a/compiler/rustc_middle/src/mir/interpret/value.rs b/compiler/rustc_middle/src/mir/interpret/value.rs index 70e5ad0635ba1..a84a4c583edd2 100644 --- a/compiler/rustc_middle/src/mir/interpret/value.rs +++ b/compiler/rustc_middle/src/mir/interpret/value.rs @@ -69,6 +69,13 @@ impl fmt::LowerHex for Scalar { } } +impl From for Scalar { + #[inline(always)] + fn from(f: Half) -> Self { + Scalar::from_f16(f) + } +} + impl From for Scalar { #[inline(always)] fn from(f: Single) -> Self { @@ -83,6 +90,13 @@ impl From for Scalar { } } +impl From for Scalar { + #[inline(always)] + fn from(f: Quad) -> Self { + Scalar::from_f128(f) + } +} + impl From for Scalar { #[inline(always)] fn from(ptr: ScalarInt) -> Self {