diff --git a/compiler/rustc_middle/src/traits/mod.rs b/compiler/rustc_middle/src/traits/mod.rs index ffde1294ec6..e5d4dfe891e 100644 --- a/compiler/rustc_middle/src/traits/mod.rs +++ b/compiler/rustc_middle/src/traits/mod.rs @@ -143,10 +143,6 @@ impl<'tcx> ObligationCause<'tcx> { ObligationCause { span, body_id: hir::CRATE_HIR_ID, code: None } } - pub fn make_mut_code(&mut self) -> &mut ObligationCauseCode<'tcx> { - Lrc::make_mut(self.code.get_or_insert_with(|| Lrc::new(MISC_OBLIGATION_CAUSE_CODE))) - } - pub fn span(&self, tcx: TyCtxt<'tcx>) -> Span { match *self.code() { ObligationCauseCode::CompareImplMethodObligation { .. } @@ -173,6 +169,16 @@ impl<'tcx> ObligationCause<'tcx> { None => Lrc::new(MISC_OBLIGATION_CAUSE_CODE), } } + + pub fn map_code( + &mut self, + f: impl FnOnce(Lrc>) -> Lrc>, + ) { + self.code = Some(f(match self.code.take() { + Some(code) => code, + None => Lrc::new(MISC_OBLIGATION_CAUSE_CODE), + })); + } } #[derive(Clone, Debug, PartialEq, Eq, Hash, Lift)] diff --git a/compiler/rustc_trait_selection/src/traits/wf.rs b/compiler/rustc_trait_selection/src/traits/wf.rs index de0ade64247..1ee5b385f9a 100644 --- a/compiler/rustc_trait_selection/src/traits/wf.rs +++ b/compiler/rustc_trait_selection/src/traits/wf.rs @@ -294,30 +294,28 @@ impl<'a, 'tcx> WfPredicates<'a, 'tcx> { let obligations = self.nominal_obligations(trait_ref.def_id, trait_ref.substs); debug!("compute_trait_ref obligations {:?}", obligations); - let cause = self.cause(traits::MiscObligation); let param_env = self.param_env; let depth = self.recursion_depth; let item = self.item; - let extend = |obligation: traits::PredicateObligation<'tcx>| { - let mut cause = cause.clone(); - if let Some(parent_trait_pred) = obligation.predicate.to_opt_poly_trait_pred() { - let derived_cause = traits::DerivedObligationCause { - parent_trait_pred, - parent_code: obligation.cause.clone_code(), - }; - *cause.make_mut_code() = - traits::ObligationCauseCode::DerivedObligation(derived_cause); + let extend = |traits::PredicateObligation { predicate, mut cause, .. }| { + if let Some(parent_trait_pred) = predicate.to_opt_poly_trait_pred() { + cause.map_code(|parent_code| { + { + traits::ObligationCauseCode::DerivedObligation( + traits::DerivedObligationCause { parent_trait_pred, parent_code }, + ) + } + .into() + }); + } else { + cause = traits::ObligationCause::misc(self.span, self.body_id); } extend_cause_with_original_assoc_item_obligation( - tcx, - trait_ref, - item, - &mut cause, - obligation.predicate, + tcx, trait_ref, item, &mut cause, predicate, ); - traits::Obligation::with_depth(cause, depth, param_env, obligation.predicate) + traits::Obligation::with_depth(cause, depth, param_env, predicate) }; if let Elaborate::All = elaborate { @@ -339,17 +337,17 @@ impl<'a, 'tcx> WfPredicates<'a, 'tcx> { }) .filter(|(_, arg)| !arg.has_escaping_bound_vars()) .map(|(i, arg)| { - let mut new_cause = cause.clone(); + let mut cause = traits::ObligationCause::misc(self.span, self.body_id); // The first subst is the self ty - use the correct span for it. if i == 0 { if let Some(hir::ItemKind::Impl(hir::Impl { self_ty, .. })) = item.map(|i| &i.kind) { - new_cause.span = self_ty.span; + cause.span = self_ty.span; } } traits::Obligation::with_depth( - new_cause, + cause, depth, param_env, ty::Binder::dummy(ty::PredicateKind::WellFormed(arg)).to_predicate(tcx), diff --git a/compiler/rustc_typeck/src/check/fn_ctxt/checks.rs b/compiler/rustc_typeck/src/check/fn_ctxt/checks.rs index 277743e4a46..604228d57a3 100644 --- a/compiler/rustc_typeck/src/check/fn_ctxt/checks.rs +++ b/compiler/rustc_typeck/src/check/fn_ctxt/checks.rs @@ -1668,13 +1668,14 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { // We make sure that only *one* argument matches the obligation failure // and we assign the obligation's span to its expression's. error.obligation.cause.span = args[ref_in].span; - let parent_code = error.obligation.cause.clone_code(); - *error.obligation.cause.make_mut_code() = + error.obligation.cause.map_code(|parent_code| { ObligationCauseCode::FunctionArgumentObligation { arg_hir_id: args[ref_in].hir_id, call_hir_id: expr.hir_id, parent_code, - }; + } + .into() + }); } else if error.obligation.cause.span == call_sp { // Make function calls point at the callee, not the whole thing. if let hir::ExprKind::Call(callee, _) = expr.kind {