1
Fork 0

core: fix const ptr::swap_nonoverlapping when there are pointers at odd offsets in the type

This commit is contained in:
Ralf Jung 2024-12-23 15:30:27 +01:00
parent 85c39893a7
commit af1c8da172
5 changed files with 107 additions and 49 deletions

View file

@ -3795,7 +3795,7 @@ where
/// See [`const_eval_select()`] for the rules and requirements around that intrinsic.
pub(crate) macro const_eval_select {
(
@capture { $($arg:ident : $ty:ty = $val:expr),* $(,)? } $( -> $ret:ty )? :
@capture$([$($binders:tt)*])? { $($arg:ident : $ty:ty = $val:expr),* $(,)? } $( -> $ret:ty )? :
if const
$(#[$compiletime_attr:meta])* $compiletime:block
else
@ -3803,7 +3803,7 @@ pub(crate) macro const_eval_select {
) => {
// Use the `noinline` arm, after adding explicit `inline` attributes
$crate::intrinsics::const_eval_select!(
@capture { $($arg : $ty = $val),* } $(-> $ret)? :
@capture$([$($binders)*])? { $($arg : $ty = $val),* } $(-> $ret)? :
#[noinline]
if const
#[inline] // prevent codegen on this function
@ -3817,7 +3817,7 @@ pub(crate) macro const_eval_select {
},
// With a leading #[noinline], we don't add inline attributes
(
@capture { $($arg:ident : $ty:ty = $val:expr),* $(,)? } $( -> $ret:ty )? :
@capture$([$($binders:tt)*])? { $($arg:ident : $ty:ty = $val:expr),* $(,)? } $( -> $ret:ty )? :
#[noinline]
if const
$(#[$compiletime_attr:meta])* $compiletime:block
@ -3825,12 +3825,12 @@ pub(crate) macro const_eval_select {
$(#[$runtime_attr:meta])* $runtime:block
) => {{
$(#[$runtime_attr])*
fn runtime($($arg: $ty),*) $( -> $ret )? {
fn runtime$(<$($binders)*>)?($($arg: $ty),*) $( -> $ret )? {
$runtime
}
$(#[$compiletime_attr])*
const fn compiletime($($arg: $ty),*) $( -> $ret )? {
const fn compiletime$(<$($binders)*>)?($($arg: $ty),*) $( -> $ret )? {
// Don't warn if one of the arguments is unused.
$(let _ = $arg;)*
@ -3842,14 +3842,14 @@ pub(crate) macro const_eval_select {
// We support leaving away the `val` expressions for *all* arguments
// (but not for *some* arguments, that's too tricky).
(
@capture { $($arg:ident : $ty:ty),* $(,)? } $( -> $ret:ty )? :
@capture$([$($binders:tt)*])? { $($arg:ident : $ty:ty),* $(,)? } $( -> $ret:ty )? :
if const
$(#[$compiletime_attr:meta])* $compiletime:block
else
$(#[$runtime_attr:meta])* $runtime:block
) => {
$crate::intrinsics::const_eval_select!(
@capture { $($arg : $ty = $arg),* } $(-> $ret)? :
@capture$([$($binders)*])? { $($arg : $ty = $arg),* } $(-> $ret)? :
if const
$(#[$compiletime_attr])* $compiletime
else

View file

@ -395,6 +395,7 @@
#![allow(clippy::not_unsafe_ptr_arg_deref)]
use crate::cmp::Ordering;
use crate::intrinsics::const_eval_select;
use crate::marker::FnPtr;
use crate::mem::{self, MaybeUninit, SizedTypeProperties};
use crate::{fmt, hash, intrinsics, ub_checks};
@ -1074,25 +1075,6 @@ pub const unsafe fn swap<T>(x: *mut T, y: *mut T) {
#[rustc_const_unstable(feature = "const_swap_nonoverlapping", issue = "133668")]
#[rustc_diagnostic_item = "ptr_swap_nonoverlapping"]
pub const unsafe fn swap_nonoverlapping<T>(x: *mut T, y: *mut T, count: usize) {
#[allow(unused)]
macro_rules! attempt_swap_as_chunks {
($ChunkTy:ty) => {
if mem::align_of::<T>() >= mem::align_of::<$ChunkTy>()
&& mem::size_of::<T>() % mem::size_of::<$ChunkTy>() == 0
{
let x: *mut $ChunkTy = x.cast();
let y: *mut $ChunkTy = y.cast();
let count = count * (mem::size_of::<T>() / mem::size_of::<$ChunkTy>());
// SAFETY: these are the same bytes that the caller promised were
// ok, just typed as `MaybeUninit<ChunkTy>`s instead of as `T`s.
// The `if` condition above ensures that we're not violating
// alignment requirements, and that the division is exact so
// that we don't lose any bytes off the end.
return unsafe { swap_nonoverlapping_simple_untyped(x, y, count) };
}
};
}
ub_checks::assert_unsafe_precondition!(
check_language_ub,
"ptr::swap_nonoverlapping requires that both pointer arguments are aligned and non-null \
@ -1111,19 +1093,48 @@ pub const unsafe fn swap_nonoverlapping<T>(x: *mut T, y: *mut T, count: usize) {
}
);
// Split up the slice into small power-of-two-sized chunks that LLVM is able
// to vectorize (unless it's a special type with more-than-pointer alignment,
// because we don't want to pessimize things like slices of SIMD vectors.)
if mem::align_of::<T>() <= mem::size_of::<usize>()
&& (!mem::size_of::<T>().is_power_of_two()
|| mem::size_of::<T>() > mem::size_of::<usize>() * 2)
{
attempt_swap_as_chunks!(usize);
attempt_swap_as_chunks!(u8);
}
const_eval_select!(
@capture[T] { x: *mut T, y: *mut T, count: usize }:
if const {
// At compile-time we want to always copy this in chunks of `T`, to ensure that if there
// are pointers inside `T` we will copy them in one go rather than trying to copy a part
// of a pointer (which would not work).
// SAFETY: Same preconditions as this function
unsafe { swap_nonoverlapping_simple_untyped(x, y, count) }
} else {
macro_rules! attempt_swap_as_chunks {
($ChunkTy:ty) => {
if mem::align_of::<T>() >= mem::align_of::<$ChunkTy>()
&& mem::size_of::<T>() % mem::size_of::<$ChunkTy>() == 0
{
let x: *mut $ChunkTy = x.cast();
let y: *mut $ChunkTy = y.cast();
let count = count * (mem::size_of::<T>() / mem::size_of::<$ChunkTy>());
// SAFETY: these are the same bytes that the caller promised were
// ok, just typed as `MaybeUninit<ChunkTy>`s instead of as `T`s.
// The `if` condition above ensures that we're not violating
// alignment requirements, and that the division is exact so
// that we don't lose any bytes off the end.
return unsafe { swap_nonoverlapping_simple_untyped(x, y, count) };
}
};
}
// SAFETY: Same preconditions as this function
unsafe { swap_nonoverlapping_simple_untyped(x, y, count) }
// Split up the slice into small power-of-two-sized chunks that LLVM is able
// to vectorize (unless it's a special type with more-than-pointer alignment,
// because we don't want to pessimize things like slices of SIMD vectors.)
if mem::align_of::<T>() <= mem::size_of::<usize>()
&& (!mem::size_of::<T>().is_power_of_two()
|| mem::size_of::<T>() > mem::size_of::<usize>() * 2)
{
attempt_swap_as_chunks!(usize);
attempt_swap_as_chunks!(u8);
}
// SAFETY: Same preconditions as this function
unsafe { swap_nonoverlapping_simple_untyped(x, y, count) }
}
)
}
/// Same behavior and safety conditions as [`swap_nonoverlapping`]

View file

@ -16,6 +16,7 @@
#![feature(const_black_box)]
#![feature(const_eval_select)]
#![feature(const_swap)]
#![feature(const_swap_nonoverlapping)]
#![feature(const_trait_impl)]
#![feature(core_intrinsics)]
#![feature(core_io_borrowed_buf)]

View file

@ -860,7 +860,10 @@ fn swap_copy_untyped() {
}
#[test]
fn test_const_copy() {
fn test_const_copy_ptr() {
// `copy` and `copy_nonoverlapping` are thin layers on top of intrinsics. Ensure they correctly
// deal with pointers even when the pointers cross the boundary from one "element" being copied
// to another.
const {
let ptr1 = &1;
let mut ptr2 = &666;
@ -899,21 +902,61 @@ fn test_const_copy() {
}
#[test]
fn test_const_swap() {
const {
let mut ptr1 = &1;
let mut ptr2 = &666;
fn test_const_swap_ptr() {
// The `swap` functions are implemented in the library, they are not primitives.
// Only `swap_nonoverlapping` takes a count; pointers that cross multiple elements
// are *not* supported.
// We put the pointer at an odd offset in the type and copy them as an array of bytes,
// which should catch most of the ways that the library implementation can get it wrong.
// Swap ptr1 and ptr2, bytewise. `swap` does not take a count
// so the best we can do is use an array.
type T = [u8; mem::size_of::<&i32>()];
#[cfg(target_pointer_width = "32")]
type HalfPtr = i16;
#[cfg(target_pointer_width = "64")]
type HalfPtr = i32;
#[repr(C, packed)]
#[allow(unused)]
struct S {
f1: HalfPtr,
// Crucially this field is at an offset that is not a multiple of the pointer size.
ptr: &'static i32,
// Make sure the entire type does not have a power-of-2 size:
// make it 3 pointers in size. This used to hit a bug in `swap_nonoverlapping`.
f2: [HalfPtr; 3],
}
// Ensure the entire thing is usize-aligned, so in principle this
// looks like it could be eligible for a `usize` copying loop.
#[cfg_attr(target_pointer_width = "32", repr(align(4)))]
#[cfg_attr(target_pointer_width = "64", repr(align(8)))]
struct A(S);
const {
let mut s1 = A(S { ptr: &1, f1: 0, f2: [0; 3] });
let mut s2 = A(S { ptr: &666, f1: 0, f2: [0; 3] });
// Swap ptr1 and ptr2, as an array.
type T = [u8; mem::size_of::<A>()];
unsafe {
ptr::swap(ptr::from_mut(&mut ptr1).cast::<T>(), ptr::from_mut(&mut ptr2).cast::<T>());
ptr::swap(ptr::from_mut(&mut s1).cast::<T>(), ptr::from_mut(&mut s2).cast::<T>());
}
// Make sure they still work.
assert!(*ptr1 == 666);
assert!(*ptr2 == 1);
assert!(*s1.0.ptr == 666);
assert!(*s2.0.ptr == 1);
// Swap them back, again as an array.
unsafe {
ptr::swap_nonoverlapping(
ptr::from_mut(&mut s1).cast::<T>(),
ptr::from_mut(&mut s2).cast::<T>(),
1,
);
}
// Make sure they still work.
assert!(*s1.0.ptr == 1);
assert!(*s2.0.ptr == 666);
};
}

View file

@ -7,6 +7,8 @@ note: inside `std::ptr::read::<MaybeUninit<MaybeUninit<u8>>>`
--> $SRC_DIR/core/src/ptr/mod.rs:LL:COL
note: inside `std::ptr::swap_nonoverlapping_simple_untyped::<MaybeUninit<u8>>`
--> $SRC_DIR/core/src/ptr/mod.rs:LL:COL
note: inside `swap_nonoverlapping::compiletime::<MaybeUninit<u8>>`
--> $SRC_DIR/core/src/ptr/mod.rs:LL:COL
note: inside `swap_nonoverlapping::<MaybeUninit<u8>>`
--> $SRC_DIR/core/src/ptr/mod.rs:LL:COL
note: inside `X`
@ -20,6 +22,7 @@ note: inside `X`
| |_________^
= help: this code performed an operation that depends on the underlying bytes representing a pointer
= help: the absolute address of a pointer is not known at compile-time, so such operations are not supported
= note: this error originates in the macro `$crate::intrinsics::const_eval_select` which comes from the expansion of the macro `const_eval_select` (in Nightly builds, run with -Z macro-backtrace for more info)
error: aborting due to 1 previous error