1
Fork 0

Safe Transmute: Change Answer type to Result

This patch updates the `Answer` type from `rustc_transmute` so that it just a
type alias to `Result`. This makes it so that the standard methods for `Result`
can be used to process the `Answer` tree, including being able to make use of
the `?` operator on `Answer`s.

Also, remove some unused functions
This commit is contained in:
Bryan Garza 2023-04-27 14:38:32 -07:00
parent 8f1cec8d84
commit 263a4f2cb6
6 changed files with 113 additions and 140 deletions

View file

@ -675,11 +675,11 @@ impl<'tcx> EvalCtxt<'_, 'tcx> {
scope, scope,
assume, assume,
) { ) {
rustc_transmute::Answer::Yes => Ok(Certainty::Yes), Ok(None) => Ok(Certainty::Yes),
rustc_transmute::Answer::No(_) Err(_)
| rustc_transmute::Answer::IfTransmutable { .. } | Ok(Some(rustc_transmute::Condition::IfTransmutable { .. }))
| rustc_transmute::Answer::IfAll(_) | Ok(Some(rustc_transmute::Condition::IfAll(_)))
| rustc_transmute::Answer::IfAny(_) => Err(NoSolution), | Ok(Some(rustc_transmute::Condition::IfAny(_))) => Err(NoSolution),
} }
} }
} }

View file

@ -2751,13 +2751,14 @@ impl<'tcx> InferCtxtPrivExt<'tcx> for TypeErrCtxt<'_, 'tcx> {
rustc_transmute::Assume::from_const(self.infcx.tcx, obligation.param_env, trait_ref.substs.const_at(3)) else { rustc_transmute::Assume::from_const(self.infcx.tcx, obligation.param_env, trait_ref.substs.const_at(3)) else {
span_bug!(span, "Unable to construct rustc_transmute::Assume where it was previously possible"); span_bug!(span, "Unable to construct rustc_transmute::Assume where it was previously possible");
}; };
// FIXME(bryangarza): Need to flatten here too
match rustc_transmute::TransmuteTypeEnv::new(self.infcx).is_transmutable( match rustc_transmute::TransmuteTypeEnv::new(self.infcx).is_transmutable(
obligation.cause, obligation.cause,
src_and_dst, src_and_dst,
scope, scope,
assume, assume,
) { ) {
rustc_transmute::Answer::No(reason) => { Err(reason) => {
let dst = trait_ref.substs.type_at(0); let dst = trait_ref.substs.type_at(0);
let src = trait_ref.substs.type_at(1); let src = trait_ref.substs.type_at(1);
let custom_err_msg = format!( let custom_err_msg = format!(
@ -2795,7 +2796,7 @@ impl<'tcx> InferCtxtPrivExt<'tcx> for TypeErrCtxt<'_, 'tcx> {
(custom_err_msg, Some(reason_msg)) (custom_err_msg, Some(reason_msg))
} }
// Should never get a Yes at this point! We already ran it before, and did not get a Yes. // Should never get a Yes at this point! We already ran it before, and did not get a Yes.
rustc_transmute::Answer::Yes => span_bug!( Ok(None) => span_bug!(
span, span,
"Inconsistent rustc_transmute::is_transmutable(...) result, got Yes", "Inconsistent rustc_transmute::is_transmutable(...) result, got Yes",
), ),

View file

@ -279,11 +279,12 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
ImplSourceBuiltinData { nested: obligations } ImplSourceBuiltinData { nested: obligations }
} }
#[instrument(skip(self))] #[instrument(level = "debug", skip(self))]
fn confirm_transmutability_candidate( fn confirm_transmutability_candidate(
&mut self, &mut self,
obligation: &TraitObligation<'tcx>, obligation: &TraitObligation<'tcx>,
) -> Result<ImplSourceBuiltinData<PredicateObligation<'tcx>>, SelectionError<'tcx>> { ) -> Result<ImplSourceBuiltinData<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
#[instrument(level = "debug", skip(tcx, obligation, predicate))]
fn flatten_answer_tree<'tcx>( fn flatten_answer_tree<'tcx>(
tcx: TyCtxt<'tcx>, tcx: TyCtxt<'tcx>,
obligation: &TraitObligation<'tcx>, obligation: &TraitObligation<'tcx>,
@ -291,11 +292,11 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
answer: rustc_transmute::Answer<rustc_transmute::layout::rustc::Ref<'tcx>>, answer: rustc_transmute::Answer<rustc_transmute::layout::rustc::Ref<'tcx>>,
) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> { ) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
match answer { match answer {
rustc_transmute::Answer::Yes => Ok(vec![]), Ok(None) => Ok(vec![]),
rustc_transmute::Answer::No(_) => Err(Unimplemented), Err(_) => Err(Unimplemented),
// FIXME(bryangarza): Add separate `IfAny` case, instead of treating as `IfAll` // FIXME(bryangarza): Add separate `IfAny` case, instead of treating as `IfAll`
rustc_transmute::Answer::IfAll(answers) Ok(Some(rustc_transmute::Condition::IfAll(answers)))
| rustc_transmute::Answer::IfAny(answers) => { | Ok(Some(rustc_transmute::Condition::IfAny(answers))) => {
let mut nested = vec![]; let mut nested = vec![];
for flattened in answers for flattened in answers
.into_iter() .into_iter()
@ -305,7 +306,7 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
} }
Ok(nested) Ok(nested)
} }
rustc_transmute::Answer::IfTransmutable { src, dst } => { Ok(Some(rustc_transmute::Condition::IfTransmutable { src, dst })) => {
let trait_def_id = obligation.predicate.def_id(); let trait_def_id = obligation.predicate.def_id();
let scope = predicate.trait_ref.substs.type_at(2); let scope = predicate.trait_ref.substs.type_at(2);
let assume_const = predicate.trait_ref.substs.const_at(3); let assume_const = predicate.trait_ref.substs.const_at(3);
@ -334,8 +335,6 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
} }
} }
debug!(?obligation, "confirm_transmutability_candidate");
// We erase regions here because transmutability calls layout queries, // We erase regions here because transmutability calls layout queries,
// which does not handle inference regions and doesn't particularly // which does not handle inference regions and doesn't particularly
// care about other regions. Erasing late-bound regions is equivalent // care about other regions. Erasing late-bound regions is equivalent
@ -352,21 +351,21 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
return Err(Unimplemented); return Err(Unimplemented);
}; };
let dst = predicate.trait_ref.substs.type_at(0);
let src = predicate.trait_ref.substs.type_at(1);
let mut transmute_env = rustc_transmute::TransmuteTypeEnv::new(self.infcx); let mut transmute_env = rustc_transmute::TransmuteTypeEnv::new(self.infcx);
let maybe_transmutable = transmute_env.is_transmutable( let maybe_transmutable = transmute_env.is_transmutable(
obligation.cause.clone(), obligation.cause.clone(),
rustc_transmute::Types { rustc_transmute::Types { dst, src },
dst: predicate.trait_ref.substs.type_at(0),
src: predicate.trait_ref.substs.type_at(1),
},
predicate.trait_ref.substs.type_at(2), predicate.trait_ref.substs.type_at(2),
assume, assume,
); );
info!(?maybe_transmutable); debug!(?src, ?dst);
let nested = flatten_answer_tree(self.tcx(), obligation, predicate, maybe_transmutable)?; let fully_flattened =
info!(?nested); flatten_answer_tree(self.tcx(), obligation, predicate, maybe_transmutable)?;
Ok(ImplSourceBuiltinData { nested }) debug!(?fully_flattened);
Ok(ImplSourceBuiltinData { nested: fully_flattened })
} }
/// This handles the case where an `auto trait Foo` impl is being used. /// This handles the case where an `auto trait Foo` impl is being used.

View file

@ -19,15 +19,12 @@ pub struct Assume {
pub validity: bool, pub validity: bool,
} }
/// The type encodes answers to the question: "Are these types transmutable?" /// Either we have an error, or we have an optional Condition that must hold.
#[derive(Debug, Hash, Eq, PartialEq, PartialOrd, Ord, Clone)] pub type Answer<R> = Result<Option<Condition<R>>, Reason>;
pub enum Answer<R> {
/// `Src` is transmutable into `Dst`.
Yes,
/// `Src` is NOT transmutable into `Dst`.
No(Reason),
/// A condition which must hold for safe transmutation to be possible
#[derive(Debug, Hash, Eq, PartialEq, Clone)]
pub enum Condition<R> {
/// `Src` is transmutable into `Dst`, if `src` is transmutable into `dst`. /// `Src` is transmutable into `Dst`, if `src` is transmutable into `dst`.
IfTransmutable { src: R, dst: R }, IfTransmutable { src: R, dst: R },

View file

@ -5,7 +5,7 @@ mod tests;
use crate::{ use crate::{
layout::{self, dfa, Byte, Dfa, Nfa, Ref, Tree, Uninhabited}, layout::{self, dfa, Byte, Dfa, Nfa, Ref, Tree, Uninhabited},
maybe_transmutable::query_context::QueryContext, maybe_transmutable::query_context::QueryContext,
Answer, Map, Reason, Answer, Condition, Map, Reason,
}; };
pub(crate) struct MaybeTransmutableQuery<L, C> pub(crate) struct MaybeTransmutableQuery<L, C>
@ -76,12 +76,12 @@ mod rustc {
let dst = Tree::from_ty(dst, context); let dst = Tree::from_ty(dst, context);
match (src, dst) { match (src, dst) {
// Answer `Yes` here, because 'unknown layout' and type errors will already // Answer `Ok(None)` here, because 'unknown layout' and type errors will already
// be reported by rustc. No need to spam the user with more errors. // be reported by rustc. No need to spam the user with more errors.
(Err(Err::TypeError(_)), _) | (_, Err(Err::TypeError(_))) => Err(Answer::Yes), (Err(Err::TypeError(_)), _) | (_, Err(Err::TypeError(_))) => Err(Ok(None)),
(Err(Err::Unknown), _) | (_, Err(Err::Unknown)) => Err(Answer::Yes), (Err(Err::Unknown), _) | (_, Err(Err::Unknown)) => Err(Ok(None)),
(Err(Err::Unspecified), _) | (_, Err(Err::Unspecified)) => { (Err(Err::Unspecified), _) | (_, Err(Err::Unspecified)) => {
Err(Answer::No(Reason::SrcIsUnspecified)) Err(Err(Reason::SrcIsUnspecified))
} }
(Ok(src), Ok(dst)) => Ok((src, dst)), (Ok(src), Ok(dst)) => Ok((src, dst)),
} }
@ -127,13 +127,12 @@ where
// Convert `src` from a tree-based representation to an NFA-based representation. // Convert `src` from a tree-based representation to an NFA-based representation.
// If the conversion fails because `src` is uninhabited, conclude that the transmutation // If the conversion fails because `src` is uninhabited, conclude that the transmutation
// is acceptable, because instances of the `src` type do not exist. // is acceptable, because instances of the `src` type do not exist.
let src = Nfa::from_tree(src).map_err(|Uninhabited| Answer::Yes)?; let src = Nfa::from_tree(src).map_err(|Uninhabited| Ok(None))?;
// Convert `dst` from a tree-based representation to an NFA-based representation. // Convert `dst` from a tree-based representation to an NFA-based representation.
// If the conversion fails because `src` is uninhabited, conclude that the transmutation // If the conversion fails because `src` is uninhabited, conclude that the transmutation
// is unacceptable, because instances of the `dst` type do not exist. // is unacceptable, because instances of the `dst` type do not exist.
let dst = let dst = Nfa::from_tree(dst).map_err(|Uninhabited| Err(Reason::DstIsPrivate))?;
Nfa::from_tree(dst).map_err(|Uninhabited| Answer::No(Reason::DstIsPrivate))?;
Ok((src, dst)) Ok((src, dst))
}); });
@ -205,13 +204,13 @@ where
} else { } else {
let answer = if dst_state == self.dst.accepting { let answer = if dst_state == self.dst.accepting {
// truncation: `size_of(Src) >= size_of(Dst)` // truncation: `size_of(Src) >= size_of(Dst)`
Answer::Yes Ok(None)
} else if src_state == self.src.accepting { } else if src_state == self.src.accepting {
// extension: `size_of(Src) >= size_of(Dst)` // extension: `size_of(Src) >= size_of(Dst)`
if let Some(dst_state_prime) = self.dst.byte_from(dst_state, Byte::Uninit) { if let Some(dst_state_prime) = self.dst.byte_from(dst_state, Byte::Uninit) {
self.answer_memo(cache, src_state, dst_state_prime) self.answer_memo(cache, src_state, dst_state_prime)
} else { } else {
Answer::No(Reason::DstIsTooBig) Err(Reason::DstIsTooBig)
} }
} else { } else {
let src_quantifier = if self.assume.validity { let src_quantifier = if self.assume.validity {
@ -244,7 +243,7 @@ where
} else { } else {
// otherwise, we've exhausted our options. // otherwise, we've exhausted our options.
// the DFAs, from this point onwards, are bit-incompatible. // the DFAs, from this point onwards, are bit-incompatible.
Answer::No(Reason::DstIsBitIncompatible) Err(Reason::DstIsBitIncompatible)
} }
}, },
), ),
@ -252,16 +251,16 @@ where
// The below early returns reflect how this code would behave: // The below early returns reflect how this code would behave:
// if self.assume.validity { // if self.assume.validity {
// bytes_answer.or(refs_answer) // or(bytes_answer, refs_answer)
// } else { // } else {
// bytes_answer.and(refs_answer) // and(bytes_answer, refs_answer)
// } // }
// ...if `refs_answer` was computed lazily. The below early // ...if `refs_answer` was computed lazily. The below early
// returns can be deleted without impacting the correctness of // returns can be deleted without impacting the correctness of
// the algoritm; only its performance. // the algoritm; only its performance.
match bytes_answer { match bytes_answer {
Answer::No(..) if !self.assume.validity => return bytes_answer, Err(_) if !self.assume.validity => return bytes_answer,
Answer::Yes if self.assume.validity => return bytes_answer, Ok(None) if self.assume.validity => return bytes_answer,
_ => {} _ => {}
}; };
@ -277,20 +276,25 @@ where
.into_iter() .into_iter()
.map(|(&dst_ref, &dst_state_prime)| { .map(|(&dst_ref, &dst_state_prime)| {
if !src_ref.is_mutable() && dst_ref.is_mutable() { if !src_ref.is_mutable() && dst_ref.is_mutable() {
Answer::No(Reason::DstIsMoreUnique) Err(Reason::DstIsMoreUnique)
} else if !self.assume.alignment } else if !self.assume.alignment
&& src_ref.min_align() < dst_ref.min_align() && src_ref.min_align() < dst_ref.min_align()
{ {
Answer::No(Reason::DstHasStricterAlignment) Err(Reason::DstHasStricterAlignment)
} else { } else {
// ...such that `src` is transmutable into `dst`, if // ...such that `src` is transmutable into `dst`, if
// `src_ref` is transmutability into `dst_ref`. // `src_ref` is transmutability into `dst_ref`.
Answer::IfTransmutable { src: src_ref, dst: dst_ref } and(
.and(self.answer_memo( Ok(Some(Condition::IfTransmutable {
src: src_ref,
dst: dst_ref,
})),
self.answer_memo(
cache, cache,
src_state_prime, src_state_prime,
dst_state_prime, dst_state_prime,
)) ),
)
} }
}), }),
) )
@ -299,9 +303,9 @@ where
); );
if self.assume.validity { if self.assume.validity {
bytes_answer.or(refs_answer) or(bytes_answer, refs_answer)
} else { } else {
bytes_answer.and(refs_answer) and(bytes_answer, refs_answer)
} }
}; };
if let Some(..) = cache.insert((src_state, dst_state), answer.clone()) { if let Some(..) = cache.insert((src_state, dst_state), answer.clone()) {
@ -312,81 +316,55 @@ where
} }
} }
impl<R> Answer<R> fn and<R>(lhs: Answer<R>, rhs: Answer<R>) -> Answer<R> {
where // Should propagate errors on the right side, because the initial value
R: layout::Ref, // used in `apply` is on the left side.
{ let rhs = rhs?;
pub(crate) fn and(self, rhs: Self) -> Self { let lhs = lhs?;
match (self, rhs) { Ok(match (lhs, rhs) {
(_, Self::No(reason)) | (Self::No(reason), _) => Self::No(reason), // If only one side has a condition, pass it along
(None, other) | (other, None) => other,
(Self::Yes, other) | (other, Self::Yes) => other, // If both sides have IfAll conditions, merge them
(Some(Condition::IfAll(mut lhs)), Some(Condition::IfAll(ref mut rhs))) => {
(Self::IfAll(mut lhs), Self::IfAll(ref mut rhs)) => { lhs.append(rhs);
lhs.append(rhs); Some(Condition::IfAll(lhs))
Self::IfAll(lhs)
}
(constraint, Self::IfAll(mut constraints))
| (Self::IfAll(mut constraints), constraint) => {
constraints.push(constraint);
Self::IfAll(constraints)
}
(lhs, rhs) => Self::IfAll(vec![lhs, rhs]),
} }
} // If only one side is an IfAll, add the other Condition to it
(constraint, Some(Condition::IfAll(mut constraints)))
pub(crate) fn or(self, rhs: Self) -> Self { | (Some(Condition::IfAll(mut constraints)), constraint) => {
match (self, rhs) { constraints.push(Ok(constraint));
(Self::Yes, _) | (_, Self::Yes) => Self::Yes, Some(Condition::IfAll(constraints))
(other, Self::No(reason)) | (Self::No(reason), other) => other,
(Self::IfAny(mut lhs), Self::IfAny(ref mut rhs)) => {
lhs.append(rhs);
Self::IfAny(lhs)
}
(constraint, Self::IfAny(mut constraints))
| (Self::IfAny(mut constraints), constraint) => {
constraints.push(constraint);
Self::IfAny(constraints)
}
(lhs, rhs) => Self::IfAny(vec![lhs, rhs]),
} }
} // Otherwise, both lhs and rhs conditions can be combined in a parent IfAll
(lhs, rhs) => Some(Condition::IfAll(vec![Ok(lhs), Ok(rhs)])),
})
} }
pub fn for_all<R, I, F>(iter: I, f: F) -> Answer<R> fn or<R>(lhs: Answer<R>, rhs: Answer<R>) -> Answer<R> {
where // If both are errors, then we should return the one on the right
R: layout::Ref, if lhs.is_err() && rhs.is_err() {
I: IntoIterator, return rhs;
F: FnMut(<I as IntoIterator>::Item) -> Answer<R>, }
{ // Otherwise, errors can be ignored for the rest of the pattern matching
use std::ops::ControlFlow::{Break, Continue}; let lhs = lhs.unwrap_or(None);
let (Continue(result) | Break(result)) = let rhs = rhs.unwrap_or(None);
iter.into_iter().map(f).try_fold(Answer::Yes, |constraints, constraint| { Ok(match (lhs, rhs) {
match constraint.and(constraints) { // If only one side has a condition, pass it along
Answer::No(reason) => Break(Answer::No(reason)), (None, other) | (other, None) => other,
maybe => Continue(maybe), // If both sides have IfAny conditions, merge them
} (Some(Condition::IfAny(mut lhs)), Some(Condition::IfAny(ref mut rhs))) => {
}); lhs.append(rhs);
result Some(Condition::IfAny(lhs))
} }
// If only one side is an IfAny, add the other Condition to it
pub fn there_exists<R, I, F>(iter: I, f: F) -> Answer<R> (constraint, Some(Condition::IfAny(mut constraints)))
where | (Some(Condition::IfAny(mut constraints)), constraint) => {
R: layout::Ref, constraints.push(Ok(constraint));
I: IntoIterator, Some(Condition::IfAny(constraints))
F: FnMut(<I as IntoIterator>::Item) -> Answer<R>, }
{ // Otherwise, both lhs and rhs conditions can be combined in a parent IfAny
use std::ops::ControlFlow::{Break, Continue}; (lhs, rhs) => Some(Condition::IfAny(vec![Ok(lhs), Ok(rhs)])),
let (Continue(result) | Break(result)) = iter.into_iter().map(f).try_fold( })
Answer::No(Reason::DstIsBitIncompatible),
|constraints, constraint| match constraint.or(constraints) {
Answer::Yes => Break(Answer::Yes),
maybe => Continue(maybe),
},
);
result
} }
pub enum Quantifier { pub enum Quantifier {
@ -403,16 +381,14 @@ impl Quantifier {
use std::ops::ControlFlow::{Break, Continue}; use std::ops::ControlFlow::{Break, Continue};
let (init, try_fold_f): (_, fn(_, _) -> _) = match self { let (init, try_fold_f): (_, fn(_, _) -> _) = match self {
Self::ThereExists => { Self::ThereExists => (Err(Reason::DstIsBitIncompatible), |accum: Answer<R>, next| {
(Answer::No(Reason::DstIsBitIncompatible), |accum: Answer<R>, next| { match or(accum, next) {
match accum.or(next) { Ok(None) => Break(Ok(None)),
Answer::Yes => Break(Answer::Yes), maybe => Continue(maybe),
maybe => Continue(maybe), }
} }),
}) Self::ForAll => (Ok(None), |accum: Answer<R>, next| match and(accum, next) {
} Err(reason) => Break(Err(reason)),
Self::ForAll => (Answer::Yes, |accum: Answer<R>, next| match accum.and(next) {
Answer::No(reason) => Break(Answer::No(reason)),
maybe => Continue(maybe), maybe => Continue(maybe),
}), }),
}; };

View file

@ -1,6 +1,6 @@
use super::query_context::test::{Def, UltraMinimal}; use super::query_context::test::{Def, UltraMinimal};
use crate::maybe_transmutable::MaybeTransmutableQuery; use crate::maybe_transmutable::MaybeTransmutableQuery;
use crate::{layout, Answer, Reason}; use crate::{layout, Reason};
use itertools::Itertools; use itertools::Itertools;
mod bool { mod bool {
@ -17,7 +17,7 @@ mod bool {
UltraMinimal, UltraMinimal,
) )
.answer(); .answer();
assert_eq!(answer, Answer::Yes); assert_eq!(answer, Ok(None));
} }
#[test] #[test]
@ -30,7 +30,7 @@ mod bool {
UltraMinimal, UltraMinimal,
) )
.answer(); .answer();
assert_eq!(answer, Answer::Yes); assert_eq!(answer, Ok(None));
} }
#[test] #[test]
@ -65,7 +65,7 @@ mod bool {
if src_set.is_subset(&dst_set) { if src_set.is_subset(&dst_set) {
assert_eq!( assert_eq!(
Answer::Yes, Ok(None),
MaybeTransmutableQuery::new( MaybeTransmutableQuery::new(
src_layout.clone(), src_layout.clone(),
dst_layout.clone(), dst_layout.clone(),
@ -80,7 +80,7 @@ mod bool {
); );
} else if !src_set.is_disjoint(&dst_set) { } else if !src_set.is_disjoint(&dst_set) {
assert_eq!( assert_eq!(
Answer::Yes, Ok(None),
MaybeTransmutableQuery::new( MaybeTransmutableQuery::new(
src_layout.clone(), src_layout.clone(),
dst_layout.clone(), dst_layout.clone(),
@ -95,7 +95,7 @@ mod bool {
); );
} else { } else {
assert_eq!( assert_eq!(
Answer::No(Reason::DstIsBitIncompatible), Err(Reason::DstIsBitIncompatible),
MaybeTransmutableQuery::new( MaybeTransmutableQuery::new(
src_layout.clone(), src_layout.clone(),
dst_layout.clone(), dst_layout.clone(),