1
Fork 0

fix OOB access in SIMD impl of str.contains()

This commit is contained in:
The 8472 2022-11-22 20:59:19 +01:00
parent d576a9b241
commit 3ed8fccff5

View file

@ -1741,6 +1741,7 @@ fn simd_contains(needle: &str, haystack: &str) -> Option<bool> {
use crate::simd::{SimdPartialEq, ToBitMask}; use crate::simd::{SimdPartialEq, ToBitMask};
let first_probe = needle[0]; let first_probe = needle[0];
let last_byte_offset = needle.len() - 1;
// the offset used for the 2nd vector // the offset used for the 2nd vector
let second_probe_offset = if needle.len() == 2 { let second_probe_offset = if needle.len() == 2 {
@ -1758,7 +1759,7 @@ fn simd_contains(needle: &str, haystack: &str) -> Option<bool> {
}; };
// do a naive search if the haystack is too small to fit // do a naive search if the haystack is too small to fit
if haystack.len() < Block::LANES + second_probe_offset { if haystack.len() < Block::LANES + last_byte_offset {
return Some(haystack.windows(needle.len()).any(|c| c == needle)); return Some(haystack.windows(needle.len()).any(|c| c == needle));
} }
@ -1815,7 +1816,7 @@ fn simd_contains(needle: &str, haystack: &str) -> Option<bool> {
// The loop condition must ensure that there's enough headroom to read LANE bytes, // The loop condition must ensure that there's enough headroom to read LANE bytes,
// and not only at the current index but also at the index shifted by block_offset // and not only at the current index but also at the index shifted by block_offset
const UNROLL: usize = 4; const UNROLL: usize = 4;
while i + second_probe_offset + UNROLL * Block::LANES < haystack.len() && !result { while i + last_byte_offset + UNROLL * Block::LANES < haystack.len() && !result {
let mut masks = [0u16; UNROLL]; let mut masks = [0u16; UNROLL];
for j in 0..UNROLL { for j in 0..UNROLL {
masks[j] = test_chunk(i + j * Block::LANES); masks[j] = test_chunk(i + j * Block::LANES);
@ -1828,7 +1829,7 @@ fn simd_contains(needle: &str, haystack: &str) -> Option<bool> {
} }
i += UNROLL * Block::LANES; i += UNROLL * Block::LANES;
} }
while i + second_probe_offset + Block::LANES < haystack.len() && !result { while i + last_byte_offset + Block::LANES < haystack.len() && !result {
let mask = test_chunk(i); let mask = test_chunk(i);
if mask != 0 { if mask != 0 {
result |= check_mask(i, mask, result); result |= check_mask(i, mask, result);
@ -1840,7 +1841,7 @@ fn simd_contains(needle: &str, haystack: &str) -> Option<bool> {
// This simply repeats the same procedure but as right-aligned chunk instead // This simply repeats the same procedure but as right-aligned chunk instead
// of a left-aligned one. The last byte must be exactly flush with the string end so // of a left-aligned one. The last byte must be exactly flush with the string end so
// we don't miss a single byte or read out of bounds. // we don't miss a single byte or read out of bounds.
let i = haystack.len() - second_probe_offset - Block::LANES; let i = haystack.len() - last_byte_offset - Block::LANES;
let mask = test_chunk(i); let mask = test_chunk(i);
if mask != 0 { if mask != 0 {
result |= check_mask(i, mask, result); result |= check_mask(i, mask, result);
@ -1860,6 +1861,7 @@ fn simd_contains(needle: &str, haystack: &str) -> Option<bool> {
#[cfg(all(target_arch = "x86_64", target_feature = "sse2"))] // only called on x86 #[cfg(all(target_arch = "x86_64", target_feature = "sse2"))] // only called on x86
#[inline] #[inline]
unsafe fn small_slice_eq(x: &[u8], y: &[u8]) -> bool { unsafe fn small_slice_eq(x: &[u8], y: &[u8]) -> bool {
debug_assert_eq!(x.len(), y.len());
// This function is adapted from // This function is adapted from
// https://github.com/BurntSushi/memchr/blob/8037d11b4357b0f07be2bb66dc2659d9cf28ad32/src/memmem/util.rs#L32 // https://github.com/BurntSushi/memchr/blob/8037d11b4357b0f07be2bb66dc2659d9cf28ad32/src/memmem/util.rs#L32