[mlir] LLVM dialect: use addressof instead of constant to create function pointers
`llvm.mlir.constant` was originally introduced as an LLVM dialect counterpart to `std.constant`. As such, it was supporting "function pointer" constants derived from the symbol name. This is different from `std.constant` that allows for creation of a "function" constant since MLIR, unlike LLVM IR, supports this. Later, `llvm.mlir.addressof` was introduced as an Op that obtains a constant pointer to a global in the LLVM dialect. It naturally extends to functions (in LLVM IR, functions are globals) and should be used for defining "function pointer" values instead. Fixes PR46344. Differential Revision: https://reviews.llvm.org/D82667
This commit is contained in:
parent
e503851d80
commit
cba733edf5
|
@ -313,10 +313,28 @@ Bitwise reinterpretation: `bitcast <value>`.
|
|||
|
||||
Selection: `select <condition>, <lhs>, <rhs>`.
|
||||
|
||||
### Auxiliary MLIR operations
|
||||
### Auxiliary MLIR Operations for Constants and Globals
|
||||
|
||||
These operations do not have LLVM IR counterparts but are necessary to map LLVM
|
||||
IR into MLIR. They should be prefixed with `llvm.mlir`.
|
||||
LLVM IR has broad support for first-class constants, which is not the case for
|
||||
MLIR. Instead, constants are defined in MLIR as regular SSA values produced by
|
||||
operations with specific traits. The LLVM dialect provides a set of operations
|
||||
that model LLVM IR constants. These operations do not correspond to LLVM IR
|
||||
instructions and are therefore prefixed with `llvm.mlir`.
|
||||
|
||||
Inline constants can be created by `llvm.mlir.constant`, which currently
|
||||
supports integer, float, string or elements attributes (constant sturcts are not
|
||||
currently supported). LLVM IR constant expressions are expected to be
|
||||
constructed as sequences of regular operations on SSA values produced by
|
||||
`llvm.mlir.constant`. Additionally, MLIR provides semantically-charged
|
||||
operations `llvm.mlir.undef` and `llvm.mlir.null` for the corresponding
|
||||
constants.
|
||||
|
||||
LLVM IR globals can be defined using `llvm.mlir.global` at the module level,
|
||||
except for functions that are defined with `llvm.func`. Globals, both variables
|
||||
and functions, can be accessed by taking their address with the
|
||||
`llvm.mlir.addressof` operation, which produces a pointer to the named global,
|
||||
unlike the `llvm.mlir.constant` that produces the value of the same type as the
|
||||
constant.
|
||||
|
||||
#### `llvm.mlir.addressof`
|
||||
|
||||
|
@ -328,11 +346,17 @@ Examples:
|
|||
|
||||
```mlir
|
||||
func @foo() {
|
||||
// Get the address of a global.
|
||||
// Get the address of a global variable.
|
||||
%0 = llvm.mlir.addressof @const : !llvm<"i32*">
|
||||
|
||||
// Use it as a regular pointer.
|
||||
%1 = llvm.load %0 : !llvm<"i32*">
|
||||
|
||||
// Get the address of a function.
|
||||
%2 = llvm.mlir.addressof @foo : !llvm<"void ()*">
|
||||
|
||||
// The function address can be used for indirect calls.
|
||||
llvm.call %2() : () -> ()
|
||||
}
|
||||
|
||||
// Define the global.
|
||||
|
|
|
@ -575,6 +575,8 @@ def Linkage : LLVM_EnumAttr<
|
|||
def LLVM_AddressOfOp
|
||||
: LLVM_OneResultOp<"mlir.addressof">,
|
||||
Arguments<(ins FlatSymbolRefAttr:$global_name)> {
|
||||
let summary = "Creates a pointer pointing to a global or a function";
|
||||
|
||||
let builders = [
|
||||
OpBuilder<"OpBuilder &builder, OperationState &result, LLVMType resType, "
|
||||
"StringRef name, ArrayRef<NamedAttribute> attrs = {}", [{
|
||||
|
@ -586,13 +588,21 @@ def LLVM_AddressOfOp
|
|||
"ArrayRef<NamedAttribute> attrs = {}", [{
|
||||
build(builder, result,
|
||||
global.getType().getPointerTo(global.addr_space().getZExtValue()),
|
||||
global.sym_name(), attrs);}]>
|
||||
global.sym_name(), attrs);}]>,
|
||||
|
||||
OpBuilder<"OpBuilder &builder, OperationState &result, LLVMFuncOp func, "
|
||||
"ArrayRef<NamedAttribute> attrs = {}", [{
|
||||
build(builder, result,
|
||||
func.getType().getPointerTo(), func.getName(), attrs);}]>
|
||||
];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
/// Return the llvm.mlir.global operation that defined the value referenced
|
||||
/// here.
|
||||
GlobalOp getGlobal();
|
||||
|
||||
/// Return the llvm.func operation that is referenced here.
|
||||
LLVMFuncOp getFunction();
|
||||
}];
|
||||
|
||||
let assemblyFormat = "$global_name attr-dict `:` type($res)";
|
||||
|
@ -733,6 +743,7 @@ def LLVM_ConstantOp
|
|||
LLVM_Builder<"$res = getLLVMConstant($_resultType, $value, $_location);">
|
||||
{
|
||||
let assemblyFormat = "`(` $value `)` attr-dict `:` type($res)";
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
}
|
||||
|
||||
def LLVM_DialectCastOp : LLVM_Op<"mlir.cast", [NoSideEffect]>,
|
||||
|
|
|
@ -1366,8 +1366,6 @@ using AddFOpLowering = VectorConvertToLLVMPattern<AddFOp, LLVM::FAddOp>;
|
|||
using AddIOpLowering = VectorConvertToLLVMPattern<AddIOp, LLVM::AddOp>;
|
||||
using AndOpLowering = VectorConvertToLLVMPattern<AndOp, LLVM::AndOp>;
|
||||
using CeilFOpLowering = VectorConvertToLLVMPattern<CeilFOp, LLVM::FCeilOp>;
|
||||
using ConstLLVMOpLowering =
|
||||
OneToOneConvertToLLVMPattern<ConstantOp, LLVM::ConstantOp>;
|
||||
using CopySignOpLowering =
|
||||
VectorConvertToLLVMPattern<CopySignOp, LLVM::CopySignOp>;
|
||||
using CosOpLowering = VectorConvertToLLVMPattern<CosOp, LLVM::CosOp>;
|
||||
|
@ -1541,6 +1539,39 @@ struct SubCFOpLowering : public ConvertOpToLLVMPattern<SubCFOp> {
|
|||
}
|
||||
};
|
||||
|
||||
struct ConstantOpLowering : public ConvertOpToLLVMPattern<ConstantOp> {
|
||||
using ConvertOpToLLVMPattern<ConstantOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *operation, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto op = cast<ConstantOp>(operation);
|
||||
// If constant refers to a function, convert it to "addressof".
|
||||
if (auto symbolRef = op.getValue().dyn_cast<FlatSymbolRefAttr>()) {
|
||||
auto type = typeConverter.convertType(op.getResult().getType())
|
||||
.dyn_cast_or_null<LLVM::LLVMType>();
|
||||
if (!type)
|
||||
return rewriter.notifyMatchFailure(op, "failed to convert result type");
|
||||
|
||||
MutableDictionaryAttr attrs(op.getAttrs());
|
||||
attrs.remove(rewriter.getIdentifier("value"));
|
||||
rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(
|
||||
op, type.cast<LLVM::LLVMType>(), symbolRef.getValue(),
|
||||
attrs.getAttrs());
|
||||
return success();
|
||||
}
|
||||
|
||||
// Calling into other scopes (non-flat reference) is not supported in LLVM.
|
||||
if (op.getValue().isa<SymbolRefAttr>())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "referring to a symbol outside of the current module");
|
||||
|
||||
return LLVM::detail::oneToOneRewrite(op,
|
||||
LLVM::ConstantOp::getOperationName(),
|
||||
operands, typeConverter, rewriter);
|
||||
}
|
||||
};
|
||||
|
||||
// Check if the MemRefType `type` is supported by the lowering. We currently
|
||||
// only support memrefs with identity maps.
|
||||
static bool isSupportedMemRefType(MemRefType type) {
|
||||
|
@ -3129,7 +3160,7 @@ void mlir::populateStdToLLVMNonMemoryConversionPatterns(
|
|||
CondBranchOpLowering,
|
||||
CopySignOpLowering,
|
||||
CosOpLowering,
|
||||
ConstLLVMOpLowering,
|
||||
ConstantOpLowering,
|
||||
CreateComplexOpLowering,
|
||||
DialectCastOpLowering,
|
||||
DivFOpLowering,
|
||||
|
|
|
@ -857,25 +857,40 @@ static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &result) {
|
|||
// Verifier for LLVM::AddressOfOp.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
GlobalOp AddressOfOp::getGlobal() {
|
||||
Operation *module = getParentOp();
|
||||
template <typename OpTy>
|
||||
static OpTy lookupSymbolInModule(Operation *parent, StringRef name) {
|
||||
Operation *module = parent;
|
||||
while (module && !satisfiesLLVMModule(module))
|
||||
module = module->getParentOp();
|
||||
assert(module && "unexpected operation outside of a module");
|
||||
return dyn_cast_or_null<LLVM::GlobalOp>(
|
||||
mlir::SymbolTable::lookupSymbolIn(module, global_name()));
|
||||
return dyn_cast_or_null<OpTy>(
|
||||
mlir::SymbolTable::lookupSymbolIn(module, name));
|
||||
}
|
||||
|
||||
GlobalOp AddressOfOp::getGlobal() {
|
||||
return lookupSymbolInModule<LLVM::GlobalOp>(getParentOp(), global_name());
|
||||
}
|
||||
|
||||
LLVMFuncOp AddressOfOp::getFunction() {
|
||||
return lookupSymbolInModule<LLVM::LLVMFuncOp>(getParentOp(), global_name());
|
||||
}
|
||||
|
||||
static LogicalResult verify(AddressOfOp op) {
|
||||
auto global = op.getGlobal();
|
||||
if (!global)
|
||||
auto function = op.getFunction();
|
||||
if (!global && !function)
|
||||
return op.emitOpError(
|
||||
"must reference a global defined by 'llvm.mlir.global'");
|
||||
"must reference a global defined by 'llvm.mlir.global' or 'llvm.func'");
|
||||
|
||||
if (global.getType().getPointerTo(global.addr_space().getZExtValue()) !=
|
||||
op.getResult().getType())
|
||||
if (global &&
|
||||
global.getType().getPointerTo(global.addr_space().getZExtValue()) !=
|
||||
op.getResult().getType())
|
||||
return op.emitOpError(
|
||||
"the type must be a pointer to the type of the referred global");
|
||||
"the type must be a pointer to the type of the referenced global");
|
||||
|
||||
if (function && function.getType().getPointerTo() != op.getResult().getType())
|
||||
return op.emitOpError(
|
||||
"the type must be a pointer to the type of the referenced function");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
@ -1395,6 +1410,18 @@ static LogicalResult verify(LLVM::NullOp op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Verification for LLVM::ConstantOp.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(LLVM::ConstantOp op) {
|
||||
if (!(op.value().isa<IntegerAttr>() || op.value().isa<FloatAttr>() ||
|
||||
op.value().isa<ElementsAttr>() || op.value().isa<StringAttr>()))
|
||||
return op.emitOpError()
|
||||
<< "only supports integer, float, string or elements attributes";
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Utility functions for parsing atomic ops
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -405,6 +405,9 @@ Value Importer::processConstant(llvm::Constant *c) {
|
|||
LLVMType type = processType(c->getType());
|
||||
if (!type)
|
||||
return nullptr;
|
||||
if (auto symbolRef = attr.dyn_cast<FlatSymbolRefAttr>())
|
||||
return instMap[c] = bEntry.create<AddressOfOp>(unknownLoc, type,
|
||||
symbolRef.getValue());
|
||||
return instMap[c] = bEntry.create<ConstantOp>(unknownLoc, type, attr);
|
||||
}
|
||||
if (auto *cn = dyn_cast<llvm::ConstantPointerNull>(c)) {
|
||||
|
|
|
@ -447,10 +447,15 @@ LogicalResult ModuleTranslation::convertOperation(Operation &opInst,
|
|||
// emit any LLVM instruction.
|
||||
if (auto addressOfOp = dyn_cast<LLVM::AddressOfOp>(opInst)) {
|
||||
LLVM::GlobalOp global = addressOfOp.getGlobal();
|
||||
// The verifier should not have allowed this.
|
||||
assert(global && "referencing an undefined global");
|
||||
LLVM::LLVMFuncOp function = addressOfOp.getFunction();
|
||||
|
||||
valueMapping[addressOfOp.getResult()] = globalsMapping.lookup(global);
|
||||
// The verifier should not have allowed this.
|
||||
assert((global || function) &&
|
||||
"referencing an undefined global or function");
|
||||
|
||||
valueMapping[addressOfOp.getResult()] =
|
||||
global ? globalsMapping.lookup(global)
|
||||
: functionMapping.lookup(function.getName());
|
||||
return success();
|
||||
}
|
||||
|
||||
|
|
|
@ -31,7 +31,7 @@ func @pass_through(%arg0: () -> ()) -> (() -> ()) {
|
|||
// CHECK-NEXT: llvm.br ^bb1(%arg0 : !llvm<"void ()*">)
|
||||
br ^bb1(%arg0 : () -> ())
|
||||
|
||||
//CHECK-NEXT: ^bb1(%0: !llvm<"void ()*">): // pred: ^bb0
|
||||
//CHECK-NEXT: ^bb1(%0: !llvm<"void ()*">):
|
||||
^bb1(%bbarg: () -> ()):
|
||||
// CHECK-NEXT: llvm.return %0 : !llvm<"void ()*">
|
||||
return %bbarg : () -> ()
|
||||
|
@ -40,11 +40,12 @@ func @pass_through(%arg0: () -> ()) -> (() -> ()) {
|
|||
// CHECK-LABEL: llvm.func @body(!llvm.i32)
|
||||
func @body(i32)
|
||||
|
||||
// CHECK-LABEL: llvm.func @indirect_const_call(%arg0: !llvm.i32) {
|
||||
// CHECK-LABEL: llvm.func @indirect_const_call
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: !llvm.i32) {
|
||||
func @indirect_const_call(%arg0: i32) {
|
||||
// CHECK-NEXT: %0 = llvm.mlir.constant(@body) : !llvm<"void (i32)*">
|
||||
// CHECK-NEXT: %[[ADDR:.*]] = llvm.mlir.addressof @body : !llvm<"void (i32)*">
|
||||
%0 = constant @body : (i32) -> ()
|
||||
// CHECK-NEXT: llvm.call %0(%arg0) : (!llvm.i32) -> ()
|
||||
// CHECK-NEXT: llvm.call %[[ADDR]](%[[ARG0:.*]]) : (!llvm.i32) -> ()
|
||||
call_indirect %0(%arg0) : (i32) -> ()
|
||||
// CHECK-NEXT: llvm.return
|
||||
return
|
||||
|
|
|
@ -140,12 +140,21 @@ func @foo() {
|
|||
llvm.mlir.global internal @foo(0: i32) : !llvm.i32
|
||||
|
||||
func @bar() {
|
||||
// expected-error @+1 {{the type must be a pointer to the type of the referred global}}
|
||||
// expected-error @+1 {{the type must be a pointer to the type of the referenced global}}
|
||||
llvm.mlir.addressof @foo : !llvm<"i64*">
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
llvm.func @foo()
|
||||
|
||||
llvm.func @bar() {
|
||||
// expected-error @+1 {{the type must be a pointer to the type of the referenced function}}
|
||||
llvm.mlir.addressof @foo : !llvm<"i8*">
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+2 {{'llvm.mlir.global' op expects regions to end with 'llvm.return', found 'llvm.mlir.constant'}}
|
||||
// expected-note @+1 {{in custom textual format, the absence of terminator implies 'llvm.return'}}
|
||||
llvm.mlir.global internal @g() : !llvm.i64 {
|
||||
|
@ -172,7 +181,7 @@ llvm.mlir.global internal @g(43 : i64) : !llvm.i64 {
|
|||
|
||||
llvm.mlir.global internal @g(32 : i64) {addr_space = 3: i32} : !llvm.i64
|
||||
func @mismatch_addr_space_implicit_global() {
|
||||
// expected-error @+1 {{op the type must be a pointer to the type of the referred global}}
|
||||
// expected-error @+1 {{op the type must be a pointer to the type of the referenced global}}
|
||||
llvm.mlir.addressof @g : !llvm<"i64*">
|
||||
}
|
||||
|
||||
|
@ -180,6 +189,6 @@ func @mismatch_addr_space_implicit_global() {
|
|||
|
||||
llvm.mlir.global internal @g(32 : i64) {addr_space = 3: i32} : !llvm.i64
|
||||
func @mismatch_addr_space() {
|
||||
// expected-error @+1 {{op the type must be a pointer to the type of the referred global}}
|
||||
// expected-error @+1 {{op the type must be a pointer to the type of the referenced global}}
|
||||
llvm.mlir.addressof @g : !llvm<"i64 addrspace(4)*">
|
||||
}
|
||||
|
|
|
@ -153,6 +153,13 @@ func @call_non_llvm_input(%callee : (i32) -> (), %arg : i32) {
|
|||
|
||||
// -----
|
||||
|
||||
func @constant_wrong_type() {
|
||||
// expected-error@+1 {{only supports integer, float, string or elements attributes}}
|
||||
llvm.mlir.constant(@constant_wrong_type) : !llvm<"void ()*">
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @insertvalue_non_llvm_type(%a : i32, %b : i32) {
|
||||
// expected-error@+1 {{expected LLVM IR Dialect type}}
|
||||
llvm.insertvalue %a, %b[0] : i32
|
||||
|
|
|
@ -55,12 +55,12 @@ func @ops(%arg0: !llvm.i32, %arg1: !llvm.float,
|
|||
// CHECK: %[[STRUCT:.*]] = llvm.call @foo(%[[I32]]) : (!llvm.i32) -> !llvm<"{ i32, double, i32 }">
|
||||
// CHECK: %[[VALUE:.*]] = llvm.extractvalue %[[STRUCT]][0] : !llvm<"{ i32, double, i32 }">
|
||||
// CHECK: %[[NEW_STRUCT:.*]] = llvm.insertvalue %[[VALUE]], %[[STRUCT]][2] : !llvm<"{ i32, double, i32 }">
|
||||
// CHECK: %[[FUNC:.*]] = llvm.mlir.constant(@foo) : !llvm<"{ i32, double, i32 } (i32)*">
|
||||
// CHECK: %[[FUNC:.*]] = llvm.mlir.addressof @foo : !llvm<"{ i32, double, i32 } (i32)*">
|
||||
// CHECK: %{{.*}} = llvm.call %[[FUNC]](%[[I32]]) : (!llvm.i32) -> !llvm<"{ i32, double, i32 }">
|
||||
%17 = llvm.call @foo(%arg0) : (!llvm.i32) -> !llvm<"{ i32, double, i32 }">
|
||||
%18 = llvm.extractvalue %17[0] : !llvm<"{ i32, double, i32 }">
|
||||
%19 = llvm.insertvalue %18, %17[2] : !llvm<"{ i32, double, i32 }">
|
||||
%20 = llvm.mlir.constant(@foo) : !llvm<"{ i32, double, i32 } (i32)*">
|
||||
%20 = llvm.mlir.addressof @foo : !llvm<"{ i32, double, i32 } (i32)*">
|
||||
%21 = llvm.call %20(%arg0) : (!llvm.i32) -> !llvm<"{ i32, double, i32 }">
|
||||
|
||||
|
||||
|
@ -130,8 +130,8 @@ func @ops(%arg0: !llvm.i32, %arg1: !llvm.float,
|
|||
}
|
||||
|
||||
// An larger self-contained function.
|
||||
// CHECK-LABEL: func @foo(%{{.*}}: !llvm.i32) -> !llvm<"{ i32, double, i32 }"> {
|
||||
func @foo(%arg0: !llvm.i32) -> !llvm<"{ i32, double, i32 }"> {
|
||||
// CHECK-LABEL: llvm.func @foo(%{{.*}}: !llvm.i32) -> !llvm<"{ i32, double, i32 }"> {
|
||||
llvm.func @foo(%arg0: !llvm.i32) -> !llvm<"{ i32, double, i32 }"> {
|
||||
// CHECK: %[[V0:.*]] = llvm.mlir.constant(3 : i64) : !llvm.i32
|
||||
// CHECK: %[[V1:.*]] = llvm.mlir.constant(3 : i64) : !llvm.i32
|
||||
// CHECK: %[[V2:.*]] = llvm.mlir.constant(4.200000e+01 : f64) : !llvm.double
|
||||
|
|
|
@ -234,7 +234,7 @@ define void @FPArithmetic(float %a, float %b, double %c, double %d) {
|
|||
; CHECK-LABEL: @precaller
|
||||
define i32 @precaller() {
|
||||
%1 = alloca i32 ()*
|
||||
; CHECK: %[[func:.*]] = llvm.mlir.constant(@callee) : !llvm<"i32 ()*">
|
||||
; CHECK: %[[func:.*]] = llvm.mlir.addressof @callee : !llvm<"i32 ()*">
|
||||
; CHECK: llvm.store %[[func]], %[[loc:.*]]
|
||||
store i32 ()* @callee, i32 ()** %1
|
||||
; CHECK: %[[indir:.*]] = llvm.load %[[loc]]
|
||||
|
@ -252,7 +252,7 @@ define i32 @callee() {
|
|||
; CHECK-LABEL: @postcaller
|
||||
define i32 @postcaller() {
|
||||
%1 = alloca i32 ()*
|
||||
; CHECK: %[[func:.*]] = llvm.mlir.constant(@callee) : !llvm<"i32 ()*">
|
||||
; CHECK: %[[func:.*]] = llvm.mlir.addressof @callee : !llvm<"i32 ()*">
|
||||
; CHECK: llvm.store %[[func]], %[[loc:.*]]
|
||||
store i32 ()* @callee, i32 ()** %1
|
||||
; CHECK: %[[indir:.*]] = llvm.load %[[loc]]
|
||||
|
@ -317,4 +317,4 @@ define i32 @useFenceInst() {
|
|||
;CHECK: llvm.fence seq_cst
|
||||
fence syncscope("") seq_cst
|
||||
ret i32 0
|
||||
}
|
||||
}
|
||||
|
|
|
@ -886,7 +886,7 @@ llvm.func @ops(%arg0: !llvm.float, %arg1: !llvm.float, %arg2: !llvm.i32, %arg3:
|
|||
// CHECK-LABEL: define void @indirect_const_call(i64 {{%.*}})
|
||||
llvm.func @indirect_const_call(%arg0: !llvm.i64) {
|
||||
// CHECK-NEXT: call void @body(i64 %0)
|
||||
%0 = llvm.mlir.constant(@body) : !llvm<"void (i64)*">
|
||||
%0 = llvm.mlir.addressof @body : !llvm<"void (i64)*">
|
||||
llvm.call %0(%arg0) : (!llvm.i64) -> ()
|
||||
// CHECK-NEXT: ret void
|
||||
llvm.return
|
||||
|
|
Loading…
Reference in a new issue