Add min/max/clamp
This commit is contained in:
parent
b936f34a5c
commit
74e6262ce4
3 changed files with 132 additions and 0 deletions
|
@ -136,6 +136,47 @@ macro_rules! impl_float_vector {
|
|||
let magnitude = self.to_bits() & !Self::splat(-0.).to_bits();
|
||||
Self::from_bits(sign_bit | magnitude)
|
||||
}
|
||||
|
||||
/// Returns the minimum of each lane.
|
||||
///
|
||||
/// If one of the values is `NAN`, then the other value is returned.
|
||||
#[inline]
|
||||
pub fn min(self, other: Self) -> Self {
|
||||
// TODO consider using an intrinsic
|
||||
self.is_nan().select(
|
||||
other,
|
||||
self.lanes_ge(other).select(other, self)
|
||||
)
|
||||
}
|
||||
|
||||
/// Returns the maximum of each lane.
|
||||
///
|
||||
/// If one of the values is `NAN`, then the other value is returned.
|
||||
#[inline]
|
||||
pub fn max(self, other: Self) -> Self {
|
||||
// TODO consider using an intrinsic
|
||||
self.is_nan().select(
|
||||
other,
|
||||
self.lanes_le(other).select(other, self)
|
||||
)
|
||||
}
|
||||
|
||||
/// Restrict each lane to a certain interval unless it is NaN.
|
||||
///
|
||||
/// For each lane in `self`, returns the corresponding lane in `max` if the lane is
|
||||
/// greater than `max`, and the corresponding lane in `min` if the lane is less
|
||||
/// than `min`. Otherwise returns the lane in `self`.
|
||||
#[inline]
|
||||
pub fn clamp(self, min: Self, max: Self) -> Self {
|
||||
assert!(
|
||||
min.lanes_le(max).all(),
|
||||
"each lane in `min` must be less than or equal to the corresponding lane in `max`",
|
||||
);
|
||||
let mut x = self;
|
||||
x = x.lanes_lt(min).select(min, x);
|
||||
x = x.lanes_gt(max).select(max, x);
|
||||
x
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
|
|
@ -483,6 +483,76 @@ macro_rules! impl_float_tests {
|
|||
)
|
||||
}
|
||||
|
||||
fn min<const LANES: usize>() {
|
||||
// Regular conditions (both values aren't zero)
|
||||
test_helpers::test_binary_elementwise(
|
||||
&Vector::<LANES>::min,
|
||||
&Scalar::min,
|
||||
// Reject the case where both values are zero with different signs
|
||||
&|a, b| {
|
||||
for (a, b) in a.iter().zip(b.iter()) {
|
||||
if *a == 0. && *b == 0. && a.signum() != b.signum() {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
);
|
||||
|
||||
// Special case where both values are zero
|
||||
let p_zero = Vector::<LANES>::splat(0.);
|
||||
let n_zero = Vector::<LANES>::splat(-0.);
|
||||
assert!(p_zero.min(n_zero).to_array().iter().all(|x| *x == 0.));
|
||||
assert!(n_zero.min(p_zero).to_array().iter().all(|x| *x == 0.));
|
||||
}
|
||||
|
||||
fn max<const LANES: usize>() {
|
||||
// Regular conditions (both values aren't zero)
|
||||
test_helpers::test_binary_elementwise(
|
||||
&Vector::<LANES>::max,
|
||||
&Scalar::max,
|
||||
// Reject the case where both values are zero with different signs
|
||||
&|a, b| {
|
||||
for (a, b) in a.iter().zip(b.iter()) {
|
||||
if *a == 0. && *b == 0. && a.signum() != b.signum() {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
);
|
||||
|
||||
// Special case where both values are zero
|
||||
let p_zero = Vector::<LANES>::splat(0.);
|
||||
let n_zero = Vector::<LANES>::splat(-0.);
|
||||
assert!(p_zero.min(n_zero).to_array().iter().all(|x| *x == 0.));
|
||||
assert!(n_zero.min(p_zero).to_array().iter().all(|x| *x == 0.));
|
||||
}
|
||||
|
||||
fn clamp<const LANES: usize>() {
|
||||
test_helpers::test_3(&|value: [Scalar; LANES], mut min: [Scalar; LANES], mut max: [Scalar; LANES]| {
|
||||
for (min, max) in min.iter_mut().zip(max.iter_mut()) {
|
||||
if max < min {
|
||||
core::mem::swap(min, max);
|
||||
}
|
||||
if min.is_nan() {
|
||||
*min = Scalar::NEG_INFINITY;
|
||||
}
|
||||
if max.is_nan() {
|
||||
*max = Scalar::INFINITY;
|
||||
}
|
||||
}
|
||||
|
||||
let mut result_scalar = [Scalar::default(); LANES];
|
||||
for i in 0..LANES {
|
||||
result_scalar[i] = value[i].clamp(min[i], max[i]);
|
||||
}
|
||||
let result_vector = Vector::from_array(value).clamp(min.into(), max.into()).to_array();
|
||||
test_helpers::prop_assert_biteq!(result_scalar, result_vector);
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
fn horizontal_sum<const LANES: usize>() {
|
||||
test_helpers::test_1(&|x| {
|
||||
test_helpers::prop_assert_biteq! (
|
||||
|
|
|
@ -97,6 +97,27 @@ pub fn test_2<A: core::fmt::Debug + DefaultStrategy, B: core::fmt::Debug + Defau
|
|||
.unwrap();
|
||||
}
|
||||
|
||||
/// Test a function that takes two values.
|
||||
pub fn test_3<
|
||||
A: core::fmt::Debug + DefaultStrategy,
|
||||
B: core::fmt::Debug + DefaultStrategy,
|
||||
C: core::fmt::Debug + DefaultStrategy,
|
||||
>(
|
||||
f: &dyn Fn(A, B, C) -> proptest::test_runner::TestCaseResult,
|
||||
) {
|
||||
let mut runner = proptest::test_runner::TestRunner::default();
|
||||
runner
|
||||
.run(
|
||||
&(
|
||||
A::default_strategy(),
|
||||
B::default_strategy(),
|
||||
C::default_strategy(),
|
||||
),
|
||||
|(a, b, c)| f(a, b, c),
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
/// Test a unary vector function against a unary scalar function, applied elementwise.
|
||||
#[inline(never)]
|
||||
pub fn test_unary_elementwise<Scalar, ScalarResult, Vector, VectorResult, const LANES: usize>(
|
||||
|
|
Loading…
Reference in a new issue