[flang] Implement GetShape with expression visitor
Original-commit: flang-compiler/f18@d607d02847 Reviewed-on: https://github.com/flang-compiler/f18/pull/611 Tree-same-pre-rewrite: false
This commit is contained in:
parent
b72ef0b370
commit
a7041f3a78
|
@ -298,53 +298,35 @@ MaybeExtentExpr GetUpperBound(FoldingContext &context, MaybeExtentExpr &&lower,
|
|||
}
|
||||
}
|
||||
|
||||
std::optional<Shape> GetShapeHelper::GetShape(const Symbol &symbol) {
|
||||
return GetShape(NamedEntity{symbol});
|
||||
void GetShapeVisitor::Handle(const Symbol &symbol) {
|
||||
Handle(NamedEntity{symbol});
|
||||
}
|
||||
|
||||
std::optional<Shape> GetShapeHelper::GetShape(const Symbol *symbol) {
|
||||
if (symbol != nullptr) {
|
||||
return GetShape(*symbol);
|
||||
} else {
|
||||
return std::nullopt;
|
||||
}
|
||||
void GetShapeVisitor::Handle(const Component &component) {
|
||||
Handle(NamedEntity{Component{component}});
|
||||
}
|
||||
|
||||
std::optional<Shape> GetShapeHelper::GetShape(const Component &component) {
|
||||
return GetShape(NamedEntity{Component{component}});
|
||||
}
|
||||
|
||||
std::optional<Shape> GetShapeHelper::GetShape(const NamedEntity &base) {
|
||||
void GetShapeVisitor::Handle(const NamedEntity &base) {
|
||||
const Symbol &symbol{base.GetLastSymbol()};
|
||||
if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
|
||||
if (IsImpliedShape(symbol)) {
|
||||
return GetShape(*details->init());
|
||||
Nested(details->init());
|
||||
} else {
|
||||
Shape result;
|
||||
int n{static_cast<int>(details->shape().size())};
|
||||
for (int dimension{0}; dimension < n; ++dimension) {
|
||||
result.emplace_back(GetExtent(context_, base, dimension));
|
||||
}
|
||||
return result;
|
||||
Return(std::move(result));
|
||||
}
|
||||
} else if (const auto *details{
|
||||
symbol.detailsIf<semantics::AssocEntityDetails>()}) {
|
||||
if (details->expr().has_value()) {
|
||||
return GetShape(*details->expr());
|
||||
}
|
||||
Nested(details->expr());
|
||||
}
|
||||
return std::nullopt;
|
||||
Return();
|
||||
}
|
||||
|
||||
std::optional<Shape> GetShapeHelper::GetShape(const BaseObject &object) {
|
||||
if (const Symbol * symbol{object.symbol()}) {
|
||||
return GetShape(*symbol);
|
||||
} else {
|
||||
return Shape{};
|
||||
}
|
||||
}
|
||||
|
||||
std::optional<Shape> GetShapeHelper::GetShape(const ArrayRef &arrayRef) {
|
||||
void GetShapeVisitor::Handle(const ArrayRef &arrayRef) {
|
||||
Shape shape;
|
||||
int dimension{0};
|
||||
for (const Subscript &ss : arrayRef.subscript()) {
|
||||
|
@ -354,13 +336,13 @@ std::optional<Shape> GetShapeHelper::GetShape(const ArrayRef &arrayRef) {
|
|||
++dimension;
|
||||
}
|
||||
if (shape.empty()) {
|
||||
return GetShape(arrayRef.base());
|
||||
Nested(arrayRef.base());
|
||||
} else {
|
||||
return shape;
|
||||
Return(std::move(shape));
|
||||
}
|
||||
}
|
||||
|
||||
std::optional<Shape> GetShapeHelper::GetShape(const CoarrayRef &coarrayRef) {
|
||||
void GetShapeVisitor::Handle(const CoarrayRef &coarrayRef) {
|
||||
Shape shape;
|
||||
NamedEntity base{coarrayRef.GetBase()};
|
||||
int dimension{0};
|
||||
|
@ -371,102 +353,45 @@ std::optional<Shape> GetShapeHelper::GetShape(const CoarrayRef &coarrayRef) {
|
|||
++dimension;
|
||||
}
|
||||
if (shape.empty()) {
|
||||
return GetShape(base);
|
||||
Nested(base);
|
||||
} else {
|
||||
return shape;
|
||||
Return(std::move(shape));
|
||||
}
|
||||
}
|
||||
|
||||
std::optional<Shape> GetShapeHelper::GetShape(const DataRef &dataRef) {
|
||||
return GetShape(dataRef.u);
|
||||
}
|
||||
|
||||
std::optional<Shape> GetShapeHelper::GetShape(const Substring &substring) {
|
||||
if (const auto *dataRef{substring.GetParentIf<DataRef>()}) {
|
||||
return GetShape(*dataRef);
|
||||
} else {
|
||||
return std::nullopt;
|
||||
}
|
||||
}
|
||||
|
||||
std::optional<Shape> GetShapeHelper::GetShape(const ComplexPart &part) {
|
||||
return GetShape(part.complex());
|
||||
}
|
||||
|
||||
std::optional<Shape> GetShapeHelper::GetShape(const ActualArgument &arg) {
|
||||
if (const auto *expr{arg.UnwrapExpr()}) {
|
||||
return GetShape(*expr);
|
||||
} else if (const Symbol * atDummy{arg.GetAssumedTypeDummy()}) {
|
||||
return GetShape(*atDummy);
|
||||
} else {
|
||||
return std::nullopt;
|
||||
}
|
||||
}
|
||||
|
||||
std::optional<Shape> GetShapeHelper::GetShape(const ProcedureDesignator &proc) {
|
||||
if (const Symbol * symbol{proc.GetSymbol()}) {
|
||||
return GetShape(*symbol);
|
||||
} else {
|
||||
return std::nullopt;
|
||||
}
|
||||
}
|
||||
|
||||
std::optional<Shape> GetShapeHelper::GetShape(const ProcedureRef &call) {
|
||||
void GetShapeVisitor::Handle(const ProcedureRef &call) {
|
||||
if (call.Rank() == 0) {
|
||||
return Shape{};
|
||||
Scalar();
|
||||
} else if (call.IsElemental()) {
|
||||
for (const auto &arg : call.arguments()) {
|
||||
if (arg.has_value() && arg->Rank() > 0) {
|
||||
return GetShape(*arg);
|
||||
Nested(*arg);
|
||||
return;
|
||||
}
|
||||
}
|
||||
Scalar();
|
||||
} else if (const Symbol * symbol{call.proc().GetSymbol()}) {
|
||||
return GetShape(*symbol);
|
||||
Handle(*symbol);
|
||||
} else if (const auto *intrinsic{
|
||||
std::get_if<SpecificIntrinsic>(&call.proc().u)}) {
|
||||
if (intrinsic->name == "shape" || intrinsic->name == "lbound" ||
|
||||
intrinsic->name == "ubound") {
|
||||
const auto *expr{call.arguments().front().value().UnwrapExpr()};
|
||||
CHECK(expr != nullptr);
|
||||
return Shape{MaybeExtentExpr{ExtentExpr{expr->Rank()}}};
|
||||
Return(Shape{MaybeExtentExpr{ExtentExpr{expr->Rank()}}});
|
||||
} else if (intrinsic->name == "reshape") {
|
||||
if (call.arguments().size() >= 2 && call.arguments().at(1).has_value()) {
|
||||
// SHAPE(RESHAPE(array,shape)) -> shape
|
||||
const auto *shapeExpr{call.arguments().at(1).value().UnwrapExpr()};
|
||||
CHECK(shapeExpr != nullptr);
|
||||
Expr<SomeInteger> shape{std::get<Expr<SomeInteger>>(shapeExpr->u)};
|
||||
return AsShape(context_, ConvertToType<ExtentType>(std::move(shape)));
|
||||
Return(AsShape(context_, ConvertToType<ExtentType>(std::move(shape))));
|
||||
}
|
||||
} else {
|
||||
// TODO: shapes of other non-elemental intrinsic results
|
||||
}
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
std::optional<Shape> GetShapeHelper::GetShape(
|
||||
const Relational<SomeType> &relation) {
|
||||
return GetShape(relation.u);
|
||||
}
|
||||
|
||||
std::optional<Shape> GetShapeHelper::GetShape(const StructureConstructor &) {
|
||||
return Shape{}; // always scalar
|
||||
}
|
||||
|
||||
std::optional<Shape> GetShapeHelper::GetShape(const ImpliedDoIndex &) {
|
||||
return Shape{}; // always scalar
|
||||
}
|
||||
|
||||
std::optional<Shape> GetShapeHelper::GetShape(const DescriptorInquiry &) {
|
||||
return Shape{}; // always scalar
|
||||
}
|
||||
|
||||
std::optional<Shape> GetShapeHelper::GetShape(const BOZLiteralConstant &) {
|
||||
return Shape{}; // always scalar
|
||||
}
|
||||
|
||||
std::optional<Shape> GetShapeHelper::GetShape(const NullPointer &) {
|
||||
return {}; // not an object
|
||||
Return();
|
||||
}
|
||||
|
||||
bool CheckConformance(parser::ContextualMessages &messages, const Shape &left,
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
|
||||
#include "expression.h"
|
||||
#include "tools.h"
|
||||
#include "traversal.h"
|
||||
#include "type.h"
|
||||
#include "variable.h"
|
||||
#include "../common/indirection.h"
|
||||
|
@ -42,6 +43,8 @@ using Shape = std::vector<MaybeExtentExpr>;
|
|||
bool IsImpliedShape(const Symbol &);
|
||||
bool IsExplicitShape(const Symbol &);
|
||||
|
||||
template<typename A> std::optional<Shape> GetShape(FoldingContext &, const A &);
|
||||
|
||||
// Conversions between various representations of shapes.
|
||||
Shape AsShape(const Constant<ExtentType> &);
|
||||
std::optional<Shape> AsShape(FoldingContext &, ExtentExpr &&);
|
||||
|
@ -87,98 +90,53 @@ bool CheckConformance(parser::ContextualMessages &, const Shape &,
|
|||
const Shape &, const char * = "left operand",
|
||||
const char * = "right operand");
|
||||
|
||||
// The implementation of GetShape() is wrapped in a helper class
|
||||
// so that the member functions may mutually recurse without prototypes.
|
||||
class GetShapeHelper {
|
||||
class GetShapeVisitor : public virtual VisitorBase<std::optional<Shape>> {
|
||||
public:
|
||||
explicit GetShapeHelper(FoldingContext &context) : context_{context} {}
|
||||
using Result = std::optional<Shape>;
|
||||
explicit GetShapeVisitor(FoldingContext &c) : context_{c} {}
|
||||
|
||||
template<typename T> std::optional<Shape> GetShape(const Expr<T> &expr) {
|
||||
return GetShape(expr.u);
|
||||
template<typename T> void Handle(const Constant<T> &c) {
|
||||
Return(AsShape(c.SHAPE()));
|
||||
}
|
||||
|
||||
std::optional<Shape> GetShape(const Symbol &);
|
||||
std::optional<Shape> GetShape(const Symbol *);
|
||||
std::optional<Shape> GetShape(const Component &);
|
||||
std::optional<Shape> GetShape(const NamedEntity &);
|
||||
std::optional<Shape> GetShape(const BaseObject &);
|
||||
std::optional<Shape> GetShape(const ArrayRef &);
|
||||
std::optional<Shape> GetShape(const CoarrayRef &);
|
||||
std::optional<Shape> GetShape(const DataRef &);
|
||||
std::optional<Shape> GetShape(const Substring &);
|
||||
std::optional<Shape> GetShape(const ComplexPart &);
|
||||
std::optional<Shape> GetShape(const ActualArgument &);
|
||||
std::optional<Shape> GetShape(const ProcedureDesignator &);
|
||||
std::optional<Shape> GetShape(const ProcedureRef &);
|
||||
std::optional<Shape> GetShape(const ImpliedDoIndex &);
|
||||
std::optional<Shape> GetShape(const Relational<SomeType> &);
|
||||
std::optional<Shape> GetShape(const StructureConstructor &);
|
||||
std::optional<Shape> GetShape(const DescriptorInquiry &);
|
||||
std::optional<Shape> GetShape(const BOZLiteralConstant &);
|
||||
std::optional<Shape> GetShape(const NullPointer &);
|
||||
|
||||
template<typename T> std::optional<Shape> GetShape(const Constant<T> &c) {
|
||||
Constant<ExtentType> shape{c.SHAPE()};
|
||||
return AsShape(shape);
|
||||
void Handle(const Symbol &);
|
||||
void Handle(const Component &);
|
||||
void Handle(const NamedEntity &);
|
||||
void Handle(const StaticDataObject::Pointer &) { Scalar(); }
|
||||
void Handle(const ArrayRef &);
|
||||
void Handle(const CoarrayRef &);
|
||||
void Handle(const ProcedureRef &);
|
||||
void Handle(const StructureConstructor &) { Scalar(); }
|
||||
template<typename T> void Handle(const ArrayConstructor<T> &aconst) {
|
||||
Return(Shape{GetArrayConstructorExtent(aconst)});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
std::optional<Shape> GetShape(const Designator<T> &designator) {
|
||||
return GetShape(designator.u);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
std::optional<Shape> GetShape(const Variable<T> &variable) {
|
||||
return GetShape(variable.u);
|
||||
}
|
||||
|
||||
template<typename D, typename R, typename... O>
|
||||
std::optional<Shape> GetShape(const Operation<D, R, O...> &operation) {
|
||||
if constexpr (sizeof...(O) > 1) {
|
||||
if (operation.right().Rank() > 0) {
|
||||
return GetShape(operation.right());
|
||||
}
|
||||
}
|
||||
return GetShape(operation.left());
|
||||
}
|
||||
|
||||
template<int KIND>
|
||||
std::optional<Shape> GetShape(const TypeParamInquiry<KIND> &) {
|
||||
return Shape{}; // always scalar, even when applied to an array
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
std::optional<Shape> GetShape(const ArrayConstructor<T> &aconst) {
|
||||
return Shape{GetArrayConstructorExtent(aconst)};
|
||||
}
|
||||
|
||||
template<typename... A>
|
||||
std::optional<Shape> GetShape(const std::variant<A...> &u) {
|
||||
return std::visit([&](const auto &x) { return GetShape(x); }, u);
|
||||
}
|
||||
|
||||
template<typename A, bool COPY>
|
||||
std::optional<Shape> GetShape(const common::Indirection<A, COPY> &p) {
|
||||
return GetShape(p.value());
|
||||
}
|
||||
|
||||
template<typename A>
|
||||
std::optional<Shape> GetShape(const std::optional<A> &x) {
|
||||
if (x.has_value()) {
|
||||
return GetShape(*x);
|
||||
void Handle(const ImpliedDoIndex &) { Scalar(); }
|
||||
void Handle(const DescriptorInquiry &) { Scalar(); }
|
||||
template<int KIND> void Handle(const TypeParamInquiry<KIND> &) { Scalar(); }
|
||||
void Handle(const BOZLiteralConstant &) { Scalar(); }
|
||||
void Handle(const NullPointer &) { Return(); }
|
||||
template<typename D, typename R, typename LO, typename RO>
|
||||
void Handle(const Operation<D, R, LO, RO> &operation) {
|
||||
if (operation.right().Rank() > 0) {
|
||||
Nested(operation.right());
|
||||
} else {
|
||||
return std::nullopt;
|
||||
Nested(operation.left());
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
void Scalar() { Return(Shape{}); }
|
||||
|
||||
template<typename A> void Nested(const A &x) {
|
||||
Return(GetShape(context_, x));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
MaybeExtentExpr GetArrayConstructorValueExtent(
|
||||
const ArrayConstructorValue<T> &value) {
|
||||
return std::visit(
|
||||
common::visitors{
|
||||
[&](const Expr<T> &x) -> MaybeExtentExpr {
|
||||
if (std::optional<Shape> xShape{GetShape(x)}) {
|
||||
if (std::optional<Shape> xShape{GetShape(context_, x)}) {
|
||||
// Array values in array constructors get linearized.
|
||||
return GetSize(std::move(*xShape));
|
||||
} else {
|
||||
|
@ -221,7 +179,7 @@ private:
|
|||
|
||||
template<typename A>
|
||||
std::optional<Shape> GetShape(FoldingContext &context, const A &x) {
|
||||
return GetShapeHelper{context}.GetShape(x);
|
||||
return Visitor<GetShapeVisitor>{context}.Traverse(x);
|
||||
}
|
||||
}
|
||||
#endif // FORTRAN_EVALUATE_SHAPE_H_
|
||||
|
|
|
@ -49,7 +49,7 @@
|
|||
// and call:
|
||||
// RESULT result{v.Traverse(topLevelExpr)};
|
||||
// Within the callback routines (Handle, Pre, Post), one may call
|
||||
// void Return(RESULT &&); // to define the result and end traversal
|
||||
// void Return(A &&); // to assign to the result and end traversal
|
||||
// void Return(); // to end traversal with current result
|
||||
// RESULT &result(); // to reference the result to define or update it
|
||||
// For any given expression object type T for which a callback is defined
|
||||
|
@ -90,7 +90,8 @@ public:
|
|||
std::nullptr_t Post(std::nullptr_t);
|
||||
|
||||
void Return() { done_ = true; }
|
||||
void Return(RESULT &&x) {
|
||||
|
||||
template<typename A> void Return(A &&x) {
|
||||
result_ = std::move(x);
|
||||
done_ = true;
|
||||
}
|
||||
|
@ -150,7 +151,6 @@ public:
|
|||
return std::move(result_);
|
||||
}
|
||||
|
||||
private:
|
||||
template<typename B> void Visit(const B &x) {
|
||||
if (!done_) {
|
||||
if constexpr ((... || HasVisitorHandle<A, B, void>::value)) {
|
||||
|
@ -174,6 +174,7 @@ private:
|
|||
}
|
||||
}
|
||||
|
||||
private:
|
||||
friend class Descender<Visitor>;
|
||||
Descender<Visitor> descender_{*this};
|
||||
};
|
||||
|
|
Loading…
Reference in a new issue