llvm/flang/lib/Evaluate/fold.cpp
peter klausler a50bb84ec0 [flang] Fix classification of shape inquiries in specification exprs
In some contexts, including the motivating case of determining whether
the expressions that define the shape of a variable are "constant expressions"
in the sense of the Fortran standard, expression rewriting via Fold()
is not necessary, and should not be required.  The inquiry intrinsics LBOUND,
UBOUND, and SIZE work correctly now in specification expressions and are
classified correctly as being constant expressions (or not).  Getting this right
led to a fair amount of API clean-up as a consequence, including the
folding of shapes and TypeAndShape objects, and new APIs for shapes
that do not fold for those cases where folding isn't needed.  Further,
the symbol-testing predicate APIs in Evaluate/tools.h now all resolve any
associations of their symbols and work transparently on use-, host-, and
construct-association symbols; the tools used to resolve those associations have
been defined and documented more precisely, and their clients adjusted as needed.

Differential Revision: https://reviews.llvm.org/D94561
2021-01-13 10:05:14 -08:00

221 lines
7.8 KiB
C++

//===-- lib/Evaluate/fold.cpp ---------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "flang/Evaluate/fold.h"
#include "fold-implementation.h"
#include "flang/Evaluate/characteristics.h"
namespace Fortran::evaluate {
characteristics::TypeAndShape Fold(
FoldingContext &context, characteristics::TypeAndShape &&x) {
x.Rewrite(context);
return std::move(x);
}
std::optional<Constant<SubscriptInteger>> GetConstantSubscript(
FoldingContext &context, Subscript &ss, const NamedEntity &base, int dim) {
ss = FoldOperation(context, std::move(ss));
return std::visit(
common::visitors{
[](IndirectSubscriptIntegerExpr &expr)
-> std::optional<Constant<SubscriptInteger>> {
if (const auto *constant{
UnwrapConstantValue<SubscriptInteger>(expr.value())}) {
return *constant;
} else {
return std::nullopt;
}
},
[&](Triplet &triplet) -> std::optional<Constant<SubscriptInteger>> {
auto lower{triplet.lower()}, upper{triplet.upper()};
std::optional<ConstantSubscript> stride{ToInt64(triplet.stride())};
if (!lower) {
lower = GetLowerBound(context, base, dim);
}
if (!upper) {
upper =
ComputeUpperBound(context, GetLowerBound(context, base, dim),
GetExtent(context, base, dim));
}
auto lbi{ToInt64(lower)}, ubi{ToInt64(upper)};
if (lbi && ubi && stride && *stride != 0) {
std::vector<SubscriptInteger::Scalar> values;
while ((*stride > 0 && *lbi <= *ubi) ||
(*stride < 0 && *lbi >= *ubi)) {
values.emplace_back(*lbi);
*lbi += *stride;
}
return Constant<SubscriptInteger>{std::move(values),
ConstantSubscripts{
static_cast<ConstantSubscript>(values.size())}};
} else {
return std::nullopt;
}
},
},
ss.u);
}
Expr<SomeDerived> FoldOperation(
FoldingContext &context, StructureConstructor &&structure) {
StructureConstructor ctor{structure.derivedTypeSpec()};
bool constantExtents{true};
for (auto &&[symbol, value] : std::move(structure)) {
auto expr{Fold(context, std::move(value.value()))};
if (!IsPointer(symbol)) {
bool ok{false};
if (auto valueShape{GetConstantExtents(context, expr)}) {
if (auto componentShape{GetConstantExtents(context, symbol)}) {
if (GetRank(*componentShape) > 0 && GetRank(*valueShape) == 0) {
expr = ScalarConstantExpander{std::move(*componentShape)}.Expand(
std::move(expr));
ok = expr.Rank() > 0;
} else {
ok = *valueShape == *componentShape;
}
}
}
if (!ok) {
constantExtents = false;
}
}
ctor.Add(symbol, Fold(context, std::move(expr)));
}
if (constantExtents && IsConstantExpr(ctor)) {
return Expr<SomeDerived>{Constant<SomeDerived>{std::move(ctor)}};
} else {
return Expr<SomeDerived>{std::move(ctor)};
}
}
Component FoldOperation(FoldingContext &context, Component &&component) {
return {FoldOperation(context, std::move(component.base())),
component.GetLastSymbol()};
}
NamedEntity FoldOperation(FoldingContext &context, NamedEntity &&x) {
if (Component * c{x.UnwrapComponent()}) {
return NamedEntity{FoldOperation(context, std::move(*c))};
} else {
return std::move(x);
}
}
Triplet FoldOperation(FoldingContext &context, Triplet &&triplet) {
MaybeExtentExpr lower{triplet.lower()};
MaybeExtentExpr upper{triplet.upper()};
return {Fold(context, std::move(lower)), Fold(context, std::move(upper)),
Fold(context, triplet.stride())};
}
Subscript FoldOperation(FoldingContext &context, Subscript &&subscript) {
return std::visit(common::visitors{
[&](IndirectSubscriptIntegerExpr &&expr) {
expr.value() = Fold(context, std::move(expr.value()));
return Subscript(std::move(expr));
},
[&](Triplet &&triplet) {
return Subscript(
FoldOperation(context, std::move(triplet)));
},
},
std::move(subscript.u));
}
ArrayRef FoldOperation(FoldingContext &context, ArrayRef &&arrayRef) {
NamedEntity base{FoldOperation(context, std::move(arrayRef.base()))};
for (Subscript &subscript : arrayRef.subscript()) {
subscript = FoldOperation(context, std::move(subscript));
}
return ArrayRef{std::move(base), std::move(arrayRef.subscript())};
}
CoarrayRef FoldOperation(FoldingContext &context, CoarrayRef &&coarrayRef) {
std::vector<Subscript> subscript;
for (Subscript x : coarrayRef.subscript()) {
subscript.emplace_back(FoldOperation(context, std::move(x)));
}
std::vector<Expr<SubscriptInteger>> cosubscript;
for (Expr<SubscriptInteger> x : coarrayRef.cosubscript()) {
cosubscript.emplace_back(Fold(context, std::move(x)));
}
CoarrayRef folded{std::move(coarrayRef.base()), std::move(subscript),
std::move(cosubscript)};
if (std::optional<Expr<SomeInteger>> stat{coarrayRef.stat()}) {
folded.set_stat(Fold(context, std::move(*stat)));
}
if (std::optional<Expr<SomeInteger>> team{coarrayRef.team()}) {
folded.set_team(
Fold(context, std::move(*team)), coarrayRef.teamIsTeamNumber());
}
return folded;
}
DataRef FoldOperation(FoldingContext &context, DataRef &&dataRef) {
return std::visit(common::visitors{
[&](SymbolRef symbol) { return DataRef{*symbol}; },
[&](auto &&x) {
return DataRef{FoldOperation(context, std::move(x))};
},
},
std::move(dataRef.u));
}
Substring FoldOperation(FoldingContext &context, Substring &&substring) {
auto lower{Fold(context, substring.lower())};
auto upper{Fold(context, substring.upper())};
if (const DataRef * dataRef{substring.GetParentIf<DataRef>()}) {
return Substring{FoldOperation(context, DataRef{*dataRef}),
std::move(lower), std::move(upper)};
} else {
auto p{*substring.GetParentIf<StaticDataObject::Pointer>()};
return Substring{std::move(p), std::move(lower), std::move(upper)};
}
}
ComplexPart FoldOperation(FoldingContext &context, ComplexPart &&complexPart) {
DataRef complex{complexPart.complex()};
return ComplexPart{
FoldOperation(context, std::move(complex)), complexPart.part()};
}
std::optional<std::int64_t> GetInt64Arg(
const std::optional<ActualArgument> &arg) {
if (const auto *intExpr{UnwrapExpr<Expr<SomeInteger>>(arg)}) {
return ToInt64(*intExpr);
} else {
return std::nullopt;
}
}
std::optional<std::int64_t> GetInt64ArgOr(
const std::optional<ActualArgument> &arg, std::int64_t defaultValue) {
if (!arg) {
return defaultValue;
} else if (const auto *intExpr{UnwrapExpr<Expr<SomeInteger>>(arg)}) {
return ToInt64(*intExpr);
} else {
return std::nullopt;
}
}
Expr<ImpliedDoIndex::Result> FoldOperation(
FoldingContext &context, ImpliedDoIndex &&iDo) {
if (std::optional<ConstantSubscript> value{context.GetImpliedDo(iDo.name)}) {
return Expr<ImpliedDoIndex::Result>{*value};
} else {
return Expr<ImpliedDoIndex::Result>{std::move(iDo)};
}
}
template class ExpressionBase<SomeDerived>;
template class ExpressionBase<SomeType>;
} // namespace Fortran::evaluate