[mlir][spirv] Add support for converting memref of vector to SPIR-V
This allow declaring buffers and alloc of vectors so that we can support vector load/store. Differential Revision: https://reviews.llvm.org/D84982
This commit is contained in:
parent
c23ae3f18e
commit
59156bad03
|
@ -217,11 +217,15 @@ CHECK_UNSIGNED_OP(spirv::UModOp)
|
|||
/// Returns true if the allocations of type `t` can be lowered to SPIR-V.
|
||||
static bool isAllocationSupported(MemRefType t) {
|
||||
// Currently only support workgroup local memory allocations with static
|
||||
// shape and int or float element type.
|
||||
return t.hasStaticShape() &&
|
||||
SPIRVTypeConverter::getMemorySpaceForStorageClass(
|
||||
spirv::StorageClass::Workgroup) == t.getMemorySpace() &&
|
||||
t.getElementType().isIntOrFloat();
|
||||
// shape and int or float or vector of int or float element type.
|
||||
if (!(t.hasStaticShape() &&
|
||||
SPIRVTypeConverter::getMemorySpaceForStorageClass(
|
||||
spirv::StorageClass::Workgroup) == t.getMemorySpace()))
|
||||
return false;
|
||||
Type elementType = t.getElementType();
|
||||
if (auto vecType = elementType.dyn_cast<VectorType>())
|
||||
elementType = vecType.getElementType();
|
||||
return elementType.isIntOrFloat();
|
||||
}
|
||||
|
||||
/// Returns the scope to use for atomic operations use for emulating store
|
||||
|
|
|
@ -170,7 +170,14 @@ static Optional<int64_t> getTypeNumBytes(Type t) {
|
|||
return llvm::None;
|
||||
}
|
||||
return bitWidth / 8;
|
||||
} else if (auto memRefType = t.dyn_cast<MemRefType>()) {
|
||||
}
|
||||
if (auto vecType = t.dyn_cast<VectorType>()) {
|
||||
auto elementSize = getTypeNumBytes(vecType.getElementType());
|
||||
if (!elementSize)
|
||||
return llvm::None;
|
||||
return vecType.getNumElements() * *elementSize;
|
||||
}
|
||||
if (auto memRefType = t.dyn_cast<MemRefType>()) {
|
||||
// TODO: Layout should also be controlled by the ABI attributes. For now
|
||||
// using the layout from MemRef.
|
||||
int64_t offset;
|
||||
|
@ -343,26 +350,31 @@ static Optional<Type> convertMemrefType(const spirv::TargetEnv &targetEnv,
|
|||
return llvm::None;
|
||||
}
|
||||
|
||||
auto scalarType = type.getElementType().dyn_cast<spirv::ScalarType>();
|
||||
if (!scalarType) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< type << " illegal: cannot convert non-scalar element type\n");
|
||||
Optional<Type> arrayElemType;
|
||||
Type elementType = type.getElementType();
|
||||
if (auto vecType = elementType.dyn_cast<VectorType>()) {
|
||||
arrayElemType = convertVectorType(targetEnv, vecType, storageClass);
|
||||
} else if (auto scalarType = elementType.dyn_cast<spirv::ScalarType>()) {
|
||||
arrayElemType = convertScalarType(targetEnv, scalarType, storageClass);
|
||||
} else {
|
||||
LLVM_DEBUG(
|
||||
llvm::dbgs()
|
||||
<< type
|
||||
<< " unhandled: can only convert scalar or vector element type\n");
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
auto arrayElemType = convertScalarType(targetEnv, scalarType, storageClass);
|
||||
if (!arrayElemType)
|
||||
return llvm::None;
|
||||
|
||||
Optional<int64_t> scalarSize = getTypeNumBytes(scalarType);
|
||||
if (!scalarSize) {
|
||||
Optional<int64_t> elementSize = getTypeNumBytes(elementType);
|
||||
if (!elementSize) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< type << " illegal: cannot deduce element size\n");
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
if (!type.hasStaticShape()) {
|
||||
auto arrayType = spirv::RuntimeArrayType::get(*arrayElemType, *scalarSize);
|
||||
auto arrayType = spirv::RuntimeArrayType::get(*arrayElemType, *elementSize);
|
||||
// Wrap in a struct to satisfy Vulkan interface requirements.
|
||||
auto structType = spirv::StructType::get(arrayType, 0);
|
||||
return spirv::PointerType::get(structType, *storageClass);
|
||||
|
@ -375,7 +387,7 @@ static Optional<Type> convertMemrefType(const spirv::TargetEnv &targetEnv,
|
|||
return llvm::None;
|
||||
}
|
||||
|
||||
auto arrayElemCount = *memrefSize / *scalarSize;
|
||||
auto arrayElemCount = *memrefSize / *elementSize;
|
||||
|
||||
Optional<int64_t> arrayElemSize = getTypeNumBytes(*arrayElemType);
|
||||
if (!arrayElemSize) {
|
||||
|
|
|
@ -75,6 +75,30 @@ module attributes {
|
|||
// CHECK: spv.func @two_allocs()
|
||||
// CHECK: spv.Return
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {
|
||||
spv.target_env = #spv.target_env<
|
||||
#spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
|
||||
{max_compute_workgroup_invocations = 128 : i32,
|
||||
max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
|
||||
}
|
||||
{
|
||||
func @two_allocs_vector() {
|
||||
%0 = alloc() : memref<4xvector<4xf32>, 3>
|
||||
%1 = alloc() : memref<2xvector<2xi32>, 3>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK-DAG: spv.globalVariable @__workgroup_mem__{{[0-9]+}}
|
||||
// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<2 x vector<2xi32>, stride=8>>, Workgroup>
|
||||
// CHECK-DAG: spv.globalVariable @__workgroup_mem__{{[0-9]+}}
|
||||
// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<4 x vector<4xf32>, stride=16>>, Workgroup>
|
||||
// CHECK: spv.func @two_allocs_vector()
|
||||
// CHECK: spv.Return
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {
|
||||
|
|
|
@ -510,6 +510,51 @@ func @dynamic_dim_memref(%arg0: memref<8x?xi32>,
|
|||
|
||||
// -----
|
||||
|
||||
// Vector types
|
||||
module attributes {
|
||||
spv.target_env = #spv.target_env<
|
||||
#spv.vce<v1.0, [], []>,
|
||||
{max_compute_workgroup_invocations = 128 : i32,
|
||||
max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
|
||||
} {
|
||||
|
||||
// CHECK-LABEL: func @memref_vector
|
||||
// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<4 x vector<2xf32>, stride=8> [0]>, StorageBuffer>
|
||||
// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<4 x vector<4xf32>, stride=16> [0]>, Uniform>
|
||||
func @memref_vector(
|
||||
%arg0: memref<4xvector<2xf32>, 0>,
|
||||
%arg1: memref<4xvector<4xf32>, 4>)
|
||||
{ return }
|
||||
|
||||
// CHECK-LABEL: func @dynamic_dim_memref_vector
|
||||
// CHECK-SAME: !spv.ptr<!spv.struct<!spv.rtarray<vector<4xi32>, stride=16> [0]>, StorageBuffer>
|
||||
// CHECK-SAME: !spv.ptr<!spv.struct<!spv.rtarray<vector<2xf32>, stride=8> [0]>, StorageBuffer>
|
||||
func @dynamic_dim_memref_vector(%arg0: memref<8x?xvector<4xi32>>,
|
||||
%arg1: memref<?x?xvector<2xf32>>)
|
||||
{ return }
|
||||
|
||||
} // end module
|
||||
|
||||
// -----
|
||||
|
||||
// Vector types, check that sizes not available in SPIR-V are not transformed.
|
||||
module attributes {
|
||||
spv.target_env = #spv.target_env<
|
||||
#spv.vce<v1.0, [], []>,
|
||||
{max_compute_workgroup_invocations = 128 : i32,
|
||||
max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
|
||||
} {
|
||||
|
||||
// CHECK-LABEL: func @memref_vector_wrong_size
|
||||
// CHECK-SAME: memref<4xvector<5xf32>>
|
||||
func @memref_vector_wrong_size(
|
||||
%arg0: memref<4xvector<5xf32>, 0>)
|
||||
{ return }
|
||||
|
||||
} // end module
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Tensor types
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
Loading…
Reference in a new issue