Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve some deriving code and add a test #98376

Merged
merged 5 commits into from
Jun 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 21 additions & 27 deletions compiler/rustc_builtin_macros/src/deriving/clone.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ use crate::deriving::generic::*;
use crate::deriving::path_std;

use rustc_ast::ptr::P;
use rustc_ast::{self as ast, Expr, GenericArg, Generics, ItemKind, MetaItem, VariantData};
use rustc_ast::{self as ast, Expr, Generics, ItemKind, MetaItem, VariantData};
use rustc_expand::base::{Annotatable, ExtCtxt};
use rustc_span::symbol::{kw, sym, Ident, Symbol};
use rustc_span::symbol::{kw, sym, Ident};
use rustc_span::Span;

pub fn expand_deriving_clone(
Expand Down Expand Up @@ -107,44 +107,38 @@ fn cs_clone_shallow(
substr: &Substructure<'_>,
is_union: bool,
) -> P<Expr> {
fn assert_ty_bounds(
cx: &mut ExtCtxt<'_>,
stmts: &mut Vec<ast::Stmt>,
ty: P<ast::Ty>,
span: Span,
helper_name: &str,
) {
// Generate statement `let _: helper_name<ty>;`,
// set the expn ID so we can use the unstable struct.
let span = cx.with_def_site_ctxt(span);
let assert_path = cx.path_all(
span,
true,
cx.std_path(&[sym::clone, Symbol::intern(helper_name)]),
vec![GenericArg::Type(ty)],
);
stmts.push(cx.stmt_let_type_only(span, cx.ty_path(assert_path)));
}
fn process_variant(cx: &mut ExtCtxt<'_>, stmts: &mut Vec<ast::Stmt>, variant: &VariantData) {
let mut stmts = Vec::new();
let mut process_variant = |variant: &VariantData| {
for field in variant.fields() {
// let _: AssertParamIsClone<FieldTy>;
assert_ty_bounds(cx, stmts, field.ty.clone(), field.span, "AssertParamIsClone");
super::assert_ty_bounds(
cx,
&mut stmts,
field.ty.clone(),
field.span,
&[sym::clone, sym::AssertParamIsClone],
);
}
}
};

let mut stmts = Vec::new();
if is_union {
// let _: AssertParamIsCopy<Self>;
let self_ty = cx.ty_path(cx.path_ident(trait_span, Ident::with_dummy_span(kw::SelfUpper)));
assert_ty_bounds(cx, &mut stmts, self_ty, trait_span, "AssertParamIsCopy");
super::assert_ty_bounds(
cx,
&mut stmts,
self_ty,
trait_span,
&[sym::clone, sym::AssertParamIsCopy],
);
} else {
match *substr.fields {
StaticStruct(vdata, ..) => {
process_variant(cx, &mut stmts, vdata);
process_variant(vdata);
}
StaticEnum(enum_def, ..) => {
for variant in &enum_def.variants {
process_variant(cx, &mut stmts, &variant.data);
process_variant(&variant.data);
}
}
_ => cx.span_bug(
Expand Down
44 changes: 14 additions & 30 deletions compiler/rustc_builtin_macros/src/deriving/cmp/eq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ use crate::deriving::generic::*;
use crate::deriving::path_std;

use rustc_ast::ptr::P;
use rustc_ast::{self as ast, Expr, GenericArg, MetaItem};
use rustc_ast::{self as ast, Expr, MetaItem};
use rustc_expand::base::{Annotatable, ExtCtxt};
use rustc_span::symbol::{sym, Ident, Symbol};
use rustc_span::symbol::{sym, Ident};
use rustc_span::Span;

pub fn expand_deriving_eq(
Expand Down Expand Up @@ -55,43 +55,27 @@ fn cs_total_eq_assert(
trait_span: Span,
substr: &Substructure<'_>,
) -> P<Expr> {
fn assert_ty_bounds(
cx: &mut ExtCtxt<'_>,
stmts: &mut Vec<ast::Stmt>,
ty: P<ast::Ty>,
span: Span,
helper_name: &str,
) {
// Generate statement `let _: helper_name<ty>;`,
// set the expn ID so we can use the unstable struct.
let span = cx.with_def_site_ctxt(span);
let assert_path = cx.path_all(
span,
true,
cx.std_path(&[sym::cmp, Symbol::intern(helper_name)]),
vec![GenericArg::Type(ty)],
);
stmts.push(cx.stmt_let_type_only(span, cx.ty_path(assert_path)));
}
fn process_variant(
cx: &mut ExtCtxt<'_>,
stmts: &mut Vec<ast::Stmt>,
variant: &ast::VariantData,
) {
let mut stmts = Vec::new();
let mut process_variant = |variant: &ast::VariantData| {
for field in variant.fields() {
// let _: AssertParamIsEq<FieldTy>;
assert_ty_bounds(cx, stmts, field.ty.clone(), field.span, "AssertParamIsEq");
super::assert_ty_bounds(
cx,
&mut stmts,
field.ty.clone(),
field.span,
&[sym::cmp, sym::AssertParamIsEq],
);
}
}
};

let mut stmts = Vec::new();
match *substr.fields {
StaticStruct(vdata, ..) => {
process_variant(cx, &mut stmts, vdata);
process_variant(vdata);
}
StaticEnum(enum_def, ..) => {
for variant in &enum_def.variants {
process_variant(cx, &mut stmts, &variant.data);
process_variant(&variant.data);
}
}
_ => cx.span_bug(trait_span, "unexpected substructure in `derive(Eq)`"),
Expand Down
109 changes: 36 additions & 73 deletions compiler/rustc_builtin_macros/src/deriving/generic/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1126,75 +1126,43 @@ impl<'a> MethodDef<'a> {
/// A1,
/// A2(i32)
/// }
///
/// // is equivalent to
///
/// impl PartialEq for A {
/// ```
/// is equivalent to:
/// ```
/// impl ::core::cmp::PartialEq for A {
/// #[inline]
/// fn eq(&self, other: &A) -> bool {
/// use A::*;
/// match (&*self, &*other) {
/// (&A1, &A1) => true,
/// (&A2(ref self_0),
/// &A2(ref __arg_1_0)) => (*self_0).eq(&(*__arg_1_0)),
/// _ => {
/// let __self_vi = match *self { A1 => 0, A2(..) => 1 };
/// let __arg_1_vi = match *other { A1 => 0, A2(..) => 1 };
/// false
/// {
/// let __self_vi = ::core::intrinsics::discriminant_value(&*self);
/// let __arg_1_vi = ::core::intrinsics::discriminant_value(&*other);
/// if true && __self_vi == __arg_1_vi {
/// match (&*self, &*other) {
/// (&A::A2(ref __self_0), &A::A2(ref __arg_1_0)) =>
/// (*__self_0) == (*__arg_1_0),
/// _ => true,
/// }
/// } else {
/// false // catch-all handler
/// }
/// }
/// }
/// }
/// ```
///
/// (Of course `__self_vi` and `__arg_1_vi` are unused for
/// `PartialEq`, and those subcomputations will hopefully be removed
/// as their results are unused. The point of `__self_vi` and
/// `__arg_1_vi` is for `PartialOrd`; see #15503.)
fn expand_enum_method_body<'b>(
&self,
cx: &mut ExtCtxt<'_>,
trait_: &TraitDef<'b>,
enum_def: &'b EnumDef,
type_ident: Ident,
self_args: Vec<P<Expr>>,
nonself_args: &[P<Expr>],
) -> P<Expr> {
self.build_enum_match_tuple(cx, trait_, enum_def, type_ident, self_args, nonself_args)
}

/// Creates a match for a tuple of all `self_args`, where either all
/// variants match, or it falls into a catch-all for when one variant
/// does not match.

///
/// There are N + 1 cases because is a case for each of the N
/// variants where all of the variants match, and one catch-all for
/// when one does not match.

///
/// As an optimization we generate code which checks whether all variants
/// match first which makes llvm see that C-like enums can be compiled into
/// a simple equality check (for PartialEq).

///
/// The catch-all handler is provided access the variant index values
/// for each of the self-args, carried in precomputed variables.

/// ```{.text}
/// let __self0_vi = std::intrinsics::discriminant_value(&self);
/// let __self1_vi = std::intrinsics::discriminant_value(&arg1);
/// let __self2_vi = std::intrinsics::discriminant_value(&arg2);
///
/// if __self0_vi == __self1_vi && __self0_vi == __self2_vi && ... {
/// match (...) {
/// (Variant1, Variant1, ...) => Body1
/// (Variant2, Variant2, ...) => Body2,
/// ...
/// _ => ::core::intrinsics::unreachable()
/// }
/// }
/// else {
/// ... // catch-all remainder can inspect above variant index values.
/// }
/// ```
fn build_enum_match_tuple<'b>(
fn expand_enum_method_body<'b>(
&self,
cx: &mut ExtCtxt<'_>,
trait_: &TraitDef<'b>,
Expand Down Expand Up @@ -1392,37 +1360,32 @@ impl<'a> MethodDef<'a> {
//
// i.e., for `enum E<T> { A, B(1), C(T, T) }`, and a deriving
// with three Self args, builds three statements:
//
// ```
// let __self0_vi = std::intrinsics::discriminant_value(&self);
// let __self1_vi = std::intrinsics::discriminant_value(&arg1);
// let __self2_vi = std::intrinsics::discriminant_value(&arg2);
// let __self_vi = std::intrinsics::discriminant_value(&self);
// let __arg_1_vi = std::intrinsics::discriminant_value(&arg1);
// let __arg_2_vi = std::intrinsics::discriminant_value(&arg2);
nnethercote marked this conversation as resolved.
Show resolved Hide resolved
// ```
let mut index_let_stmts: Vec<ast::Stmt> = Vec::with_capacity(vi_idents.len() + 1);

// We also build an expression which checks whether all discriminants are equal
// discriminant_test = __self0_vi == __self1_vi && __self0_vi == __self2_vi && ...
// We also build an expression which checks whether all discriminants are equal:
// `__self_vi == __arg_1_vi && __self_vi == __arg_2_vi && ...`
let mut discriminant_test = cx.expr_bool(span, true);

let mut first_ident = None;
for (&ident, self_arg) in iter::zip(&vi_idents, &self_args) {
for (i, (&ident, self_arg)) in iter::zip(&vi_idents, &self_args).enumerate() {
let self_addr = cx.expr_addr_of(span, self_arg.clone());
let variant_value =
deriving::call_intrinsic(cx, span, sym::discriminant_value, vec![self_addr]);
let let_stmt = cx.stmt_let(span, false, ident, variant_value);
index_let_stmts.push(let_stmt);

match first_ident {
Some(first) => {
let first_expr = cx.expr_ident(span, first);
let id = cx.expr_ident(span, ident);
let test = cx.expr_binary(span, BinOpKind::Eq, first_expr, id);
discriminant_test =
cx.expr_binary(span, BinOpKind::And, discriminant_test, test)
}
None => {
first_ident = Some(ident);
}
if i > 0 {
let id0 = cx.expr_ident(span, vi_idents[0]);
let id = cx.expr_ident(span, ident);
let test = cx.expr_binary(span, BinOpKind::Eq, id0, id);
discriminant_test = if i == 1 {
test
} else {
cx.expr_binary(span, BinOpKind::And, discriminant_test, test)
};
}
}

Expand Down Expand Up @@ -1453,7 +1416,7 @@ impl<'a> MethodDef<'a> {
// }
// }
// else {
// <delegated expression referring to __self0_vi, et al.>
// <delegated expression referring to __self_vi, et al.>
// }
let all_match = cx.expr_match(span, match_arg, match_arms);
let arm_expr = cx.expr_if(span, discriminant_test, all_match, Some(arm_expr));
Expand Down
15 changes: 14 additions & 1 deletion compiler/rustc_builtin_macros/src/deriving/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

use rustc_ast as ast;
use rustc_ast::ptr::P;
use rustc_ast::{Impl, ItemKind, MetaItem};
use rustc_ast::{GenericArg, Impl, ItemKind, MetaItem};
use rustc_expand::base::{Annotatable, ExpandResult, ExtCtxt, MultiItemModifier};
use rustc_span::symbol::{sym, Ident, Symbol};
use rustc_span::Span;
Expand Down Expand Up @@ -193,3 +193,16 @@ fn inject_impl_of_structural_trait(

push(Annotatable::Item(newitem));
}

fn assert_ty_bounds(
cx: &mut ExtCtxt<'_>,
stmts: &mut Vec<ast::Stmt>,
ty: P<ast::Ty>,
span: Span,
assert_path: &[Symbol],
) {
// Generate statement `let _: assert_path<ty>;`.
let span = cx.with_def_site_ctxt(span);
let assert_path = cx.path_all(span, true, cx.std_path(assert_path), vec![GenericArg::Type(ty)]);
stmts.push(cx.stmt_let_type_only(span, cx.ty_path(assert_path)));
}
3 changes: 3 additions & 0 deletions compiler/rustc_span/src/symbol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@ symbols! {
Arguments,
AsMut,
AsRef,
AssertParamIsClone,
AssertParamIsCopy,
AssertParamIsEq,
AtomicBool,
AtomicI128,
AtomicI16,
Expand Down
Loading