Fix pretty printing of unsafe binders

This commit is contained in:
Michael Goulet 2025-03-01 19:28:04 +00:00
parent f4a216d28e
commit 83fa2faf23
5 changed files with 130 additions and 41 deletions

View file

@ -133,6 +133,20 @@ pub macro with_no_queries($e:expr) {{
))
}}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum WrapBinderMode {
ForAll,
Unsafe,
}
impl WrapBinderMode {
pub fn start_str(self) -> &'static str {
match self {
WrapBinderMode::ForAll => "for<",
WrapBinderMode::Unsafe => "unsafe<",
}
}
}
/// The "region highlights" are used to control region printing during
/// specific error messages. When a "region highlight" is enabled, it
/// gives an alternate way to print specific regions. For now, we
@ -219,7 +233,11 @@ pub trait PrettyPrinter<'tcx>: Printer<'tcx> + fmt::Write {
self.print_def_path(def_id, args)
}
fn in_binder<T>(&mut self, value: &ty::Binder<'tcx, T>) -> Result<(), PrintError>
fn in_binder<T>(
&mut self,
value: &ty::Binder<'tcx, T>,
_mode: WrapBinderMode,
) -> Result<(), PrintError>
where
T: Print<'tcx, Self> + TypeFoldable<TyCtxt<'tcx>>,
{
@ -229,6 +247,7 @@ pub trait PrettyPrinter<'tcx>: Printer<'tcx> + fmt::Write {
fn wrap_binder<T, F: FnOnce(&T, &mut Self) -> Result<(), fmt::Error>>(
&mut self,
value: &ty::Binder<'tcx, T>,
_mode: WrapBinderMode,
f: F,
) -> Result<(), PrintError>
where
@ -703,8 +722,9 @@ pub trait PrettyPrinter<'tcx>: Printer<'tcx> + fmt::Write {
}
ty::FnPtr(ref sig_tys, hdr) => p!(print(sig_tys.with(hdr))),
ty::UnsafeBinder(ref bound_ty) => {
// FIXME(unsafe_binders): Make this print `unsafe<>` rather than `for<>`.
self.wrap_binder(bound_ty, |ty, cx| cx.pretty_print_type(*ty))?;
self.wrap_binder(bound_ty, WrapBinderMode::Unsafe, |ty, cx| {
cx.pretty_print_type(*ty)
})?;
}
ty::Infer(infer_ty) => {
if self.should_print_verbose() {
@ -1067,29 +1087,33 @@ pub trait PrettyPrinter<'tcx>: Printer<'tcx> + fmt::Write {
};
if let Some(return_ty) = entry.return_ty {
self.wrap_binder(&bound_args_and_self_ty, |(args, _), cx| {
define_scoped_cx!(cx);
p!(write("{}", tcx.item_name(trait_def_id)));
p!("(");
self.wrap_binder(
&bound_args_and_self_ty,
WrapBinderMode::ForAll,
|(args, _), cx| {
define_scoped_cx!(cx);
p!(write("{}", tcx.item_name(trait_def_id)));
p!("(");
for (idx, ty) in args.iter().enumerate() {
if idx > 0 {
p!(", ");
for (idx, ty) in args.iter().enumerate() {
if idx > 0 {
p!(", ");
}
p!(print(ty));
}
p!(print(ty));
}
p!(")");
if let Some(ty) = return_ty.skip_binder().as_type() {
if !ty.is_unit() {
p!(" -> ", print(return_ty));
p!(")");
if let Some(ty) = return_ty.skip_binder().as_type() {
if !ty.is_unit() {
p!(" -> ", print(return_ty));
}
}
}
p!(write("{}", if paren_needed { ")" } else { "" }));
p!(write("{}", if paren_needed { ")" } else { "" }));
first = false;
Ok(())
})?;
first = false;
Ok(())
},
)?;
} else {
// Otherwise, render this like a regular trait.
traits.insert(
@ -1110,7 +1134,7 @@ pub trait PrettyPrinter<'tcx>: Printer<'tcx> + fmt::Write {
for (trait_pred, assoc_items) in traits {
write!(self, "{}", if first { "" } else { " + " })?;
self.wrap_binder(&trait_pred, |trait_pred, cx| {
self.wrap_binder(&trait_pred, WrapBinderMode::ForAll, |trait_pred, cx| {
define_scoped_cx!(cx);
if trait_pred.polarity == ty::PredicatePolarity::Negative {
@ -1302,7 +1326,7 @@ pub trait PrettyPrinter<'tcx>: Printer<'tcx> + fmt::Write {
let mut first = true;
if let Some(bound_principal) = predicates.principal() {
self.wrap_binder(&bound_principal, |principal, cx| {
self.wrap_binder(&bound_principal, WrapBinderMode::ForAll, |principal, cx| {
define_scoped_cx!(cx);
p!(print_def_path(principal.def_id, &[]));
@ -1927,7 +1951,7 @@ pub trait PrettyPrinter<'tcx>: Printer<'tcx> + fmt::Write {
let kind = closure.kind_ty().to_opt_closure_kind().unwrap_or(ty::ClosureKind::Fn);
write!(self, "impl ")?;
self.wrap_binder(&sig, |sig, cx| {
self.wrap_binder(&sig, WrapBinderMode::ForAll, |sig, cx| {
define_scoped_cx!(cx);
p!(write("{kind}("));
@ -2367,22 +2391,27 @@ impl<'tcx> PrettyPrinter<'tcx> for FmtPrinter<'_, 'tcx> {
Ok(())
}
fn in_binder<T>(&mut self, value: &ty::Binder<'tcx, T>) -> Result<(), PrintError>
fn in_binder<T>(
&mut self,
value: &ty::Binder<'tcx, T>,
mode: WrapBinderMode,
) -> Result<(), PrintError>
where
T: Print<'tcx, Self> + TypeFoldable<TyCtxt<'tcx>>,
{
self.pretty_in_binder(value)
self.pretty_in_binder(value, mode)
}
fn wrap_binder<T, C: FnOnce(&T, &mut Self) -> Result<(), PrintError>>(
&mut self,
value: &ty::Binder<'tcx, T>,
mode: WrapBinderMode,
f: C,
) -> Result<(), PrintError>
where
T: TypeFoldable<TyCtxt<'tcx>>,
{
self.pretty_wrap_binder(value, f)
self.pretty_wrap_binder(value, mode, f)
}
fn typed_value(
@ -2632,6 +2661,7 @@ impl<'tcx> FmtPrinter<'_, 'tcx> {
pub fn name_all_regions<T>(
&mut self,
value: &ty::Binder<'tcx, T>,
mode: WrapBinderMode,
) -> Result<(T, UnordMap<ty::BoundRegion, ty::Region<'tcx>>), fmt::Error>
where
T: TypeFoldable<TyCtxt<'tcx>>,
@ -2705,9 +2735,13 @@ impl<'tcx> FmtPrinter<'_, 'tcx> {
// anyways.
let (new_value, map) = if self.should_print_verbose() {
for var in value.bound_vars().iter() {
start_or_continue(self, "for<", ", ");
start_or_continue(self, mode.start_str(), ", ");
write!(self, "{var:?}")?;
}
// Unconditionally render `unsafe<>`.
if value.bound_vars().is_empty() && mode == WrapBinderMode::Unsafe {
start_or_continue(self, mode.start_str(), "");
}
start_or_continue(self, "", "> ");
(value.clone().skip_binder(), UnordMap::default())
} else {
@ -2772,8 +2806,9 @@ impl<'tcx> FmtPrinter<'_, 'tcx> {
}
};
if !trim_path {
start_or_continue(self, "for<", ", ");
// Unconditionally render `unsafe<>`.
if !trim_path || mode == WrapBinderMode::Unsafe {
start_or_continue(self, mode.start_str(), ", ");
do_continue(self, name);
}
ty::Region::new_bound(tcx, ty::INNERMOST, ty::BoundRegion { var: br.var, kind })
@ -2786,9 +2821,12 @@ impl<'tcx> FmtPrinter<'_, 'tcx> {
};
let new_value = value.clone().skip_binder().fold_with(&mut folder);
let region_map = folder.region_map;
if !trim_path {
start_or_continue(self, "", "> ");
if mode == WrapBinderMode::Unsafe && region_map.is_empty() {
start_or_continue(self, mode.start_str(), "");
}
start_or_continue(self, "", "> ");
(new_value, region_map)
};
@ -2797,12 +2835,16 @@ impl<'tcx> FmtPrinter<'_, 'tcx> {
Ok((new_value, map))
}
pub fn pretty_in_binder<T>(&mut self, value: &ty::Binder<'tcx, T>) -> Result<(), fmt::Error>
pub fn pretty_in_binder<T>(
&mut self,
value: &ty::Binder<'tcx, T>,
mode: WrapBinderMode,
) -> Result<(), fmt::Error>
where
T: Print<'tcx, Self> + TypeFoldable<TyCtxt<'tcx>>,
{
let old_region_index = self.region_index;
let (new_value, _) = self.name_all_regions(value)?;
let (new_value, _) = self.name_all_regions(value, mode)?;
new_value.print(self)?;
self.region_index = old_region_index;
self.binder_depth -= 1;
@ -2812,13 +2854,14 @@ impl<'tcx> FmtPrinter<'_, 'tcx> {
pub fn pretty_wrap_binder<T, C: FnOnce(&T, &mut Self) -> Result<(), fmt::Error>>(
&mut self,
value: &ty::Binder<'tcx, T>,
mode: WrapBinderMode,
f: C,
) -> Result<(), fmt::Error>
where
T: TypeFoldable<TyCtxt<'tcx>>,
{
let old_region_index = self.region_index;
let (new_value, _) = self.name_all_regions(value)?;
let (new_value, _) = self.name_all_regions(value, mode)?;
f(&new_value, self)?;
self.region_index = old_region_index;
self.binder_depth -= 1;
@ -2877,7 +2920,7 @@ where
T: Print<'tcx, P> + TypeFoldable<TyCtxt<'tcx>>,
{
fn print(&self, cx: &mut P) -> Result<(), PrintError> {
cx.in_binder(self)
cx.in_binder(self, WrapBinderMode::ForAll)
}
}