[Flang] Add a factory class for creating Complex Ops

Use the factory class in the FIRBuilder.
Add unit tests for the factory class function and the convert function
of the Complex class.

Reviewed By: clementval, rovka

Differential Revision: https://reviews.llvm.org/D114125

Co-authored-by: Jean Perier <jperier@nvidia.com>
Co-authored-by: Eric Schweitz <eschweitz@nvidia.com>
This commit is contained in:
Kiran Chandramohan 2021-11-18 16:43:16 +00:00
parent 45e102a173
commit a1f9bd32c5
7 changed files with 264 additions and 0 deletions

View file

@ -0,0 +1,89 @@
//===-- Complex.h -- lowering of complex values -----------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
//
//===----------------------------------------------------------------------===//
#ifndef FORTRAN_OPTIMIZER_BUILDER_COMPLEX_H
#define FORTRAN_OPTIMIZER_BUILDER_COMPLEX_H
#include "flang/Optimizer/Builder/FIRBuilder.h"
namespace fir::factory {
/// Helper to facilitate lowering of COMPLEX manipulations in FIR.
class Complex {
public:
explicit Complex(FirOpBuilder &builder, mlir::Location loc)
: builder(builder), loc(loc) {}
Complex(const Complex &) = delete;
// The values of part enum members are meaningful for
// InsertValueOp and ExtractValueOp so they are explicit.
enum class Part { Real = 0, Imag = 1 };
/// Get the Complex Type. Determine the type. Do not create MLIR operations.
mlir::Type getComplexPartType(mlir::Value cplx) const;
mlir::Type getComplexPartType(mlir::Type complexType) const;
/// Complex operation creation. They create MLIR operations.
mlir::Value createComplex(fir::KindTy kind, mlir::Value real,
mlir::Value imag);
/// Create a complex value.
mlir::Value createComplex(mlir::Type complexType, mlir::Value real,
mlir::Value imag);
/// Returns the Real/Imag part of \p cplx
mlir::Value extractComplexPart(mlir::Value cplx, bool isImagPart) {
return isImagPart ? extract<Part::Imag>(cplx) : extract<Part::Real>(cplx);
}
/// Returns (Real, Imag) pair of \p cplx
std::pair<mlir::Value, mlir::Value> extractParts(mlir::Value cplx) {
return {extract<Part::Real>(cplx), extract<Part::Imag>(cplx)};
}
mlir::Value insertComplexPart(mlir::Value cplx, mlir::Value part,
bool isImagPart) {
return isImagPart ? insert<Part::Imag>(cplx, part)
: insert<Part::Real>(cplx, part);
}
protected:
template <Part partId>
mlir::Value extract(mlir::Value cplx) {
return builder.create<fir::ExtractValueOp>(
loc, getComplexPartType(cplx), cplx,
builder.getArrayAttr({builder.getIntegerAttr(
builder.getIndexType(), static_cast<int>(partId))}));
}
template <Part partId>
mlir::Value insert(mlir::Value cplx, mlir::Value part) {
return builder.create<fir::InsertValueOp>(
loc, cplx.getType(), cplx, part,
builder.getArrayAttr({builder.getIntegerAttr(
builder.getIndexType(), static_cast<int>(partId))}));
}
template <Part partId>
mlir::Value createPartId() {
return builder.createIntegerConstant(loc, builder.getIndexType(),
static_cast<int>(partId));
}
private:
FirOpBuilder &builder;
mlir::Location loc;
};
} // namespace fir::factory
#endif // FORTRAN_OPTIMIZER_BUILDER_COMPLEX_H

View file

@ -57,6 +57,15 @@ public:
/// Get a reference to the kind map.
const fir::KindMapping &getKindMap() { return kindMap; }
/// The LHS and RHS are not always in agreement in terms of
/// type. In some cases, the disagreement is between COMPLEX and other scalar
/// types. In that case, the conversion must insert/extract out of a COMPLEX
/// value to have the proper semantics and be strongly typed. For e.g for
/// converting an integer/real to a complex, the real part is filled using
/// the integer/real after type conversion and the imaginary part is zero.
mlir::Value convertWithSemantics(mlir::Location loc, mlir::Type toTy,
mlir::Value val);
/// Get the entry block of the current Function
mlir::Block *getEntryBlock() { return &getFunction().front(); }

View file

@ -3,6 +3,7 @@ get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
add_flang_library(FIRBuilder
BoxValue.cpp
Character.cpp
Complex.cpp
DoLoopHelper.cpp
FIRBuilder.cpp
MutableBox.cpp

View file

@ -0,0 +1,36 @@
//===-- Complex.cpp -------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "flang/Optimizer/Builder/Complex.h"
//===----------------------------------------------------------------------===//
// Complex Factory implementation
//===----------------------------------------------------------------------===//
mlir::Type
fir::factory::Complex::getComplexPartType(mlir::Type complexType) const {
return builder.getRealType(complexType.cast<fir::ComplexType>().getFKind());
}
mlir::Type fir::factory::Complex::getComplexPartType(mlir::Value cplx) const {
return getComplexPartType(cplx.getType());
}
mlir::Value fir::factory::Complex::createComplex(fir::KindTy kind,
mlir::Value real,
mlir::Value imag) {
auto complexTy = fir::ComplexType::get(builder.getContext(), kind);
return createComplex(complexTy, real, imag);
}
mlir::Value fir::factory::Complex::createComplex(mlir::Type cplxTy,
mlir::Value real,
mlir::Value imag) {
mlir::Value und = builder.create<fir::UndefOp>(loc, cplxTy);
return insert<Part::Imag>(insert<Part::Real>(und, real), imag);
}

View file

@ -9,6 +9,7 @@
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Builder/BoxValue.h"
#include "flang/Optimizer/Builder/Character.h"
#include "flang/Optimizer/Builder/Complex.h"
#include "flang/Optimizer/Builder/MutableBox.h"
#include "flang/Optimizer/Dialect/FIROpsSupport.h"
#include "flang/Optimizer/Support/FatalError.h"
@ -257,6 +258,33 @@ fir::GlobalOp fir::FirOpBuilder::createGlobal(
return glob;
}
mlir::Value fir::FirOpBuilder::convertWithSemantics(mlir::Location loc,
mlir::Type toTy,
mlir::Value val) {
assert(toTy && "store location must be typed");
auto fromTy = val.getType();
if (fromTy == toTy)
return val;
fir::factory::Complex helper{*this, loc};
if ((fir::isa_real(fromTy) || fir::isa_integer(fromTy)) &&
fir::isa_complex(toTy)) {
// imaginary part is zero
auto eleTy = helper.getComplexPartType(toTy);
auto cast = createConvert(loc, eleTy, val);
llvm::APFloat zero{
kindMap.getFloatSemantics(toTy.cast<fir::ComplexType>().getFKind()), 0};
auto imag = createRealConstant(loc, eleTy, zero);
return helper.createComplex(toTy, cast, imag);
}
if (fir::isa_complex(fromTy) &&
(fir::isa_integer(toTy) || fir::isa_real(toTy))) {
// drop the imaginary part
auto rp = helper.extractComplexPart(val, /*isImagPart=*/false);
return createConvert(loc, toTy, rp);
}
return createConvert(loc, toTy, val);
}
mlir::Value fir::FirOpBuilder::createConvert(mlir::Location loc,
mlir::Type toTy, mlir::Value val) {
if (val.getType() != toTy) {

View file

@ -0,0 +1,100 @@
//===- ComplexExprTest.cpp -- ComplexExpr unit tests ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "flang/Optimizer/Builder/Complex.h"
#include "gtest/gtest.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Support/InitFIR.h"
#include "flang/Optimizer/Support/KindMapping.h"
struct ComplexTest : public testing::Test {
public:
void SetUp() override {
mlir::OpBuilder builder(&context);
auto loc = builder.getUnknownLoc();
// Set up a Module with a dummy function operation inside.
// Set the insertion point in the function entry block.
mlir::ModuleOp mod = builder.create<mlir::ModuleOp>(loc);
mlir::FuncOp func = mlir::FuncOp::create(
loc, "func1", builder.getFunctionType(llvm::None, llvm::None));
auto *entryBlock = func.addEntryBlock();
mod.push_back(mod);
builder.setInsertionPointToStart(entryBlock);
fir::support::loadDialects(context);
kindMap = std::make_unique<fir::KindMapping>(&context);
firBuilder = std::make_unique<fir::FirOpBuilder>(mod, *kindMap);
helper = std::make_unique<fir::factory::Complex>(*firBuilder, loc);
// Init commonly used types
realTy1 = mlir::FloatType::getF32(&context);
complexTy1 = fir::ComplexType::get(&context, 4);
integerTy1 = mlir::IntegerType::get(&context, 32);
// Create commonly used reals
rOne = firBuilder->createRealConstant(loc, realTy1, 1u);
rTwo = firBuilder->createRealConstant(loc, realTy1, 2u);
rThree = firBuilder->createRealConstant(loc, realTy1, 3u);
rFour = firBuilder->createRealConstant(loc, realTy1, 4u);
}
mlir::MLIRContext context;
std::unique_ptr<fir::KindMapping> kindMap;
std::unique_ptr<fir::FirOpBuilder> firBuilder;
std::unique_ptr<fir::factory::Complex> helper;
// Commonly used real/complex/integer types
mlir::FloatType realTy1;
fir::ComplexType complexTy1;
mlir::IntegerType integerTy1;
// Commonly used real numbers
mlir::Value rOne;
mlir::Value rTwo;
mlir::Value rThree;
mlir::Value rFour;
};
TEST_F(ComplexTest, verifyTypes) {
mlir::Value cVal1 = helper->createComplex(complexTy1, rOne, rTwo);
mlir::Value cVal2 = helper->createComplex(4, rOne, rTwo);
EXPECT_TRUE(fir::isa_complex(cVal1.getType()));
EXPECT_TRUE(fir::isa_complex(cVal2.getType()));
EXPECT_TRUE(fir::isa_real(helper->getComplexPartType(cVal1)));
EXPECT_TRUE(fir::isa_real(helper->getComplexPartType(cVal2)));
mlir::Value real1 = helper->extractComplexPart(cVal1, /*isImagPart=*/false);
mlir::Value imag1 = helper->extractComplexPart(cVal1, /*isImagPart=*/true);
mlir::Value real2 = helper->extractComplexPart(cVal2, /*isImagPart=*/false);
mlir::Value imag2 = helper->extractComplexPart(cVal2, /*isImagPart=*/true);
EXPECT_EQ(realTy1, real1.getType());
EXPECT_EQ(realTy1, imag1.getType());
EXPECT_EQ(realTy1, real2.getType());
EXPECT_EQ(realTy1, imag2.getType());
mlir::Value cVal3 =
helper->insertComplexPart(cVal1, rThree, /*isImagPart=*/false);
mlir::Value cVal4 =
helper->insertComplexPart(cVal3, rFour, /*isImagPart=*/true);
EXPECT_TRUE(fir::isa_complex(cVal4.getType()));
EXPECT_TRUE(fir::isa_real(helper->getComplexPartType(cVal4)));
}
TEST_F(ComplexTest, verifyConvertWithSemantics) {
auto loc = firBuilder->getUnknownLoc();
rOne = firBuilder->createRealConstant(loc, realTy1, 1u);
// Convert real to complex
mlir::Value v1 = firBuilder->convertWithSemantics(loc, complexTy1, rOne);
EXPECT_TRUE(fir::isa_complex(v1.getType()));
// Convert complex to integer
mlir::Value v2 = firBuilder->convertWithSemantics(loc, integerTy1, v1);
EXPECT_TRUE(v2.getType().isa<mlir::IntegerType>());
EXPECT_TRUE(mlir::dyn_cast<fir::ConvertOp>(v2.getDefiningOp()));
}

View file

@ -10,6 +10,7 @@ set(LIBS
add_flang_unittest(FlangOptimizerTests
Builder/CharacterTest.cpp
Builder/ComplexTest.cpp
Builder/DoLoopHelperTest.cpp
Builder/FIRBuilderTest.cpp
FIRContextTest.cpp