[mlir][NVVM] Add support for nvvm mma.sync ops

This patch adds MLIR NVVM support for the various NVPTX `mma.sync`
operations. There are a number of possible data type, shape,
and other attribute combinations supported by the operation, so a
custom assebmly format is added and attributes are inferred where

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D122410
This commit is contained in:
Christopher Bate 2022-03-25 17:20:07 +00:00 committed by Thomas Raoux
parent 5bc9ee1b78
commit 3be7c28917
7 changed files with 1047 additions and 101 deletions

View file

@ -35,6 +35,8 @@ set(LLVM_TARGET_DEFINITIONS NVVMOps.td)
mlir_tablegen(NVVMConversions.inc -gen-llvmir-conversions)
mlir_tablegen(NVVMOpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(NVVMOpsEnums.cpp.inc -gen-enum-defs)
mlir_tablegen(NVVMOpsStructs.h.inc -gen-struct-attr-decls)
mlir_tablegen(NVVMOpsStructs.cpp.inc -gen-struct-attr-defs)
mlir_tablegen(NVVMOpsAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=nvvm)
mlir_tablegen(NVVMOpsAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=nvvm)

View file

@ -21,6 +21,7 @@
#include "llvm/IR/IntrinsicsNVPTX.h"
#include "mlir/Dialect/LLVMIR/NVVMOpsEnums.h.inc"
#include "mlir/Dialect/LLVMIR/NVVMOpsStructs.h.inc"
namespace mlir {
namespace NVVM {

View file

@ -195,18 +195,6 @@ def NVVM_CpAsyncWaitGroupOp : NVVM_Op<"cp.async.wait.group">,
let assemblyFormat = "$n attr-dict";
def NVVM_MmaOp :
Results<(outs LLVM_Type:$res)>,
Arguments<(ins Variadic<LLVM_Type>:$args)> {
string llvmBuilder = [{
$res = createIntrinsicCall(
builder, llvm::Intrinsic::nvvm_mma_m8n8k4_row_col_f32_f32, $args);
let assemblyFormat = "$args attr-dict `:` functional-type($args, $res)";
let hasVerifier = 1;
/// Helpers to instantiate different version of wmma intrinsics.
/// This matches the hierarchy used in IntrinsicsNVVM.td to define all the
/// combinations of the intrinsics.
@ -296,6 +284,7 @@ class MMA_LDST_OPS<list<GEOM> Geom, list<string> Frags, list<string> Types> {
// Creates list of valid combinations of fragments. This is a subset of what
// llvm supports and can be extended as needed.
class NVVM_MMA_OPS {
// "wmma" operations
list<list<WMMA_REGS>> tf32_wmma_ops = MMA_OPS<
[GEOM<16, 16, 8>],
["tf32"], [], ["f32"], []>.ret;
@ -324,6 +313,32 @@ class NVVM_MMA_OPS {
// Separate A/B/C fragments (loads) from D (stores).
list<WMMA_REGS> all_ld_ops = !filter(op, all_ldst_ops, !ne(op.frag, "d"));
list<WMMA_REGS> all_st_ops = !filter(op, all_ldst_ops, !eq(op.frag, "d"));
// "mma_sync" operations
list<list<WMMA_REGS>> tf32_mma_ops = MMA_OPS<
[GEOM<16,8,4>, GEOM<16,8,8>],
["tf32"], [], ["f32"], []>.ret;
list<list<WMMA_REGS>> bf16_mma_ops = MMA_OPS<
[GEOM<16,8,16>, GEOM<16,8,8>],
["bf16"], [], ["f32"], []>.ret;
list<list<WMMA_REGS>> f64_mma_ops = MMA_OPS<
["f64"], [], ["f64"], []>.ret;
list<list<WMMA_REGS>> fp_mma_ops = MMA_OPS<
[GEOM<8,8,4>, GEOM<16,8,8>, GEOM<16,8,16>],
["f16"], [], ["f16", "f32"], ["f16", "f32"]>.ret;
list<list<WMMA_REGS>> int_mma_ops = MMA_OPS<
[GEOM<8,8,16>, GEOM<16,8,16>, GEOM<16,8,32>],
["s8", "u8"], ["s8", "u8"], ["s32"], []>.ret;
list<list<WMMA_REGS>> subint_mma_ops = MMA_OPS<
[GEOM<8,8,32>, GEOM<16,8,32>, GEOM<16,8,64>],
["s4", "u4"], ["s4", "u4"], ["s32"], []>.ret;
list<list<WMMA_REGS>> bit_mma_ops = MMA_OPS<
[GEOM<8,8,128>, GEOM<16,8,128>, GEOM<16,8,256>],
["b1"], [], ["s32"], []>.ret;
list<list<WMMA_REGS>> all_mma_sync_ops = !listconcat(
tf32_mma_ops, bf16_mma_ops, f64_mma_ops,
fp_mma_ops, int_mma_ops, subint_mma_ops, bit_mma_ops);
@ -405,6 +420,150 @@ class MMA_MMA_INTR<string opName> {
string id = !foldl("", f, acc, el, acc # "\n" # el);
/// Enum attribute for binary (b1) MMA operation type
def MMAB1OpNone : I32EnumAttrCase<"none", 0>;
def MMAB1OpXorPopc : I32EnumAttrCase<"xor_popc", 1>;
def MMAB1OpAndPopc : I32EnumAttrCase<"and_popc", 2>;
def MMAB1Op : I32EnumAttr<"MMAB1Op", "MMA binary operations",
[MMAB1OpNone, MMAB1OpXorPopc, MMAB1OpAndPopc]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::NVVM";
def MMAB1OpAttr : EnumAttr<NVVM_Dialect, MMAB1Op, "mma_b1op"> {
let assemblyFormat = "`<` $value `>`";
/// Enum attribute type for the overflow behavior of MMA integer operations
def MMAIntOverflowWrap : I32EnumAttrCase<"wrapped", 0>;
def MMAIntOverflowSat : I32EnumAttrCase<"satfinite", 1>;
def MMAIntOverflow : I32EnumAttr<"MMAIntOverflow", "MMA overflow options",
[MMAIntOverflowSat, MMAIntOverflowWrap]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::NVVM";
def MMAIntOverflowAttr : EnumAttr<NVVM_Dialect, MMAIntOverflow, "mma_int_overflow"> {
let assemblyFormat = "`<` $value `>`";
/// Attribute to hold the MMA shape
def NVVM_MMAShapeAttr : StructAttr<"MMAShapeAttr", NVVM_Dialect, [
StructFieldAttr<"m", I32Attr>,
StructFieldAttr<"n", I32Attr>,
StructFieldAttr<"k", I32Attr>
]> {
let summary = "Attribute for MMA operation shape.";
// Returns true if this combination of layout/satf for MMA ops is supported;
// false otherwise.
// E.g.
// if NVVM_MMA_SUPPORTED<...>.ret then
// def : FOO<>; // The record will only be defined for supported ops.
class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b, int satf> {
// MMA ops check both layouts.
string layout = layout_a # ":" # layout_b;
string a_type = frags[0].ptx_elt_type;
string b_type = frags[1].ptx_elt_type;
string c_type = frags[2].ptx_elt_type;
string d_type = frags[3].ptx_elt_type;
string geom = frags[0].geom;
// gcd is a shortcut used to identify instructions that depend on
// geom+frag_c+frag_d.
string gcd = geom # ":" # c_type # d_type;
bit ret = !cond(
// Limit satf to valid types
!and(!eq(satf, 1),
!ne(a_type, "s8"),
!ne(a_type, "u8"),
!ne(a_type, "s4"),
!ne(a_type, "u4")): false,
// m8n8k4 has no C=f32 D=f16 variant.
!eq(gcd, "m8n8k4:f32f16"): false,
// only m8n8k4 for f16 does not require row:col layout
!and(!ne(layout, "row:col"),
!or(!ne(geom, "m8n8k4"),
!ne(a_type, "f16"))) : false,
// m16n8k8 requires A and B to be the same type and C and D to be the same
// type.
!and(!eq(geom, "m16n8k8"),
!or(!ne(a_type, b_type),
!ne(c_type, d_type))): false,
// m16n8k8 requires C and D to be the same type.
!and(!eq(geom, "m16n8k8"),
!ne(c_type, d_type)): false,
// All other are OK.
true: true
// Returns a list of operation suffixes corresponding to possible b1
// multiply-and-accumulate operations for all fragments which have a
// b1 type. For all other fragments, the list returned holds a list
// containing the empty string.
class NVVM_MMA_B1OPS<list<WMMA_REGS> frags> {
list<string> ret = !cond(
!eq(frags[0].ptx_elt_type, "b1") : ["xor_popc", "and_popc"],
true: [""]
/// Generate enum value of the mma.sync intrinsic.
class MMA_SYNC_NAME<string ALayout, string BLayout, string b1op, int Satfinite,
string signature = MMA_SIGNATURE<A, B, C, D>.ret;
string id = "llvm::Intrinsic::nvvm_mma"
# !if(!ne(b1op, ""), "_" # b1op, "")
# "_" # A.geom
# "_" # ALayout
# "_" # BLayout
# !if(Satfinite, "_satfinite", "")
# signature;
/// Helper to create the mapping between the configuration and the mma.sync
/// intrinsic enum value.
list<list<list<list<list<string>>>>> cond0 =
!foreach(op, NVVM_MMA_OPS.all_mma_sync_ops,
!foreach(layoutA, ["row", "col"],
!foreach(layoutB, ["row", "col"],
!foreach (sat, [0, 1],
!foreach (b1op, NVVM_MMA_B1OPS<op>.ret,
!if(NVVM_MMA_SUPPORTED<[op[0], op[1], op[2], op[3]],
layoutA, layoutB, sat>.ret,
"if (layoutA == \"" # layoutA # "\" && layoutB == \"" # layoutB # "\" && "
" m == " # op[0].m # " && n == " # op[0].n # " && k == " # op[0].k #
" && \"" # op[0].ptx_elt_type # "\" == eltypeA && \""
# op[1].ptx_elt_type # "\" == eltypeB && "
# " \"" # op[2].ptx_elt_type # "\" == eltypeC && "
# " \"" # op[3].ptx_elt_type # "\" == eltypeD "
# " && (sat.hasValue() ? " # sat # " == static_cast<int>(*sat) : true)"
# !if(!ne(b1op, ""), " && (b1Op.hasValue() ? MMAB1Op::" # b1op # " == b1Op.getValue() : true)", "") # ")\n"
# " return " #
MMA_SYNC_NAME<layoutA, layoutB, b1op, sat, op[0], op[1], op[2], op[3]>.id # ";",
"") // if supported
) // b1op
) // sat
) // layoutB
) // layoutA
); // all_mma_sync_ops
list<list<list<string>>> f1 = !foldl([[[""]]],
!foldl([[[[""]]]], cond0, acc, el,
!listconcat(acc, el)),
acc1, el1, !listconcat(acc1, el1));
list<list<string>> f2 = !foldl([[""]], f1, acc1, el1, !listconcat(acc1, el1));
list<string> f3 = !foldl([""], f2, acc, el, !listconcat(acc, el));
string id = !foldl("", f3, acc, el, acc # "\n" # el);
def MMALayoutRow : I32EnumAttrCase<"row", 0>;
def MMALayoutCol : I32EnumAttrCase<"col", 1>;
@ -418,13 +577,24 @@ def MMALayoutAttr : EnumAttr<NVVM_Dialect, MMALayout, "mma_layout"> {
let assemblyFormat = "`<` $value `>`";
/// Enum attribute of the different PTX element types used for MMA operands.
def MMATypeF16 : I32EnumAttrCase<"f16", 0>;
def MMATypeF32 : I32EnumAttrCase<"f32", 1>;
def MMATypeTF32 : I32EnumAttrCase<"tf32", 2>;
def MMATypeU8 : I32EnumAttrCase<"u8", 3>;
def MMATypeS8 : I32EnumAttrCase<"s8", 4>;
def MMATypeS32 : I32EnumAttrCase<"s32", 5>;
def MMATypeB1 : I32EnumAttrCase<"b1", 6>;
def MMATypeU4 : I32EnumAttrCase<"u4", 7>;
def MMATypeS4 : I32EnumAttrCase<"s4", 8>;
def MMATypeBF16 : I32EnumAttrCase<"bf16", 9>;
def MMATypeF64 : I32EnumAttrCase<"f64", 10>;
/// Enum attribute of the different matrix types.
def MMATypes : I32EnumAttr<"MMATypes", "NVVM MMA types",
[MMATypeF16, MMATypeF32, MMATypeTF32]> {
[MMATypeF16, MMATypeF32, MMATypeTF32,
MMATypeBF16, MMATypeS8, MMATypeU8,
MMATypeS32, MMATypeS4, MMATypeU4,
MMATypeB1, MMATypeF64]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::NVVM";
@ -678,4 +848,141 @@ def NVVM_LdMatrixOp: NVVM_Op<"ldmatrix">,
let hasVerifier = 1;
def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
let summary = "cooperative matrix-multiply and accumulate";
let description = [{
The `nvvm.mma.sync` operation collectively performs the operation
`D = matmul(A, B) + C` using all threads in a warp.
All the threads in the warp must execute the same `mma.sync` operation.
For each possible multiplicand PTX data type, there are one or more possible
instruction shapes given as "mMnNkK". The below table describes the posssibilities
as well as the types required for the operands. Note that the data type for
C (the accumulator) and D (the result) can vary independently when there are
multiple possibilities in the "C/D Type" column.
When an optional attribute cannot be immediately inferred from the types of
the operands and the result during parsing or validation, an error will be
`b1Op` is only relevant when the binary (b1) type is given to
`multiplicandDataType`. It specifies how the multiply-and-acumulate is
performed and is either `xor_popc` or `and_poc`. The default is `xor_popc`.
`intOverflowBehavior` is only relevant when the `multiplicandType` attribute
is one of `u8, s8, u4, s4`, this attribute describes how overflow is handled
in the accumulator. When the attribute is `satfinite`, the accumulator values
are clamped in the int32 range on overflow. This is the default behavior.
Alternatively, accumulator behavior `wrapped` can also be specified, in
which case overflow wraps from one end of the range to the other.
`layoutA` and `layoutB` are required and should generally be set to
`#nvvm.mma_layout<row>` and `#nvvm.mma_layout<col>` respectively, but other
combinations are possible for certain layouts according to the table below.
| A/B Type | Shape | ALayout | BLayout | A Type | B Type | C/D Type |
| f64 | .m8n8k4 | row | col | 1x f64 | 1x f64 | 2x f64 |
| f16 | .m8n8k4 | row/col | row/col | 2x f16x2 | 2x f16x2 | 4x f16x2 or 8xf32 |
| | .m16n8k8 | row | col | 2x f16x2 | 1x f16x2 | 2x f16x2 or 4 f32 |
| | .m16n8k16 | row | col | 4x f16x2 | 2x f16x2 | 2x f16x2 or 4 f32 |
| bf16 | .m16n8k8 | row | col | 2x f16x2 | 1x f16x2 | 2x f16x2 or 4 f32 |
| | .m16n8k16 | row | col | 4x f16x2 | 2x f16x2 | 2x f16x2 or 4 f32 |
| tf32 | .m16n8k4 | row | col | 2x i32 | 1x i32 | 4x f32 |
| | .m16n8k8 | row | col | 4x i32 | 2x i32 | 2x f16x2 or 4 f32 |
| u8/s8 | .m8n8k16 | row | col | 1x i32 | 1x i32 | 2x i32 |
| | .m16n8k16 | row | col | 2x i32 | 1x i32 | 4x i32 |
| | .m16n8k32 | row | col | 4x i32 | 2x i32 | 4x i32 |
| u4/s4 | .m8n8k32 | row | col | 1x i32 | 1x i32 | 2x i32 |
| | m16n8k32 | row | col | 2x i32 | 1x i32 | 4x i32 |
| | m16n8k64 | row | col | 4x i32 | 2x i32 | 4x i32 |
| b1 | m8n8k128 | row | col | 1x i32 | 1x i32 | 2x i32 |
| | m16n8k128 | row | col | 2x i32 | 1x i32 | 4x i32 |
%128 = nvvm.mma.sync A[%120, %121, %122, %123]
B[%124, %125]
C[%126, %127]
{layoutA = #nvvm.mma_layout<row>,
layoutB = #nvvm.mma_layout<col>,
shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}}
: (vector<2xf16>, vector<2xf16>, vector<2xf16>)
-> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
let results = (outs LLVM_AnyStruct:$res);
let arguments = (ins NVVM_MMAShapeAttr:$shape,
let extraClassDeclaration = !strconcat([{
static llvm::Intrinsic::ID getIntrinsicID(
int64_t m, int64_t n, uint64_t k,
llvm::Optional<MMAB1Op> b1Op,
llvm::Optional<MMAIntOverflow> sat,
mlir::NVVM::MMALayout layoutAEnum, mlir::NVVM::MMALayout layoutBEnum,
mlir::NVVM::MMATypes eltypeAEnum, mlir::NVVM::MMATypes eltypeBEnum,
mlir::NVVM::MMATypes eltypeCEnum, mlir::NVVM::MMATypes eltypeDEnum) {
llvm::StringRef layoutA = stringifyEnum(layoutAEnum);
llvm::StringRef layoutB = stringifyEnum(layoutBEnum);
llvm::StringRef eltypeA = stringifyEnum(eltypeAEnum);
llvm::StringRef eltypeB = stringifyEnum(eltypeBEnum);
llvm::StringRef eltypeC = stringifyEnum(eltypeCEnum);
llvm::StringRef eltypeD = stringifyEnum(eltypeDEnum);
MMA_SYNC_INTR<>.id, [{
return 0;
static Optional<mlir::NVVM::MMATypes> inferOperandMMAType(Type operandElType,
bool isAccumulator);
MMATypes accumPtxType();
MMATypes resultPtxType();
let builders = [
OpBuilder<(ins "Type":$resultType, "ValueRange":$operandA,
"ValueRange":$operandB, "ValueRange":$operandC,
"ArrayRef<int64_t>":$shape, "Optional<MMAB1Op>":$b1Op,
"Optional<std::array<MMATypes, 2>>":$multiplicandPtxTypes,
"Optional<std::array<MMALayout, 2>>":$multiplicandLayouts)>
string llvmBuilder = [{
auto operands = moduleTranslation.lookupValues(opInst.getOperands());
auto intId = mlir::NVVM::MmaOp::getIntrinsicID(
$shape.m().getInt(), $shape.n().getInt(), $shape.k().getInt(),
$b1Op, $intOverflowBehavior,
$layoutA, $layoutB,
$res = createIntrinsicCall(
builder, intId, operands);
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
#endif // NVVMIR_OPS

View file

@ -34,6 +34,7 @@ using namespace NVVM;
#include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
#include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
#include "mlir/Dialect/LLVMIR/NVVMOpsStructs.cpp.inc"
// Printing/parsing for NVVM ops
@ -69,47 +70,455 @@ LogicalResult CpAsyncOp::verify() {
return success();
// Given the element type of an operand and whether or not it is an accumulator,
// this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the
// operand's element type.
Optional<mlir::NVVM::MMATypes> MmaOp::inferOperandMMAType(Type operandElType,
bool isAccumulator) {
auto half2Type =
LLVM::getFixedVectorType(Float16Type::get(operandElType.getContext()), 2);
if (operandElType.isF64())
return NVVM::MMATypes::f64;
if (operandElType.isF16() || operandElType == half2Type)
return NVVM::MMATypes::f16;
if (operandElType.isF32())
return NVVM::MMATypes::f32;
if (operandElType.isa<IntegerType>()) {
if (isAccumulator)
return NVVM::MMATypes::s32;
return llvm::None;
if (auto structType = operandElType.dyn_cast<LLVM::LLVMStructType>()) {
if (structType.getBody().empty())
return llvm::None;
return inferOperandMMAType(structType.getBody()[0], isAccumulator);
return llvm::None;
static bool isInt4PtxType(MMATypes type) {
return (type == MMATypes::u4 || type == MMATypes::s4);
static bool isInt8PtxType(MMATypes type) {
return (type == MMATypes::u8 || type == MMATypes::s8);
static bool isIntegerPtxType(MMATypes type) {
return isInt4PtxType(type) || isInt8PtxType(type) || type == MMATypes::b1 ||
type == MMATypes::s32;
MMATypes MmaOp::accumPtxType() {
Optional<mlir::NVVM::MMATypes> val = inferOperandMMAType(
getODSOperands(2).getTypes().front(), /*isAccum=*/true);
assert(val.hasValue() && "accumulator PTX type should always be inferrable");
return val.getValue();
MMATypes MmaOp::resultPtxType() {
Optional<mlir::NVVM::MMATypes> val =
inferOperandMMAType(getResult().getType(), /*isAccum=*/true);
assert(val.hasValue() && "result PTX type should always be inferrable");
return val.getValue();
void MmaOp::print(OpAsmPrinter &p) {
SmallVector<Type, 4> regTypes;
struct OperandFragment {
StringRef operandName;
StringRef ptxTypeAttr;
SmallVector<Value, 4> regs;
explicit OperandFragment(StringRef name, StringRef ptxTypeName)
: operandName(name), ptxTypeAttr(ptxTypeName) {}
std::array<OperandFragment, 3> frags{
OperandFragment("A", multiplicandAPtxTypeAttrName()),
OperandFragment("B", multiplicandBPtxTypeAttrName()),
OperandFragment("C", "")};
SmallVector<StringRef, 4> ignoreAttrNames{
for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
auto &frag = frags[fragIdx];
auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
for (auto operandIdx = varOperandSpec.first;
operandIdx < varOperandSpec.first + varOperandSpec.second;
operandIdx++) {
if (operandIdx == 0) {
Optional<MMATypes> inferredType =
inferOperandMMAType(regTypes.back(), /*isAccum=*/fragIdx >= 2);
if (inferredType)
auto printMmaOperand = [&](const OperandFragment &frag) -> void {
p << " " << frag.operandName;
p << "[";
p << "] ";
for (const auto &frag : frags) {
p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
// Print the types of the operands and result.
p << " : "
<< "(";
llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
p << ")";
void MmaOp::build(OpBuilder &builder, OperationState &result, Type resultType,
ValueRange operandA, ValueRange operandB, ValueRange operandC,
ArrayRef<int64_t> shape, Optional<MMAB1Op> b1Op,
Optional<MMAIntOverflow> intOverflow,
Optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
Optional<std::array<MMALayout, 2>> multiplicandLayouts) {
assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
MLIRContext *ctx = builder.getContext();
Type i32 = builder.getIntegerType(32);
"shape", MMAShapeAttr::get(builder.getIntegerAttr(i32, shape[0]),
builder.getIntegerAttr(i32, shape[1]),
builder.getIntegerAttr(i32, shape[2]), ctx));
if (multiplicandPtxTypes.hasValue()) {
MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
} else {
if (auto res = inferOperandMMAType(operandA[0].getType(), false))
result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
if (auto res = inferOperandMMAType(operandB[0].getType(), false))
result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
if (multiplicandLayouts.hasValue()) {
MMALayoutAttr::get(ctx, (*multiplicandLayouts)[0]));
MMALayoutAttr::get(ctx, (*multiplicandLayouts)[1]));
} else {
result.addAttribute("layoutA", MMALayoutAttr::get(ctx, MMALayout::row));
result.addAttribute("layoutB", MMALayoutAttr::get(ctx, MMALayout::col));
if (intOverflow.hasValue())
MMAIntOverflowAttr::get(ctx, *intOverflow));
if (b1Op.hasValue())
result.addAttribute("b1Op", MMAB1OpAttr::get(ctx, *b1Op));
// <operation> :=
// A `[` $operandA `]` B `[` $operandB `]` C `[` $operandC `]`
// attr-dict : (type($operandA[0]), type($operandB[0]), type($operandC[0]))
// `->` type($res)
ParseResult MmaOp::parse(OpAsmParser &parser, OperationState &result) {
struct OperandFragment {
Optional<MMATypes> elemtype;
SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
SmallVector<Type> regTypes;
Builder &builder = parser.getBuilder();
std::array<OperandFragment, 4> frags;
NamedAttrList namedAttributes;
// A helper to parse the operand segments.
auto parseMmaOperand = [&](StringRef operandName,
OperandFragment &frag) -> LogicalResult {
if (parser.parseKeyword(operandName).failed())
return failure();
if (parser
.parseOperandList(frag.regs, OpAsmParser::Delimiter::OptionalSquare)
return failure();
return success();
// Parse the operand segments.
if (parseMmaOperand("A", frags[0]).failed())
return failure();
if (parseMmaOperand("B", frags[1]).failed())
return failure();
if (parseMmaOperand("C", frags[2]).failed())
return failure();
if (parser.parseOptionalAttrDict(namedAttributes).failed())
return failure();
// Parse the type specification and resolve operands.
SmallVector<Type, 3> operandTypes;
if (failed(parser.parseColon()))
return failure();
if (failed(parser.parseLParen()))
return failure();
if (failed(parser.parseTypeList(operandTypes)))
return failure();
if (failed(parser.parseRParen()))
if (operandTypes.size() != 3)
return parser.emitError(
"expected one type for each operand segment but got " +
Twine(operandTypes.size()) + " types");
for (auto iter : llvm::enumerate(operandTypes)) {
auto &frag = frags[iter.index()];
frag.regTypes.resize(frag.regs.size(), iter.value());
if (failed(parser.resolveOperands(frag.regs, frag.regTypes,
parser.getNameLoc(), result.operands)))
return failure();
frag.elemtype =
inferOperandMMAType(frag.regTypes[0], /*isAccum=*/iter.index() < 2);
Type resultType;
frags[3].elemtype = inferOperandMMAType(resultType, /*isAccum=*/true);
std::array<StringRef, 2> names{"multiplicandAPtxType",
for (unsigned idx = 0; idx < names.size(); idx++) {
const auto &frag = frags[idx];
Optional<NamedAttribute> attr = namedAttributes.getNamed(names[idx]);
if (!frag.elemtype.hasValue() && !attr.hasValue()) {
return parser.emitError(
"attribute " + names[idx] +
" is not provided explicitly and cannot be inferred");
if (!attr.hasValue())
names[idx], MMATypesAttr::get(parser.getContext(), *frag.elemtype));
if (!namedAttributes.empty())
return success();
LogicalResult MmaOp::verify() {
MLIRContext *context = getContext();
auto f16Ty = Float16Type::get(context);
auto i32Ty = IntegerType::get(context, 32);
auto f16x2Ty = LLVM::getFixedVectorType(f16Ty, 2);
auto f32Ty = Float32Type::get(context);
auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
auto f32x8StructTy = LLVM::LLVMStructType::getLiteral(
context, {f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty});
auto operandTypes = getOperandTypes();
if (operandTypes != SmallVector<Type, 8>(8, f16x2Ty) &&
operandTypes != ArrayRef<Type>{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, f32Ty,
f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty,
f32Ty}) {
return emitOpError("expected operands to be 4 <halfx2>s followed by either "
"4 <halfx2>s or 8 floats");
if (getType() != f32x8StructTy && getType() != f16x2x4StructTy) {
return emitOpError("expected result type to be a struct of either 4 "
"<halfx2>s or 8 floats");
auto s32x4StructTy =
LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
auto f32x8StructTy =
LLVM::LLVMStructType::getLiteral(context, SmallVector<Type>(8, f32Ty));
auto f16x2x2StructTy =
LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
auto f32x4StructTy =
LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
auto s32x2StructTy =
LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
std::array<int64_t, 3> mmaShape{shapeAttr().m().getInt(),
// These variables define the set of allowed data types for matrices A, B, C,
// and result.
using AllowedShapes = SmallVector<std::array<int64_t, 3>, 2>;
using AllowedTypes = SmallVector<SmallVector<Type, 4>, 2>;
AllowedShapes allowedShapes;
AllowedTypes expectedA;
AllowedTypes expectedB;
AllowedTypes expectedC;
SmallVector<Type> expectedResult;
// When M = 16, we just need to calculate the number of 8xk tiles, where
// k is a factor that depends on the data type.
if (mmaShape[0] == 16) {
int64_t kFactor;
Type multiplicandFragType;
switch (multiplicandAPtxType().getValue()) {
case MMATypes::tf32:
kFactor = 4;
context, {i32Ty, i32Ty, i32Ty, i32Ty}));
case MMATypes::f16:
case MMATypes::bf16:
kFactor = 8;
multiplicandFragType = f16x2Ty;
case MMATypes::s4:
case MMATypes::u4:
kFactor = 32;
case MMATypes::b1:
kFactor = 128;
case MMATypes::s8:
case MMATypes::u8:
kFactor = 16;
return emitError("invalid shape or multiplicand type: " +
if (isIntegerPtxType(multiplicandAPtxType().getValue())) {
expectedC.emplace_back(4, i32Ty);
multiplicandFragType = i32Ty;
} else {
expectedC.emplace_back(2, f16x2Ty);
expectedC.emplace_back(4, f32Ty);
int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor);
int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
expectedA.emplace_back(unitA, multiplicandFragType);
expectedB.emplace_back(unitB, multiplicandFragType);
allowedShapes.push_back({16, 8, kFactor});
allowedShapes.push_back({16, 8, kFactor * 2});
auto alayout = (*this)->getAttrOfType<StringAttr>("alayout");
auto blayout = (*this)->getAttrOfType<StringAttr>("blayout");
if (!(alayout && blayout) ||
!(alayout.getValue() == "row" || alayout.getValue() == "col") ||
!(blayout.getValue() == "row" || blayout.getValue() == "col")) {
return emitOpError("alayout and blayout attributes must be set to either "
"\"row\" or \"col\"");
// In the M=8 case, there is only 1 possible case per data type.
if (mmaShape[0] == 8) {
if (multiplicandAPtxType().getValue() == MMATypes::f16) {
expectedA.emplace_back(2, f16x2Ty);
expectedB.emplace_back(2, f16x2Ty);
expectedC.emplace_back(4, f16x2Ty);
expectedC.emplace_back(8, f32Ty);
allowedShapes.push_back({8, 8, 4});
if (multiplicandAPtxType().getValue() == MMATypes::f64) {
Type f64Ty = Float64Type::get(context);
expectedA.emplace_back(1, f64Ty);
expectedB.emplace_back(1, f64Ty);
expectedC.emplace_back(2, f64Ty);
// expectedC.emplace_back(1, LLVM::getFixedVectorType(f64Ty, 2));
context, SmallVector<Type>(2, f64Ty)));
allowedShapes.push_back({8, 8, 4});
if (isIntegerPtxType(multiplicandAPtxType().getValue())) {
expectedC.push_back({i32Ty, i32Ty});
if (isInt4PtxType(multiplicandAPtxType().getValue()))
allowedShapes.push_back({8, 8, 32});
if (isInt8PtxType(multiplicandAPtxType().getValue()))
allowedShapes.push_back({8, 8, 16});
if (multiplicandAPtxType().getValue() == MMATypes::b1)
allowedShapes.push_back({8, 8, 128});
if (operandTypes == ArrayRef<Type>{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, f32Ty,
f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty,
f32Ty} &&
getType() == f32x8StructTy && alayout.getValue() == "row" &&
blayout.getValue() == "col") {
return success();
std::string errorMessage;
llvm::raw_string_ostream errorStream(errorMessage);
// Check that we matched an existing shape/dtype combination.
if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
[&](const auto &allowed) { return allowed == mmaShape; })) {
errorStream << "unimplemented variant for MMA shape <";
llvm::interleaveComma(mmaShape, errorStream);
errorStream << ">";
return emitOpError(errorMessage);
return emitOpError("unimplemented mma.sync variant");
// Verify the operand types for segments of A, B, and C operands.
std::array<StringRef, 3> operandNames{"A", "B", "C"};
for (const auto &iter : llvm::enumerate(
SmallVector<AllowedTypes, 3>{expectedA, expectedB, expectedC})) {
auto spec = this->getODSOperandIndexAndLength(iter.index());
SmallVector<Type, 4> operandTySeg(operand_type_begin() + spec.first,
operand_type_begin() + spec.first +
bool match =
llvm::any_of(iter.value(), [&](const SmallVector<Type, 4> &typeSet) {
return typeSet == operandTySeg;
if (!match) {
errorStream << "Could not match types for the "
<< operandNames[iter.index()]
<< " operands; expected one of ";
for (const auto &x : iter.value()) {
errorStream << x.size() << "x" << x[0] << " ";
errorStream << "but got ";
llvm::interleaveComma(operandTySeg, errorStream);
return emitOpError(errorStream.str());
// Check the result type
if (!llvm::any_of(expectedResult, [&](Type expectedResultType) {
return expectedResultType == getResult().getType();
})) {
<< "Could not match allowed types for the result; expected one of ";
llvm::interleaveComma(expectedResult, errorStream);
errorStream << " but got " << getResult().getType();
return emitOpError(errorStream.str());
// Ensure that binary MMA variants have a b1 MMA operation defined.
if (multiplicandAPtxType() == MMATypes::b1 && !b1Op().hasValue()) {
return emitOpError("op requires " + b1OpAttrName().strref() + " attribute");
// Ensure int4/int8 MMA variants specify the accum overflow behavior
// attribute.
if (isInt4PtxType(*multiplicandAPtxType()) ||
isInt8PtxType(*multiplicandAPtxType())) {
if (!intOverflowBehavior().hasValue())
return emitOpError("op requires " +
intOverflowBehaviorAttrName().strref() + " attribute");
return success();
LogicalResult ShflOp::verify() {

View file

@ -514,12 +514,13 @@ func @nvvm_invalid_shfl_pred_3(%arg0 : i32, %arg1 : i32, %arg2 : i32, %arg3 : i3
// -----
func @nvvm_invalid_mma_0(%a0 : f16, %a1 : vector<2xf16>,
func @nvvm_invalid_mma_0(%a0 : f16, %a1 : f16,
%b0 : vector<2xf16>, %b1 : vector<2xf16>,
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
%c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) {
// expected-error@+1 {{expected operands to be 4 <halfx2>s followed by either 4 <halfx2>s or 8 floats}}
%0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="row", blayout="col"} : (f16, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
// expected-error@+1 {{Could not match types for the A operands; expected one of 2xvector<2xf16> but got f16, f16}}
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7]
{layoutA=#nvvm.mma_layout<row>, layoutB=#nvvm.mma_layout<col>, shape = {k = 4 : i32, m = 8 : i32, n = 8 : i32}} : (f16, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
@ -529,8 +530,9 @@ func @nvvm_invalid_mma_1(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
%b0 : vector<2xf16>, %b1 : vector<2xf16>,
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
%c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) {
// expected-error@+1 {{expected result type to be a struct of either 4 <halfx2>s or 8 floats}}
%0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="row", blayout="col"} : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f16)>
// expected-error@+1 {{Could not match allowed types for the result; expected one of !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>, !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> but got !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f16)>}}
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7]
{layoutA=#nvvm.mma_layout<row>, layoutB=#nvvm.mma_layout<col>, shape = {k = 4 : i32, m = 8 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f16)>
llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f16)>
@ -540,8 +542,9 @@ func @nvvm_invalid_mma_2(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
%b0 : vector<2xf16>, %b1 : vector<2xf16>,
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
%c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) {
// expected-error@+1 {{alayout and blayout attributes must be set to either "row" or "col"}}
%0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
// expected-error@+1 {{op requires attribute 'layoutA'}}
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7]
{shape = {k = 4 : i32, m = 8 : i32, n = 8 : i32}}: (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
@ -549,55 +552,23 @@ func @nvvm_invalid_mma_2(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
func @nvvm_invalid_mma_3(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
%b0 : vector<2xf16>, %b1 : vector<2xf16>,
%c0 : vector<2xf16>, %c1 : vector<2xf16>,
%c2 : vector<2xf16>, %c3 : vector<2xf16>) {
// expected-error@+1 {{unimplemented mma.sync variant}}
%0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3 {alayout="row", blayout="col"} : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
%c0 : vector<2xf16>, %c1 : vector<2xf16>) {
// expected-error@+1 {{unimplemented variant for MMA shape <8, 8, 16>}}
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1] {layoutA=#nvvm.mma_layout<row>, layoutB=#nvvm.mma_layout<col>, shape = {k = 16 : i32, m = 8 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
// -----
func @nvvm_invalid_mma_4(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
%b0 : vector<2xf16>, %b1 : vector<2xf16>,
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
%c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) {
// expected-error@+1 {{unimplemented mma.sync variant}}
%0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="row", blayout="col"} : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// -----
func @nvvm_invalid_mma_5(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
%b0 : vector<2xf16>, %b1 : vector<2xf16>,
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
%c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) {
// expected-error@+1 {{unimplemented mma.sync variant}}
%0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="col", blayout="row"} : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
// -----
func @nvvm_invalid_mma_6(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
%b0 : vector<2xf16>, %b1 : vector<2xf16>,
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
%c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) {
// expected-error@+1 {{invalid kind of type specified}}
%0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="col", blayout="row"} : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
// -----
func @nvvm_invalid_mma_7(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
%b0 : vector<2xf16>, %b1 : vector<2xf16>,
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
%c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) {
// expected-error@+1 {{op requires one result}}
%0:2 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="col", blayout="row"} : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32) -> (!llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>, i32)
llvm.return %0#0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
func @nvvm_invalid_mma_8(%a0 : i32, %a1 : i32,
%b0 : i32,
%c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) {
// expected-error@+1 {{op requires b1Op attribute}}
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
multiplicandAPtxType = #nvvm.mma_type<b1>, multiplicandBPtxType = #nvvm.mma_type<b1>,
shape = {k = 128 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
// -----

View file

@ -66,15 +66,164 @@ func @nvvm_vote(%arg0 : i32, %arg1 : i1) -> i32 {
llvm.return %0 : i32
func @nvvm_mma(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
// CHECK-LABEL: @nvvm_mma_m8n8k4_row_col_f32_f32
func @nvvm_mma_m8n8k4_row_col_f32_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
%b0 : vector<2xf16>, %b1 : vector<2xf16>,
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
%c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) {
// CHECK: nvvm.mma.sync {{.*}} {alayout = "row", blayout = "col"} : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
%0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout = "row", blayout = "col"} : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, %c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) {
// CHECK: nvvm.mma.sync
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7]
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
shape = {k = 4 : i32, m = 8 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
func @nvvm_mma_m8n8k4_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
%b0 : vector<2xf16>, %b1 : vector<2xf16>,
%c0 : vector<2xf16>, %c1 : vector<2xf16>, %c2 : vector<2xf16>, %c3 : vector<2xf16>) {
// CHECK: nvvm.mma.sync A[{{.*}}] B[{{.*}}] C[{{.*}}]
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
shape = {k = 4 : i32, m = 8 : i32, n = 8 : i32}} : (vector<2xf16>,vector<2xf16>,vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
func @nvvm_mma_m8n8k16_s8_s8(%a0 : i32, %b0 : i32,
%c0 : i32, %c1 : i32) {
// CHECK: nvvm.mma.sync A[{{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>, shape = {k = 16 : i32, m = 8 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32)>
%0 = nvvm.mma.sync A[%a0] B[%b0] C[%c0, %c1]
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>,
shape = {k = 16 : i32, m = 8 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32)>
llvm.return %0 : !llvm.struct<(i32, i32)>
func @nvvm_mma_m16n8k8_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
%b0 : vector<2xf16>,
%c0 : vector<2xf16>, %c1 : vector<2xf16>) {
// CHECK: nvvm.mma.sync A[%{{.*}}, %{{.*}}] B[%{{.*}}] C[%{{.*}}, %{{.*}}] {{{.*}}} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1]
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
shape = {k = 8 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
func @nvvm_mma_m16n8k16_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
%a2 : vector<2xf16>, %a3 : vector<2xf16>,
%b0 : vector<2xf16>, %b1 : vector<2xf16>,
%c0 : vector<2xf16>, %c1 : vector<2xf16>) {
// CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
%0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1]
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
func @nvvm_mma_m16n8k16_f32_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
%a2 : vector<2xf16>, %a3 : vector<2xf16>,
%b0 : vector<2xf16>, %b1 : vector<2xf16>,
%c0 : vector<2xf16>, %c1 : vector<2xf16>) {
// CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(f32, f32, f32, f32)>
%0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1]
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>,vector<2xf16>,vector<2xf16>) -> !llvm.struct<(f32, f32, f32, f32)>
llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
func @nvvm_mma_m16n8k16_f16_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
%a2 : vector<2xf16>, %a3 : vector<2xf16>,
%b0 : vector<2xf16>, %b1 : vector<2xf16>,
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) {
// CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
%0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
func @nvvm_mma_m16n8k16_f32_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
%a2 : vector<2xf16>, %a3 : vector<2xf16>,
%b0 : vector<2xf16>, %b1 : vector<2xf16>,
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) {
// CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)>
%0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)>
llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
func @nvvm_mma_m16n8k16_s8_s8(%a0 : i32, %a1 : i32, %b0 : i32,
%c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) {
// CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>, shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>,
shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
func @nvvm_mma_m16n8k16_s8_u8(%a0 : i32, %a1 : i32,
%b0 : i32,
%c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) {
// CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<satfinite>, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<u8>, shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<u8>,
shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
func @nvvm_mma_m16n8k256_b1_b1(%a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
%b0 : i32, %b1 : i32,
%c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) {
// CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {b1Op = #nvvm.mma_b1op<xor_popc>, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<b1>, multiplicandBPtxType = #nvvm.mma_type<b1>, shape = {k = 256 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
%0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
multiplicandAPtxType = #nvvm.mma_type<b1>, multiplicandBPtxType = #nvvm.mma_type<b1>,
b1Op = #nvvm.mma_b1op<xor_popc>, shape = {k = 256 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
func @nvvm_mma_m16n8k128_b1_b1(%a0 : i32, %a1 : i32,
%b0 : i32,
%c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) {
// CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {b1Op = #nvvm.mma_b1op<xor_popc>, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<b1>, multiplicandBPtxType = #nvvm.mma_type<b1>, shape = {k = 128 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
multiplicandAPtxType = #nvvm.mma_type<b1>, multiplicandBPtxType = #nvvm.mma_type<b1>,
b1Op = #nvvm.mma_b1op<xor_popc>,
shape = {k = 128 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
// CHECK-LABEL: @nvvm_mma_m8n8k128_b1_b1
func @nvvm_mma_m8n8k128_b1_b1(%a0 : i32,
%b0 : i32,
%c0 : i32, %c1 : i32) {
// CHECK: nvvm.mma.sync A[{{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}] {b1Op = #nvvm.mma_b1op<xor_popc>, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<b1>, multiplicandBPtxType = #nvvm.mma_type<b1>, shape = {k = 128 : i32, m = 8 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32)>
%0 = nvvm.mma.sync A[%a0] B[%b0] C[%c0, %c1]
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
multiplicandAPtxType = #nvvm.mma_type<b1>, multiplicandBPtxType = #nvvm.mma_type<b1>,
b1Op = #nvvm.mma_b1op<xor_popc>, shape = {k = 128 : i32, m = 8 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32)>
llvm.return %0 : !llvm.struct<(i32,i32)>
// CHECK-LABEL: @nvvm_mma_m16n8k32_s4_s4
func @nvvm_mma_m16n8k32_s4_s4(%a0 : i32, %a1 : i32,
%b0 : i32,
%c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) {
// CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<s4>, multiplicandBPtxType = #nvvm.mma_type<s4>, shape = {k = 32 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
multiplicandAPtxType = #nvvm.mma_type<s4>, multiplicandBPtxType = #nvvm.mma_type<s4>,
shape = {k = 32 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
// CHECK-LABEL: @nvvm_wmma_load_tf32
func @nvvm_wmma_load_tf32(%arg0: !llvm.ptr<i32>, %arg1 : i32) -> !llvm.struct<(i32, i32, i32, i32)> {
// CHECK: nvvm.wmma.load {{.*}} {eltype = #nvvm.mma_type<tf32>, frag = #nvvm.mma_frag<a>, k = 8 : i32, layout = #nvvm.mma_layout<row>, m = 16 : i32, n = 16 : i32}
%0 = nvvm.wmma.load %arg0, %arg1

View file

@ -88,17 +88,124 @@ llvm.func @nvvm_vote(%0 : i32, %1 : i1) -> i32 {
llvm.return %3 : i32
llvm.func @nvvm_mma(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
// CHECK-LABEL: @nvvm_mma_mn8n8k4_row_col_f32_f32
llvm.func @nvvm_mma_mn8n8k4_row_col_f32_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
%b0 : vector<2xf16>, %b1 : vector<2xf16>,
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
%c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> {
// CHECK: call { float, float, float, float, float, float, float, float } @llvm.nvvm.mma.m8n8k4.row.col.f32.f32
%0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="row", blayout="col"} : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7]
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, shape = {m = 8 : i32, n = 8 : i32, k = 4 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
llvm.func @nvvm_mma_m16n8k16_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
%a2 : vector<2xf16>, %a3 : vector<2xf16>,
%b0 : vector<2xf16>, %b1 : vector<2xf16>,
%c0 : vector<2xf16>, %c1 : vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> {
// CHECK: call { <2 x half>, <2 x half> } @llvm.nvvm.mma.m16n8k16.row.col.f16.f16
%0 = nvvm.mma.sync A[ %a0, %a1, %a2, %a3 ] B[ %b0, %b1 ] C[ %c0, %c1 ]
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, shape = {m = 16 : i32, n = 8 : i32, k = 16 : i32}}
: (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
// f32 return type, f16 accumulate type
llvm.func @nvvm_mma_m16n8k16_f32_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
%a2 : vector<2xf16>, %a3 : vector<2xf16>,
%b0 : vector<2xf16>, %b1 : vector<2xf16>,
%c0 : vector<2xf16>, %c1 : vector<2xf16>) -> !llvm.struct<(f32, f32, f32, f32)> {
// CHECK: call { float, float, float, float } @llvm.nvvm.mma.m16n8k16.row.col.f32.f16
%0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1]
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
shape = {m = 16 : i32, n = 8 : i32, k = 16 : i32}} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(f32, f32, f32, f32)>
llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
// f16 return type, f32 accumulate type
llvm.func @nvvm_mma_m16n8k16_f16_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
%a2 : vector<2xf16>, %a3 : vector<2xf16>,
%b0 : vector<2xf16>, %b1 : vector<2xf16>,
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> {
// CHECK: call { <2 x half>, <2 x half> } @llvm.nvvm.mma.m16n8k16.row.col.f16.f32
%0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
shape = {m = 16 : i32, n = 8 : i32, k = 16 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
// f32 return type, f32 accumulate type
llvm.func @nvvm_mma_m16n8k16_f32_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
%a2 : vector<2xf16>, %a3 : vector<2xf16>,
%b0 : vector<2xf16>, %b1 : vector<2xf16>,
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) -> !llvm.struct<(f32, f32, f32, f32)> {
// CHECK: call { float, float, float, float } @llvm.nvvm.mma.m16n8k16.row.col.f32.f32
%0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
shape = {m = 16 : i32, n = 8 : i32, k = 16 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)>
llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
llvm.func @nvvm_mma_m16n8k16_s8_s8(%a0 : i32, %a1 : i32,
%b0 : i32,
%c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) -> !llvm.struct<(i32, i32, i32, i32)> {
// CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.mma.m16n8k16.row.col.s8
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>,
shape = {m = 16 : i32, n = 8 : i32, k = 16 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
llvm.func @nvvm_mma_m16n8k16_s8_u8(%a0 : i32, %a1 : i32,
%b0 : i32,
%c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) -> !llvm.struct<(i32, i32, i32, i32)> {
// CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.mma.m16n8k16.row.col.satfinite.s8.u8
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<u8>,
shape = {m = 16 : i32, n = 8 : i32, k = 16 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
llvm.func @nvvm_mma_m16n8k128_b1_b1(%a0 : i32, %a1 : i32,
%b0 : i32,
%c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) -> !llvm.struct<(i32,i32,i32,i32)> {
// CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.mma.xor.popc.m16n8k128.row.col.b1
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
multiplicandAPtxType = #nvvm.mma_type<b1>, multiplicandBPtxType = #nvvm.mma_type<b1>,
b1Op = #nvvm.mma_b1op<xor_popc>, shape = {k = 128 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
llvm.func @nvvm_mma_m16n8k32_s4_s4(%a0 : i32, %a1 : i32,
%b0 : i32,
%c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) -> !llvm.struct<(i32,i32,i32,i32)> {
// CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.mma.m16n8k32.row.col.satfinite.s4
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
multiplicandAPtxType = #nvvm.mma_type<s4>, multiplicandBPtxType = #nvvm.mma_type<s4>,
shape = {k = 32 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
llvm.func @nvvm_mma_m8n8k4_f64_f64(%a0 : f64,
%b0 : f64,
%c0 : f64, %c1 : f64) -> !llvm.struct<(f64, f64)> {
// CHECK: call { double, double } @llvm.nvvm.mma.m8n8k4.row.col.f64
%0 = nvvm.mma.sync A[%a0] B[%b0] C[%c0, %c1]
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
shape = {m = 8 : i32, n = 8 : i32, k = 4 : i32}} : (f64, f64, f64) -> !llvm.struct<(f64, f64)>
llvm.return %0 : !llvm.struct<(f64, f64)>
// The test below checks the correct mapping of the nvvm.wmma.*.load.* op to the correct intrinsic
// in the LLVM NVPTX backend.
// CHECK-LABEL: @gpu_wmma_load_op
llvm.func @gpu_wmma_load_op(%arg0: !llvm.ptr<i32, 3>, %arg1: i32) {
// CHECK: call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.load.a.row.stride.f16.p3i32(i32 addrspace(3)* %{{.*}}, i32 %{{.*}})
%0 = nvvm.wmma.load %arg0, %arg1