[mlir][spirv] Add support for converting memref of vector to SPIR-V

This allow declaring buffers and alloc of vectors so that we can support vector
load/store.

Differential Revision: https://reviews.llvm.org/D84982
This commit is contained in:
Thomas Raoux 2020-07-30 14:56:50 -07:00
parent c23ae3f18e
commit 59156bad03
4 changed files with 101 additions and 16 deletions

View file

@ -217,11 +217,15 @@ CHECK_UNSIGNED_OP(spirv::UModOp)
/// Returns true if the allocations of type `t` can be lowered to SPIR-V.
static bool isAllocationSupported(MemRefType t) {
// Currently only support workgroup local memory allocations with static
// shape and int or float element type.
return t.hasStaticShape() &&
SPIRVTypeConverter::getMemorySpaceForStorageClass(
spirv::StorageClass::Workgroup) == t.getMemorySpace() &&
t.getElementType().isIntOrFloat();
// shape and int or float or vector of int or float element type.
if (!(t.hasStaticShape() &&
SPIRVTypeConverter::getMemorySpaceForStorageClass(
spirv::StorageClass::Workgroup) == t.getMemorySpace()))
return false;
Type elementType = t.getElementType();
if (auto vecType = elementType.dyn_cast<VectorType>())
elementType = vecType.getElementType();
return elementType.isIntOrFloat();
}
/// Returns the scope to use for atomic operations use for emulating store

View file

@ -170,7 +170,14 @@ static Optional<int64_t> getTypeNumBytes(Type t) {
return llvm::None;
}
return bitWidth / 8;
} else if (auto memRefType = t.dyn_cast<MemRefType>()) {
}
if (auto vecType = t.dyn_cast<VectorType>()) {
auto elementSize = getTypeNumBytes(vecType.getElementType());
if (!elementSize)
return llvm::None;
return vecType.getNumElements() * *elementSize;
}
if (auto memRefType = t.dyn_cast<MemRefType>()) {
// TODO: Layout should also be controlled by the ABI attributes. For now
// using the layout from MemRef.
int64_t offset;
@ -343,26 +350,31 @@ static Optional<Type> convertMemrefType(const spirv::TargetEnv &targetEnv,
return llvm::None;
}
auto scalarType = type.getElementType().dyn_cast<spirv::ScalarType>();
if (!scalarType) {
LLVM_DEBUG(llvm::dbgs()
<< type << " illegal: cannot convert non-scalar element type\n");
Optional<Type> arrayElemType;
Type elementType = type.getElementType();
if (auto vecType = elementType.dyn_cast<VectorType>()) {
arrayElemType = convertVectorType(targetEnv, vecType, storageClass);
} else if (auto scalarType = elementType.dyn_cast<spirv::ScalarType>()) {
arrayElemType = convertScalarType(targetEnv, scalarType, storageClass);
} else {
LLVM_DEBUG(
llvm::dbgs()
<< type
<< " unhandled: can only convert scalar or vector element type\n");
return llvm::None;
}
auto arrayElemType = convertScalarType(targetEnv, scalarType, storageClass);
if (!arrayElemType)
return llvm::None;
Optional<int64_t> scalarSize = getTypeNumBytes(scalarType);
if (!scalarSize) {
Optional<int64_t> elementSize = getTypeNumBytes(elementType);
if (!elementSize) {
LLVM_DEBUG(llvm::dbgs()
<< type << " illegal: cannot deduce element size\n");
return llvm::None;
}
if (!type.hasStaticShape()) {
auto arrayType = spirv::RuntimeArrayType::get(*arrayElemType, *scalarSize);
auto arrayType = spirv::RuntimeArrayType::get(*arrayElemType, *elementSize);
// Wrap in a struct to satisfy Vulkan interface requirements.
auto structType = spirv::StructType::get(arrayType, 0);
return spirv::PointerType::get(structType, *storageClass);
@ -375,7 +387,7 @@ static Optional<Type> convertMemrefType(const spirv::TargetEnv &targetEnv,
return llvm::None;
}
auto arrayElemCount = *memrefSize / *scalarSize;
auto arrayElemCount = *memrefSize / *elementSize;
Optional<int64_t> arrayElemSize = getTypeNumBytes(*arrayElemType);
if (!arrayElemSize) {

View file

@ -75,6 +75,30 @@ module attributes {
// CHECK: spv.func @two_allocs()
// CHECK: spv.Return
// -----
module attributes {
spv.target_env = #spv.target_env<
#spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
{max_compute_workgroup_invocations = 128 : i32,
max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
}
{
func @two_allocs_vector() {
%0 = alloc() : memref<4xvector<4xf32>, 3>
%1 = alloc() : memref<2xvector<2xi32>, 3>
return
}
}
// CHECK-DAG: spv.globalVariable @__workgroup_mem__{{[0-9]+}}
// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<2 x vector<2xi32>, stride=8>>, Workgroup>
// CHECK-DAG: spv.globalVariable @__workgroup_mem__{{[0-9]+}}
// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<4 x vector<4xf32>, stride=16>>, Workgroup>
// CHECK: spv.func @two_allocs_vector()
// CHECK: spv.Return
// -----
module attributes {

View file

@ -510,6 +510,51 @@ func @dynamic_dim_memref(%arg0: memref<8x?xi32>,
// -----
// Vector types
module attributes {
spv.target_env = #spv.target_env<
#spv.vce<v1.0, [], []>,
{max_compute_workgroup_invocations = 128 : i32,
max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
} {
// CHECK-LABEL: func @memref_vector
// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<4 x vector<2xf32>, stride=8> [0]>, StorageBuffer>
// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<4 x vector<4xf32>, stride=16> [0]>, Uniform>
func @memref_vector(
%arg0: memref<4xvector<2xf32>, 0>,
%arg1: memref<4xvector<4xf32>, 4>)
{ return }
// CHECK-LABEL: func @dynamic_dim_memref_vector
// CHECK-SAME: !spv.ptr<!spv.struct<!spv.rtarray<vector<4xi32>, stride=16> [0]>, StorageBuffer>
// CHECK-SAME: !spv.ptr<!spv.struct<!spv.rtarray<vector<2xf32>, stride=8> [0]>, StorageBuffer>
func @dynamic_dim_memref_vector(%arg0: memref<8x?xvector<4xi32>>,
%arg1: memref<?x?xvector<2xf32>>)
{ return }
} // end module
// -----
// Vector types, check that sizes not available in SPIR-V are not transformed.
module attributes {
spv.target_env = #spv.target_env<
#spv.vce<v1.0, [], []>,
{max_compute_workgroup_invocations = 128 : i32,
max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
} {
// CHECK-LABEL: func @memref_vector_wrong_size
// CHECK-SAME: memref<4xvector<5xf32>>
func @memref_vector_wrong_size(
%arg0: memref<4xvector<5xf32>, 0>)
{ return }
} // end module
// -----
//===----------------------------------------------------------------------===//
// Tensor types
//===----------------------------------------------------------------------===//