[flang] Fix MAXLOC/MINLOC when MASK is scalar .FALSE.

When passing a scalar .FALSE. as the MASK argument to MAXLOC, we were getting
bad memory references.  We were falling into the code intended when the MASK
argument was missing.

I fixed this by checking for a scalar MASK with a .FALSE. value and
setting the result to all zeroes in that case.  I also added tests for
MAXLOC and MINLOC with scalar values of .TRUE. and .FALSE. for the MASK
argument.

I also special cased situations where the MASK argument is a scalar with
a .TRUE. value and passed along a nullptr in such cases.

Along the way, I eliminated the unused "chars" argument from the constructor
for ExtremumLocAccumulator.

Differential Revision: https://reviews.llvm.org/D124484
This commit is contained in:
Peter Steinfeld 2022-04-26 15:21:03 -07:00
parent 1747a93b28
commit 9df99d8ac2
2 changed files with 91 additions and 4 deletions

View file

@ -60,7 +60,7 @@ private:
template <typename COMPARE> class ExtremumLocAccumulator {
public:
using Type = typename COMPARE::Type;
ExtremumLocAccumulator(const Descriptor &array, std::size_t chars = 0)
ExtremumLocAccumulator(const Descriptor &array)
: array_{array}, argRank_{array.rank()}, compare_{array.ElementBytes()} {
Reinitialize();
}
@ -241,24 +241,39 @@ inline void TypedPartialMaxOrMinLoc(const char *intrinsic, Descriptor &result,
CheckIntegerKind(terminator, kind, intrinsic);
auto catKind{x.type().GetCategoryAndKind()};
RUNTIME_CHECK(terminator, catKind.has_value());
const Descriptor *maskToUse{mask};
SubscriptValue maskAt[maxRank]; // contents unused
if (mask && mask->rank() == 0) {
if (IsLogicalElementTrue(*mask, maskAt)) {
// A scalar MASK that's .TRUE. In this case, just get rid of the MASK.
maskToUse = nullptr;
} else {
// For scalar MASK arguments that are .FALSE., return all zeroes
CreatePartialReductionResult(result, x, dim, terminator, intrinsic,
TypeCode{TypeCategory::Integer, kind});
std::memset(
result.OffsetElement(), 0, result.Elements() * result.ElementBytes());
return;
}
}
switch (catKind->first) {
case TypeCategory::Integer:
ApplyIntegerKind<DoPartialMaxOrMinLocHelper<TypeCategory::Integer, IS_MAX,
NumericCompare>::template Functor,
void>(catKind->second, terminator, intrinsic, result, x, kind, dim,
mask, back, terminator);
maskToUse, back, terminator);
break;
case TypeCategory::Real:
ApplyFloatingPointKind<DoPartialMaxOrMinLocHelper<TypeCategory::Real,
IS_MAX, NumericCompare>::template Functor,
void>(catKind->second, terminator, intrinsic, result, x, kind, dim,
mask, back, terminator);
maskToUse, back, terminator);
break;
case TypeCategory::Character:
ApplyCharacterKind<DoPartialMaxOrMinLocHelper<TypeCategory::Character,
IS_MAX, CharacterCompare>::template Functor,
void>(catKind->second, terminator, intrinsic, result, x, kind, dim,
mask, back, terminator);
maskToUse, back, terminator);
break;
default:
terminator.Crash(

View file

@ -158,12 +158,84 @@ TEST(Reductions, DoubleMaxMinNorm2) {
EXPECT_EQ(scalarResult.rank(), 0);
EXPECT_EQ(*scalarResult.ZeroBasedIndexedElement<std::int16_t>(0), 23);
scalarResult.Destroy();
// Test .FALSE. scalar MASK argument
auto falseMask{MakeArray<TypeCategory::Logical, 4>(
std::vector<int>{}, std::vector<std::int32_t>{0})};
RTNAME(MaxlocDim)
(loc, *array, /*KIND=*/4, /*DIM=*/2, __FILE__, __LINE__,
/*MASK=*/&*falseMask, /*BACK=*/false);
EXPECT_EQ(loc.rank(), 2);
EXPECT_EQ(loc.type().raw(), (TypeCode{TypeCategory::Integer, 4}.raw()));
EXPECT_EQ(loc.GetDimension(0).LowerBound(), 1);
EXPECT_EQ(loc.GetDimension(0).Extent(), 3);
EXPECT_EQ(loc.GetDimension(1).LowerBound(), 1);
EXPECT_EQ(loc.GetDimension(1).Extent(), 2);
for (int i{0}; i < 6; ++i) {
EXPECT_EQ(*loc.ZeroBasedIndexedElement<std::int32_t>(0), 0);
}
loc.Destroy();
// Test .TRUE. scalar MASK argument
auto trueMask{MakeArray<TypeCategory::Logical, 4>(
std::vector<int>{}, std::vector<std::int32_t>{1})};
RTNAME(MaxlocDim)
(loc, *array, /*KIND=*/4, /*DIM=*/2, __FILE__, __LINE__,
/*MASK=*/&*trueMask, /*BACK=*/false);
EXPECT_EQ(loc.rank(), 2);
EXPECT_EQ(loc.type().raw(), (TypeCode{TypeCategory::Integer, 4}.raw()));
EXPECT_EQ(loc.GetDimension(0).LowerBound(), 1);
EXPECT_EQ(loc.GetDimension(0).Extent(), 3);
EXPECT_EQ(loc.GetDimension(1).LowerBound(), 1);
EXPECT_EQ(loc.GetDimension(1).Extent(), 2);
EXPECT_EQ(*loc.ZeroBasedIndexedElement<std::int32_t>(0), 3);
EXPECT_EQ(*loc.ZeroBasedIndexedElement<std::int32_t>(1), 4);
EXPECT_EQ(*loc.ZeroBasedIndexedElement<std::int32_t>(2), 3);
EXPECT_EQ(*loc.ZeroBasedIndexedElement<std::int32_t>(3), 3);
EXPECT_EQ(*loc.ZeroBasedIndexedElement<std::int32_t>(4), 4);
EXPECT_EQ(*loc.ZeroBasedIndexedElement<std::int32_t>(5), 4);
loc.Destroy();
RTNAME(MinlocDim)
(scalarResult, *array1, /*KIND=*/2, /*DIM=*/1, __FILE__, __LINE__,
/*MASK=*/nullptr, /*BACK=*/true);
EXPECT_EQ(scalarResult.rank(), 0);
EXPECT_EQ(*scalarResult.ZeroBasedIndexedElement<std::int16_t>(0), 22);
scalarResult.Destroy();
// Test .FALSE. scalar MASK argument
RTNAME(MinlocDim)
(loc, *array, /*KIND=*/4, /*DIM=*/2, __FILE__, __LINE__,
/*MASK=*/&*falseMask, /*BACK=*/false);
EXPECT_EQ(loc.rank(), 2);
EXPECT_EQ(loc.type().raw(), (TypeCode{TypeCategory::Integer, 4}.raw()));
EXPECT_EQ(loc.GetDimension(0).LowerBound(), 1);
EXPECT_EQ(loc.GetDimension(0).Extent(), 3);
EXPECT_EQ(loc.GetDimension(1).LowerBound(), 1);
EXPECT_EQ(loc.GetDimension(1).Extent(), 2);
for (int i{0}; i < 6; ++i) {
EXPECT_EQ(*loc.ZeroBasedIndexedElement<std::int32_t>(0), 0);
}
loc.Destroy();
// Test .TRUE. scalar MASK argument
RTNAME(MinlocDim)
(loc, *array, /*KIND=*/4, /*DIM=*/2, __FILE__, __LINE__,
/*MASK=*/&*trueMask, /*BACK=*/false);
EXPECT_EQ(loc.rank(), 2);
EXPECT_EQ(loc.type().raw(), (TypeCode{TypeCategory::Integer, 4}.raw()));
EXPECT_EQ(loc.GetDimension(0).LowerBound(), 1);
EXPECT_EQ(loc.GetDimension(0).Extent(), 3);
EXPECT_EQ(loc.GetDimension(1).LowerBound(), 1);
EXPECT_EQ(loc.GetDimension(1).Extent(), 2);
EXPECT_EQ(*loc.ZeroBasedIndexedElement<std::int32_t>(0), 4);
EXPECT_EQ(*loc.ZeroBasedIndexedElement<std::int32_t>(1), 3);
EXPECT_EQ(*loc.ZeroBasedIndexedElement<std::int32_t>(2), 4);
EXPECT_EQ(*loc.ZeroBasedIndexedElement<std::int32_t>(3), 4);
EXPECT_EQ(*loc.ZeroBasedIndexedElement<std::int32_t>(4), 3);
EXPECT_EQ(*loc.ZeroBasedIndexedElement<std::int32_t>(5), 2);
loc.Destroy();
RTNAME(MaxvalDim)
(scalarResult, *array1, /*DIM=*/1, __FILE__, __LINE__, /*MASK=*/nullptr);
EXPECT_EQ(scalarResult.rank(), 0);