Consistently use the most significant bit of vector masks
This improves the codegen for vector `select`, `gather`, `scatter` and boolean reduction intrinsics and fixes rust-lang/portable-simd#316. The current behavior of most mask operations during llvm codegen is to truncate the mask vector to <N x i1>, telling llvm to use the least significat bit. The exception is the `simd_bitmask` intrinsics, which already used the most signifiant bit. Since sse/avx instructions are defined to use the most significant bit, truncating means that llvm has to insert a left shift to move the bit into the most significant position, before the mask can actually be used. Similarly on aarch64, mask operations like blend work bit by bit, repeating the least significant bit across the whole lane involves shifting it into the sign position and then comparing against zero. By shifting before truncating to <N x i1>, we tell llvm that we only consider the most significant bit, removing the need for additional shift instructions in the assembly.
This commit is contained in:
parent
c2270becb6
commit
3779b8e32e
13 changed files with 280 additions and 172 deletions
|
@ -1182,6 +1182,60 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
|
|||
}};
|
||||
}
|
||||
|
||||
/// Returns the bitwidth of the `$ty` argument if it is an `Int` type.
|
||||
macro_rules! require_int_ty {
|
||||
($ty: expr, $diag: expr) => {
|
||||
match $ty {
|
||||
ty::Int(i) => i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits()),
|
||||
_ => {
|
||||
return_error!($diag);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// Returns the bitwidth of the `$ty` argument if it is an `Int` or `Uint` type.
|
||||
macro_rules! require_int_or_uint_ty {
|
||||
($ty: expr, $diag: expr) => {
|
||||
match $ty {
|
||||
ty::Int(i) => i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits()),
|
||||
ty::Uint(i) => {
|
||||
i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits())
|
||||
}
|
||||
_ => {
|
||||
return_error!($diag);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// Converts a vector mask, where each element has a bit width equal to the data elements it is used with,
|
||||
/// down to an i1 based mask that can be used by llvm intrinsics.
|
||||
///
|
||||
/// The rust simd semantics are that each element should either consist of all ones or all zeroes,
|
||||
/// but this information is not available to llvm. Truncating the vector effectively uses the lowest bit,
|
||||
/// but codegen for several targets is better if we consider the highest bit by shifting.
|
||||
///
|
||||
/// For x86 SSE/AVX targets this is beneficial since most instructions with mask parameters only consider the highest bit.
|
||||
/// So even though on llvm level we have an additional shift, in the final assembly there is no shift or truncate and
|
||||
/// instead the mask can be used as is.
|
||||
///
|
||||
/// For aarch64 and other targets there is a benefit because a mask from the sign bit can be more
|
||||
/// efficiently converted to an all ones / all zeroes mask by comparing whether each element is negative.
|
||||
fn vector_mask_to_bitmask<'a, 'll, 'tcx>(
|
||||
bx: &mut Builder<'a, 'll, 'tcx>,
|
||||
i_xn: &'ll Value,
|
||||
in_elem_bitwidth: u64,
|
||||
in_len: u64,
|
||||
) -> &'ll Value {
|
||||
// Shift the MSB to the right by "in_elem_bitwidth - 1" into the first bit position.
|
||||
let shift_idx = bx.cx.const_int(bx.type_ix(in_elem_bitwidth), (in_elem_bitwidth - 1) as _);
|
||||
let shift_indices = vec![shift_idx; in_len as _];
|
||||
let i_xn_msb = bx.lshr(i_xn, bx.const_vector(shift_indices.as_slice()));
|
||||
// Truncate vector to an <i1 x N>
|
||||
bx.trunc(i_xn_msb, bx.type_vector(bx.type_i1(), in_len))
|
||||
}
|
||||
|
||||
let tcx = bx.tcx();
|
||||
let sig = tcx.normalize_erasing_late_bound_regions(bx.typing_env(), callee_ty.fn_sig(tcx));
|
||||
let arg_tys = sig.inputs();
|
||||
|
@ -1433,14 +1487,13 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
|
|||
m_len,
|
||||
v_len
|
||||
});
|
||||
match m_elem_ty.kind() {
|
||||
ty::Int(_) => {}
|
||||
_ => return_error!(InvalidMonomorphization::MaskType { span, name, ty: m_elem_ty }),
|
||||
}
|
||||
// truncate the mask to a vector of i1s
|
||||
let i1 = bx.type_i1();
|
||||
let i1xn = bx.type_vector(i1, m_len as u64);
|
||||
let m_i1s = bx.trunc(args[0].immediate(), i1xn);
|
||||
let in_elem_bitwidth =
|
||||
require_int_ty!(m_elem_ty.kind(), InvalidMonomorphization::MaskType {
|
||||
span,
|
||||
name,
|
||||
ty: m_elem_ty
|
||||
});
|
||||
let m_i1s = vector_mask_to_bitmask(bx, args[0].immediate(), in_elem_bitwidth, m_len);
|
||||
return Ok(bx.select(m_i1s, args[1].immediate(), args[2].immediate()));
|
||||
}
|
||||
|
||||
|
@ -1457,33 +1510,15 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
|
|||
let expected_bytes = in_len.div_ceil(8);
|
||||
|
||||
// Integer vector <i{in_bitwidth} x in_len>:
|
||||
let (i_xn, in_elem_bitwidth) = match in_elem.kind() {
|
||||
ty::Int(i) => (
|
||||
args[0].immediate(),
|
||||
i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits()),
|
||||
),
|
||||
ty::Uint(i) => (
|
||||
args[0].immediate(),
|
||||
i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits()),
|
||||
),
|
||||
_ => return_error!(InvalidMonomorphization::VectorArgument {
|
||||
let in_elem_bitwidth =
|
||||
require_int_or_uint_ty!(in_elem.kind(), InvalidMonomorphization::VectorArgument {
|
||||
span,
|
||||
name,
|
||||
in_ty,
|
||||
in_elem
|
||||
}),
|
||||
};
|
||||
});
|
||||
|
||||
// LLVM doesn't always know the inputs are `0` or `!0`, so we shift here so it optimizes to
|
||||
// `pmovmskb` and similar on x86.
|
||||
let shift_indices =
|
||||
vec![
|
||||
bx.cx.const_int(bx.type_ix(in_elem_bitwidth), (in_elem_bitwidth - 1) as _);
|
||||
in_len as _
|
||||
];
|
||||
let i_xn_msb = bx.lshr(i_xn, bx.const_vector(shift_indices.as_slice()));
|
||||
// Truncate vector to an <i1 x N>
|
||||
let i1xn = bx.trunc(i_xn_msb, bx.type_vector(bx.type_i1(), in_len));
|
||||
let i1xn = vector_mask_to_bitmask(bx, args[0].immediate(), in_elem_bitwidth, in_len);
|
||||
// Bitcast <i1 x N> to iN:
|
||||
let i_ = bx.bitcast(i1xn, bx.type_ix(in_len));
|
||||
|
||||
|
@ -1704,28 +1739,21 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
|
|||
}
|
||||
);
|
||||
|
||||
match element_ty2.kind() {
|
||||
ty::Int(_) => (),
|
||||
_ => {
|
||||
return_error!(InvalidMonomorphization::ThirdArgElementType {
|
||||
span,
|
||||
name,
|
||||
expected_element: element_ty2,
|
||||
third_arg: arg_tys[2]
|
||||
});
|
||||
}
|
||||
}
|
||||
let mask_elem_bitwidth =
|
||||
require_int_ty!(element_ty2.kind(), InvalidMonomorphization::ThirdArgElementType {
|
||||
span,
|
||||
name,
|
||||
expected_element: element_ty2,
|
||||
third_arg: arg_tys[2]
|
||||
});
|
||||
|
||||
// Alignment of T, must be a constant integer value:
|
||||
let alignment_ty = bx.type_i32();
|
||||
let alignment = bx.const_i32(bx.align_of(in_elem).bytes() as i32);
|
||||
|
||||
// Truncate the mask vector to a vector of i1s:
|
||||
let (mask, mask_ty) = {
|
||||
let i1 = bx.type_i1();
|
||||
let i1xn = bx.type_vector(i1, in_len);
|
||||
(bx.trunc(args[2].immediate(), i1xn), i1xn)
|
||||
};
|
||||
let mask = vector_mask_to_bitmask(bx, args[2].immediate(), mask_elem_bitwidth, in_len);
|
||||
let mask_ty = bx.type_vector(bx.type_i1(), in_len);
|
||||
|
||||
// Type of the vector of pointers:
|
||||
let llvm_pointer_vec_ty = llvm_vector_ty(bx, element_ty1, in_len);
|
||||
|
@ -1810,27 +1838,21 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
|
|||
}
|
||||
);
|
||||
|
||||
require!(
|
||||
matches!(mask_elem.kind(), ty::Int(_)),
|
||||
InvalidMonomorphization::ThirdArgElementType {
|
||||
let m_elem_bitwidth =
|
||||
require_int_ty!(mask_elem.kind(), InvalidMonomorphization::ThirdArgElementType {
|
||||
span,
|
||||
name,
|
||||
expected_element: values_elem,
|
||||
third_arg: mask_ty,
|
||||
}
|
||||
);
|
||||
});
|
||||
|
||||
let mask = vector_mask_to_bitmask(bx, args[0].immediate(), m_elem_bitwidth, mask_len);
|
||||
let mask_ty = bx.type_vector(bx.type_i1(), mask_len);
|
||||
|
||||
// Alignment of T, must be a constant integer value:
|
||||
let alignment_ty = bx.type_i32();
|
||||
let alignment = bx.const_i32(bx.align_of(values_elem).bytes() as i32);
|
||||
|
||||
// Truncate the mask vector to a vector of i1s:
|
||||
let (mask, mask_ty) = {
|
||||
let i1 = bx.type_i1();
|
||||
let i1xn = bx.type_vector(i1, mask_len);
|
||||
(bx.trunc(args[0].immediate(), i1xn), i1xn)
|
||||
};
|
||||
|
||||
let llvm_pointer = bx.type_ptr();
|
||||
|
||||
// Type of the vector of elements:
|
||||
|
@ -1901,27 +1923,21 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
|
|||
}
|
||||
);
|
||||
|
||||
require!(
|
||||
matches!(mask_elem.kind(), ty::Int(_)),
|
||||
InvalidMonomorphization::ThirdArgElementType {
|
||||
let m_elem_bitwidth =
|
||||
require_int_ty!(mask_elem.kind(), InvalidMonomorphization::ThirdArgElementType {
|
||||
span,
|
||||
name,
|
||||
expected_element: values_elem,
|
||||
third_arg: mask_ty,
|
||||
}
|
||||
);
|
||||
});
|
||||
|
||||
let mask = vector_mask_to_bitmask(bx, args[0].immediate(), m_elem_bitwidth, mask_len);
|
||||
let mask_ty = bx.type_vector(bx.type_i1(), mask_len);
|
||||
|
||||
// Alignment of T, must be a constant integer value:
|
||||
let alignment_ty = bx.type_i32();
|
||||
let alignment = bx.const_i32(bx.align_of(values_elem).bytes() as i32);
|
||||
|
||||
// Truncate the mask vector to a vector of i1s:
|
||||
let (mask, mask_ty) = {
|
||||
let i1 = bx.type_i1();
|
||||
let i1xn = bx.type_vector(i1, in_len);
|
||||
(bx.trunc(args[0].immediate(), i1xn), i1xn)
|
||||
};
|
||||
|
||||
let ret_t = bx.type_void();
|
||||
|
||||
let llvm_pointer = bx.type_ptr();
|
||||
|
@ -1995,28 +2011,21 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
|
|||
);
|
||||
|
||||
// The element type of the third argument must be a signed integer type of any width:
|
||||
match element_ty2.kind() {
|
||||
ty::Int(_) => (),
|
||||
_ => {
|
||||
return_error!(InvalidMonomorphization::ThirdArgElementType {
|
||||
span,
|
||||
name,
|
||||
expected_element: element_ty2,
|
||||
third_arg: arg_tys[2]
|
||||
});
|
||||
}
|
||||
}
|
||||
let mask_elem_bitwidth =
|
||||
require_int_ty!(element_ty2.kind(), InvalidMonomorphization::ThirdArgElementType {
|
||||
span,
|
||||
name,
|
||||
expected_element: element_ty2,
|
||||
third_arg: arg_tys[2]
|
||||
});
|
||||
|
||||
// Alignment of T, must be a constant integer value:
|
||||
let alignment_ty = bx.type_i32();
|
||||
let alignment = bx.const_i32(bx.align_of(in_elem).bytes() as i32);
|
||||
|
||||
// Truncate the mask vector to a vector of i1s:
|
||||
let (mask, mask_ty) = {
|
||||
let i1 = bx.type_i1();
|
||||
let i1xn = bx.type_vector(i1, in_len);
|
||||
(bx.trunc(args[2].immediate(), i1xn), i1xn)
|
||||
};
|
||||
let mask = vector_mask_to_bitmask(bx, args[2].immediate(), mask_elem_bitwidth, in_len);
|
||||
let mask_ty = bx.type_vector(bx.type_i1(), in_len);
|
||||
|
||||
let ret_t = bx.type_void();
|
||||
|
||||
|
@ -2164,8 +2173,13 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
|
|||
});
|
||||
args[0].immediate()
|
||||
} else {
|
||||
match in_elem.kind() {
|
||||
ty::Int(_) | ty::Uint(_) => {}
|
||||
let bitwidth = match in_elem.kind() {
|
||||
ty::Int(i) => {
|
||||
i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits())
|
||||
}
|
||||
ty::Uint(i) => {
|
||||
i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits())
|
||||
}
|
||||
_ => return_error!(InvalidMonomorphization::UnsupportedSymbol {
|
||||
span,
|
||||
name,
|
||||
|
@ -2174,12 +2188,9 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
|
|||
in_elem,
|
||||
ret_ty
|
||||
}),
|
||||
}
|
||||
};
|
||||
|
||||
// boolean reductions operate on vectors of i1s:
|
||||
let i1 = bx.type_i1();
|
||||
let i1xn = bx.type_vector(i1, in_len as u64);
|
||||
bx.trunc(args[0].immediate(), i1xn)
|
||||
vector_mask_to_bitmask(bx, args[0].immediate(), bitwidth, in_len as _)
|
||||
};
|
||||
return match in_elem.kind() {
|
||||
ty::Int(_) | ty::Uint(_) => {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue