[mlir][nvvm] Fix support for tf32 data type in mma.sync

The NVVM dialect test coverage for all possible type/shape combinations
in the `nvvm.mma.sync` op is mostly complete. However, there were tests
missing for TF32 datatype support. This change adds tests for the one
relevant shape/type combination. This uncovered a small bug in the op
verifier, which this change also fixes.

Differential Revision: https://reviews.llvm.org/D124975
This commit is contained in:
Christopher Bate 2022-05-04 14:12:07 -06:00
parent 6385c039b8
commit 22c6e7b277
3 changed files with 28 additions and 3 deletions

View file

@ -81,8 +81,10 @@ Optional<mlir::NVVM::MMATypes> MmaOp::inferOperandMMAType(Type operandElType,
return NVVM::MMATypes::f64;
if (operandElType.isF16() || operandElType == half2Type)
return NVVM::MMATypes::f16;
if (operandElType.isF32())
if (operandElType.isF32() && isAccumulator)
return NVVM::MMATypes::f32;
if (operandElType.isF32() && !isAccumulator)
return NVVM::MMATypes::tf32;
if (operandElType.isa<IntegerType>()) {
if (isAccumulator)
return NVVM::MMATypes::s32;
@ -291,7 +293,7 @@ ParseResult MmaOp::parse(OpAsmParser &parser, OperationState &result) {
parser.getNameLoc(),
"expected one type for each operand segment but got " +
Twine(operandTypes.size()) + " types");
for (const auto& iter : llvm::enumerate(operandTypes)) {
for (const auto &iter : llvm::enumerate(operandTypes)) {
auto &frag = frags[iter.index()];
frag.regTypes.resize(frag.regs.size(), iter.value());
if (failed(parser.resolveOperands(frag.regs, frag.regTypes,
@ -376,8 +378,9 @@ LogicalResult MmaOp::verify() {
switch (multiplicandAPtxType().getValue()) {
case MMATypes::tf32:
kFactor = 4;
multiplicandFragType = i32Ty;
expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
context, {i32Ty, i32Ty, i32Ty, i32Ty}));
context, {f32Ty, f32Ty, f32Ty, f32Ty}));
break;
case MMATypes::f16:
case MMATypes::bf16:

View file

@ -152,6 +152,17 @@ func.func @nvvm_mma_m16n8k16_f32_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
func.func @nvvm_mma_m16n8k4_tf32_f32(%a0 : i32, %a1 : i32,
%b0 : i32,
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) {
// CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<tf32>, multiplicandBPtxType = #nvvm.mma_type<tf32>, shape = {k = 4 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
multiplicandAPtxType = #nvvm.mma_type<tf32>, multiplicandBPtxType = #nvvm.mma_type<tf32>,
shape = {k = 4 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
func.func @nvvm_mma_m16n8k16_s8_s8(%a0 : i32, %a1 : i32, %b0 : i32,
%c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) {
// CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>, shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>

View file

@ -203,6 +203,17 @@ llvm.func @nvvm_mma_m8n8k4_f64_f64(%a0 : f64,
llvm.return %0 : !llvm.struct<(f64, f64)>
}
llvm.func @nvvm_mma_m16n8k4_tf32_f32(%a0 : i32, %a1 : i32,
%b0 : i32,
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) -> !llvm.struct<(f32, f32, f32, f32)> {
// CHECK: call { float, float, float, float } @llvm.nvvm.mma.m16n8k4.row.col.tf32
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
multiplicandAPtxType = #nvvm.mma_type<tf32>, multiplicandBPtxType = #nvvm.mma_type<tf32>,
shape = {m = 16 : i32, n = 8 : i32, k = 4 : i32}} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
// The test below checks the correct mapping of the nvvm.wmma.*.load.* op to the correct intrinsic
// in the LLVM NVPTX backend.
// CHECK-LABEL: @gpu_wmma_load_op