Fix array::IntoIter::fold
to use the optimized Range::fold
It was using `Iterator::by_ref` in the implementation, which ended up pessimizing it enough that, for example, it didn't vectorize when we tried it in the <https://rust-lang.zulipchat.com/#narrow/stream/257879-project-portable-simd/topic/Reducing.20sum.20into.20wider.20types> conversation. Demonstration that the codegen test doesn't pass on the current nightly: <https://rust.godbolt.org/z/Taxev5eMn>
This commit is contained in:
parent
eb82facb16
commit
83595f9242
4 changed files with 130 additions and 1 deletions
|
@ -266,7 +266,7 @@ impl<T, const N: usize> Iterator for IntoIter<T, N> {
|
|||
Fold: FnMut(Acc, Self::Item) -> Acc,
|
||||
{
|
||||
let data = &mut self.data;
|
||||
self.alive.by_ref().fold(init, |acc, idx| {
|
||||
iter::ByRefSized(&mut self.alive).fold(init, |acc, idx| {
|
||||
// SAFETY: idx is obtained by folding over the `alive` range, which implies the
|
||||
// value is currently considered alive but as the range is being consumed each value
|
||||
// we read here will only be read once and then considered dead.
|
||||
|
@ -323,6 +323,20 @@ impl<T, const N: usize> DoubleEndedIterator for IntoIter<T, N> {
|
|||
})
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn rfold<Acc, Fold>(mut self, init: Acc, mut rfold: Fold) -> Acc
|
||||
where
|
||||
Fold: FnMut(Acc, Self::Item) -> Acc,
|
||||
{
|
||||
let data = &mut self.data;
|
||||
iter::ByRefSized(&mut self.alive).rfold(init, |acc, idx| {
|
||||
// SAFETY: idx is obtained by folding over the `alive` range, which implies the
|
||||
// value is currently considered alive but as the range is being consumed each value
|
||||
// we read here will only be read once and then considered dead.
|
||||
rfold(acc, unsafe { data.get_unchecked(idx).assume_init_read() })
|
||||
})
|
||||
}
|
||||
|
||||
fn advance_back_by(&mut self, n: usize) -> Result<(), usize> {
|
||||
let len = self.len();
|
||||
|
||||
|
|
|
@ -40,3 +40,32 @@ impl<I: Iterator> Iterator for ByRefSized<'_, I> {
|
|||
self.0.try_fold(init, f)
|
||||
}
|
||||
}
|
||||
|
||||
impl<I: DoubleEndedIterator> DoubleEndedIterator for ByRefSized<'_, I> {
|
||||
fn next_back(&mut self) -> Option<Self::Item> {
|
||||
self.0.next_back()
|
||||
}
|
||||
|
||||
fn advance_back_by(&mut self, n: usize) -> Result<(), usize> {
|
||||
self.0.advance_back_by(n)
|
||||
}
|
||||
|
||||
fn nth_back(&mut self, n: usize) -> Option<Self::Item> {
|
||||
self.0.nth_back(n)
|
||||
}
|
||||
|
||||
fn rfold<B, F>(self, init: B, f: F) -> B
|
||||
where
|
||||
F: FnMut(B, Self::Item) -> B,
|
||||
{
|
||||
self.0.rfold(init, f)
|
||||
}
|
||||
|
||||
fn try_rfold<B, F, R>(&mut self, init: B, f: F) -> R
|
||||
where
|
||||
F: FnMut(B, Self::Item) -> R,
|
||||
R: Try<Output = B>,
|
||||
{
|
||||
self.0.try_rfold(init, f)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -668,3 +668,35 @@ fn array_mixed_equality_nans() {
|
|||
assert!(!(mut3 == array3));
|
||||
assert!(mut3 != array3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn array_into_iter_fold() {
|
||||
// Strings to help MIRI catch if we double-free or something
|
||||
let a = ["Aa".to_string(), "Bb".to_string(), "Cc".to_string()];
|
||||
let mut s = "s".to_string();
|
||||
a.into_iter().for_each(|b| s += &b);
|
||||
assert_eq!(s, "sAaBbCc");
|
||||
|
||||
let a = [1, 2, 3, 4, 5, 6];
|
||||
let mut it = a.into_iter();
|
||||
it.advance_by(1).unwrap();
|
||||
it.advance_back_by(2).unwrap();
|
||||
let s = it.fold(10, |a, b| 10 * a + b);
|
||||
assert_eq!(s, 10234);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn array_into_iter_rfold() {
|
||||
// Strings to help MIRI catch if we double-free or something
|
||||
let a = ["Aa".to_string(), "Bb".to_string(), "Cc".to_string()];
|
||||
let mut s = "s".to_string();
|
||||
a.into_iter().rev().for_each(|b| s += &b);
|
||||
assert_eq!(s, "sCcBbAa");
|
||||
|
||||
let a = [1, 2, 3, 4, 5, 6];
|
||||
let mut it = a.into_iter();
|
||||
it.advance_by(1).unwrap();
|
||||
it.advance_back_by(2).unwrap();
|
||||
let s = it.rfold(10, |a, b| 10 * a + b);
|
||||
assert_eq!(s, 10432);
|
||||
}
|
||||
|
|
54
src/test/codegen/simd-wide-sum.rs
Normal file
54
src/test/codegen/simd-wide-sum.rs
Normal file
|
@ -0,0 +1,54 @@
|
|||
// compile-flags: -C opt-level=3 --edition=2021
|
||||
// only-x86_64
|
||||
// ignore-debug: the debug assertions get in the way
|
||||
|
||||
#![crate_type = "lib"]
|
||||
#![feature(portable_simd)]
|
||||
|
||||
use std::simd::Simd;
|
||||
const N: usize = 8;
|
||||
|
||||
#[no_mangle]
|
||||
// CHECK-LABEL: @wider_reduce_simd
|
||||
pub fn wider_reduce_simd(x: Simd<u8, N>) -> u16 {
|
||||
// CHECK: zext <8 x i8>
|
||||
// CHECK-SAME: to <8 x i16>
|
||||
// CHECK: call i16 @llvm.vector.reduce.add.v8i16(<8 x i16>
|
||||
let x: Simd<u16, N> = x.cast();
|
||||
x.reduce_sum()
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
// CHECK-LABEL: @wider_reduce_loop
|
||||
pub fn wider_reduce_loop(x: Simd<u8, N>) -> u16 {
|
||||
// CHECK: zext <8 x i8>
|
||||
// CHECK-SAME: to <8 x i16>
|
||||
// CHECK: call i16 @llvm.vector.reduce.add.v8i16(<8 x i16>
|
||||
let mut sum = 0_u16;
|
||||
for i in 0..N {
|
||||
sum += u16::from(x[i]);
|
||||
}
|
||||
sum
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
// CHECK-LABEL: @wider_reduce_iter
|
||||
pub fn wider_reduce_iter(x: Simd<u8, N>) -> u16 {
|
||||
// CHECK: zext <8 x i8>
|
||||
// CHECK-SAME: to <8 x i16>
|
||||
// CHECK: call i16 @llvm.vector.reduce.add.v8i16(<8 x i16>
|
||||
x.as_array().iter().copied().map(u16::from).sum()
|
||||
}
|
||||
|
||||
// This iterator one is the most interesting, as it's the one
|
||||
// which used to not auto-vectorize due to a suboptimality in the
|
||||
// `<array::IntoIter as Iterator>::fold` implementation.
|
||||
|
||||
#[no_mangle]
|
||||
// CHECK-LABEL: @wider_reduce_into_iter
|
||||
pub fn wider_reduce_into_iter(x: Simd<u8, N>) -> u16 {
|
||||
// CHECK: zext <8 x i8>
|
||||
// CHECK-SAME: to <8 x i16>
|
||||
// CHECK: call i16 @llvm.vector.reduce.add.v8i16(<8 x i16>
|
||||
x.to_array().into_iter().map(u16::from).sum()
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue