[mlir] LLVM dialect: use addressof instead of constant to create function pointers

`llvm.mlir.constant` was originally introduced as an LLVM dialect counterpart
to `std.constant`. As such, it was supporting "function pointer" constants
derived from the symbol name. This is different from `std.constant` that allows
for creation of a "function" constant since MLIR, unlike LLVM IR, supports
this. Later, `llvm.mlir.addressof` was introduced as an Op that obtains a
constant pointer to a global in the LLVM dialect. It naturally extends to
functions (in LLVM IR, functions are globals) and should be used for defining
"function pointer" values instead.

Fixes PR46344.

Differential Revision: https://reviews.llvm.org/D82667
This commit is contained in:
Alex Zinenko 2020-06-29 12:16:23 +02:00
parent e503851d80
commit cba733edf5
12 changed files with 153 additions and 35 deletions

View file

@ -313,10 +313,28 @@ Bitwise reinterpretation: `bitcast <value>`.
Selection: `select <condition>, <lhs>, <rhs>`.
### Auxiliary MLIR operations
### Auxiliary MLIR Operations for Constants and Globals
These operations do not have LLVM IR counterparts but are necessary to map LLVM
IR into MLIR. They should be prefixed with `llvm.mlir`.
LLVM IR has broad support for first-class constants, which is not the case for
MLIR. Instead, constants are defined in MLIR as regular SSA values produced by
operations with specific traits. The LLVM dialect provides a set of operations
that model LLVM IR constants. These operations do not correspond to LLVM IR
instructions and are therefore prefixed with `llvm.mlir`.
Inline constants can be created by `llvm.mlir.constant`, which currently
supports integer, float, string or elements attributes (constant sturcts are not
currently supported). LLVM IR constant expressions are expected to be
constructed as sequences of regular operations on SSA values produced by
`llvm.mlir.constant`. Additionally, MLIR provides semantically-charged
operations `llvm.mlir.undef` and `llvm.mlir.null` for the corresponding
constants.
LLVM IR globals can be defined using `llvm.mlir.global` at the module level,
except for functions that are defined with `llvm.func`. Globals, both variables
and functions, can be accessed by taking their address with the
`llvm.mlir.addressof` operation, which produces a pointer to the named global,
unlike the `llvm.mlir.constant` that produces the value of the same type as the
constant.
#### `llvm.mlir.addressof`
@ -328,11 +346,17 @@ Examples:
```mlir
func @foo() {
// Get the address of a global.
// Get the address of a global variable.
%0 = llvm.mlir.addressof @const : !llvm<"i32*">
// Use it as a regular pointer.
%1 = llvm.load %0 : !llvm<"i32*">
// Get the address of a function.
%2 = llvm.mlir.addressof @foo : !llvm<"void ()*">
// The function address can be used for indirect calls.
llvm.call %2() : () -> ()
}
// Define the global.

View file

@ -575,6 +575,8 @@ def Linkage : LLVM_EnumAttr<
def LLVM_AddressOfOp
: LLVM_OneResultOp<"mlir.addressof">,
Arguments<(ins FlatSymbolRefAttr:$global_name)> {
let summary = "Creates a pointer pointing to a global or a function";
let builders = [
OpBuilder<"OpBuilder &builder, OperationState &result, LLVMType resType, "
"StringRef name, ArrayRef<NamedAttribute> attrs = {}", [{
@ -586,13 +588,21 @@ def LLVM_AddressOfOp
"ArrayRef<NamedAttribute> attrs = {}", [{
build(builder, result,
global.getType().getPointerTo(global.addr_space().getZExtValue()),
global.sym_name(), attrs);}]>
global.sym_name(), attrs);}]>,
OpBuilder<"OpBuilder &builder, OperationState &result, LLVMFuncOp func, "
"ArrayRef<NamedAttribute> attrs = {}", [{
build(builder, result,
func.getType().getPointerTo(), func.getName(), attrs);}]>
];
let extraClassDeclaration = [{
/// Return the llvm.mlir.global operation that defined the value referenced
/// here.
GlobalOp getGlobal();
/// Return the llvm.func operation that is referenced here.
LLVMFuncOp getFunction();
}];
let assemblyFormat = "$global_name attr-dict `:` type($res)";
@ -733,6 +743,7 @@ def LLVM_ConstantOp
LLVM_Builder<"$res = getLLVMConstant($_resultType, $value, $_location);">
{
let assemblyFormat = "`(` $value `)` attr-dict `:` type($res)";
let verifier = [{ return ::verify(*this); }];
}
def LLVM_DialectCastOp : LLVM_Op<"mlir.cast", [NoSideEffect]>,

View file

@ -1366,8 +1366,6 @@ using AddFOpLowering = VectorConvertToLLVMPattern<AddFOp, LLVM::FAddOp>;
using AddIOpLowering = VectorConvertToLLVMPattern<AddIOp, LLVM::AddOp>;
using AndOpLowering = VectorConvertToLLVMPattern<AndOp, LLVM::AndOp>;
using CeilFOpLowering = VectorConvertToLLVMPattern<CeilFOp, LLVM::FCeilOp>;
using ConstLLVMOpLowering =
OneToOneConvertToLLVMPattern<ConstantOp, LLVM::ConstantOp>;
using CopySignOpLowering =
VectorConvertToLLVMPattern<CopySignOp, LLVM::CopySignOp>;
using CosOpLowering = VectorConvertToLLVMPattern<CosOp, LLVM::CosOp>;
@ -1541,6 +1539,39 @@ struct SubCFOpLowering : public ConvertOpToLLVMPattern<SubCFOp> {
}
};
struct ConstantOpLowering : public ConvertOpToLLVMPattern<ConstantOp> {
using ConvertOpToLLVMPattern<ConstantOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(Operation *operation, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto op = cast<ConstantOp>(operation);
// If constant refers to a function, convert it to "addressof".
if (auto symbolRef = op.getValue().dyn_cast<FlatSymbolRefAttr>()) {
auto type = typeConverter.convertType(op.getResult().getType())
.dyn_cast_or_null<LLVM::LLVMType>();
if (!type)
return rewriter.notifyMatchFailure(op, "failed to convert result type");
MutableDictionaryAttr attrs(op.getAttrs());
attrs.remove(rewriter.getIdentifier("value"));
rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(
op, type.cast<LLVM::LLVMType>(), symbolRef.getValue(),
attrs.getAttrs());
return success();
}
// Calling into other scopes (non-flat reference) is not supported in LLVM.
if (op.getValue().isa<SymbolRefAttr>())
return rewriter.notifyMatchFailure(
op, "referring to a symbol outside of the current module");
return LLVM::detail::oneToOneRewrite(op,
LLVM::ConstantOp::getOperationName(),
operands, typeConverter, rewriter);
}
};
// Check if the MemRefType `type` is supported by the lowering. We currently
// only support memrefs with identity maps.
static bool isSupportedMemRefType(MemRefType type) {
@ -3129,7 +3160,7 @@ void mlir::populateStdToLLVMNonMemoryConversionPatterns(
CondBranchOpLowering,
CopySignOpLowering,
CosOpLowering,
ConstLLVMOpLowering,
ConstantOpLowering,
CreateComplexOpLowering,
DialectCastOpLowering,
DivFOpLowering,

View file

@ -857,25 +857,40 @@ static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &result) {
// Verifier for LLVM::AddressOfOp.
//===----------------------------------------------------------------------===//
GlobalOp AddressOfOp::getGlobal() {
Operation *module = getParentOp();
template <typename OpTy>
static OpTy lookupSymbolInModule(Operation *parent, StringRef name) {
Operation *module = parent;
while (module && !satisfiesLLVMModule(module))
module = module->getParentOp();
assert(module && "unexpected operation outside of a module");
return dyn_cast_or_null<LLVM::GlobalOp>(
mlir::SymbolTable::lookupSymbolIn(module, global_name()));
return dyn_cast_or_null<OpTy>(
mlir::SymbolTable::lookupSymbolIn(module, name));
}
GlobalOp AddressOfOp::getGlobal() {
return lookupSymbolInModule<LLVM::GlobalOp>(getParentOp(), global_name());
}
LLVMFuncOp AddressOfOp::getFunction() {
return lookupSymbolInModule<LLVM::LLVMFuncOp>(getParentOp(), global_name());
}
static LogicalResult verify(AddressOfOp op) {
auto global = op.getGlobal();
if (!global)
auto function = op.getFunction();
if (!global && !function)
return op.emitOpError(
"must reference a global defined by 'llvm.mlir.global'");
"must reference a global defined by 'llvm.mlir.global' or 'llvm.func'");
if (global.getType().getPointerTo(global.addr_space().getZExtValue()) !=
op.getResult().getType())
if (global &&
global.getType().getPointerTo(global.addr_space().getZExtValue()) !=
op.getResult().getType())
return op.emitOpError(
"the type must be a pointer to the type of the referred global");
"the type must be a pointer to the type of the referenced global");
if (function && function.getType().getPointerTo() != op.getResult().getType())
return op.emitOpError(
"the type must be a pointer to the type of the referenced function");
return success();
}
@ -1395,6 +1410,18 @@ static LogicalResult verify(LLVM::NullOp op) {
return success();
}
//===----------------------------------------------------------------------===//
// Verification for LLVM::ConstantOp.
//===----------------------------------------------------------------------===//
static LogicalResult verify(LLVM::ConstantOp op) {
if (!(op.value().isa<IntegerAttr>() || op.value().isa<FloatAttr>() ||
op.value().isa<ElementsAttr>() || op.value().isa<StringAttr>()))
return op.emitOpError()
<< "only supports integer, float, string or elements attributes";
return success();
}
//===----------------------------------------------------------------------===//
// Utility functions for parsing atomic ops
//===----------------------------------------------------------------------===//

View file

@ -405,6 +405,9 @@ Value Importer::processConstant(llvm::Constant *c) {
LLVMType type = processType(c->getType());
if (!type)
return nullptr;
if (auto symbolRef = attr.dyn_cast<FlatSymbolRefAttr>())
return instMap[c] = bEntry.create<AddressOfOp>(unknownLoc, type,
symbolRef.getValue());
return instMap[c] = bEntry.create<ConstantOp>(unknownLoc, type, attr);
}
if (auto *cn = dyn_cast<llvm::ConstantPointerNull>(c)) {

View file

@ -447,10 +447,15 @@ LogicalResult ModuleTranslation::convertOperation(Operation &opInst,
// emit any LLVM instruction.
if (auto addressOfOp = dyn_cast<LLVM::AddressOfOp>(opInst)) {
LLVM::GlobalOp global = addressOfOp.getGlobal();
// The verifier should not have allowed this.
assert(global && "referencing an undefined global");
LLVM::LLVMFuncOp function = addressOfOp.getFunction();
valueMapping[addressOfOp.getResult()] = globalsMapping.lookup(global);
// The verifier should not have allowed this.
assert((global || function) &&
"referencing an undefined global or function");
valueMapping[addressOfOp.getResult()] =
global ? globalsMapping.lookup(global)
: functionMapping.lookup(function.getName());
return success();
}

View file

@ -31,7 +31,7 @@ func @pass_through(%arg0: () -> ()) -> (() -> ()) {
// CHECK-NEXT: llvm.br ^bb1(%arg0 : !llvm<"void ()*">)
br ^bb1(%arg0 : () -> ())
//CHECK-NEXT: ^bb1(%0: !llvm<"void ()*">): // pred: ^bb0
//CHECK-NEXT: ^bb1(%0: !llvm<"void ()*">):
^bb1(%bbarg: () -> ()):
// CHECK-NEXT: llvm.return %0 : !llvm<"void ()*">
return %bbarg : () -> ()
@ -40,11 +40,12 @@ func @pass_through(%arg0: () -> ()) -> (() -> ()) {
// CHECK-LABEL: llvm.func @body(!llvm.i32)
func @body(i32)
// CHECK-LABEL: llvm.func @indirect_const_call(%arg0: !llvm.i32) {
// CHECK-LABEL: llvm.func @indirect_const_call
// CHECK-SAME: (%[[ARG0:.*]]: !llvm.i32) {
func @indirect_const_call(%arg0: i32) {
// CHECK-NEXT: %0 = llvm.mlir.constant(@body) : !llvm<"void (i32)*">
// CHECK-NEXT: %[[ADDR:.*]] = llvm.mlir.addressof @body : !llvm<"void (i32)*">
%0 = constant @body : (i32) -> ()
// CHECK-NEXT: llvm.call %0(%arg0) : (!llvm.i32) -> ()
// CHECK-NEXT: llvm.call %[[ADDR]](%[[ARG0:.*]]) : (!llvm.i32) -> ()
call_indirect %0(%arg0) : (i32) -> ()
// CHECK-NEXT: llvm.return
return

View file

@ -140,12 +140,21 @@ func @foo() {
llvm.mlir.global internal @foo(0: i32) : !llvm.i32
func @bar() {
// expected-error @+1 {{the type must be a pointer to the type of the referred global}}
// expected-error @+1 {{the type must be a pointer to the type of the referenced global}}
llvm.mlir.addressof @foo : !llvm<"i64*">
}
// -----
llvm.func @foo()
llvm.func @bar() {
// expected-error @+1 {{the type must be a pointer to the type of the referenced function}}
llvm.mlir.addressof @foo : !llvm<"i8*">
}
// -----
// expected-error @+2 {{'llvm.mlir.global' op expects regions to end with 'llvm.return', found 'llvm.mlir.constant'}}
// expected-note @+1 {{in custom textual format, the absence of terminator implies 'llvm.return'}}
llvm.mlir.global internal @g() : !llvm.i64 {
@ -172,7 +181,7 @@ llvm.mlir.global internal @g(43 : i64) : !llvm.i64 {
llvm.mlir.global internal @g(32 : i64) {addr_space = 3: i32} : !llvm.i64
func @mismatch_addr_space_implicit_global() {
// expected-error @+1 {{op the type must be a pointer to the type of the referred global}}
// expected-error @+1 {{op the type must be a pointer to the type of the referenced global}}
llvm.mlir.addressof @g : !llvm<"i64*">
}
@ -180,6 +189,6 @@ func @mismatch_addr_space_implicit_global() {
llvm.mlir.global internal @g(32 : i64) {addr_space = 3: i32} : !llvm.i64
func @mismatch_addr_space() {
// expected-error @+1 {{op the type must be a pointer to the type of the referred global}}
// expected-error @+1 {{op the type must be a pointer to the type of the referenced global}}
llvm.mlir.addressof @g : !llvm<"i64 addrspace(4)*">
}

View file

@ -153,6 +153,13 @@ func @call_non_llvm_input(%callee : (i32) -> (), %arg : i32) {
// -----
func @constant_wrong_type() {
// expected-error@+1 {{only supports integer, float, string or elements attributes}}
llvm.mlir.constant(@constant_wrong_type) : !llvm<"void ()*">
}
// -----
func @insertvalue_non_llvm_type(%a : i32, %b : i32) {
// expected-error@+1 {{expected LLVM IR Dialect type}}
llvm.insertvalue %a, %b[0] : i32

View file

@ -55,12 +55,12 @@ func @ops(%arg0: !llvm.i32, %arg1: !llvm.float,
// CHECK: %[[STRUCT:.*]] = llvm.call @foo(%[[I32]]) : (!llvm.i32) -> !llvm<"{ i32, double, i32 }">
// CHECK: %[[VALUE:.*]] = llvm.extractvalue %[[STRUCT]][0] : !llvm<"{ i32, double, i32 }">
// CHECK: %[[NEW_STRUCT:.*]] = llvm.insertvalue %[[VALUE]], %[[STRUCT]][2] : !llvm<"{ i32, double, i32 }">
// CHECK: %[[FUNC:.*]] = llvm.mlir.constant(@foo) : !llvm<"{ i32, double, i32 } (i32)*">
// CHECK: %[[FUNC:.*]] = llvm.mlir.addressof @foo : !llvm<"{ i32, double, i32 } (i32)*">
// CHECK: %{{.*}} = llvm.call %[[FUNC]](%[[I32]]) : (!llvm.i32) -> !llvm<"{ i32, double, i32 }">
%17 = llvm.call @foo(%arg0) : (!llvm.i32) -> !llvm<"{ i32, double, i32 }">
%18 = llvm.extractvalue %17[0] : !llvm<"{ i32, double, i32 }">
%19 = llvm.insertvalue %18, %17[2] : !llvm<"{ i32, double, i32 }">
%20 = llvm.mlir.constant(@foo) : !llvm<"{ i32, double, i32 } (i32)*">
%20 = llvm.mlir.addressof @foo : !llvm<"{ i32, double, i32 } (i32)*">
%21 = llvm.call %20(%arg0) : (!llvm.i32) -> !llvm<"{ i32, double, i32 }">
@ -130,8 +130,8 @@ func @ops(%arg0: !llvm.i32, %arg1: !llvm.float,
}
// An larger self-contained function.
// CHECK-LABEL: func @foo(%{{.*}}: !llvm.i32) -> !llvm<"{ i32, double, i32 }"> {
func @foo(%arg0: !llvm.i32) -> !llvm<"{ i32, double, i32 }"> {
// CHECK-LABEL: llvm.func @foo(%{{.*}}: !llvm.i32) -> !llvm<"{ i32, double, i32 }"> {
llvm.func @foo(%arg0: !llvm.i32) -> !llvm<"{ i32, double, i32 }"> {
// CHECK: %[[V0:.*]] = llvm.mlir.constant(3 : i64) : !llvm.i32
// CHECK: %[[V1:.*]] = llvm.mlir.constant(3 : i64) : !llvm.i32
// CHECK: %[[V2:.*]] = llvm.mlir.constant(4.200000e+01 : f64) : !llvm.double

View file

@ -234,7 +234,7 @@ define void @FPArithmetic(float %a, float %b, double %c, double %d) {
; CHECK-LABEL: @precaller
define i32 @precaller() {
%1 = alloca i32 ()*
; CHECK: %[[func:.*]] = llvm.mlir.constant(@callee) : !llvm<"i32 ()*">
; CHECK: %[[func:.*]] = llvm.mlir.addressof @callee : !llvm<"i32 ()*">
; CHECK: llvm.store %[[func]], %[[loc:.*]]
store i32 ()* @callee, i32 ()** %1
; CHECK: %[[indir:.*]] = llvm.load %[[loc]]
@ -252,7 +252,7 @@ define i32 @callee() {
; CHECK-LABEL: @postcaller
define i32 @postcaller() {
%1 = alloca i32 ()*
; CHECK: %[[func:.*]] = llvm.mlir.constant(@callee) : !llvm<"i32 ()*">
; CHECK: %[[func:.*]] = llvm.mlir.addressof @callee : !llvm<"i32 ()*">
; CHECK: llvm.store %[[func]], %[[loc:.*]]
store i32 ()* @callee, i32 ()** %1
; CHECK: %[[indir:.*]] = llvm.load %[[loc]]
@ -317,4 +317,4 @@ define i32 @useFenceInst() {
;CHECK: llvm.fence seq_cst
fence syncscope("") seq_cst
ret i32 0
}
}

View file

@ -886,7 +886,7 @@ llvm.func @ops(%arg0: !llvm.float, %arg1: !llvm.float, %arg2: !llvm.i32, %arg3:
// CHECK-LABEL: define void @indirect_const_call(i64 {{%.*}})
llvm.func @indirect_const_call(%arg0: !llvm.i64) {
// CHECK-NEXT: call void @body(i64 %0)
%0 = llvm.mlir.constant(@body) : !llvm<"void (i64)*">
%0 = llvm.mlir.addressof @body : !llvm<"void (i64)*">
llvm.call %0(%arg0) : (!llvm.i64) -> ()
// CHECK-NEXT: ret void
llvm.return