[flang] Fix LBOUND() folding for constant arrays

Previously constant folding uses 'dim' without checks which leads to ICE if we
do not have DIM= parameter. And for inputs without DIM= we need to form an
array of rank size with computed bounds instead of single value.

Add additional PackageConstant function to simplify 'if (dim)' handling since we
need to distinguish between scalar initialization in case of DIM= argument and
rank=1 array.

Also add a few more tests with 'parameter' type to verify folding for constant
arrays.

Reviewed By: jeanPerier

Differential Revision: https://reviews.llvm.org/D123237
This commit is contained in:
Mike Kashkarov 2022-04-20 19:47:21 +03:00
parent c8f822ad51
commit bd5371e4fc
2 changed files with 53 additions and 16 deletions

View file

@ -12,30 +12,52 @@
namespace Fortran::evaluate {
// Given a collection of ConstantSubscripts values, package them as a Constant.
// Return scalar value if asScalar == true and shape-dim array otherwise.
template <typename T>
Expr<T> PackageConstantBounds(
const ConstantSubscripts &&bounds, bool asScalar = false) {
if (asScalar) {
return Expr<T>{Constant<T>{bounds.at(0)}};
} else {
// As rank-dim array
const int rank{GetRank(bounds)};
std::vector<Scalar<T>> packed(rank);
std::transform(bounds.begin(), bounds.end(), packed.begin(),
[](ConstantSubscript x) { return Scalar<T>(x); });
return Expr<T>{Constant<T>{std::move(packed), ConstantSubscripts{rank}}};
}
}
// Class to retrieve the constant lower bound of an expression which is an
// array that devolves to a type of Constant<T>
class GetConstantArrayLboundHelper {
public:
GetConstantArrayLboundHelper(ConstantSubscript dim) : dim_{dim} {}
GetConstantArrayLboundHelper(std::optional<ConstantSubscript> dim)
: dim_{dim} {}
template <typename T> ConstantSubscript GetLbound(const T &) {
template <typename T> ConstantSubscripts GetLbound(const T &) {
// The method is needed for template expansion, but we should never get
// here in practice.
CHECK(false);
return 0;
return {0};
}
template <typename T> ConstantSubscript GetLbound(const Constant<T> &x) {
template <typename T> ConstantSubscripts GetLbound(const Constant<T> &x) {
// Return the lower bound
return x.lbounds()[dim_];
if (dim_) {
return {x.lbounds().at(*dim_)};
} else {
return x.lbounds();
}
}
template <typename T> ConstantSubscript GetLbound(const Parentheses<T> &x) {
template <typename T> ConstantSubscripts GetLbound(const Parentheses<T> &x) {
// Strip off the parentheses
return GetLbound(x.left());
}
template <typename T> ConstantSubscript GetLbound(const Expr<T> &x) {
template <typename T> ConstantSubscripts GetLbound(const Expr<T> &x) {
// recurse through Expr<T>'a until we hit a constant
return common::visit([&](const auto &inner) { return GetLbound(inner); },
// [&](const auto &) { return 0; },
@ -43,7 +65,7 @@ public:
}
private:
ConstantSubscript dim_;
std::optional<ConstantSubscript> dim_;
};
template <int KIND>
@ -89,16 +111,13 @@ Expr<Type<TypeCategory::Integer, KIND>> LBOUND(FoldingContext &context,
}
}
if (IsActuallyConstant(*array)) {
return Expr<T>{GetConstantArrayLboundHelper{*dim}.GetLbound(*array)};
const ConstantSubscripts bounds{
GetConstantArrayLboundHelper{dim}.GetLbound(*array)};
return PackageConstantBounds<T>(std::move(bounds), dim.has_value());
}
if (lowerBoundsAreOne) {
if (dim) {
return Expr<T>{1};
} else {
std::vector<Scalar<T>> ones(rank, Scalar<T>{1});
return Expr<T>{
Constant<T>{std::move(ones), ConstantSubscripts{rank}}};
}
ConstantSubscripts ones(rank, ConstantSubscript{1});
return PackageConstantBounds<T>(std::move(ones), dim.has_value());
}
}
}

View file

@ -77,4 +77,22 @@ module m
end block
end associate
end subroutine
subroutine test3_lbound_parameter
! Test lbound with constant arrays
integer, parameter :: a1(1) = 0
integer, parameter :: lba1(*) = lbound(a1)
logical, parameter :: test_lba1 = all(lba1 == [1])
integer, parameter :: a2(0:1) = 0
integer, parameter :: lba2(*) = lbound(a2)
logical, parameter :: test_lba2 = all(lba2 == [0])
integer, parameter :: a3(-10:-5,1,4:6) = 0
integer, parameter :: lba3(*) = lbound(a3)
logical, parameter :: test_lba3 = all(lba3 == [-10, 1, 4])
! Exercise with DIM=
logical, parameter :: test_lba3_dim = lbound(a3, 1) == -10 .and. &
lbound(a3, 2) == 1 .and. &
lbound(a3, 3) == 4
end subroutine
end