1
Fork 0

Auto merge of #89341 - audunhalland:derive-type-params-with-bound-generic-params, r=jackh726

Deriving: Include bound generic params in type parameters for where clause

Fixes #89188.

The `derive` macro ignored the `for<'s>` needed with the `Fn` trait in that code example.

edit: I'm unsure if this might cause regressions. I'm not an experienced compiler developer so I'm not used to thinking about unwanted side effects code changes like this might have.
This commit is contained in:
bors 2021-10-02 18:46:27 +00:00
commit f03eb6bef8
2 changed files with 79 additions and 10 deletions

View file

@ -332,20 +332,27 @@ pub fn combine_substructure(
RefCell::new(f) RefCell::new(f)
} }
struct TypeParameter {
bound_generic_params: Vec<ast::GenericParam>,
ty: P<ast::Ty>,
}
/// This method helps to extract all the type parameters referenced from a /// This method helps to extract all the type parameters referenced from a
/// type. For a type parameter `<T>`, it looks for either a `TyPath` that /// type. For a type parameter `<T>`, it looks for either a `TyPath` that
/// is not global and starts with `T`, or a `TyQPath`. /// is not global and starts with `T`, or a `TyQPath`.
/// Also include bound generic params from the input type.
fn find_type_parameters( fn find_type_parameters(
ty: &ast::Ty, ty: &ast::Ty,
ty_param_names: &[Symbol], ty_param_names: &[Symbol],
cx: &ExtCtxt<'_>, cx: &ExtCtxt<'_>,
) -> Vec<P<ast::Ty>> { ) -> Vec<TypeParameter> {
use rustc_ast::visit; use rustc_ast::visit;
struct Visitor<'a, 'b> { struct Visitor<'a, 'b> {
cx: &'a ExtCtxt<'b>, cx: &'a ExtCtxt<'b>,
ty_param_names: &'a [Symbol], ty_param_names: &'a [Symbol],
types: Vec<P<ast::Ty>>, bound_generic_params_stack: Vec<ast::GenericParam>,
type_params: Vec<TypeParameter>,
} }
impl<'a, 'b> visit::Visitor<'a> for Visitor<'a, 'b> { impl<'a, 'b> visit::Visitor<'a> for Visitor<'a, 'b> {
@ -353,7 +360,10 @@ fn find_type_parameters(
if let ast::TyKind::Path(_, ref path) = ty.kind { if let ast::TyKind::Path(_, ref path) = ty.kind {
if let Some(segment) = path.segments.first() { if let Some(segment) = path.segments.first() {
if self.ty_param_names.contains(&segment.ident.name) { if self.ty_param_names.contains(&segment.ident.name) {
self.types.push(P(ty.clone())); self.type_params.push(TypeParameter {
bound_generic_params: self.bound_generic_params_stack.clone(),
ty: P(ty.clone()),
});
} }
} }
} }
@ -361,15 +371,35 @@ fn find_type_parameters(
visit::walk_ty(self, ty) visit::walk_ty(self, ty)
} }
// Place bound generic params on a stack, to extract them when a type is encountered.
fn visit_poly_trait_ref(
&mut self,
trait_ref: &'a ast::PolyTraitRef,
modifier: &'a ast::TraitBoundModifier,
) {
let stack_len = self.bound_generic_params_stack.len();
self.bound_generic_params_stack
.extend(trait_ref.bound_generic_params.clone().into_iter());
visit::walk_poly_trait_ref(self, trait_ref, modifier);
self.bound_generic_params_stack.truncate(stack_len);
}
fn visit_mac_call(&mut self, mac: &ast::MacCall) { fn visit_mac_call(&mut self, mac: &ast::MacCall) {
self.cx.span_err(mac.span(), "`derive` cannot be used on items with type macros"); self.cx.span_err(mac.span(), "`derive` cannot be used on items with type macros");
} }
} }
let mut visitor = Visitor { cx, ty_param_names, types: Vec::new() }; let mut visitor = Visitor {
cx,
ty_param_names,
bound_generic_params_stack: Vec::new(),
type_params: Vec::new(),
};
visit::Visitor::visit_ty(&mut visitor, ty); visit::Visitor::visit_ty(&mut visitor, ty);
visitor.types visitor.type_params
} }
impl<'a> TraitDef<'a> { impl<'a> TraitDef<'a> {
@ -617,11 +647,11 @@ impl<'a> TraitDef<'a> {
ty_params.map(|ty_param| ty_param.ident.name).collect(); ty_params.map(|ty_param| ty_param.ident.name).collect();
for field_ty in field_tys { for field_ty in field_tys {
let tys = find_type_parameters(&field_ty, &ty_param_names, cx); let field_ty_params = find_type_parameters(&field_ty, &ty_param_names, cx);
for ty in tys { for field_ty_param in field_ty_params {
// if we have already handled this type, skip it // if we have already handled this type, skip it
if let ast::TyKind::Path(_, ref p) = ty.kind { if let ast::TyKind::Path(_, ref p) = field_ty_param.ty.kind {
if p.segments.len() == 1 if p.segments.len() == 1
&& ty_param_names.contains(&p.segments[0].ident.name) && ty_param_names.contains(&p.segments[0].ident.name)
{ {
@ -639,8 +669,8 @@ impl<'a> TraitDef<'a> {
let predicate = ast::WhereBoundPredicate { let predicate = ast::WhereBoundPredicate {
span: self.span, span: self.span,
bound_generic_params: Vec::new(), bound_generic_params: field_ty_param.bound_generic_params,
bounded_ty: ty, bounded_ty: field_ty_param.ty,
bounds, bounds,
}; };

View file

@ -0,0 +1,39 @@
// check-pass
#![feature(generic_associated_types)]
trait CallWithShim: Sized {
type Shim<'s>
where
Self: 's;
}
#[derive(Clone)]
struct ShimMethod<T: CallWithShim + 'static>(pub &'static dyn for<'s> Fn(&'s mut T::Shim<'s>));
trait CallWithShim2: Sized {
type Shim<T>;
}
struct S<'s>(&'s ());
#[derive(Clone)]
struct ShimMethod2<T: CallWithShim2 + 'static>(pub &'static dyn for<'s> Fn(&'s mut T::Shim<S<'s>>));
trait Trait<'s, 't, 'u> {}
#[derive(Clone)]
struct ShimMethod3<T: CallWithShim2 + 'static>(
pub &'static dyn for<'s> Fn(
&'s mut T::Shim<dyn for<'t> Fn(&'s mut T::Shim<dyn for<'u> Trait<'s, 't, 'u>>)>,
),
);
trait Trait2 {
type As;
}
#[derive(Clone)]
struct ShimMethod4<T: Trait2 + 'static>(pub &'static dyn for<'s> Fn(&'s mut T::As));
pub fn main() {}