diff --git a/crates/core_simd/src/comparisons.rs b/crates/core_simd/src/comparisons.rs index 88270a9b7e9..7b0d0a6864b 100644 --- a/crates/core_simd/src/comparisons.rs +++ b/crates/core_simd/src/comparisons.rs @@ -67,36 +67,54 @@ where } } -macro_rules! impl_min_max_vector { +macro_rules! impl_ord_methods_vector { { $type:ty } => { impl Simd<$type, LANES> where LaneCount: SupportedLaneCount, { - /// Returns the lane-wise minimum with other + /// Returns the lane-wise minimum with `other`. #[must_use = "method returns a new vector and does not mutate the original value"] #[inline] pub fn min(self, other: Self) -> Self { self.lanes_gt(other).select(other, self) } - /// Returns the lane-wise maximum with other + /// Returns the lane-wise maximum with `other`. #[must_use = "method returns a new vector and does not mutate the original value"] #[inline] pub fn max(self, other: Self) -> Self { self.lanes_lt(other).select(other, self) } + + /// Restrict each lane to a certain interval. + /// + /// For each lane, returns `max` if `self` is greater than `max`, and `min` if `self` is + /// less than `min`. Otherwise returns `self`. + /// + /// # Panics + /// + /// Panics if `min > max` on any lane. + #[must_use = "method returns a new vector and does not mutate the original value"] + #[inline] + pub fn clamp(self, min: Self, max: Self) -> Self { + assert!( + min.lanes_le(max).all(), + "each lane in `min` must be less than or equal to the corresponding lane in `max`", + ); + self.max(min).min(max) + } } } } -impl_min_max_vector!(i8); -impl_min_max_vector!(i16); -impl_min_max_vector!(i32); -impl_min_max_vector!(i64); -impl_min_max_vector!(isize); -impl_min_max_vector!(u8); -impl_min_max_vector!(u16); -impl_min_max_vector!(u32); -impl_min_max_vector!(u64); -impl_min_max_vector!(usize); +impl_ord_methods_vector!(i8); +impl_ord_methods_vector!(i16); +impl_ord_methods_vector!(i32); +impl_ord_methods_vector!(i64); +impl_ord_methods_vector!(isize); +impl_ord_methods_vector!(u8); +impl_ord_methods_vector!(u16); +impl_ord_methods_vector!(u32); +impl_ord_methods_vector!(u64); +impl_ord_methods_vector!(usize); diff --git a/crates/core_simd/tests/i16_ops.rs b/crates/core_simd/tests/i16_ops.rs index cd6cadc2d5e..171e5b472fa 100644 --- a/crates/core_simd/tests/i16_ops.rs +++ b/crates/core_simd/tests/i16_ops.rs @@ -18,3 +18,15 @@ fn min_is_not_lexicographic() { let b = i16x2::from_array([12, -4]); assert_eq!(a.min(b), i16x2::from_array([10, -4])); } + +#[test] +fn clamp_is_not_lexicographic() { + let a = i16x2::splat(10); + let lo = i16x2::from_array([-12, -4]); + let up = i16x2::from_array([-4, 12]); + assert_eq!(a.clamp(lo, up), i16x2::from_array([-4, 10])); + + let x = i16x2::from_array([1, 10]); + let y = x.clamp(i16x2::splat(0), i16x2::splat(9)); + assert_eq!(y, i16x2::from_array([1, 9])); +} diff --git a/crates/core_simd/tests/ops_macros.rs b/crates/core_simd/tests/ops_macros.rs index 96da8c1b8dc..bea02750ef2 100644 --- a/crates/core_simd/tests/ops_macros.rs +++ b/crates/core_simd/tests/ops_macros.rs @@ -239,6 +239,18 @@ macro_rules! impl_signed_tests { let b = Vector::::splat(0); assert_eq!(a.max(b), a); } + + fn clamp() { + let min = Vector::::splat(Scalar::MIN); + let max = Vector::::splat(Scalar::MAX); + let zero = Vector::::splat(0); + let one = Vector::::splat(1); + let negone = Vector::::splat(-1); + assert_eq!(zero.clamp(min, max), zero); + assert_eq!(zero.clamp(min, one), zero); + assert_eq!(zero.clamp(one, max), one); + assert_eq!(zero.clamp(min, negone), negone); + } } test_helpers::test_lanes_panic! {