From 248d5b935d99c4a39cfee1f7669ec8308fda2234 Mon Sep 17 00:00:00 2001 From: lcnr Date: Tue, 21 May 2024 20:09:37 +0000 Subject: [PATCH] alias-relate: add fast reject optimization --- .../src/solve/alias_relate.rs | 113 +++++++++++++++++- 1 file changed, 112 insertions(+), 1 deletion(-) diff --git a/compiler/rustc_trait_selection/src/solve/alias_relate.rs b/compiler/rustc_trait_selection/src/solve/alias_relate.rs index 33b30bef68328..076fb95ecde8d 100644 --- a/compiler/rustc_trait_selection/src/solve/alias_relate.rs +++ b/compiler/rustc_trait_selection/src/solve/alias_relate.rs @@ -16,9 +16,12 @@ //! relate them structurally. use super::EvalCtxt; +use rustc_data_structures::fx::FxHashSet; use rustc_infer::infer::InferCtxt; +use rustc_middle::traits::query::NoSolution; use rustc_middle::traits::solve::{Certainty, Goal, QueryResult}; -use rustc_middle::ty; +use rustc_middle::ty::{self, Ty, TyCtxt}; +use rustc_middle::ty::{TypeSuperVisitable, TypeVisitable, TypeVisitableExt, TypeVisitor}; impl<'tcx> EvalCtxt<'_, InferCtxt<'tcx>> { #[instrument(level = "trace", skip(self), ret)] @@ -30,6 +33,12 @@ impl<'tcx> EvalCtxt<'_, InferCtxt<'tcx>> { let Goal { param_env, predicate: (lhs, rhs, direction) } = goal; debug_assert!(lhs.to_alias_term().is_some() || rhs.to_alias_term().is_some()); + if self.fast_reject_unnameable_rigid_term(param_env, lhs, rhs) + || self.fast_reject_unnameable_rigid_term(param_env, rhs, lhs) + { + return Err(NoSolution); + } + // Structurally normalize the lhs. let lhs = if let Some(alias) = lhs.to_alias_term() { let term = self.next_term_infer_of_kind(lhs); @@ -85,3 +94,105 @@ impl<'tcx> EvalCtxt<'_, InferCtxt<'tcx>> { } } } + +enum IgnoreAliases { + Yes, + No, +} + +impl<'tcx> EvalCtxt<'_, InferCtxt<'tcx>> { + /// In case a rigid term refers to a placeholder which is not referenced by the + /// alias, the alias cannot be normalized to that rigid term unless it contains + /// either inference variables or these placeholders are referenced in a term + /// of a `Projection`-clause in the environment. + fn fast_reject_unnameable_rigid_term( + &mut self, + param_env: ty::ParamEnv<'tcx>, + rigid_term: ty::Term<'tcx>, + alias: ty::Term<'tcx>, + ) -> bool { + // Check that the rigid term is actually rigid. + if rigid_term.to_alias_term().is_some() || alias.to_alias_term().is_none() { + return false; + } + + // If the alias has any type or const inference variables, + // do not try to apply the fast path as these inference variables + // may resolve to something containing placeholders. + if alias.has_non_region_infer() { + return false; + } + + let mut referenced_placeholders = + self.collect_placeholders_in_term(rigid_term, IgnoreAliases::Yes); + for clause in param_env.caller_bounds() { + match clause.kind().skip_binder() { + ty::ClauseKind::Projection(ty::ProjectionPredicate { term, .. }) => { + if term.has_non_region_infer() { + return false; + } + + let env_term_placeholders = + self.collect_placeholders_in_term(term, IgnoreAliases::No); + #[allow(rustc::potential_query_instability)] + referenced_placeholders.retain(|p| !env_term_placeholders.contains(p)); + } + ty::ClauseKind::Trait(_) + | ty::ClauseKind::TypeOutlives(_) + | ty::ClauseKind::RegionOutlives(_) + | ty::ClauseKind::ConstArgHasType(..) + | ty::ClauseKind::WellFormed(_) + | ty::ClauseKind::ConstEvaluatable(_) => continue, + } + } + + if referenced_placeholders.is_empty() { + return false; + } + + let alias_placeholders = self.collect_placeholders_in_term(alias, IgnoreAliases::No); + // If the rigid term references a placeholder not mentioned by the alias, + // they can never unify. + !referenced_placeholders.is_subset(&alias_placeholders) + } + + fn collect_placeholders_in_term( + &mut self, + term: ty::Term<'tcx>, + ignore_aliases: IgnoreAliases, + ) -> FxHashSet> { + // Fast path to avoid walking the term. + if !term.has_placeholders() { + return Default::default(); + } + + struct PlaceholderCollector<'tcx> { + ignore_aliases: IgnoreAliases, + placeholders: FxHashSet>, + } + impl<'tcx> TypeVisitor> for PlaceholderCollector<'tcx> { + type Result = (); + + fn visit_ty(&mut self, t: Ty<'tcx>) { + match t.kind() { + ty::Placeholder(_) => drop(self.placeholders.insert(t.into())), + ty::Alias(..) if matches!(self.ignore_aliases, IgnoreAliases::Yes) => {} + _ => t.super_visit_with(self), + } + } + + fn visit_const(&mut self, ct: ty::Const<'tcx>) { + match ct.kind() { + ty::ConstKind::Placeholder(_) => drop(self.placeholders.insert(ct.into())), + ty::ConstKind::Unevaluated(_) | ty::ConstKind::Expr(_) + if matches!(self.ignore_aliases, IgnoreAliases::Yes) => {} + _ => ct.super_visit_with(self), + } + } + } + + let mut visitor = PlaceholderCollector { ignore_aliases, placeholders: Default::default() }; + term.visit_with(&mut visitor); + visitor.placeholders + } +}