// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "shape.h" #include "fold.h" #include "tools.h" #include "traversal.h" #include "../common/idioms.h" #include "../common/template.h" #include "../semantics/symbol.h" namespace Fortran::evaluate { Shape AsShape(const Constant &arrayConstant) { CHECK(arrayConstant.Rank() == 1); Shape result; std::size_t dimensions{arrayConstant.size()}; for (std::size_t j{0}; j < dimensions; ++j) { Scalar extent{arrayConstant.values().at(j)}; result.emplace_back(MaybeExtent{ExtentExpr{extent}}); } return result; } std::optional AsShape(ExtentExpr &&arrayExpr) { if (auto *constArray{UnwrapExpr>(arrayExpr)}) { return AsShape(*constArray); } if (auto *constructor{UnwrapExpr>(arrayExpr)}) { Shape result; for (auto &value : constructor->values()) { if (auto *expr{std::get_if(&value.u)}) { if (expr->Rank() == 0) { result.emplace_back(std::move(*expr)); continue; } } return std::nullopt; } return result; } // TODO: linearize other array-valued expressions of known shape, e.g. A+B // as well as conversions of arrays; this will be easier given a // general-purpose array expression flattener (pmk) return std::nullopt; } std::optional AsShapeArrayExpr(const Shape &shape) { ArrayConstructorValues values; for (const auto &dim : shape) { if (dim.has_value()) { values.Push(common::Clone(*dim)); } else { return std::nullopt; } } return ExtentExpr{ArrayConstructor{std::move(values)}}; } std::optional> AsConstantShape(const Shape &shape) { if (auto shapeArray{AsShapeArrayExpr(shape)}) { FoldingContext noFoldingContext; auto folded{Fold(noFoldingContext, std::move(*shapeArray))}; if (auto *p{UnwrapExpr>(folded)}) { return std::move(*p); } } return std::nullopt; } static ExtentExpr ComputeTripCount( ExtentExpr &&lower, ExtentExpr &&upper, ExtentExpr &&stride) { ExtentExpr strideCopy{common::Clone(stride)}; ExtentExpr span{ (std::move(upper) - std::move(lower) + std::move(strideCopy)) / std::move(stride)}; ExtentExpr extent{ Extremum{std::move(span), ExtentExpr{0}, Ordering::Greater}}; FoldingContext noFoldingContext; return Fold(noFoldingContext, std::move(extent)); } ExtentExpr CountTrips( ExtentExpr &&lower, ExtentExpr &&upper, ExtentExpr &&stride) { return ComputeTripCount( std::move(lower), std::move(upper), std::move(stride)); } ExtentExpr CountTrips(const ExtentExpr &lower, const ExtentExpr &upper, const ExtentExpr &stride) { return ComputeTripCount( common::Clone(lower), common::Clone(upper), common::Clone(stride)); } MaybeExtent CountTrips( MaybeExtent &&lower, MaybeExtent &&upper, MaybeExtent &&stride) { return common::MapOptional( ComputeTripCount, std::move(lower), std::move(upper), std::move(stride)); } MaybeExtent GetSize(Shape &&shape) { ExtentExpr extent{1}; for (auto &&dim : std::move(shape)) { if (dim.has_value()) { extent = std::move(extent) * std::move(*dim); } else { return std::nullopt; } } return extent; } static MaybeExtent GetLowerBound( const Symbol &symbol, const Component *component, int dimension) { if (const auto *details{symbol.detailsIf()}) { int j{0}; for (const auto &shapeSpec : details->shape()) { if (j++ == dimension) { if (const auto &bound{shapeSpec.lbound().GetExplicit()}) { return *bound; } else if (component != nullptr) { return ExtentExpr{DescriptorInquiry{ *component, DescriptorInquiry::Field::LowerBound, dimension}}; } else { return ExtentExpr{DescriptorInquiry{ symbol, DescriptorInquiry::Field::LowerBound, dimension}}; } } } } return std::nullopt; } static MaybeExtent GetExtent( const Symbol &symbol, const Component *component, int dimension) { if (const auto *details{symbol.detailsIf()}) { int j{0}; for (const auto &shapeSpec : details->shape()) { if (j++ == dimension) { if (const auto &lbound{shapeSpec.lbound().GetExplicit()}) { if (const auto &ubound{shapeSpec.ubound().GetExplicit()}) { FoldingContext noFoldingContext; return Fold(noFoldingContext, common::Clone(ubound.value()) - common::Clone(lbound.value()) + ExtentExpr{1}); } } if (component != nullptr) { return ExtentExpr{DescriptorInquiry{ *component, DescriptorInquiry::Field::Extent, dimension}}; } else { return ExtentExpr{DescriptorInquiry{ &symbol, DescriptorInquiry::Field::Extent, dimension}}; } } } } return std::nullopt; } static MaybeExtent GetExtent(const Subscript &subscript, const Symbol &symbol, const Component *component, int dimension) { return std::visit( common::visitors{ [&](const Triplet &triplet) -> MaybeExtent { MaybeExtent upper{triplet.upper()}; if (!upper.has_value()) { upper = GetExtent(symbol, component, dimension); } MaybeExtent lower{triplet.lower()}; if (!lower.has_value()) { lower = GetLowerBound(symbol, component, dimension); } return CountTrips(std::move(lower), std::move(upper), MaybeExtent{triplet.stride()}); }, [](const IndirectSubscriptIntegerExpr &subs) -> MaybeExtent { if (auto shape{GetShape(subs.value())}) { if (shape->size() > 0) { CHECK(shape->size() == 1); // vector-valued subscript return std::move(shape->at(0)); } } return std::nullopt; }, }, subscript.u); } bool ContainsAnyImpliedDoIndex(const ExtentExpr &expr) { struct MyVisitor : public virtual VisitorBase { using Result = bool; explicit MyVisitor(int) { result() = false; } void Handle(const ImpliedDoIndex &) { Return(true); } }; return Visitor{0}.Traverse(expr); } std::optional GetShape( const Symbol &symbol, const Component *component) { if (const auto *details{symbol.detailsIf()}) { Shape result; int n = details->shape().size(); for (int dimension{0}; dimension < n; ++dimension) { result.emplace_back(GetExtent(symbol, component, dimension++)); } return result; } else { return std::nullopt; } } std::optional GetShape(const Symbol *symbol) { if (symbol != nullptr) { return GetShape(*symbol); } else { return std::nullopt; } } std::optional GetShape(const BaseObject &object) { if (const Symbol * symbol{object.symbol()}) { return GetShape(*symbol); } else { return Shape{}; } } std::optional GetShape(const Component &component) { const Symbol &symbol{component.GetLastSymbol()}; if (symbol.Rank() > 0) { return GetShape(symbol, &component); } else { return GetShape(component.base()); } } std::optional GetShape(const ArrayRef &arrayRef) { Shape shape; const Symbol &symbol{arrayRef.GetLastSymbol()}; const Component *component{std::get_if(&arrayRef.base())}; int dimension{0}; for (const Subscript &ss : arrayRef.subscript()) { if (ss.Rank() > 0) { shape.emplace_back(GetExtent(ss, symbol, component, dimension)); } ++dimension; } if (shape.empty()) { return GetShape(arrayRef.base()); } else { return shape; } } std::optional GetShape(const CoarrayRef &coarrayRef) { Shape shape; SymbolOrComponent base{coarrayRef.GetBaseSymbolOrComponent()}; const Symbol &symbol{coarrayRef.GetLastSymbol()}; const Component *component{std::get_if(&base)}; int dimension{0}; for (const Subscript &ss : coarrayRef.subscript()) { if (ss.Rank() > 0) { shape.emplace_back(GetExtent(ss, symbol, component, dimension)); } ++dimension; } if (shape.empty()) { return GetShape(coarrayRef.GetLastSymbol()); } else { return shape; } } std::optional GetShape(const DataRef &dataRef) { return GetShape(dataRef.u); } std::optional GetShape(const Substring &substring) { if (const auto *dataRef{substring.GetParentIf()}) { return GetShape(*dataRef); } else { return std::nullopt; } } std::optional GetShape(const ComplexPart &part) { return GetShape(part.complex()); } std::optional GetShape(const ActualArgument &arg) { return GetShape(arg.value()); } std::optional GetShape(const ProcedureRef &call) { if (call.Rank() == 0) { return Shape{}; } else if (call.IsElemental()) { for (const auto &arg : call.arguments()) { if (arg.has_value() && arg->Rank() > 0) { return GetShape(*arg); } } } else if (const Symbol * symbol{call.proc().GetSymbol()}) { return GetShape(*symbol); } else if (const auto *intrinsic{ std::get_if(&call.proc().u)}) { if (intrinsic->name == "shape" || intrinsic->name == "lbound" || intrinsic->name == "ubound") { return Shape{MaybeExtent{ ExtentExpr{call.arguments().front().value().value().Rank()}}}; } else if (intrinsic->name == "reshape") { if (call.arguments().size() >= 2 && call.arguments().at(1).has_value()) { // SHAPE(RESHAPE(array,shape)) -> shape const Expr &shapeExpr{call.arguments().at(1)->value()}; Expr shape{std::get>(shapeExpr.u)}; return AsShape(ConvertToType(std::move(shape))); } } else { // TODO: shapes of other non-elemental intrinsic results } } return std::nullopt; } std::optional GetShape(const Relational &relation) { return GetShape(relation.u); } std::optional GetShape(const StructureConstructor &) { return Shape{}; // always scalar } std::optional GetShape(const ImpliedDoIndex &) { return Shape{}; // always scalar } std::optional GetShape(const DescriptorInquiry &) { return Shape{}; // always scalar } std::optional GetShape(const BOZLiteralConstant &) { return Shape{}; // always scalar } std::optional GetShape(const NullPointer &) { return {}; // not an object } }