[mlir] translate types between MLIR LLVM dialect and LLVM IR

With new LLVM dialect type modeling, the dialect types no longer wrap LLVM IR
types. Therefore, they need to be translated to and from LLVM IR during export
and import. Introduce the relevant functionality for translating types. It is
currently exercised by an ad-hoc type translation roundtripping test that will
be subsumed by the actual translation test when the type system transition is
complete.

Depends On D84339

Reviewed By: herhut

Differential Revision: https://reviews.llvm.org/D85019
This commit is contained in:
Alex Zinenko 2020-08-04 11:37:25 +02:00
parent 8979a9cdf2
commit d4fbbab2e4
9 changed files with 675 additions and 0 deletions

View file

@ -0,0 +1,36 @@
//===- TypeTranslation.h - Translate types between MLIR & LLVM --*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file declares the type translation function going from MLIR LLVM dialect
// to LLVM IR and back.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TARGET_LLVMIR_TYPETRANSLATION_H
#define MLIR_TARGET_LLVMIR_TYPETRANSLATION_H
namespace llvm {
class LLVMContext;
class Type;
} // namespace llvm
namespace mlir {
class MLIRContext;
namespace LLVM {
class LLVMTypeNew;
llvm::Type *translateTypeToLLVMIR(LLVMTypeNew type, llvm::LLVMContext &context);
LLVMTypeNew translateTypeFromLLVMIR(llvm::Type *type, MLIRContext &context);
} // namespace LLVM
} // namespace mlir
#endif // MLIR_TARGET_LLVMIR_TYPETRANSLATION_H

View file

@ -1,6 +1,7 @@
add_mlir_translation_library(MLIRTargetLLVMIRModuleTranslation add_mlir_translation_library(MLIRTargetLLVMIRModuleTranslation
LLVMIR/DebugTranslation.cpp LLVMIR/DebugTranslation.cpp
LLVMIR/ModuleTranslation.cpp LLVMIR/ModuleTranslation.cpp
LLVMIR/TypeTranslation.cpp
ADDITIONAL_HEADER_DIRS ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Target/LLVMIR ${MLIR_MAIN_INCLUDE_DIR}/mlir/Target/LLVMIR

View file

@ -0,0 +1,309 @@
//===- TypeTranslation.cpp - type translation between MLIR LLVM & LLVM IR -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Target/LLVMIR/TypeTranslation.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Type.h"
using namespace mlir;
namespace {
/// Support for translating MLIR LLVM dialect types to LLVM IR.
class TypeToLLVMIRTranslator {
public:
/// Constructs a class creating types in the given LLVM context.
TypeToLLVMIRTranslator(llvm::LLVMContext &context) : context(context) {}
/// Translates a single type.
llvm::Type *translateType(LLVM::LLVMTypeNew type) {
// If the conversion is already known, just return it.
if (knownTranslations.count(type))
return knownTranslations.lookup(type);
// Dispatch to an appropriate function.
llvm::Type *translated =
llvm::TypeSwitch<LLVM::LLVMTypeNew, llvm::Type *>(type)
.Case([this](LLVM::LLVMVoidType) {
return llvm::Type::getVoidTy(context);
})
.Case([this](LLVM::LLVMHalfType) {
return llvm::Type::getHalfTy(context);
})
.Case([this](LLVM::LLVMBFloatType) {
return llvm::Type::getBFloatTy(context);
})
.Case([this](LLVM::LLVMFloatType) {
return llvm::Type::getFloatTy(context);
})
.Case([this](LLVM::LLVMDoubleType) {
return llvm::Type::getDoubleTy(context);
})
.Case([this](LLVM::LLVMFP128Type) {
return llvm::Type::getFP128Ty(context);
})
.Case([this](LLVM::LLVMX86FP80Type) {
return llvm::Type::getX86_FP80Ty(context);
})
.Case([this](LLVM::LLVMPPCFP128Type) {
return llvm::Type::getPPC_FP128Ty(context);
})
.Case([this](LLVM::LLVMX86MMXType) {
return llvm::Type::getX86_MMXTy(context);
})
.Case([this](LLVM::LLVMTokenType) {
return llvm::Type::getTokenTy(context);
})
.Case([this](LLVM::LLVMLabelType) {
return llvm::Type::getLabelTy(context);
})
.Case([this](LLVM::LLVMMetadataType) {
return llvm::Type::getMetadataTy(context);
})
.Case<LLVM::LLVMArrayType, LLVM::LLVMIntegerType,
LLVM::LLVMFunctionType, LLVM::LLVMPointerType,
LLVM::LLVMStructType, LLVM::LLVMFixedVectorType,
LLVM::LLVMScalableVectorType>(
[this](auto array) { return translate(array); })
.Default([](LLVM::LLVMTypeNew t) -> llvm::Type * {
llvm_unreachable("unknown LLVM dialect type");
});
// Cache the result of the conversion and return.
knownTranslations.try_emplace(type, translated);
return translated;
}
private:
/// Translates the given array type.
llvm::Type *translate(LLVM::LLVMArrayType type) {
return llvm::ArrayType::get(translateType(type.getElementType()),
type.getNumElements());
}
/// Translates the given function type.
llvm::Type *translate(LLVM::LLVMFunctionType type) {
SmallVector<llvm::Type *, 8> paramTypes;
translateTypes(type.getParams(), paramTypes);
return llvm::FunctionType::get(translateType(type.getReturnType()),
paramTypes, type.isVarArg());
}
/// Translates the given integer type.
llvm::Type *translate(LLVM::LLVMIntegerType type) {
return llvm::IntegerType::get(context, type.getBitWidth());
}
/// Translates the given pointer type.
llvm::Type *translate(LLVM::LLVMPointerType type) {
return llvm::PointerType::get(translateType(type.getElementType()),
type.getAddressSpace());
}
/// Translates the given structure type, supports both identified and literal
/// structs. This will _create_ a new identified structure every time, use
/// `convertType` if a structure with the same name must be looked up instead.
llvm::Type *translate(LLVM::LLVMStructType type) {
SmallVector<llvm::Type *, 8> subtypes;
if (!type.isIdentified()) {
translateTypes(type.getBody(), subtypes);
return llvm::StructType::get(context, subtypes, type.isPacked());
}
llvm::StructType *structType =
llvm::StructType::create(context, type.getName());
// Mark the type we just created as known so that recursive calls can pick
// it up and use directly.
knownTranslations.try_emplace(type, structType);
if (type.isOpaque())
return structType;
translateTypes(type.getBody(), subtypes);
structType->setBody(subtypes, type.isPacked());
return structType;
}
/// Translates the given fixed-vector type.
llvm::Type *translate(LLVM::LLVMFixedVectorType type) {
return llvm::FixedVectorType::get(translateType(type.getElementType()),
type.getNumElements());
}
/// Translates the given scalable-vector type.
llvm::Type *translate(LLVM::LLVMScalableVectorType type) {
return llvm::ScalableVectorType::get(translateType(type.getElementType()),
type.getMinNumElements());
}
/// Translates a list of types.
void translateTypes(ArrayRef<LLVM::LLVMTypeNew> types,
SmallVectorImpl<llvm::Type *> &result) {
result.reserve(result.size() + types.size());
for (auto type : types)
result.push_back(translateType(type));
}
/// Reference to the context in which the LLVM IR types are created.
llvm::LLVMContext &context;
/// Map of known translation. This serves a double purpose: caches translation
/// results to avoid repeated recursive calls and makes sure identified
/// structs with the same name (that is, equal) are resolved to an existing
/// type instead of creating a new type.
llvm::DenseMap<LLVM::LLVMTypeNew, llvm::Type *> knownTranslations;
};
} // end namespace
/// Translates a type from MLIR LLVM dialect to LLVM IR. This does not maintain
/// the mapping for identified structs so new structs will be created with
/// auto-renaming on each call. This is intended exclusively for testing.
llvm::Type *mlir::LLVM::translateTypeToLLVMIR(LLVM::LLVMTypeNew type,
llvm::LLVMContext &context) {
return TypeToLLVMIRTranslator(context).translateType(type);
}
namespace {
/// Support for translating LLVM IR types to MLIR LLVM dialect types.
class TypeFromLLVMIRTranslator {
public:
/// Constructs a class creating types in the given MLIR context.
TypeFromLLVMIRTranslator(MLIRContext &context) : context(context) {}
/// Translates the given type.
LLVM::LLVMTypeNew translateType(llvm::Type *type) {
if (knownTranslations.count(type))
return knownTranslations.lookup(type);
LLVM::LLVMTypeNew translated =
llvm::TypeSwitch<llvm::Type *, LLVM::LLVMTypeNew>(type)
.Case<llvm::ArrayType, llvm::FunctionType, llvm::IntegerType,
llvm::PointerType, llvm::StructType, llvm::FixedVectorType,
llvm::ScalableVectorType>(
[this](auto *type) { return translate(type); })
.Default([this](llvm::Type *type) {
return translatePrimitiveType(type);
});
knownTranslations.try_emplace(type, translated);
return translated;
}
private:
/// Translates the given primitive, i.e. non-parametric in MLIR nomenclature,
/// type.
LLVM::LLVMTypeNew translatePrimitiveType(llvm::Type *type) {
if (type->isVoidTy())
return LLVM::LLVMVoidType::get(&context);
if (type->isHalfTy())
return LLVM::LLVMHalfType::get(&context);
if (type->isBFloatTy())
return LLVM::LLVMBFloatType::get(&context);
if (type->isFloatTy())
return LLVM::LLVMFloatType::get(&context);
if (type->isDoubleTy())
return LLVM::LLVMDoubleType::get(&context);
if (type->isFP128Ty())
return LLVM::LLVMFP128Type::get(&context);
if (type->isX86_FP80Ty())
return LLVM::LLVMX86FP80Type::get(&context);
if (type->isPPC_FP128Ty())
return LLVM::LLVMPPCFP128Type::get(&context);
if (type->isX86_MMXTy())
return LLVM::LLVMX86MMXType::get(&context);
if (type->isLabelTy())
return LLVM::LLVMLabelType::get(&context);
if (type->isMetadataTy())
return LLVM::LLVMMetadataType::get(&context);
llvm_unreachable("not a primitive type");
}
/// Translates the given array type.
LLVM::LLVMTypeNew translate(llvm::ArrayType *type) {
return LLVM::LLVMArrayType::get(translateType(type->getElementType()),
type->getNumElements());
}
/// Translates the given function type.
LLVM::LLVMTypeNew translate(llvm::FunctionType *type) {
SmallVector<LLVM::LLVMTypeNew, 8> paramTypes;
translateTypes(type->params(), paramTypes);
return LLVM::LLVMFunctionType::get(translateType(type->getReturnType()),
paramTypes, type->isVarArg());
}
/// Translates the given integer type.
LLVM::LLVMTypeNew translate(llvm::IntegerType *type) {
return LLVM::LLVMIntegerType::get(&context, type->getBitWidth());
}
/// Translates the given pointer type.
LLVM::LLVMTypeNew translate(llvm::PointerType *type) {
return LLVM::LLVMPointerType::get(translateType(type->getElementType()),
type->getAddressSpace());
}
/// Translates the given structure type.
LLVM::LLVMTypeNew translate(llvm::StructType *type) {
SmallVector<LLVM::LLVMTypeNew, 8> subtypes;
if (type->isLiteral()) {
translateTypes(type->subtypes(), subtypes);
return LLVM::LLVMStructType::getLiteral(&context, subtypes,
type->isPacked());
}
if (type->isOpaque())
return LLVM::LLVMStructType::getOpaque(type->getName(), &context);
LLVM::LLVMStructType translated =
LLVM::LLVMStructType::getIdentified(&context, type->getName());
knownTranslations.try_emplace(type, translated);
translateTypes(type->subtypes(), subtypes);
LogicalResult bodySet = translated.setBody(subtypes, type->isPacked());
assert(succeeded(bodySet) &&
"could not set the body of an identified struct");
(void)bodySet;
return translated;
}
/// Translates the given fixed-vector type.
LLVM::LLVMTypeNew translate(llvm::FixedVectorType *type) {
return LLVM::LLVMFixedVectorType::get(translateType(type->getElementType()),
type->getNumElements());
}
/// Translates the given scalable-vector type.
LLVM::LLVMTypeNew translate(llvm::ScalableVectorType *type) {
return LLVM::LLVMScalableVectorType::get(
translateType(type->getElementType()), type->getMinNumElements());
}
/// Translates a list of types.
void translateTypes(ArrayRef<llvm::Type *> types,
SmallVectorImpl<LLVM::LLVMTypeNew> &result) {
result.reserve(result.size() + types.size());
for (llvm::Type *type : types)
result.push_back(translateType(type));
}
/// Map of known translations. Serves as a cache and as recursion stopper for
/// translating recursive structs.
llvm::DenseMap<llvm::Type *, LLVM::LLVMTypeNew> knownTranslations;
/// The context in which MLIR types are created.
MLIRContext &context;
};
} // end namespace
/// Translates a type from LLVM IR to MLIR LLVM dialect. This is intended
/// exclusively for testing.
LLVM::LLVMTypeNew mlir::LLVM::translateTypeFromLLVMIR(llvm::Type *type,
MLIRContext &context) {
return TypeFromLLVMIRTranslator(context).translateType(type);
}

View file

@ -0,0 +1,228 @@
// RUN: mlir-translate -test-mlir-to-llvmir -split-input-file %s | FileCheck %s
llvm.func @primitives() {
// CHECK: declare void @return_void()
// CHECK: declare void @return_void_round()
"llvm.test_introduce_func"() { name = "return_void", type = !llvm2.void } : () -> ()
// CHECK: declare half @return_half()
// CHECK: declare half @return_half_round()
"llvm.test_introduce_func"() { name = "return_half", type = !llvm2.half } : () -> ()
// CHECK: declare bfloat @return_bfloat()
// CHECK: declare bfloat @return_bfloat_round()
"llvm.test_introduce_func"() { name = "return_bfloat", type = !llvm2.bfloat } : () -> ()
// CHECK: declare float @return_float()
// CHECK: declare float @return_float_round()
"llvm.test_introduce_func"() { name = "return_float", type = !llvm2.float } : () -> ()
// CHECK: declare double @return_double()
// CHECK: declare double @return_double_round()
"llvm.test_introduce_func"() { name = "return_double", type = !llvm2.double } : () -> ()
// CHECK: declare fp128 @return_fp128()
// CHECK: declare fp128 @return_fp128_round()
"llvm.test_introduce_func"() { name = "return_fp128", type = !llvm2.fp128 } : () -> ()
// CHECK: declare x86_fp80 @return_x86_fp80()
// CHECK: declare x86_fp80 @return_x86_fp80_round()
"llvm.test_introduce_func"() { name = "return_x86_fp80", type = !llvm2.x86_fp80 } : () -> ()
// CHECK: declare ppc_fp128 @return_ppc_fp128()
// CHECK: declare ppc_fp128 @return_ppc_fp128_round()
"llvm.test_introduce_func"() { name = "return_ppc_fp128", type = !llvm2.ppc_fp128 } : () -> ()
// CHECK: declare x86_mmx @return_x86_mmx()
// CHECK: declare x86_mmx @return_x86_mmx_round()
"llvm.test_introduce_func"() { name = "return_x86_mmx", type = !llvm2.x86_mmx } : () -> ()
llvm.return
}
llvm.func @funcs() {
// CHECK: declare void @f_void_i32(i32)
// CHECK: declare void @f_void_i32_round(i32)
"llvm.test_introduce_func"() { name ="f_void_i32", type = !llvm2.func<void (i32)> } : () -> ()
// CHECK: declare i32 @f_i32_empty()
// CHECK: declare i32 @f_i32_empty_round()
"llvm.test_introduce_func"() { name ="f_i32_empty", type = !llvm2.func<i32 ()> } : () -> ()
// CHECK: declare i32 @f_i32_half_bfloat_float_double(half, bfloat, float, double)
// CHECK: declare i32 @f_i32_half_bfloat_float_double_round(half, bfloat, float, double)
"llvm.test_introduce_func"() { name ="f_i32_half_bfloat_float_double", type = !llvm2.func<i32 (half, bfloat, float, double)> } : () -> ()
// CHECK: declare i32 @f_i32_i32_i32(i32, i32)
// CHECK: declare i32 @f_i32_i32_i32_round(i32, i32)
"llvm.test_introduce_func"() { name ="f_i32_i32_i32", type = !llvm2.func<i32 (i32, i32)> } : () -> ()
// CHECK: declare void @f_void_variadic(...)
// CHECK: declare void @f_void_variadic_round(...)
"llvm.test_introduce_func"() { name ="f_void_variadic", type = !llvm2.func<void (...)> } : () -> ()
// CHECK: declare void @f_void_i32_i32_variadic(i32, i32, ...)
// CHECK: declare void @f_void_i32_i32_variadic_round(i32, i32, ...)
"llvm.test_introduce_func"() { name ="f_void_i32_i32_variadic", type = !llvm2.func<void (i32, i32, ...)> } : () -> ()
llvm.return
}
llvm.func @ints() {
// CHECK: declare i1 @return_i1()
// CHECK: declare i1 @return_i1_round()
"llvm.test_introduce_func"() { name = "return_i1", type = !llvm2.i1 } : () -> ()
// CHECK: declare i8 @return_i8()
// CHECK: declare i8 @return_i8_round()
"llvm.test_introduce_func"() { name = "return_i8", type = !llvm2.i8 } : () -> ()
// CHECK: declare i16 @return_i16()
// CHECK: declare i16 @return_i16_round()
"llvm.test_introduce_func"() { name = "return_i16", type = !llvm2.i16 } : () -> ()
// CHECK: declare i32 @return_i32()
// CHECK: declare i32 @return_i32_round()
"llvm.test_introduce_func"() { name = "return_i32", type = !llvm2.i32 } : () -> ()
// CHECK: declare i64 @return_i64()
// CHECK: declare i64 @return_i64_round()
"llvm.test_introduce_func"() { name = "return_i64", type = !llvm2.i64 } : () -> ()
// CHECK: declare i57 @return_i57()
// CHECK: declare i57 @return_i57_round()
"llvm.test_introduce_func"() { name = "return_i57", type = !llvm2.i57 } : () -> ()
// CHECK: declare i129 @return_i129()
// CHECK: declare i129 @return_i129_round()
"llvm.test_introduce_func"() { name = "return_i129", type = !llvm2.i129 } : () -> ()
llvm.return
}
llvm.func @pointers() {
// CHECK: declare i8* @return_pi8()
// CHECK: declare i8* @return_pi8_round()
"llvm.test_introduce_func"() { name = "return_pi8", type = !llvm2.ptr<i8> } : () -> ()
// CHECK: declare float* @return_pfloat()
// CHECK: declare float* @return_pfloat_round()
"llvm.test_introduce_func"() { name = "return_pfloat", type = !llvm2.ptr<float> } : () -> ()
// CHECK: declare i8** @return_ppi8()
// CHECK: declare i8** @return_ppi8_round()
"llvm.test_introduce_func"() { name = "return_ppi8", type = !llvm2.ptr<ptr<i8>> } : () -> ()
// CHECK: declare i8***** @return_pppppi8()
// CHECK: declare i8***** @return_pppppi8_round()
"llvm.test_introduce_func"() { name = "return_pppppi8", type = !llvm2.ptr<ptr<ptr<ptr<ptr<i8>>>>> } : () -> ()
// CHECK: declare i8* @return_pi8_0()
// CHECK: declare i8* @return_pi8_0_round()
"llvm.test_introduce_func"() { name = "return_pi8_0", type = !llvm2.ptr<i8, 0> } : () -> ()
// CHECK: declare i8 addrspace(1)* @return_pi8_1()
// CHECK: declare i8 addrspace(1)* @return_pi8_1_round()
"llvm.test_introduce_func"() { name = "return_pi8_1", type = !llvm2.ptr<i8, 1> } : () -> ()
// CHECK: declare i8 addrspace(42)* @return_pi8_42()
// CHECK: declare i8 addrspace(42)* @return_pi8_42_round()
"llvm.test_introduce_func"() { name = "return_pi8_42", type = !llvm2.ptr<i8, 42> } : () -> ()
// CHECK: declare i8 addrspace(42)* addrspace(9)* @return_ppi8_42_9()
// CHECK: declare i8 addrspace(42)* addrspace(9)* @return_ppi8_42_9_round()
"llvm.test_introduce_func"() { name = "return_ppi8_42_9", type = !llvm2.ptr<ptr<i8, 42>, 9> } : () -> ()
llvm.return
}
llvm.func @vectors() {
// CHECK: declare <4 x i32> @return_v4_i32()
// CHECK: declare <4 x i32> @return_v4_i32_round()
"llvm.test_introduce_func"() { name = "return_v4_i32", type = !llvm2.vec<4 x i32> } : () -> ()
// CHECK: declare <4 x float> @return_v4_float()
// CHECK: declare <4 x float> @return_v4_float_round()
"llvm.test_introduce_func"() { name = "return_v4_float", type = !llvm2.vec<4 x float> } : () -> ()
// CHECK: declare <vscale x 4 x i32> @return_vs_4_i32()
// CHECK: declare <vscale x 4 x i32> @return_vs_4_i32_round()
"llvm.test_introduce_func"() { name = "return_vs_4_i32", type = !llvm2.vec<? x 4 x i32> } : () -> ()
// CHECK: declare <vscale x 8 x half> @return_vs_8_half()
// CHECK: declare <vscale x 8 x half> @return_vs_8_half_round()
"llvm.test_introduce_func"() { name = "return_vs_8_half", type = !llvm2.vec<? x 8 x half> } : () -> ()
// CHECK: declare <4 x i8*> @return_v_4_pi8()
// CHECK: declare <4 x i8*> @return_v_4_pi8_round()
"llvm.test_introduce_func"() { name = "return_v_4_pi8", type = !llvm2.vec<4 x ptr<i8>> } : () -> ()
llvm.return
}
llvm.func @arrays() {
// CHECK: declare [10 x i32] @return_a10_i32()
// CHECK: declare [10 x i32] @return_a10_i32_round()
"llvm.test_introduce_func"() { name = "return_a10_i32", type = !llvm2.array<10 x i32> } : () -> ()
// CHECK: declare [8 x float] @return_a8_float()
// CHECK: declare [8 x float] @return_a8_float_round()
"llvm.test_introduce_func"() { name = "return_a8_float", type = !llvm2.array<8 x float> } : () -> ()
// CHECK: declare [10 x i32 addrspace(4)*] @return_a10_pi32_4()
// CHECK: declare [10 x i32 addrspace(4)*] @return_a10_pi32_4_round()
"llvm.test_introduce_func"() { name = "return_a10_pi32_4", type = !llvm2.array<10 x ptr<i32, 4>> } : () -> ()
// CHECK: declare [10 x [4 x float]] @return_a10_a4_float()
// CHECK: declare [10 x [4 x float]] @return_a10_a4_float_round()
"llvm.test_introduce_func"() { name = "return_a10_a4_float", type = !llvm2.array<10 x array<4 x float>> } : () -> ()
llvm.return
}
llvm.func @literal_structs() {
// CHECK: declare {} @return_struct_empty()
// CHECK: declare {} @return_struct_empty_round()
"llvm.test_introduce_func"() { name = "return_struct_empty", type = !llvm2.struct<()> } : () -> ()
// CHECK: declare { i32 } @return_s_i32()
// CHECK: declare { i32 } @return_s_i32_round()
"llvm.test_introduce_func"() { name = "return_s_i32", type = !llvm2.struct<(i32)> } : () -> ()
// CHECK: declare { float, i32 } @return_s_float_i32()
// CHECK: declare { float, i32 } @return_s_float_i32_round()
"llvm.test_introduce_func"() { name = "return_s_float_i32", type = !llvm2.struct<(float, i32)> } : () -> ()
// CHECK: declare { { i32 } } @return_s_s_i32()
// CHECK: declare { { i32 } } @return_s_s_i32_round()
"llvm.test_introduce_func"() { name = "return_s_s_i32", type = !llvm2.struct<(struct<(i32)>)> } : () -> ()
// CHECK: declare { i32, { i32 }, float } @return_s_i32_s_i32_float()
// CHECK: declare { i32, { i32 }, float } @return_s_i32_s_i32_float_round()
"llvm.test_introduce_func"() { name = "return_s_i32_s_i32_float", type = !llvm2.struct<(i32, struct<(i32)>, float)> } : () -> ()
// CHECK: declare <{}> @return_sp_empty()
// CHECK: declare <{}> @return_sp_empty_round()
"llvm.test_introduce_func"() { name = "return_sp_empty", type = !llvm2.struct<packed ()> } : () -> ()
// CHECK: declare <{ i32 }> @return_sp_i32()
// CHECK: declare <{ i32 }> @return_sp_i32_round()
"llvm.test_introduce_func"() { name = "return_sp_i32", type = !llvm2.struct<packed (i32)> } : () -> ()
// CHECK: declare <{ float, i32 }> @return_sp_float_i32()
// CHECK: declare <{ float, i32 }> @return_sp_float_i32_round()
"llvm.test_introduce_func"() { name = "return_sp_float_i32", type = !llvm2.struct<packed (float, i32)> } : () -> ()
// CHECK: declare <{ i32, { i32, i1 }, float }> @return_sp_i32_s_i31_1_float()
// CHECK: declare <{ i32, { i32, i1 }, float }> @return_sp_i32_s_i31_1_float_round()
"llvm.test_introduce_func"() { name = "return_sp_i32_s_i31_1_float", type = !llvm2.struct<packed (i32, struct<(i32, i1)>, float)> } : () -> ()
// CHECK: declare { <{ i32 }> } @return_s_sp_i32()
// CHECK: declare { <{ i32 }> } @return_s_sp_i32_round()
"llvm.test_introduce_func"() { name = "return_s_sp_i32", type = !llvm2.struct<(struct<packed (i32)>)> } : () -> ()
// CHECK: declare <{ { i32 } }> @return_sp_s_i32()
// CHECK: declare <{ { i32 } }> @return_sp_s_i32_round()
"llvm.test_introduce_func"() { name = "return_sp_s_i32", type = !llvm2.struct<packed (struct<(i32)>)> } : () -> ()
llvm.return
}
// -----
// Put structs into a separate split so that we can match their declarations
// locally.
// CHECK: %empty = type {}
// CHECK: %opaque = type opaque
// CHECK: %long = type { i32, { i32, i1 }, float, void ()* }
// CHECK: %self-recursive = type { %self-recursive* }
// CHECK: %unpacked = type { i32 }
// CHECK: %packed = type <{ i32 }>
// CHECK: %"name with spaces and !^$@$#" = type <{ i32 }>
// CHECK: %mutually-a = type { %mutually-b* }
// CHECK: %mutually-b = type { %mutually-a addrspace(3)* }
// CHECK: %struct-of-arrays = type { [10 x i32] }
// CHECK: %array-of-structs = type { i32 }
// CHECK: %ptr-to-struct = type { i8 }
llvm.func @identified_structs() {
// CHECK: declare %empty
"llvm.test_introduce_func"() { name = "return_s_empty", type = !llvm2.struct<"empty", ()> } : () -> ()
// CHECK: declare %opaque
"llvm.test_introduce_func"() { name = "return_s_opaque", type = !llvm2.struct<"opaque", opaque> } : () -> ()
// CHECK: declare %long
"llvm.test_introduce_func"() { name = "return_s_long", type = !llvm2.struct<"long", (i32, struct<(i32, i1)>, float, ptr<func<void ()>>)> } : () -> ()
// CHECK: declare %self-recursive
"llvm.test_introduce_func"() { name = "return_s_self_recurisve", type = !llvm2.struct<"self-recursive", (ptr<struct<"self-recursive">>)> } : () -> ()
// CHECK: declare %unpacked
"llvm.test_introduce_func"() { name = "return_s_unpacked", type = !llvm2.struct<"unpacked", (i32)> } : () -> ()
// CHECK: declare %packed
"llvm.test_introduce_func"() { name = "return_s_packed", type = !llvm2.struct<"packed", packed (i32)> } : () -> ()
// CHECK: declare %"name with spaces and !^$@$#"
"llvm.test_introduce_func"() { name = "return_s_symbols", type = !llvm2.struct<"name with spaces and !^$@$#", packed (i32)> } : () -> ()
// CHECK: declare %mutually-a
"llvm.test_introduce_func"() { name = "return_s_mutually_a", type = !llvm2.struct<"mutually-a", (ptr<struct<"mutually-b", (ptr<struct<"mutually-a">, 3>)>>)> } : () -> ()
// CHECK: declare %mutually-b
"llvm.test_introduce_func"() { name = "return_s_mutually_b", type = !llvm2.struct<"mutually-b", (ptr<struct<"mutually-a", (ptr<struct<"mutually-b">>)>, 3>)> } : () -> ()
// CHECK: declare %struct-of-arrays
"llvm.test_introduce_func"() { name = "return_s_struct_of_arrays", type = !llvm2.struct<"struct-of-arrays", (array<10 x i32>)> } : () -> ()
// CHECK: declare [10 x %array-of-structs]
"llvm.test_introduce_func"() { name = "return_s_array_of_structs", type = !llvm2.array<10 x struct<"array-of-structs", (i32)>> } : () -> ()
// CHECK: declare %ptr-to-struct*
"llvm.test_introduce_func"() { name = "return_s_ptr_to_struct", type = !llvm2.ptr<struct<"ptr-to-struct", (i8)>> } : () -> ()
llvm.return
}

View file

@ -2,4 +2,5 @@ add_subdirectory(Dialect)
add_subdirectory(IR) add_subdirectory(IR)
add_subdirectory(Pass) add_subdirectory(Pass)
add_subdirectory(Reducer) add_subdirectory(Reducer)
add_subdirectory(Target)
add_subdirectory(Transforms) add_subdirectory(Transforms)

View file

@ -0,0 +1,13 @@
add_mlir_translation_library(MLIRTestLLVMTypeTranslation
TestLLVMTypeTranslation.cpp
LINK_COMPONENTS
Core
TransformUtils
LINK_LIBS PUBLIC
MLIRLLVMIR
MLIRTargetLLVMIRModuleTranslation
MLIRTestIR
MLIRTranslation
)

View file

@ -0,0 +1,79 @@
//===- TestLLVMTypeTranslation.cpp - Test MLIR/LLVM IR type translation ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
#include "mlir/Target/LLVMIR/TypeTranslation.h"
#include "mlir/Translation.h"
using namespace mlir;
namespace {
class TestLLVMTypeTranslation : public LLVM::ModuleTranslation {
// Allow access to the constructors under MSVC.
friend LLVM::ModuleTranslation;
public:
using LLVM::ModuleTranslation::ModuleTranslation;
protected:
/// Simple test facility for translating types from MLIR LLVM dialect to LLVM
/// IR. This converts the "llvm.test_introduce_func" operation into an LLVM IR
/// function with the name extracted from the `name` attribute that returns
/// the type contained in the `type` attribute if it is a non-function type or
/// that has the signature obtained by converting `type` if it is a function
/// type. This is a temporary check before type translation is substituted
/// into the main translation flow and exercised here.
LogicalResult convertOperation(Operation &op,
llvm::IRBuilder<> &builder) override {
if (op.getName().getStringRef() == "llvm.test_introduce_func") {
auto attr = op.getAttrOfType<TypeAttr>("type");
assert(attr && "expected 'type' attribute");
auto type = attr.getValue().cast<LLVM::LLVMTypeNew>();
auto nameAttr = op.getAttrOfType<StringAttr>("name");
assert(nameAttr && "expected 'name' attributes");
llvm::Type *translated =
LLVM::translateTypeToLLVMIR(type, builder.getContext());
llvm::Module *module = builder.GetInsertBlock()->getModule();
if (auto *funcType = dyn_cast<llvm::FunctionType>(translated))
module->getOrInsertFunction(nameAttr.getValue(), funcType);
else
module->getOrInsertFunction(nameAttr.getValue(), translated);
std::string roundtripName = (Twine(nameAttr.getValue()) + "_round").str();
LLVM::LLVMTypeNew translatedBack =
LLVM::translateTypeFromLLVMIR(translated, *op.getContext());
llvm::Type *translatedBackAndForth =
LLVM::translateTypeToLLVMIR(translatedBack, builder.getContext());
if (auto *funcType = dyn_cast<llvm::FunctionType>(translatedBackAndForth))
module->getOrInsertFunction(roundtripName, funcType);
else
module->getOrInsertFunction(roundtripName, translatedBackAndForth);
return success();
}
return LLVM::ModuleTranslation::convertOperation(op, builder);
}
};
} // namespace
namespace mlir {
void registerTestLLVMTypeTranslation() {
TranslateFromMLIRRegistration reg(
"test-mlir-to-llvmir", [](ModuleOp module, raw_ostream &output) {
std::unique_ptr<llvm::Module> llvmModule =
LLVM::ModuleTranslation::translateModule<TestLLVMTypeTranslation>(
module.getOperation());
llvmModule->print(output, nullptr);
return success();
});
}
} // namespace mlir

View file

@ -13,7 +13,11 @@ target_link_libraries(mlir-translate
PRIVATE PRIVATE
${dialect_libs} ${dialect_libs}
${translation_libs} ${translation_libs}
${test_libs}
MLIRIR MLIRIR
# TODO: remove after LLVM dialect transition is complete; translation uses a
# registration function defined in this library unconditionally.
MLIRLLVMTypeTestDialect
MLIRParser MLIRParser
MLIRPass MLIRPass
MLIRSPIRV MLIRSPIRV

View file

@ -49,17 +49,21 @@ static llvm::cl::opt<bool> verifyDiagnostics(
namespace mlir { namespace mlir {
// Defined in the test directory, no public header. // Defined in the test directory, no public header.
void registerLLVMTypeTestDialect();
void registerTestLLVMTypeTranslation();
void registerTestRoundtripSPIRV(); void registerTestRoundtripSPIRV();
void registerTestRoundtripDebugSPIRV(); void registerTestRoundtripDebugSPIRV();
} // namespace mlir } // namespace mlir
static void registerTestTranslations() { static void registerTestTranslations() {
registerTestLLVMTypeTranslation();
registerTestRoundtripSPIRV(); registerTestRoundtripSPIRV();
registerTestRoundtripDebugSPIRV(); registerTestRoundtripDebugSPIRV();
} }
int main(int argc, char **argv) { int main(int argc, char **argv) {
registerAllDialects(); registerAllDialects();
registerLLVMTypeTestDialect();
registerAllTranslations(); registerAllTranslations();
registerTestTranslations(); registerTestTranslations();
llvm::InitLLVM y(argc, argv); llvm::InitLLVM y(argc, argv);