[flang] Implement GetShape with expression visitor

Original-commit: flang-compiler/f18@d607d02847
Reviewed-on: https://github.com/flang-compiler/f18/pull/611
Tree-same-pre-rewrite: false
This commit is contained in:
peter klausler 2019-07-30 11:33:06 -07:00
parent b72ef0b370
commit a7041f3a78
3 changed files with 64 additions and 180 deletions

View file

@ -298,53 +298,35 @@ MaybeExtentExpr GetUpperBound(FoldingContext &context, MaybeExtentExpr &&lower,
}
}
std::optional<Shape> GetShapeHelper::GetShape(const Symbol &symbol) {
return GetShape(NamedEntity{symbol});
void GetShapeVisitor::Handle(const Symbol &symbol) {
Handle(NamedEntity{symbol});
}
std::optional<Shape> GetShapeHelper::GetShape(const Symbol *symbol) {
if (symbol != nullptr) {
return GetShape(*symbol);
} else {
return std::nullopt;
}
void GetShapeVisitor::Handle(const Component &component) {
Handle(NamedEntity{Component{component}});
}
std::optional<Shape> GetShapeHelper::GetShape(const Component &component) {
return GetShape(NamedEntity{Component{component}});
}
std::optional<Shape> GetShapeHelper::GetShape(const NamedEntity &base) {
void GetShapeVisitor::Handle(const NamedEntity &base) {
const Symbol &symbol{base.GetLastSymbol()};
if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
if (IsImpliedShape(symbol)) {
return GetShape(*details->init());
Nested(details->init());
} else {
Shape result;
int n{static_cast<int>(details->shape().size())};
for (int dimension{0}; dimension < n; ++dimension) {
result.emplace_back(GetExtent(context_, base, dimension));
}
return result;
Return(std::move(result));
}
} else if (const auto *details{
symbol.detailsIf<semantics::AssocEntityDetails>()}) {
if (details->expr().has_value()) {
return GetShape(*details->expr());
}
Nested(details->expr());
}
return std::nullopt;
Return();
}
std::optional<Shape> GetShapeHelper::GetShape(const BaseObject &object) {
if (const Symbol * symbol{object.symbol()}) {
return GetShape(*symbol);
} else {
return Shape{};
}
}
std::optional<Shape> GetShapeHelper::GetShape(const ArrayRef &arrayRef) {
void GetShapeVisitor::Handle(const ArrayRef &arrayRef) {
Shape shape;
int dimension{0};
for (const Subscript &ss : arrayRef.subscript()) {
@ -354,13 +336,13 @@ std::optional<Shape> GetShapeHelper::GetShape(const ArrayRef &arrayRef) {
++dimension;
}
if (shape.empty()) {
return GetShape(arrayRef.base());
Nested(arrayRef.base());
} else {
return shape;
Return(std::move(shape));
}
}
std::optional<Shape> GetShapeHelper::GetShape(const CoarrayRef &coarrayRef) {
void GetShapeVisitor::Handle(const CoarrayRef &coarrayRef) {
Shape shape;
NamedEntity base{coarrayRef.GetBase()};
int dimension{0};
@ -371,102 +353,45 @@ std::optional<Shape> GetShapeHelper::GetShape(const CoarrayRef &coarrayRef) {
++dimension;
}
if (shape.empty()) {
return GetShape(base);
Nested(base);
} else {
return shape;
Return(std::move(shape));
}
}
std::optional<Shape> GetShapeHelper::GetShape(const DataRef &dataRef) {
return GetShape(dataRef.u);
}
std::optional<Shape> GetShapeHelper::GetShape(const Substring &substring) {
if (const auto *dataRef{substring.GetParentIf<DataRef>()}) {
return GetShape(*dataRef);
} else {
return std::nullopt;
}
}
std::optional<Shape> GetShapeHelper::GetShape(const ComplexPart &part) {
return GetShape(part.complex());
}
std::optional<Shape> GetShapeHelper::GetShape(const ActualArgument &arg) {
if (const auto *expr{arg.UnwrapExpr()}) {
return GetShape(*expr);
} else if (const Symbol * atDummy{arg.GetAssumedTypeDummy()}) {
return GetShape(*atDummy);
} else {
return std::nullopt;
}
}
std::optional<Shape> GetShapeHelper::GetShape(const ProcedureDesignator &proc) {
if (const Symbol * symbol{proc.GetSymbol()}) {
return GetShape(*symbol);
} else {
return std::nullopt;
}
}
std::optional<Shape> GetShapeHelper::GetShape(const ProcedureRef &call) {
void GetShapeVisitor::Handle(const ProcedureRef &call) {
if (call.Rank() == 0) {
return Shape{};
Scalar();
} else if (call.IsElemental()) {
for (const auto &arg : call.arguments()) {
if (arg.has_value() && arg->Rank() > 0) {
return GetShape(*arg);
Nested(*arg);
return;
}
}
Scalar();
} else if (const Symbol * symbol{call.proc().GetSymbol()}) {
return GetShape(*symbol);
Handle(*symbol);
} else if (const auto *intrinsic{
std::get_if<SpecificIntrinsic>(&call.proc().u)}) {
if (intrinsic->name == "shape" || intrinsic->name == "lbound" ||
intrinsic->name == "ubound") {
const auto *expr{call.arguments().front().value().UnwrapExpr()};
CHECK(expr != nullptr);
return Shape{MaybeExtentExpr{ExtentExpr{expr->Rank()}}};
Return(Shape{MaybeExtentExpr{ExtentExpr{expr->Rank()}}});
} else if (intrinsic->name == "reshape") {
if (call.arguments().size() >= 2 && call.arguments().at(1).has_value()) {
// SHAPE(RESHAPE(array,shape)) -> shape
const auto *shapeExpr{call.arguments().at(1).value().UnwrapExpr()};
CHECK(shapeExpr != nullptr);
Expr<SomeInteger> shape{std::get<Expr<SomeInteger>>(shapeExpr->u)};
return AsShape(context_, ConvertToType<ExtentType>(std::move(shape)));
Return(AsShape(context_, ConvertToType<ExtentType>(std::move(shape))));
}
} else {
// TODO: shapes of other non-elemental intrinsic results
}
}
return std::nullopt;
}
std::optional<Shape> GetShapeHelper::GetShape(
const Relational<SomeType> &relation) {
return GetShape(relation.u);
}
std::optional<Shape> GetShapeHelper::GetShape(const StructureConstructor &) {
return Shape{}; // always scalar
}
std::optional<Shape> GetShapeHelper::GetShape(const ImpliedDoIndex &) {
return Shape{}; // always scalar
}
std::optional<Shape> GetShapeHelper::GetShape(const DescriptorInquiry &) {
return Shape{}; // always scalar
}
std::optional<Shape> GetShapeHelper::GetShape(const BOZLiteralConstant &) {
return Shape{}; // always scalar
}
std::optional<Shape> GetShapeHelper::GetShape(const NullPointer &) {
return {}; // not an object
Return();
}
bool CheckConformance(parser::ContextualMessages &messages, const Shape &left,

View file

@ -20,6 +20,7 @@
#include "expression.h"
#include "tools.h"
#include "traversal.h"
#include "type.h"
#include "variable.h"
#include "../common/indirection.h"
@ -42,6 +43,8 @@ using Shape = std::vector<MaybeExtentExpr>;
bool IsImpliedShape(const Symbol &);
bool IsExplicitShape(const Symbol &);
template<typename A> std::optional<Shape> GetShape(FoldingContext &, const A &);
// Conversions between various representations of shapes.
Shape AsShape(const Constant<ExtentType> &);
std::optional<Shape> AsShape(FoldingContext &, ExtentExpr &&);
@ -87,98 +90,53 @@ bool CheckConformance(parser::ContextualMessages &, const Shape &,
const Shape &, const char * = "left operand",
const char * = "right operand");
// The implementation of GetShape() is wrapped in a helper class
// so that the member functions may mutually recurse without prototypes.
class GetShapeHelper {
class GetShapeVisitor : public virtual VisitorBase<std::optional<Shape>> {
public:
explicit GetShapeHelper(FoldingContext &context) : context_{context} {}
using Result = std::optional<Shape>;
explicit GetShapeVisitor(FoldingContext &c) : context_{c} {}
template<typename T> std::optional<Shape> GetShape(const Expr<T> &expr) {
return GetShape(expr.u);
template<typename T> void Handle(const Constant<T> &c) {
Return(AsShape(c.SHAPE()));
}
std::optional<Shape> GetShape(const Symbol &);
std::optional<Shape> GetShape(const Symbol *);
std::optional<Shape> GetShape(const Component &);
std::optional<Shape> GetShape(const NamedEntity &);
std::optional<Shape> GetShape(const BaseObject &);
std::optional<Shape> GetShape(const ArrayRef &);
std::optional<Shape> GetShape(const CoarrayRef &);
std::optional<Shape> GetShape(const DataRef &);
std::optional<Shape> GetShape(const Substring &);
std::optional<Shape> GetShape(const ComplexPart &);
std::optional<Shape> GetShape(const ActualArgument &);
std::optional<Shape> GetShape(const ProcedureDesignator &);
std::optional<Shape> GetShape(const ProcedureRef &);
std::optional<Shape> GetShape(const ImpliedDoIndex &);
std::optional<Shape> GetShape(const Relational<SomeType> &);
std::optional<Shape> GetShape(const StructureConstructor &);
std::optional<Shape> GetShape(const DescriptorInquiry &);
std::optional<Shape> GetShape(const BOZLiteralConstant &);
std::optional<Shape> GetShape(const NullPointer &);
template<typename T> std::optional<Shape> GetShape(const Constant<T> &c) {
Constant<ExtentType> shape{c.SHAPE()};
return AsShape(shape);
void Handle(const Symbol &);
void Handle(const Component &);
void Handle(const NamedEntity &);
void Handle(const StaticDataObject::Pointer &) { Scalar(); }
void Handle(const ArrayRef &);
void Handle(const CoarrayRef &);
void Handle(const ProcedureRef &);
void Handle(const StructureConstructor &) { Scalar(); }
template<typename T> void Handle(const ArrayConstructor<T> &aconst) {
Return(Shape{GetArrayConstructorExtent(aconst)});
}
template<typename T>
std::optional<Shape> GetShape(const Designator<T> &designator) {
return GetShape(designator.u);
}
template<typename T>
std::optional<Shape> GetShape(const Variable<T> &variable) {
return GetShape(variable.u);
}
template<typename D, typename R, typename... O>
std::optional<Shape> GetShape(const Operation<D, R, O...> &operation) {
if constexpr (sizeof...(O) > 1) {
if (operation.right().Rank() > 0) {
return GetShape(operation.right());
}
}
return GetShape(operation.left());
}
template<int KIND>
std::optional<Shape> GetShape(const TypeParamInquiry<KIND> &) {
return Shape{}; // always scalar, even when applied to an array
}
template<typename T>
std::optional<Shape> GetShape(const ArrayConstructor<T> &aconst) {
return Shape{GetArrayConstructorExtent(aconst)};
}
template<typename... A>
std::optional<Shape> GetShape(const std::variant<A...> &u) {
return std::visit([&](const auto &x) { return GetShape(x); }, u);
}
template<typename A, bool COPY>
std::optional<Shape> GetShape(const common::Indirection<A, COPY> &p) {
return GetShape(p.value());
}
template<typename A>
std::optional<Shape> GetShape(const std::optional<A> &x) {
if (x.has_value()) {
return GetShape(*x);
void Handle(const ImpliedDoIndex &) { Scalar(); }
void Handle(const DescriptorInquiry &) { Scalar(); }
template<int KIND> void Handle(const TypeParamInquiry<KIND> &) { Scalar(); }
void Handle(const BOZLiteralConstant &) { Scalar(); }
void Handle(const NullPointer &) { Return(); }
template<typename D, typename R, typename LO, typename RO>
void Handle(const Operation<D, R, LO, RO> &operation) {
if (operation.right().Rank() > 0) {
Nested(operation.right());
} else {
return std::nullopt;
Nested(operation.left());
}
}
private:
void Scalar() { Return(Shape{}); }
template<typename A> void Nested(const A &x) {
Return(GetShape(context_, x));
}
template<typename T>
MaybeExtentExpr GetArrayConstructorValueExtent(
const ArrayConstructorValue<T> &value) {
return std::visit(
common::visitors{
[&](const Expr<T> &x) -> MaybeExtentExpr {
if (std::optional<Shape> xShape{GetShape(x)}) {
if (std::optional<Shape> xShape{GetShape(context_, x)}) {
// Array values in array constructors get linearized.
return GetSize(std::move(*xShape));
} else {
@ -221,7 +179,7 @@ private:
template<typename A>
std::optional<Shape> GetShape(FoldingContext &context, const A &x) {
return GetShapeHelper{context}.GetShape(x);
return Visitor<GetShapeVisitor>{context}.Traverse(x);
}
}
#endif // FORTRAN_EVALUATE_SHAPE_H_

View file

@ -49,7 +49,7 @@
// and call:
// RESULT result{v.Traverse(topLevelExpr)};
// Within the callback routines (Handle, Pre, Post), one may call
// void Return(RESULT &&); // to define the result and end traversal
// void Return(A &&); // to assign to the result and end traversal
// void Return(); // to end traversal with current result
// RESULT &result(); // to reference the result to define or update it
// For any given expression object type T for which a callback is defined
@ -90,7 +90,8 @@ public:
std::nullptr_t Post(std::nullptr_t);
void Return() { done_ = true; }
void Return(RESULT &&x) {
template<typename A> void Return(A &&x) {
result_ = std::move(x);
done_ = true;
}
@ -150,7 +151,6 @@ public:
return std::move(result_);
}
private:
template<typename B> void Visit(const B &x) {
if (!done_) {
if constexpr ((... || HasVisitorHandle<A, B, void>::value)) {
@ -174,6 +174,7 @@ private:
}
}
private:
friend class Descender<Visitor>;
Descender<Visitor> descender_{*this};
};