[mlir][GPUToVulkan] Fix signature of bindMemRef function for f16

Binding MemRefs of f16 needs special handling as the type is not supported on
CPU. There was a bug in the type used.

Differential Revision: https://reviews.llvm.org/D86328
This commit is contained in:
Thomas Raoux 2020-08-21 10:34:12 -07:00
parent 08249d7f72
commit 36ee9a322a
2 changed files with 3 additions and 1 deletions

View file

@ -328,7 +328,7 @@ void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) {
std::string fnName = "bindMemRef" + std::to_string(i) + "D" +
std::string(stringifyType(type));
if (type.isHalfTy())
type = getMemRefType(i, LLVM::LLVMType::getInt16Ty(&getContext()));
type = LLVM::LLVMType::getInt16Ty(&getContext());
if (!module.lookupSymbol(fnName)) {
auto fnType = LLVM::LLVMType::getFunctionTy(
getVoidType(),

View file

@ -15,6 +15,8 @@
// CHECK: llvm.call @runOnVulkan(%[[Vulkan_Runtime_ptr]]) : (!llvm.ptr<i8>) -> !llvm.void
// CHECK: llvm.call @deinitVulkan(%[[Vulkan_Runtime_ptr]]) : (!llvm.ptr<i8>) -> !llvm.void
// CHECK: llvm.func @bindMemRef1DHalf(!llvm.ptr<i8>, !llvm.i32, !llvm.i32, !llvm.ptr<struct<(ptr<i16>, ptr<i16>, i64, array<1 x i64>, array<1 x i64>)>>)
module attributes {gpu.container_module} {
llvm.func @malloc(!llvm.i64) -> !llvm.ptr<i8>
llvm.func @foo() {