1
Fork 0

Handle tags better.

Currently, for the enums and comparison traits we always check the tag
for equality before doing anything else. This is a bit clumsy. This
commit changes things so that the tags are handled very much like a
zeroth field in the enum.

For `eq`/ne` this makes the code slightly cleaner.

For `partial_cmp` and `cmp` it's a more notable change: in the case
where the tags aren't equal, instead of having a tag equality check
followed by a tag comparison, it just does a single tag comparison.

The commit also improves how `Hash` works for enums: instead of having
duplicated code to hash the tag for every arm within the match, we do
it just once before the match.

All this required replacing the `EnumNonMatchingCollapsed` value with a
new `EnumTag` value.

For fieldless enums the new code is particularly improved. All the code
now produced is close to optimal, being very similar to what you'd write
by hand.
This commit is contained in:
Nicholas Nethercote 2022-07-08 15:32:27 +10:00
parent 4bcbd76bc9
commit 10144e29af
9 changed files with 245 additions and 329 deletions

View file

@ -21,21 +21,14 @@
//! `struct T(i32, char)`).
//! - `EnumMatching`, when `Self` is an enum and all the arguments are the
//! same variant of the enum (e.g., `Some(1)`, `Some(3)` and `Some(4)`)
//! - `EnumNonMatchingCollapsed` when `Self` is an enum and the arguments
//! are not the same variant (e.g., `None`, `Some(1)` and `None`).
//! - `EnumTag` when `Self` is an enum, for comparing the enum tags.
//! - `StaticEnum` and `StaticStruct` for static methods, where the type
//! being derived upon is either an enum or struct respectively. (Any
//! argument with type Self is just grouped among the non-self
//! arguments.)
//!
//! In the first two cases, the values from the corresponding fields in
//! all the arguments are grouped together. For `EnumNonMatchingCollapsed`
//! this isn't possible (different variants have different fields), so the
//! fields are inaccessible. (Previous versions of the deriving infrastructure
//! had a way to expand into code that could access them, at the cost of
//! generating exponential amounts of code; see issue #15375). There are no
//! fields with values in the static cases, so these are treated entirely
//! differently.
//! all the arguments are grouped together.
//!
//! The non-static cases have `Option<ident>` in several places associated
//! with field `expr`s. This represents the name of the field it is
@ -142,21 +135,15 @@
//! }])
//! ```
//!
//! For `C0(a)` and `C1 {x}` ,
//! For the tags,
//!
//! ```{.text}
//! EnumNonMatchingCollapsed(
//! &[<ident for self index value>, <ident of __arg1 index value>])
//! EnumTag(
//! &[<ident of self tag>, <ident of other tag>], <expr to combine with>)
//! ```
//!
//! It is the same for when the arguments are flipped to `C1 {x}` and
//! `C0(a)`; the only difference is what the values of the identifiers
//! <ident for self index value> and <ident of __arg1 index value> will
//! be in the generated code.
//!
//! `EnumNonMatchingCollapsed` deliberately provides far less information
//! than is generally available for a given pair of variants; see #15375
//! for discussion.
//! Note that this setup doesn't allow for the brute-force "match every variant
//! against every other variant" approach, which is bad because it produces a
//! quadratic amount of code (see #15375).
//!
//! ## Static
//!
@ -180,7 +167,7 @@ use std::iter;
use std::vec;
use rustc_ast::ptr::P;
use rustc_ast::{self as ast, BinOpKind, EnumDef, Expr, Generics, PatKind};
use rustc_ast::{self as ast, EnumDef, Expr, Generics, PatKind};
use rustc_ast::{GenericArg, GenericParamKind, VariantData};
use rustc_attr as attr;
use rustc_expand::base::{Annotatable, ExtCtxt};
@ -235,6 +222,8 @@ pub struct MethodDef<'a> {
pub attributes: Vec<ast::Attribute>,
/// Can we combine fieldless variants for enums into a single match arm?
/// If true, indicates that the trait operation uses the enum tag in some
/// way.
pub unify_fieldless_variants: bool,
pub combine_substructure: RefCell<CombineSubstructureFunc<'a>>,
@ -274,19 +263,22 @@ pub enum StaticFields {
/// A summary of the possible sets of fields.
pub enum SubstructureFields<'a> {
/// A non-static method with `Self` is a struct.
Struct(&'a ast::VariantData, Vec<FieldInfo>),
/// Matching variants of the enum: variant index, variant count, ast::Variant,
/// fields: the field name is only non-`None` in the case of a struct
/// variant.
EnumMatching(usize, usize, &'a ast::Variant, Vec<FieldInfo>),
/// Non-matching variants of the enum, but with all state hidden from the
/// consequent code. The field is a list of `Ident`s bound to the variant
/// index values for each of the actual input `Self` arguments.
EnumNonMatchingCollapsed(&'a [Ident]),
/// The tag of an enum. The first field is a `FieldInfo` for the tags, as
/// if they were fields. The second field is the expression to combine the
/// tag expression with; it will be `None` if no match is necessary.
EnumTag(FieldInfo, Option<P<Expr>>),
/// A static method where `Self` is a struct.
StaticStruct(&'a ast::VariantData, StaticFields),
/// A static method where `Self` is an enum.
StaticEnum(&'a ast::EnumDef, Vec<(Ident, Span, StaticFields)>),
}
@ -324,8 +316,8 @@ impl BlockOrExpr {
BlockOrExpr(vec![], Some(expr))
}
pub fn new_mixed(stmts: Vec<ast::Stmt>, expr: P<Expr>) -> BlockOrExpr {
BlockOrExpr(stmts, Some(expr))
pub fn new_mixed(stmts: Vec<ast::Stmt>, expr: Option<P<Expr>>) -> BlockOrExpr {
BlockOrExpr(stmts, expr)
}
// Converts it into a block.
@ -339,7 +331,6 @@ impl BlockOrExpr {
// Converts it into an expression.
fn into_expr(self, cx: &ExtCtxt<'_>, span: Span) -> P<Expr> {
if self.0.is_empty() {
// No statements.
match self.1 {
None => cx.expr_block(cx.block(span, vec![])),
Some(expr) => expr,
@ -1135,44 +1126,34 @@ impl<'a> MethodDef<'a> {
/// fn eq(&self, other: &A) -> bool {
/// let __self_tag = ::core::intrinsics::discriminant_value(self);
/// let __arg1_tag = ::core::intrinsics::discriminant_value(other);
/// if __self_tag == __arg1_tag {
/// __self_tag == __arg1_tag &&
/// match (self, other) {
/// (A::A2(__self_0), A::A2(__arg1_0)) =>
/// *__self_0 == *__arg1_0,
/// _ => true,
/// }
/// } else {
/// false // catch-all handler
/// }
/// }
/// }
/// ```
/// Creates a match for a tuple of all `selflike_args`, where either all
/// variants match, or it falls into a catch-all for when one variant
/// does not match.
///
/// There are N + 1 cases because is a case for each of the N
/// variants where all of the variants match, and one catch-all for
/// when one does not match.
///
/// As an optimization we generate code which checks whether all variants
/// match first which makes llvm see that C-like enums can be compiled into
/// a simple equality check (for PartialEq).
///
/// The catch-all handler is provided access the variant index values
/// for each of the selflike_args, carried in precomputed variables.
/// Creates a tag check combined with a match for a tuple of all
/// `selflike_args`, with an arm for each variant with fields, possibly an
/// arm for each fieldless variant (if `!unify_fieldless_variants` is not
/// true), and possibly a default arm.
fn expand_enum_method_body<'b>(
&self,
cx: &mut ExtCtxt<'_>,
trait_: &TraitDef<'b>,
enum_def: &'b EnumDef,
type_ident: Ident,
mut selflike_args: Vec<P<Expr>>,
selflike_args: Vec<P<Expr>>,
nonselflike_args: &[P<Expr>],
) -> BlockOrExpr {
let span = trait_.span;
let variants = &enum_def.variants;
// Traits that unify fieldless variants always use the tag(s).
let uses_tags = self.unify_fieldless_variants;
// There is no sensible code to be generated for *any* deriving on a
// zero-variant enum. So we just generate a failing expression.
if variants.is_empty() {
@ -1189,27 +1170,82 @@ impl<'a> MethodDef<'a> {
)
.collect::<Vec<String>>();
// The `tag_idents` will be bound, solely in the catch-all, to
// a series of let statements mapping each selflike_arg to an int
// value corresponding to its discriminant.
let tag_idents = prefixes
.iter()
.map(|name| Ident::from_str_and_span(&format!("{}_tag", name), span))
.collect::<Vec<Ident>>();
// Build a series of let statements mapping each selflike_arg
// to its discriminant value.
//
// e.g. for `PartialEq::eq` builds two statements:
// ```
// let __self_tag = ::core::intrinsics::discriminant_value(self);
// let __arg1_tag = ::core::intrinsics::discriminant_value(other);
// ```
let get_tag_pieces = |cx: &ExtCtxt<'_>| {
let tag_idents: Vec<_> = prefixes
.iter()
.map(|name| Ident::from_str_and_span(&format!("{}_tag", name), span))
.collect();
// Builds, via callback to call_substructure_method, the
// delegated expression that handles the catch-all case,
// using `__variants_tuple` to drive logic if necessary.
let catch_all_substructure = EnumNonMatchingCollapsed(&tag_idents);
let mut tag_exprs: Vec<_> = tag_idents
.iter()
.map(|&ident| cx.expr_addr_of(span, cx.expr_ident(span, ident)))
.collect();
let first_fieldless = variants.iter().find(|v| v.data.fields().is_empty());
let self_expr = tag_exprs.remove(0);
let other_selflike_exprs = tag_exprs;
let tag_field = FieldInfo { span, name: None, self_expr, other_selflike_exprs };
let tag_let_stmts: Vec<_> = iter::zip(&tag_idents, &selflike_args)
.map(|(&ident, selflike_arg)| {
let variant_value = deriving::call_intrinsic(
cx,
span,
sym::discriminant_value,
vec![selflike_arg.clone()],
);
cx.stmt_let(span, false, ident, variant_value)
})
.collect();
(tag_field, tag_let_stmts)
};
// There are some special cases involving fieldless enums where no
// match is necessary.
let all_fieldless = variants.iter().all(|v| v.data.fields().is_empty());
if all_fieldless {
if uses_tags && variants.len() > 1 {
// If the type is fieldless and the trait uses the tag and
// there are multiple variants, we need just an operation on
// the tag(s).
let (tag_field, mut tag_let_stmts) = get_tag_pieces(cx);
let mut tag_check = self.call_substructure_method(
cx,
trait_,
type_ident,
nonselflike_args,
&EnumTag(tag_field, None),
);
tag_let_stmts.append(&mut tag_check.0);
return BlockOrExpr(tag_let_stmts, tag_check.1);
}
if variants.len() == 1 {
// If there is a single variant, we don't need an operation on
// the tag(s). Just use the most degenerate result.
return self.call_substructure_method(
cx,
trait_,
type_ident,
nonselflike_args,
&EnumMatching(0, 1, &variants[0], Vec::new()),
);
};
}
// These arms are of the form:
// (Variant1, Variant1, ...) => Body1
// (Variant2, Variant2, ...) => Body2
// ...
// where each tuple has length = selflike_args.len()
let mut match_arms: Vec<ast::Arm> = variants
.iter()
.enumerate()
@ -1233,7 +1269,7 @@ impl<'a> MethodDef<'a> {
use_ref_pat,
);
// Here is the pat = `(&VariantK, &VariantK, ...)`
// `(VariantK, VariantK, ...)` or just `VariantK`.
let single_pat = if subpats.len() == 1 {
subpats.pop().unwrap()
} else {
@ -1263,27 +1299,28 @@ impl<'a> MethodDef<'a> {
})
.collect();
// Add a default arm to the match, if necessary.
let first_fieldless = variants.iter().find(|v| v.data.fields().is_empty());
let default = match first_fieldless {
Some(v) if self.unify_fieldless_variants => {
// We need a default case that handles the fieldless variants.
// The index and actual variant aren't meaningful in this case,
// so just use whatever
let substructure = EnumMatching(0, variants.len(), v, Vec::new());
// We need a default case that handles all the fieldless
// variants. The index and actual variant aren't meaningful in
// this case, so just use dummy values.
Some(
self.call_substructure_method(
cx,
trait_,
type_ident,
nonselflike_args,
&substructure,
&EnumMatching(0, variants.len(), v, Vec::new()),
)
.into_expr(cx, span),
)
}
_ if variants.len() > 1 && selflike_args.len() > 1 => {
// Since we know that all the arguments will match if we reach
// Because we know that all the arguments will match if we reach
// the match expression we add the unreachable intrinsics as the
// result of the catch all which should help llvm in optimizing it
// result of the default which should help llvm in optimizing it.
Some(deriving::call_unreachable(cx, span))
}
_ => None,
@ -1292,92 +1329,41 @@ impl<'a> MethodDef<'a> {
match_arms.push(cx.arm(span, cx.pat_wild(span), arm));
}
// We will usually need the catch-all after matching the
// tuples `(VariantK, VariantK, ...)` for each VariantK of the
// enum. But:
//
// * when there is only one Self arg, the arms above suffice
// (and the deriving we call back into may not be prepared to
// handle EnumNonMatchCollapsed), and,
//
// * when the enum has only one variant, the single arm that
// is already present always suffices.
//
// * In either of the two cases above, if we *did* add a
// catch-all `_` match, it would trigger the
// unreachable-pattern error.
//
if variants.len() > 1 && selflike_args.len() > 1 {
// Build a series of let statements mapping each selflike_arg
// to its discriminant value.
//
// i.e., for `enum E<T> { A, B(1), C(T, T) }` for `PartialEq::eq`,
// builds two statements:
// ```
// let __self_tag = ::core::intrinsics::discriminant_value(self);
// let __arg1_tag = ::core::intrinsics::discriminant_value(other);
// ```
let mut index_let_stmts: Vec<ast::Stmt> = Vec::with_capacity(tag_idents.len() + 1);
// We also build an expression which checks whether all discriminants are equal, e.g.
// `__self_tag == __arg1_tag`.
let mut discriminant_test = cx.expr_bool(span, true);
for (i, (&ident, selflike_arg)) in iter::zip(&tag_idents, &selflike_args).enumerate() {
let variant_value = deriving::call_intrinsic(
cx,
span,
sym::discriminant_value,
vec![selflike_arg.clone()],
);
let let_stmt = cx.stmt_let(span, false, ident, variant_value);
index_let_stmts.push(let_stmt);
if i > 0 {
let id0 = cx.expr_ident(span, tag_idents[0]);
let id = cx.expr_ident(span, ident);
let test = cx.expr_binary(span, BinOpKind::Eq, id0, id);
discriminant_test = if i == 1 {
test
} else {
cx.expr_binary(span, BinOpKind::And, discriminant_test, test)
};
}
}
let arm_expr = self
.call_substructure_method(
cx,
trait_,
type_ident,
nonselflike_args,
&catch_all_substructure,
)
.into_expr(cx, span);
let match_arg = cx.expr(span, ast::ExprKind::Tup(selflike_args));
// Lastly we create an expression which branches on all discriminants being equal, e.g.
// if __self_tag == _arg1_tag {
// match (self, other) {
// (Variant1, Variant1, ...) => Body1
// (Variant2, Variant2, ...) => Body2,
// ...
// _ => ::core::intrinsics::unreachable()
// }
// }
// else {
// <delegated expression referring to __self_tag, et al.>
// }
let all_match = cx.expr_match(span, match_arg, match_arms);
let arm_expr = cx.expr_if(span, discriminant_test, all_match, Some(arm_expr));
BlockOrExpr(index_let_stmts, Some(arm_expr))
} else {
// Create a match expression with one arm per discriminant plus
// possibly a default arm, e.g.:
// match (self, other) {
// (Variant1, Variant1, ...) => Body1
// (Variant2, Variant2, ...) => Body2,
// ...
// _ => ::core::intrinsics::unreachable()
// }
let get_match_expr = |mut selflike_args: Vec<P<Expr>>| {
let match_arg = if selflike_args.len() == 1 {
selflike_args.pop().unwrap()
} else {
cx.expr(span, ast::ExprKind::Tup(selflike_args))
};
BlockOrExpr(vec![], Some(cx.expr_match(span, match_arg, match_arms)))
cx.expr_match(span, match_arg, match_arms)
};
// If the trait uses the tag and there are multiple variants, we need
// to add a tag check operation before the match. Otherwise, the match
// is enough.
if uses_tags && variants.len() > 1 {
let (tag_field, mut tag_let_stmts) = get_tag_pieces(cx);
// Combine a tag check with the match.
let mut tag_check_plus_match = self.call_substructure_method(
cx,
trait_,
type_ident,
nonselflike_args,
&EnumTag(tag_field, Some(get_match_expr(selflike_args))),
);
tag_let_stmts.append(&mut tag_check_plus_match.0);
BlockOrExpr(tag_let_stmts, tag_check_plus_match.1)
} else {
BlockOrExpr(vec![], Some(get_match_expr(selflike_args)))
}
}
@ -1591,11 +1577,6 @@ pub enum CsFold<'a> {
// The fallback case for a struct or enum variant with no fields.
Fieldless,
/// The fallback case for non-matching enum variants. The slice is the
/// identifiers holding the variant index value for each of the `Self`
/// arguments.
EnumNonMatching(Span, &'a [Ident]),
}
/// Folds over fields, combining the expressions for each field in a sequence.
@ -1610,8 +1591,8 @@ pub fn cs_fold<F>(
where
F: FnMut(&mut ExtCtxt<'_>, CsFold<'_>) -> P<Expr>,
{
match *substructure.fields {
EnumMatching(.., ref all_fields) | Struct(_, ref all_fields) => {
match substructure.fields {
EnumMatching(.., all_fields) | Struct(_, all_fields) => {
if all_fields.is_empty() {
return f(cx, CsFold::Fieldless);
}
@ -1635,7 +1616,18 @@ where
rest.iter().rfold(base_expr, op)
}
}
EnumNonMatchingCollapsed(tuple) => f(cx, CsFold::EnumNonMatching(trait_span, tuple)),
EnumTag(tag_field, match_expr) => {
let tag_check_expr = f(cx, CsFold::Single(tag_field));
if let Some(match_expr) = match_expr {
if use_foldl {
f(cx, CsFold::Combine(trait_span, tag_check_expr, match_expr.clone()))
} else {
f(cx, CsFold::Combine(trait_span, match_expr.clone(), tag_check_expr))
}
} else {
tag_check_expr
}
}
StaticEnum(..) | StaticStruct(..) => cx.span_bug(trait_span, "static function in `derive`"),
}
}