[mlir][NVVM] Add support for nvvm mma.sync ops
This patch adds MLIR NVVM support for the various NVPTX `mma.sync` operations. There are a number of possible data type, shape, and other attribute combinations supported by the operation, so a custom assebmly format is added and attributes are inferred where possible. Reviewed By: ThomasRaoux Differential Revision: https://reviews.llvm.org/D122410
This commit is contained in:
parent
5bc9ee1b78
commit
3be7c28917
|
@ -35,6 +35,8 @@ set(LLVM_TARGET_DEFINITIONS NVVMOps.td)
|
|||
mlir_tablegen(NVVMConversions.inc -gen-llvmir-conversions)
|
||||
mlir_tablegen(NVVMOpsEnums.h.inc -gen-enum-decls)
|
||||
mlir_tablegen(NVVMOpsEnums.cpp.inc -gen-enum-defs)
|
||||
mlir_tablegen(NVVMOpsStructs.h.inc -gen-struct-attr-decls)
|
||||
mlir_tablegen(NVVMOpsStructs.cpp.inc -gen-struct-attr-defs)
|
||||
mlir_tablegen(NVVMOpsAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=nvvm)
|
||||
mlir_tablegen(NVVMOpsAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=nvvm)
|
||||
add_public_tablegen_target(MLIRNVVMConversionsIncGen)
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include "llvm/IR/IntrinsicsNVPTX.h"
|
||||
|
||||
#include "mlir/Dialect/LLVMIR/NVVMOpsEnums.h.inc"
|
||||
#include "mlir/Dialect/LLVMIR/NVVMOpsStructs.h.inc"
|
||||
|
||||
namespace mlir {
|
||||
namespace NVVM {
|
||||
|
|
|
@ -195,18 +195,6 @@ def NVVM_CpAsyncWaitGroupOp : NVVM_Op<"cp.async.wait.group">,
|
|||
let assemblyFormat = "$n attr-dict";
|
||||
}
|
||||
|
||||
def NVVM_MmaOp :
|
||||
NVVM_Op<"mma.sync">,
|
||||
Results<(outs LLVM_Type:$res)>,
|
||||
Arguments<(ins Variadic<LLVM_Type>:$args)> {
|
||||
string llvmBuilder = [{
|
||||
$res = createIntrinsicCall(
|
||||
builder, llvm::Intrinsic::nvvm_mma_m8n8k4_row_col_f32_f32, $args);
|
||||
}];
|
||||
let assemblyFormat = "$args attr-dict `:` functional-type($args, $res)";
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
/// Helpers to instantiate different version of wmma intrinsics.
|
||||
/// This matches the hierarchy used in IntrinsicsNVVM.td to define all the
|
||||
/// combinations of the intrinsics.
|
||||
|
@ -296,6 +284,7 @@ class MMA_LDST_OPS<list<GEOM> Geom, list<string> Frags, list<string> Types> {
|
|||
// Creates list of valid combinations of fragments. This is a subset of what
|
||||
// llvm supports and can be extended as needed.
|
||||
class NVVM_MMA_OPS {
|
||||
// "wmma" operations
|
||||
list<list<WMMA_REGS>> tf32_wmma_ops = MMA_OPS<
|
||||
[GEOM<16, 16, 8>],
|
||||
["tf32"], [], ["f32"], []>.ret;
|
||||
|
@ -324,6 +313,32 @@ class NVVM_MMA_OPS {
|
|||
// Separate A/B/C fragments (loads) from D (stores).
|
||||
list<WMMA_REGS> all_ld_ops = !filter(op, all_ldst_ops, !ne(op.frag, "d"));
|
||||
list<WMMA_REGS> all_st_ops = !filter(op, all_ldst_ops, !eq(op.frag, "d"));
|
||||
|
||||
// "mma_sync" operations
|
||||
list<list<WMMA_REGS>> tf32_mma_ops = MMA_OPS<
|
||||
[GEOM<16,8,4>, GEOM<16,8,8>],
|
||||
["tf32"], [], ["f32"], []>.ret;
|
||||
list<list<WMMA_REGS>> bf16_mma_ops = MMA_OPS<
|
||||
[GEOM<16,8,16>, GEOM<16,8,8>],
|
||||
["bf16"], [], ["f32"], []>.ret;
|
||||
list<list<WMMA_REGS>> f64_mma_ops = MMA_OPS<
|
||||
[GEOM<8,8,4>],
|
||||
["f64"], [], ["f64"], []>.ret;
|
||||
list<list<WMMA_REGS>> fp_mma_ops = MMA_OPS<
|
||||
[GEOM<8,8,4>, GEOM<16,8,8>, GEOM<16,8,16>],
|
||||
["f16"], [], ["f16", "f32"], ["f16", "f32"]>.ret;
|
||||
list<list<WMMA_REGS>> int_mma_ops = MMA_OPS<
|
||||
[GEOM<8,8,16>, GEOM<16,8,16>, GEOM<16,8,32>],
|
||||
["s8", "u8"], ["s8", "u8"], ["s32"], []>.ret;
|
||||
list<list<WMMA_REGS>> subint_mma_ops = MMA_OPS<
|
||||
[GEOM<8,8,32>, GEOM<16,8,32>, GEOM<16,8,64>],
|
||||
["s4", "u4"], ["s4", "u4"], ["s32"], []>.ret;
|
||||
list<list<WMMA_REGS>> bit_mma_ops = MMA_OPS<
|
||||
[GEOM<8,8,128>, GEOM<16,8,128>, GEOM<16,8,256>],
|
||||
["b1"], [], ["s32"], []>.ret;
|
||||
list<list<WMMA_REGS>> all_mma_sync_ops = !listconcat(
|
||||
tf32_mma_ops, bf16_mma_ops, f64_mma_ops,
|
||||
fp_mma_ops, int_mma_ops, subint_mma_ops, bit_mma_ops);
|
||||
}
|
||||
|
||||
def NVVM_MMA_OPS : NVVM_MMA_OPS;
|
||||
|
@ -405,6 +420,150 @@ class MMA_MMA_INTR<string opName> {
|
|||
string id = !foldl("", f, acc, el, acc # "\n" # el);
|
||||
}
|
||||
|
||||
/// Enum attribute for binary (b1) MMA operation type
|
||||
def MMAB1OpNone : I32EnumAttrCase<"none", 0>;
|
||||
def MMAB1OpXorPopc : I32EnumAttrCase<"xor_popc", 1>;
|
||||
def MMAB1OpAndPopc : I32EnumAttrCase<"and_popc", 2>;
|
||||
def MMAB1Op : I32EnumAttr<"MMAB1Op", "MMA binary operations",
|
||||
[MMAB1OpNone, MMAB1OpXorPopc, MMAB1OpAndPopc]> {
|
||||
let genSpecializedAttr = 0;
|
||||
let cppNamespace = "::mlir::NVVM";
|
||||
}
|
||||
def MMAB1OpAttr : EnumAttr<NVVM_Dialect, MMAB1Op, "mma_b1op"> {
|
||||
let assemblyFormat = "`<` $value `>`";
|
||||
}
|
||||
|
||||
/// Enum attribute type for the overflow behavior of MMA integer operations
|
||||
def MMAIntOverflowWrap : I32EnumAttrCase<"wrapped", 0>;
|
||||
def MMAIntOverflowSat : I32EnumAttrCase<"satfinite", 1>;
|
||||
def MMAIntOverflow : I32EnumAttr<"MMAIntOverflow", "MMA overflow options",
|
||||
[MMAIntOverflowSat, MMAIntOverflowWrap]> {
|
||||
let genSpecializedAttr = 0;
|
||||
let cppNamespace = "::mlir::NVVM";
|
||||
}
|
||||
def MMAIntOverflowAttr : EnumAttr<NVVM_Dialect, MMAIntOverflow, "mma_int_overflow"> {
|
||||
let assemblyFormat = "`<` $value `>`";
|
||||
}
|
||||
|
||||
/// Attribute to hold the MMA shape
|
||||
def NVVM_MMAShapeAttr : StructAttr<"MMAShapeAttr", NVVM_Dialect, [
|
||||
StructFieldAttr<"m", I32Attr>,
|
||||
StructFieldAttr<"n", I32Attr>,
|
||||
StructFieldAttr<"k", I32Attr>
|
||||
]> {
|
||||
let summary = "Attribute for MMA operation shape.";
|
||||
}
|
||||
|
||||
// Returns true if this combination of layout/satf for MMA ops is supported;
|
||||
// false otherwise.
|
||||
// E.g.
|
||||
// if NVVM_MMA_SUPPORTED<...>.ret then
|
||||
// def : FOO<>; // The record will only be defined for supported ops.
|
||||
//
|
||||
class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b, int satf> {
|
||||
// MMA ops check both layouts.
|
||||
string layout = layout_a # ":" # layout_b;
|
||||
string a_type = frags[0].ptx_elt_type;
|
||||
string b_type = frags[1].ptx_elt_type;
|
||||
string c_type = frags[2].ptx_elt_type;
|
||||
string d_type = frags[3].ptx_elt_type;
|
||||
string geom = frags[0].geom;
|
||||
|
||||
// gcd is a shortcut used to identify instructions that depend on
|
||||
// geom+frag_c+frag_d.
|
||||
string gcd = geom # ":" # c_type # d_type;
|
||||
bit ret = !cond(
|
||||
|
||||
// Limit satf to valid types
|
||||
!and(!eq(satf, 1),
|
||||
!ne(a_type, "s8"),
|
||||
!ne(a_type, "u8"),
|
||||
!ne(a_type, "s4"),
|
||||
!ne(a_type, "u4")): false,
|
||||
|
||||
// m8n8k4 has no C=f32 D=f16 variant.
|
||||
!eq(gcd, "m8n8k4:f32f16"): false,
|
||||
|
||||
// only m8n8k4 for f16 does not require row:col layout
|
||||
!and(!ne(layout, "row:col"),
|
||||
!or(!ne(geom, "m8n8k4"),
|
||||
!ne(a_type, "f16"))) : false,
|
||||
|
||||
// m16n8k8 requires A and B to be the same type and C and D to be the same
|
||||
// type.
|
||||
!and(!eq(geom, "m16n8k8"),
|
||||
!or(!ne(a_type, b_type),
|
||||
!ne(c_type, d_type))): false,
|
||||
|
||||
// m16n8k8 requires C and D to be the same type.
|
||||
!and(!eq(geom, "m16n8k8"),
|
||||
!ne(c_type, d_type)): false,
|
||||
|
||||
// All other are OK.
|
||||
true: true
|
||||
);
|
||||
}
|
||||
|
||||
// Returns a list of operation suffixes corresponding to possible b1
|
||||
// multiply-and-accumulate operations for all fragments which have a
|
||||
// b1 type. For all other fragments, the list returned holds a list
|
||||
// containing the empty string.
|
||||
class NVVM_MMA_B1OPS<list<WMMA_REGS> frags> {
|
||||
list<string> ret = !cond(
|
||||
!eq(frags[0].ptx_elt_type, "b1") : ["xor_popc", "and_popc"],
|
||||
true: [""]
|
||||
);
|
||||
}
|
||||
|
||||
/// Generate enum value of the mma.sync intrinsic.
|
||||
class MMA_SYNC_NAME<string ALayout, string BLayout, string b1op, int Satfinite,
|
||||
WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
|
||||
string signature = MMA_SIGNATURE<A, B, C, D>.ret;
|
||||
string id = "llvm::Intrinsic::nvvm_mma"
|
||||
# !if(!ne(b1op, ""), "_" # b1op, "")
|
||||
# "_" # A.geom
|
||||
# "_" # ALayout
|
||||
# "_" # BLayout
|
||||
# !if(Satfinite, "_satfinite", "")
|
||||
# signature;
|
||||
}
|
||||
|
||||
/// Helper to create the mapping between the configuration and the mma.sync
|
||||
/// intrinsic enum value.
|
||||
class MMA_SYNC_INTR {
|
||||
list<list<list<list<list<string>>>>> cond0 =
|
||||
!foreach(op, NVVM_MMA_OPS.all_mma_sync_ops,
|
||||
!foreach(layoutA, ["row", "col"],
|
||||
!foreach(layoutB, ["row", "col"],
|
||||
!foreach (sat, [0, 1],
|
||||
!foreach (b1op, NVVM_MMA_B1OPS<op>.ret,
|
||||
!if(NVVM_MMA_SUPPORTED<[op[0], op[1], op[2], op[3]],
|
||||
layoutA, layoutB, sat>.ret,
|
||||
"if (layoutA == \"" # layoutA # "\" && layoutB == \"" # layoutB # "\" && "
|
||||
" m == " # op[0].m # " && n == " # op[0].n # " && k == " # op[0].k #
|
||||
" && \"" # op[0].ptx_elt_type # "\" == eltypeA && \""
|
||||
# op[1].ptx_elt_type # "\" == eltypeB && "
|
||||
# " \"" # op[2].ptx_elt_type # "\" == eltypeC && "
|
||||
# " \"" # op[3].ptx_elt_type # "\" == eltypeD "
|
||||
# " && (sat.hasValue() ? " # sat # " == static_cast<int>(*sat) : true)"
|
||||
# !if(!ne(b1op, ""), " && (b1Op.hasValue() ? MMAB1Op::" # b1op # " == b1Op.getValue() : true)", "") # ")\n"
|
||||
# " return " #
|
||||
MMA_SYNC_NAME<layoutA, layoutB, b1op, sat, op[0], op[1], op[2], op[3]>.id # ";",
|
||||
"") // if supported
|
||||
) // b1op
|
||||
) // sat
|
||||
) // layoutB
|
||||
) // layoutA
|
||||
); // all_mma_sync_ops
|
||||
list<list<list<string>>> f1 = !foldl([[[""]]],
|
||||
!foldl([[[[""]]]], cond0, acc, el,
|
||||
!listconcat(acc, el)),
|
||||
acc1, el1, !listconcat(acc1, el1));
|
||||
list<list<string>> f2 = !foldl([[""]], f1, acc1, el1, !listconcat(acc1, el1));
|
||||
list<string> f3 = !foldl([""], f2, acc, el, !listconcat(acc, el));
|
||||
string id = !foldl("", f3, acc, el, acc # "\n" # el);
|
||||
}
|
||||
|
||||
def MMALayoutRow : I32EnumAttrCase<"row", 0>;
|
||||
def MMALayoutCol : I32EnumAttrCase<"col", 1>;
|
||||
|
||||
|
@ -418,13 +577,24 @@ def MMALayoutAttr : EnumAttr<NVVM_Dialect, MMALayout, "mma_layout"> {
|
|||
let assemblyFormat = "`<` $value `>`";
|
||||
}
|
||||
|
||||
/// Enum attribute of the different PTX element types used for MMA operands.
|
||||
def MMATypeF16 : I32EnumAttrCase<"f16", 0>;
|
||||
def MMATypeF32 : I32EnumAttrCase<"f32", 1>;
|
||||
def MMATypeTF32 : I32EnumAttrCase<"tf32", 2>;
|
||||
def MMATypeU8 : I32EnumAttrCase<"u8", 3>;
|
||||
def MMATypeS8 : I32EnumAttrCase<"s8", 4>;
|
||||
def MMATypeS32 : I32EnumAttrCase<"s32", 5>;
|
||||
def MMATypeB1 : I32EnumAttrCase<"b1", 6>;
|
||||
def MMATypeU4 : I32EnumAttrCase<"u4", 7>;
|
||||
def MMATypeS4 : I32EnumAttrCase<"s4", 8>;
|
||||
def MMATypeBF16 : I32EnumAttrCase<"bf16", 9>;
|
||||
def MMATypeF64 : I32EnumAttrCase<"f64", 10>;
|
||||
|
||||
/// Enum attribute of the different matrix types.
|
||||
def MMATypes : I32EnumAttr<"MMATypes", "NVVM MMA types",
|
||||
[MMATypeF16, MMATypeF32, MMATypeTF32]> {
|
||||
[MMATypeF16, MMATypeF32, MMATypeTF32,
|
||||
MMATypeBF16, MMATypeS8, MMATypeU8,
|
||||
MMATypeS32, MMATypeS4, MMATypeU4,
|
||||
MMATypeB1, MMATypeF64]> {
|
||||
let genSpecializedAttr = 0;
|
||||
let cppNamespace = "::mlir::NVVM";
|
||||
}
|
||||
|
@ -678,4 +848,141 @@ def NVVM_LdMatrixOp: NVVM_Op<"ldmatrix">,
|
|||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
|
||||
|
||||
let summary = "cooperative matrix-multiply and accumulate";
|
||||
|
||||
let description = [{
|
||||
The `nvvm.mma.sync` operation collectively performs the operation
|
||||
`D = matmul(A, B) + C` using all threads in a warp.
|
||||
|
||||
All the threads in the warp must execute the same `mma.sync` operation.
|
||||
|
||||
For each possible multiplicand PTX data type, there are one or more possible
|
||||
instruction shapes given as "mMnNkK". The below table describes the posssibilities
|
||||
as well as the types required for the operands. Note that the data type for
|
||||
C (the accumulator) and D (the result) can vary independently when there are
|
||||
multiple possibilities in the "C/D Type" column.
|
||||
|
||||
When an optional attribute cannot be immediately inferred from the types of
|
||||
the operands and the result during parsing or validation, an error will be
|
||||
raised.
|
||||
|
||||
`b1Op` is only relevant when the binary (b1) type is given to
|
||||
`multiplicandDataType`. It specifies how the multiply-and-acumulate is
|
||||
performed and is either `xor_popc` or `and_poc`. The default is `xor_popc`.
|
||||
|
||||
`intOverflowBehavior` is only relevant when the `multiplicandType` attribute
|
||||
is one of `u8, s8, u4, s4`, this attribute describes how overflow is handled
|
||||
in the accumulator. When the attribute is `satfinite`, the accumulator values
|
||||
are clamped in the int32 range on overflow. This is the default behavior.
|
||||
Alternatively, accumulator behavior `wrapped` can also be specified, in
|
||||
which case overflow wraps from one end of the range to the other.
|
||||
|
||||
`layoutA` and `layoutB` are required and should generally be set to
|
||||
`#nvvm.mma_layout<row>` and `#nvvm.mma_layout<col>` respectively, but other
|
||||
combinations are possible for certain layouts according to the table below.
|
||||
|
||||
```
|
||||
| A/B Type | Shape | ALayout | BLayout | A Type | B Type | C/D Type |
|
||||
|----------|-----------|---------|---------|----------|----------|-------------------|
|
||||
| f64 | .m8n8k4 | row | col | 1x f64 | 1x f64 | 2x f64 |
|
||||
| f16 | .m8n8k4 | row/col | row/col | 2x f16x2 | 2x f16x2 | 4x f16x2 or 8xf32 |
|
||||
| | .m16n8k8 | row | col | 2x f16x2 | 1x f16x2 | 2x f16x2 or 4 f32 |
|
||||
| | .m16n8k16 | row | col | 4x f16x2 | 2x f16x2 | 2x f16x2 or 4 f32 |
|
||||
| bf16 | .m16n8k8 | row | col | 2x f16x2 | 1x f16x2 | 2x f16x2 or 4 f32 |
|
||||
| | .m16n8k16 | row | col | 4x f16x2 | 2x f16x2 | 2x f16x2 or 4 f32 |
|
||||
| tf32 | .m16n8k4 | row | col | 2x i32 | 1x i32 | 4x f32 |
|
||||
| | .m16n8k8 | row | col | 4x i32 | 2x i32 | 2x f16x2 or 4 f32 |
|
||||
| u8/s8 | .m8n8k16 | row | col | 1x i32 | 1x i32 | 2x i32 |
|
||||
| | .m16n8k16 | row | col | 2x i32 | 1x i32 | 4x i32 |
|
||||
| | .m16n8k32 | row | col | 4x i32 | 2x i32 | 4x i32 |
|
||||
| u4/s4 | .m8n8k32 | row | col | 1x i32 | 1x i32 | 2x i32 |
|
||||
| | m16n8k32 | row | col | 2x i32 | 1x i32 | 4x i32 |
|
||||
| | m16n8k64 | row | col | 4x i32 | 2x i32 | 4x i32 |
|
||||
| b1 | m8n8k128 | row | col | 1x i32 | 1x i32 | 2x i32 |
|
||||
| | m16n8k128 | row | col | 2x i32 | 1x i32 | 4x i32 |
|
||||
```
|
||||
|
||||
|
||||
Example:
|
||||
```mlir
|
||||
|
||||
%128 = nvvm.mma.sync A[%120, %121, %122, %123]
|
||||
B[%124, %125]
|
||||
C[%126, %127]
|
||||
{layoutA = #nvvm.mma_layout<row>,
|
||||
layoutB = #nvvm.mma_layout<col>,
|
||||
shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}}
|
||||
: (vector<2xf16>, vector<2xf16>, vector<2xf16>)
|
||||
-> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
|
||||
```
|
||||
}];
|
||||
|
||||
let results = (outs LLVM_AnyStruct:$res);
|
||||
let arguments = (ins NVVM_MMAShapeAttr:$shape,
|
||||
OptionalAttr<MMAB1OpAttr>:$b1Op,
|
||||
OptionalAttr<MMAIntOverflowAttr>:$intOverflowBehavior,
|
||||
MMALayoutAttr:$layoutA,
|
||||
MMALayoutAttr:$layoutB,
|
||||
OptionalAttr<MMATypesAttr>:$multiplicandAPtxType,
|
||||
OptionalAttr<MMATypesAttr>:$multiplicandBPtxType,
|
||||
Variadic<LLVM_Type>:$operandA,
|
||||
Variadic<LLVM_Type>:$operandB,
|
||||
Variadic<LLVM_Type>:$operandC);
|
||||
|
||||
let extraClassDeclaration = !strconcat([{
|
||||
static llvm::Intrinsic::ID getIntrinsicID(
|
||||
int64_t m, int64_t n, uint64_t k,
|
||||
llvm::Optional<MMAB1Op> b1Op,
|
||||
llvm::Optional<MMAIntOverflow> sat,
|
||||
mlir::NVVM::MMALayout layoutAEnum, mlir::NVVM::MMALayout layoutBEnum,
|
||||
mlir::NVVM::MMATypes eltypeAEnum, mlir::NVVM::MMATypes eltypeBEnum,
|
||||
mlir::NVVM::MMATypes eltypeCEnum, mlir::NVVM::MMATypes eltypeDEnum) {
|
||||
llvm::StringRef layoutA = stringifyEnum(layoutAEnum);
|
||||
llvm::StringRef layoutB = stringifyEnum(layoutBEnum);
|
||||
llvm::StringRef eltypeA = stringifyEnum(eltypeAEnum);
|
||||
llvm::StringRef eltypeB = stringifyEnum(eltypeBEnum);
|
||||
llvm::StringRef eltypeC = stringifyEnum(eltypeCEnum);
|
||||
llvm::StringRef eltypeD = stringifyEnum(eltypeDEnum);
|
||||
}],
|
||||
MMA_SYNC_INTR<>.id, [{
|
||||
return 0;
|
||||
}
|
||||
|
||||
static Optional<mlir::NVVM::MMATypes> inferOperandMMAType(Type operandElType,
|
||||
bool isAccumulator);
|
||||
|
||||
MMATypes accumPtxType();
|
||||
MMATypes resultPtxType();
|
||||
}]);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "Type":$resultType, "ValueRange":$operandA,
|
||||
"ValueRange":$operandB, "ValueRange":$operandC,
|
||||
"ArrayRef<int64_t>":$shape, "Optional<MMAB1Op>":$b1Op,
|
||||
"Optional<MMAIntOverflow>":$intOverflow,
|
||||
"Optional<std::array<MMATypes, 2>>":$multiplicandPtxTypes,
|
||||
"Optional<std::array<MMALayout, 2>>":$multiplicandLayouts)>
|
||||
];
|
||||
|
||||
string llvmBuilder = [{
|
||||
auto operands = moduleTranslation.lookupValues(opInst.getOperands());
|
||||
auto intId = mlir::NVVM::MmaOp::getIntrinsicID(
|
||||
$shape.m().getInt(), $shape.n().getInt(), $shape.k().getInt(),
|
||||
$b1Op, $intOverflowBehavior,
|
||||
$layoutA, $layoutB,
|
||||
$multiplicandAPtxType.getValue(),
|
||||
$multiplicandBPtxType.getValue(),
|
||||
op.accumPtxType(),
|
||||
op.resultPtxType());
|
||||
|
||||
$res = createIntrinsicCall(
|
||||
builder, intId, operands);
|
||||
}];
|
||||
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
#endif // NVVMIR_OPS
|
||||
|
|
|
@ -34,6 +34,7 @@ using namespace NVVM;
|
|||
|
||||
#include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
|
||||
#include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
|
||||
#include "mlir/Dialect/LLVMIR/NVVMOpsStructs.cpp.inc"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Printing/parsing for NVVM ops
|
||||
|
@ -69,47 +70,455 @@ LogicalResult CpAsyncOp::verify() {
|
|||
return success();
|
||||
}
|
||||
|
||||
// Given the element type of an operand and whether or not it is an accumulator,
|
||||
// this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the
|
||||
// operand's element type.
|
||||
Optional<mlir::NVVM::MMATypes> MmaOp::inferOperandMMAType(Type operandElType,
|
||||
bool isAccumulator) {
|
||||
auto half2Type =
|
||||
LLVM::getFixedVectorType(Float16Type::get(operandElType.getContext()), 2);
|
||||
if (operandElType.isF64())
|
||||
return NVVM::MMATypes::f64;
|
||||
if (operandElType.isF16() || operandElType == half2Type)
|
||||
return NVVM::MMATypes::f16;
|
||||
if (operandElType.isF32())
|
||||
return NVVM::MMATypes::f32;
|
||||
if (operandElType.isa<IntegerType>()) {
|
||||
if (isAccumulator)
|
||||
return NVVM::MMATypes::s32;
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
if (auto structType = operandElType.dyn_cast<LLVM::LLVMStructType>()) {
|
||||
if (structType.getBody().empty())
|
||||
return llvm::None;
|
||||
return inferOperandMMAType(structType.getBody()[0], isAccumulator);
|
||||
}
|
||||
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
static bool isInt4PtxType(MMATypes type) {
|
||||
return (type == MMATypes::u4 || type == MMATypes::s4);
|
||||
}
|
||||
|
||||
static bool isInt8PtxType(MMATypes type) {
|
||||
return (type == MMATypes::u8 || type == MMATypes::s8);
|
||||
}
|
||||
|
||||
static bool isIntegerPtxType(MMATypes type) {
|
||||
return isInt4PtxType(type) || isInt8PtxType(type) || type == MMATypes::b1 ||
|
||||
type == MMATypes::s32;
|
||||
}
|
||||
|
||||
MMATypes MmaOp::accumPtxType() {
|
||||
Optional<mlir::NVVM::MMATypes> val = inferOperandMMAType(
|
||||
getODSOperands(2).getTypes().front(), /*isAccum=*/true);
|
||||
assert(val.hasValue() && "accumulator PTX type should always be inferrable");
|
||||
return val.getValue();
|
||||
}
|
||||
|
||||
MMATypes MmaOp::resultPtxType() {
|
||||
Optional<mlir::NVVM::MMATypes> val =
|
||||
inferOperandMMAType(getResult().getType(), /*isAccum=*/true);
|
||||
assert(val.hasValue() && "result PTX type should always be inferrable");
|
||||
return val.getValue();
|
||||
}
|
||||
|
||||
void MmaOp::print(OpAsmPrinter &p) {
|
||||
SmallVector<Type, 4> regTypes;
|
||||
struct OperandFragment {
|
||||
StringRef operandName;
|
||||
StringRef ptxTypeAttr;
|
||||
SmallVector<Value, 4> regs;
|
||||
explicit OperandFragment(StringRef name, StringRef ptxTypeName)
|
||||
: operandName(name), ptxTypeAttr(ptxTypeName) {}
|
||||
};
|
||||
|
||||
std::array<OperandFragment, 3> frags{
|
||||
OperandFragment("A", multiplicandAPtxTypeAttrName()),
|
||||
OperandFragment("B", multiplicandBPtxTypeAttrName()),
|
||||
OperandFragment("C", "")};
|
||||
SmallVector<StringRef, 4> ignoreAttrNames{
|
||||
mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()};
|
||||
|
||||
for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
|
||||
auto &frag = frags[fragIdx];
|
||||
auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
|
||||
for (auto operandIdx = varOperandSpec.first;
|
||||
operandIdx < varOperandSpec.first + varOperandSpec.second;
|
||||
operandIdx++) {
|
||||
frag.regs.push_back(this->getOperand(operandIdx));
|
||||
if (operandIdx == 0) {
|
||||
regTypes.push_back(this->getOperand(operandIdx).getType());
|
||||
}
|
||||
}
|
||||
Optional<MMATypes> inferredType =
|
||||
inferOperandMMAType(regTypes.back(), /*isAccum=*/fragIdx >= 2);
|
||||
if (inferredType)
|
||||
ignoreAttrNames.push_back(frag.ptxTypeAttr);
|
||||
}
|
||||
|
||||
auto printMmaOperand = [&](const OperandFragment &frag) -> void {
|
||||
p << " " << frag.operandName;
|
||||
p << "[";
|
||||
p.printOperands(frag.regs);
|
||||
p << "] ";
|
||||
};
|
||||
|
||||
for (const auto &frag : frags) {
|
||||
printMmaOperand(frag);
|
||||
}
|
||||
|
||||
p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
|
||||
|
||||
// Print the types of the operands and result.
|
||||
p << " : "
|
||||
<< "(";
|
||||
llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
|
||||
frags[1].regs[0].getType(),
|
||||
frags[2].regs[0].getType()},
|
||||
p);
|
||||
p << ")";
|
||||
p.printArrowTypeList(TypeRange{this->res().getType()});
|
||||
}
|
||||
|
||||
void MmaOp::build(OpBuilder &builder, OperationState &result, Type resultType,
|
||||
ValueRange operandA, ValueRange operandB, ValueRange operandC,
|
||||
ArrayRef<int64_t> shape, Optional<MMAB1Op> b1Op,
|
||||
Optional<MMAIntOverflow> intOverflow,
|
||||
Optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
|
||||
Optional<std::array<MMALayout, 2>> multiplicandLayouts) {
|
||||
|
||||
assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
|
||||
MLIRContext *ctx = builder.getContext();
|
||||
Type i32 = builder.getIntegerType(32);
|
||||
result.addAttribute(
|
||||
"shape", MMAShapeAttr::get(builder.getIntegerAttr(i32, shape[0]),
|
||||
builder.getIntegerAttr(i32, shape[1]),
|
||||
builder.getIntegerAttr(i32, shape[2]), ctx));
|
||||
|
||||
result.addOperands(operandA);
|
||||
result.addOperands(operandB);
|
||||
result.addOperands(operandC);
|
||||
|
||||
if (multiplicandPtxTypes.hasValue()) {
|
||||
result.addAttribute("multiplicandAPtxType",
|
||||
MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
|
||||
result.addAttribute("multiplicandBPtxType",
|
||||
MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
|
||||
} else {
|
||||
if (auto res = inferOperandMMAType(operandA[0].getType(), false))
|
||||
result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
|
||||
if (auto res = inferOperandMMAType(operandB[0].getType(), false))
|
||||
result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
|
||||
}
|
||||
|
||||
if (multiplicandLayouts.hasValue()) {
|
||||
result.addAttribute("layoutA",
|
||||
MMALayoutAttr::get(ctx, (*multiplicandLayouts)[0]));
|
||||
result.addAttribute("layoutB",
|
||||
MMALayoutAttr::get(ctx, (*multiplicandLayouts)[1]));
|
||||
} else {
|
||||
result.addAttribute("layoutA", MMALayoutAttr::get(ctx, MMALayout::row));
|
||||
result.addAttribute("layoutB", MMALayoutAttr::get(ctx, MMALayout::col));
|
||||
}
|
||||
|
||||
if (intOverflow.hasValue())
|
||||
result.addAttribute("intOverflowBehavior",
|
||||
MMAIntOverflowAttr::get(ctx, *intOverflow));
|
||||
if (b1Op.hasValue())
|
||||
result.addAttribute("b1Op", MMAB1OpAttr::get(ctx, *b1Op));
|
||||
|
||||
result.addTypes(resultType);
|
||||
result.addAttribute(
|
||||
MmaOp::getOperandSegmentSizeAttr(),
|
||||
builder.getI32VectorAttr({static_cast<int32_t>(operandA.size()),
|
||||
static_cast<int32_t>(operandB.size()),
|
||||
static_cast<int32_t>(operandC.size())}));
|
||||
}
|
||||
|
||||
// <operation> :=
|
||||
// A `[` $operandA `]` B `[` $operandB `]` C `[` $operandC `]`
|
||||
// attr-dict : (type($operandA[0]), type($operandB[0]), type($operandC[0]))
|
||||
// `->` type($res)
|
||||
ParseResult MmaOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
struct OperandFragment {
|
||||
Optional<MMATypes> elemtype;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
|
||||
SmallVector<Type> regTypes;
|
||||
};
|
||||
|
||||
Builder &builder = parser.getBuilder();
|
||||
std::array<OperandFragment, 4> frags;
|
||||
|
||||
NamedAttrList namedAttributes;
|
||||
|
||||
// A helper to parse the operand segments.
|
||||
auto parseMmaOperand = [&](StringRef operandName,
|
||||
OperandFragment &frag) -> LogicalResult {
|
||||
if (parser.parseKeyword(operandName).failed())
|
||||
return failure();
|
||||
if (parser
|
||||
.parseOperandList(frag.regs, OpAsmParser::Delimiter::OptionalSquare)
|
||||
.failed())
|
||||
return failure();
|
||||
return success();
|
||||
};
|
||||
|
||||
// Parse the operand segments.
|
||||
if (parseMmaOperand("A", frags[0]).failed())
|
||||
return failure();
|
||||
if (parseMmaOperand("B", frags[1]).failed())
|
||||
return failure();
|
||||
if (parseMmaOperand("C", frags[2]).failed())
|
||||
return failure();
|
||||
|
||||
if (parser.parseOptionalAttrDict(namedAttributes).failed())
|
||||
return failure();
|
||||
|
||||
// Parse the type specification and resolve operands.
|
||||
SmallVector<Type, 3> operandTypes;
|
||||
if (failed(parser.parseColon()))
|
||||
return failure();
|
||||
if (failed(parser.parseLParen()))
|
||||
return failure();
|
||||
if (failed(parser.parseTypeList(operandTypes)))
|
||||
return failure();
|
||||
if (failed(parser.parseRParen()))
|
||||
if (operandTypes.size() != 3)
|
||||
return parser.emitError(
|
||||
parser.getNameLoc(),
|
||||
"expected one type for each operand segment but got " +
|
||||
Twine(operandTypes.size()) + " types");
|
||||
for (auto iter : llvm::enumerate(operandTypes)) {
|
||||
auto &frag = frags[iter.index()];
|
||||
frag.regTypes.resize(frag.regs.size(), iter.value());
|
||||
if (failed(parser.resolveOperands(frag.regs, frag.regTypes,
|
||||
parser.getNameLoc(), result.operands)))
|
||||
return failure();
|
||||
frag.elemtype =
|
||||
inferOperandMMAType(frag.regTypes[0], /*isAccum=*/iter.index() < 2);
|
||||
}
|
||||
|
||||
Type resultType;
|
||||
parser.parseArrow();
|
||||
parser.parseType(resultType);
|
||||
frags[3].elemtype = inferOperandMMAType(resultType, /*isAccum=*/true);
|
||||
|
||||
std::array<StringRef, 2> names{"multiplicandAPtxType",
|
||||
"multiplicandBPtxType"};
|
||||
for (unsigned idx = 0; idx < names.size(); idx++) {
|
||||
const auto &frag = frags[idx];
|
||||
Optional<NamedAttribute> attr = namedAttributes.getNamed(names[idx]);
|
||||
if (!frag.elemtype.hasValue() && !attr.hasValue()) {
|
||||
return parser.emitError(
|
||||
parser.getNameLoc(),
|
||||
"attribute " + names[idx] +
|
||||
" is not provided explicitly and cannot be inferred");
|
||||
}
|
||||
if (!attr.hasValue())
|
||||
result.addAttribute(
|
||||
names[idx], MMATypesAttr::get(parser.getContext(), *frag.elemtype));
|
||||
}
|
||||
|
||||
result.addTypes(resultType);
|
||||
if (!namedAttributes.empty())
|
||||
result.addAttributes(namedAttributes);
|
||||
result.addAttribute(MmaOp::getOperandSegmentSizeAttr(),
|
||||
builder.getI32VectorAttr({
|
||||
static_cast<int32_t>(frags[0].regs.size()),
|
||||
static_cast<int32_t>(frags[1].regs.size()),
|
||||
static_cast<int32_t>(frags[2].regs.size()),
|
||||
}));
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult MmaOp::verify() {
|
||||
MLIRContext *context = getContext();
|
||||
auto f16Ty = Float16Type::get(context);
|
||||
auto i32Ty = IntegerType::get(context, 32);
|
||||
auto f16x2Ty = LLVM::getFixedVectorType(f16Ty, 2);
|
||||
auto f32Ty = Float32Type::get(context);
|
||||
auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
|
||||
context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
|
||||
auto f32x8StructTy = LLVM::LLVMStructType::getLiteral(
|
||||
context, {f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty});
|
||||
|
||||
auto operandTypes = getOperandTypes();
|
||||
if (operandTypes != SmallVector<Type, 8>(8, f16x2Ty) &&
|
||||
operandTypes != ArrayRef<Type>{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, f32Ty,
|
||||
f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty,
|
||||
f32Ty}) {
|
||||
return emitOpError("expected operands to be 4 <halfx2>s followed by either "
|
||||
"4 <halfx2>s or 8 floats");
|
||||
}
|
||||
if (getType() != f32x8StructTy && getType() != f16x2x4StructTy) {
|
||||
return emitOpError("expected result type to be a struct of either 4 "
|
||||
"<halfx2>s or 8 floats");
|
||||
auto s32x4StructTy =
|
||||
LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
|
||||
auto f32x8StructTy =
|
||||
LLVM::LLVMStructType::getLiteral(context, SmallVector<Type>(8, f32Ty));
|
||||
auto f16x2x2StructTy =
|
||||
LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
|
||||
auto f32x4StructTy =
|
||||
LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
|
||||
auto s32x2StructTy =
|
||||
LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
|
||||
|
||||
std::array<int64_t, 3> mmaShape{shapeAttr().m().getInt(),
|
||||
shapeAttr().n().getInt(),
|
||||
shapeAttr().k().getInt()};
|
||||
|
||||
// These variables define the set of allowed data types for matrices A, B, C,
|
||||
// and result.
|
||||
using AllowedShapes = SmallVector<std::array<int64_t, 3>, 2>;
|
||||
using AllowedTypes = SmallVector<SmallVector<Type, 4>, 2>;
|
||||
AllowedShapes allowedShapes;
|
||||
AllowedTypes expectedA;
|
||||
AllowedTypes expectedB;
|
||||
AllowedTypes expectedC;
|
||||
SmallVector<Type> expectedResult;
|
||||
|
||||
// When M = 16, we just need to calculate the number of 8xk tiles, where
|
||||
// k is a factor that depends on the data type.
|
||||
if (mmaShape[0] == 16) {
|
||||
int64_t kFactor;
|
||||
Type multiplicandFragType;
|
||||
switch (multiplicandAPtxType().getValue()) {
|
||||
case MMATypes::tf32:
|
||||
kFactor = 4;
|
||||
expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
|
||||
context, {i32Ty, i32Ty, i32Ty, i32Ty}));
|
||||
break;
|
||||
case MMATypes::f16:
|
||||
case MMATypes::bf16:
|
||||
kFactor = 8;
|
||||
multiplicandFragType = f16x2Ty;
|
||||
expectedResult.push_back(f16x2x2StructTy);
|
||||
expectedResult.push_back(f32x4StructTy);
|
||||
break;
|
||||
case MMATypes::s4:
|
||||
case MMATypes::u4:
|
||||
kFactor = 32;
|
||||
break;
|
||||
case MMATypes::b1:
|
||||
kFactor = 128;
|
||||
break;
|
||||
case MMATypes::s8:
|
||||
case MMATypes::u8:
|
||||
kFactor = 16;
|
||||
break;
|
||||
default:
|
||||
return emitError("invalid shape or multiplicand type: " +
|
||||
stringifyEnum(multiplicandAPtxType().getValue()));
|
||||
}
|
||||
|
||||
if (isIntegerPtxType(multiplicandAPtxType().getValue())) {
|
||||
expectedResult.push_back(s32x4StructTy);
|
||||
expectedC.emplace_back(4, i32Ty);
|
||||
multiplicandFragType = i32Ty;
|
||||
} else {
|
||||
expectedC.emplace_back(2, f16x2Ty);
|
||||
expectedC.emplace_back(4, f32Ty);
|
||||
}
|
||||
|
||||
int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor);
|
||||
int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
|
||||
expectedA.emplace_back(unitA, multiplicandFragType);
|
||||
expectedB.emplace_back(unitB, multiplicandFragType);
|
||||
allowedShapes.push_back({16, 8, kFactor});
|
||||
allowedShapes.push_back({16, 8, kFactor * 2});
|
||||
}
|
||||
|
||||
auto alayout = (*this)->getAttrOfType<StringAttr>("alayout");
|
||||
auto blayout = (*this)->getAttrOfType<StringAttr>("blayout");
|
||||
|
||||
if (!(alayout && blayout) ||
|
||||
!(alayout.getValue() == "row" || alayout.getValue() == "col") ||
|
||||
!(blayout.getValue() == "row" || blayout.getValue() == "col")) {
|
||||
return emitOpError("alayout and blayout attributes must be set to either "
|
||||
"\"row\" or \"col\"");
|
||||
// In the M=8 case, there is only 1 possible case per data type.
|
||||
if (mmaShape[0] == 8) {
|
||||
if (multiplicandAPtxType().getValue() == MMATypes::f16) {
|
||||
expectedA.emplace_back(2, f16x2Ty);
|
||||
expectedB.emplace_back(2, f16x2Ty);
|
||||
expectedResult.push_back(f16x2x4StructTy);
|
||||
expectedResult.push_back(f32x8StructTy);
|
||||
expectedC.emplace_back(4, f16x2Ty);
|
||||
expectedC.emplace_back(8, f32Ty);
|
||||
allowedShapes.push_back({8, 8, 4});
|
||||
}
|
||||
if (multiplicandAPtxType().getValue() == MMATypes::f64) {
|
||||
Type f64Ty = Float64Type::get(context);
|
||||
expectedA.emplace_back(1, f64Ty);
|
||||
expectedB.emplace_back(1, f64Ty);
|
||||
expectedC.emplace_back(2, f64Ty);
|
||||
// expectedC.emplace_back(1, LLVM::getFixedVectorType(f64Ty, 2));
|
||||
expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
|
||||
context, SmallVector<Type>(2, f64Ty)));
|
||||
allowedShapes.push_back({8, 8, 4});
|
||||
}
|
||||
if (isIntegerPtxType(multiplicandAPtxType().getValue())) {
|
||||
expectedA.push_back({i32Ty});
|
||||
expectedB.push_back({i32Ty});
|
||||
expectedC.push_back({i32Ty, i32Ty});
|
||||
expectedResult.push_back(s32x2StructTy);
|
||||
if (isInt4PtxType(multiplicandAPtxType().getValue()))
|
||||
allowedShapes.push_back({8, 8, 32});
|
||||
if (isInt8PtxType(multiplicandAPtxType().getValue()))
|
||||
allowedShapes.push_back({8, 8, 16});
|
||||
if (multiplicandAPtxType().getValue() == MMATypes::b1)
|
||||
allowedShapes.push_back({8, 8, 128});
|
||||
}
|
||||
}
|
||||
|
||||
if (operandTypes == ArrayRef<Type>{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, f32Ty,
|
||||
f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty,
|
||||
f32Ty} &&
|
||||
getType() == f32x8StructTy && alayout.getValue() == "row" &&
|
||||
blayout.getValue() == "col") {
|
||||
return success();
|
||||
std::string errorMessage;
|
||||
llvm::raw_string_ostream errorStream(errorMessage);
|
||||
|
||||
// Check that we matched an existing shape/dtype combination.
|
||||
if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
|
||||
!llvm::any_of(allowedShapes,
|
||||
[&](const auto &allowed) { return allowed == mmaShape; })) {
|
||||
errorStream << "unimplemented variant for MMA shape <";
|
||||
llvm::interleaveComma(mmaShape, errorStream);
|
||||
errorStream << ">";
|
||||
return emitOpError(errorMessage);
|
||||
}
|
||||
return emitOpError("unimplemented mma.sync variant");
|
||||
|
||||
// Verify the operand types for segments of A, B, and C operands.
|
||||
std::array<StringRef, 3> operandNames{"A", "B", "C"};
|
||||
for (const auto &iter : llvm::enumerate(
|
||||
SmallVector<AllowedTypes, 3>{expectedA, expectedB, expectedC})) {
|
||||
auto spec = this->getODSOperandIndexAndLength(iter.index());
|
||||
SmallVector<Type, 4> operandTySeg(operand_type_begin() + spec.first,
|
||||
operand_type_begin() + spec.first +
|
||||
spec.second);
|
||||
bool match =
|
||||
llvm::any_of(iter.value(), [&](const SmallVector<Type, 4> &typeSet) {
|
||||
return typeSet == operandTySeg;
|
||||
});
|
||||
|
||||
if (!match) {
|
||||
errorStream << "Could not match types for the "
|
||||
<< operandNames[iter.index()]
|
||||
<< " operands; expected one of ";
|
||||
for (const auto &x : iter.value()) {
|
||||
errorStream << x.size() << "x" << x[0] << " ";
|
||||
}
|
||||
errorStream << "but got ";
|
||||
llvm::interleaveComma(operandTySeg, errorStream);
|
||||
return emitOpError(errorStream.str());
|
||||
}
|
||||
}
|
||||
|
||||
// Check the result type
|
||||
if (!llvm::any_of(expectedResult, [&](Type expectedResultType) {
|
||||
return expectedResultType == getResult().getType();
|
||||
})) {
|
||||
errorStream
|
||||
<< "Could not match allowed types for the result; expected one of ";
|
||||
llvm::interleaveComma(expectedResult, errorStream);
|
||||
errorStream << " but got " << getResult().getType();
|
||||
return emitOpError(errorStream.str());
|
||||
}
|
||||
|
||||
// Ensure that binary MMA variants have a b1 MMA operation defined.
|
||||
if (multiplicandAPtxType() == MMATypes::b1 && !b1Op().hasValue()) {
|
||||
return emitOpError("op requires " + b1OpAttrName().strref() + " attribute");
|
||||
}
|
||||
|
||||
// Ensure int4/int8 MMA variants specify the accum overflow behavior
|
||||
// attribute.
|
||||
if (isInt4PtxType(*multiplicandAPtxType()) ||
|
||||
isInt8PtxType(*multiplicandAPtxType())) {
|
||||
if (!intOverflowBehavior().hasValue())
|
||||
return emitOpError("op requires " +
|
||||
intOverflowBehaviorAttrName().strref() + " attribute");
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ShflOp::verify() {
|
||||
|
|
|
@ -514,12 +514,13 @@ func @nvvm_invalid_shfl_pred_3(%arg0 : i32, %arg1 : i32, %arg2 : i32, %arg3 : i3
|
|||
|
||||
// -----
|
||||
|
||||
func @nvvm_invalid_mma_0(%a0 : f16, %a1 : vector<2xf16>,
|
||||
func @nvvm_invalid_mma_0(%a0 : f16, %a1 : f16,
|
||||
%b0 : vector<2xf16>, %b1 : vector<2xf16>,
|
||||
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
|
||||
%c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) {
|
||||
// expected-error@+1 {{expected operands to be 4 <halfx2>s followed by either 4 <halfx2>s or 8 floats}}
|
||||
%0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="row", blayout="col"} : (f16, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
|
||||
// expected-error@+1 {{Could not match types for the A operands; expected one of 2xvector<2xf16> but got f16, f16}}
|
||||
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7]
|
||||
{layoutA=#nvvm.mma_layout<row>, layoutB=#nvvm.mma_layout<col>, shape = {k = 4 : i32, m = 8 : i32, n = 8 : i32}} : (f16, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
|
||||
llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
|
||||
}
|
||||
|
||||
|
@ -529,8 +530,9 @@ func @nvvm_invalid_mma_1(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
|
|||
%b0 : vector<2xf16>, %b1 : vector<2xf16>,
|
||||
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
|
||||
%c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) {
|
||||
// expected-error@+1 {{expected result type to be a struct of either 4 <halfx2>s or 8 floats}}
|
||||
%0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="row", blayout="col"} : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f16)>
|
||||
// expected-error@+1 {{Could not match allowed types for the result; expected one of !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>, !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> but got !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f16)>}}
|
||||
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7]
|
||||
{layoutA=#nvvm.mma_layout<row>, layoutB=#nvvm.mma_layout<col>, shape = {k = 4 : i32, m = 8 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f16)>
|
||||
llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f16)>
|
||||
}
|
||||
|
||||
|
@ -540,8 +542,9 @@ func @nvvm_invalid_mma_2(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
|
|||
%b0 : vector<2xf16>, %b1 : vector<2xf16>,
|
||||
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
|
||||
%c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) {
|
||||
// expected-error@+1 {{alayout and blayout attributes must be set to either "row" or "col"}}
|
||||
%0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
|
||||
// expected-error@+1 {{op requires attribute 'layoutA'}}
|
||||
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7]
|
||||
{shape = {k = 4 : i32, m = 8 : i32, n = 8 : i32}}: (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
|
||||
llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
|
||||
}
|
||||
|
||||
|
@ -549,55 +552,23 @@ func @nvvm_invalid_mma_2(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
|
|||
|
||||
func @nvvm_invalid_mma_3(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
|
||||
%b0 : vector<2xf16>, %b1 : vector<2xf16>,
|
||||
%c0 : vector<2xf16>, %c1 : vector<2xf16>,
|
||||
%c2 : vector<2xf16>, %c3 : vector<2xf16>) {
|
||||
// expected-error@+1 {{unimplemented mma.sync variant}}
|
||||
%0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3 {alayout="row", blayout="col"} : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
|
||||
llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
|
||||
%c0 : vector<2xf16>, %c1 : vector<2xf16>) {
|
||||
// expected-error@+1 {{unimplemented variant for MMA shape <8, 8, 16>}}
|
||||
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1] {layoutA=#nvvm.mma_layout<row>, layoutB=#nvvm.mma_layout<col>, shape = {k = 16 : i32, m = 8 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
|
||||
llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @nvvm_invalid_mma_4(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
|
||||
%b0 : vector<2xf16>, %b1 : vector<2xf16>,
|
||||
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
|
||||
%c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) {
|
||||
// expected-error@+1 {{unimplemented mma.sync variant}}
|
||||
%0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="row", blayout="col"} : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @nvvm_invalid_mma_5(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
|
||||
%b0 : vector<2xf16>, %b1 : vector<2xf16>,
|
||||
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
|
||||
%c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) {
|
||||
// expected-error@+1 {{unimplemented mma.sync variant}}
|
||||
%0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="col", blayout="row"} : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
|
||||
llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @nvvm_invalid_mma_6(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
|
||||
%b0 : vector<2xf16>, %b1 : vector<2xf16>,
|
||||
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
|
||||
%c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) {
|
||||
// expected-error@+1 {{invalid kind of type specified}}
|
||||
%0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="col", blayout="row"} : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
|
||||
llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @nvvm_invalid_mma_7(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
|
||||
%b0 : vector<2xf16>, %b1 : vector<2xf16>,
|
||||
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
|
||||
%c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) {
|
||||
// expected-error@+1 {{op requires one result}}
|
||||
%0:2 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="col", blayout="row"} : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32) -> (!llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>, i32)
|
||||
llvm.return %0#0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
|
||||
func @nvvm_invalid_mma_8(%a0 : i32, %a1 : i32,
|
||||
%b0 : i32,
|
||||
%c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) {
|
||||
// expected-error@+1 {{op requires b1Op attribute}}
|
||||
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
|
||||
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
|
||||
multiplicandAPtxType = #nvvm.mma_type<b1>, multiplicandBPtxType = #nvvm.mma_type<b1>,
|
||||
shape = {k = 128 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
|
||||
llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
|
|
@ -66,15 +66,164 @@ func @nvvm_vote(%arg0 : i32, %arg1 : i1) -> i32 {
|
|||
llvm.return %0 : i32
|
||||
}
|
||||
|
||||
func @nvvm_mma(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
|
||||
// CHECK-LABEL: @nvvm_mma_m8n8k4_row_col_f32_f32
|
||||
func @nvvm_mma_m8n8k4_row_col_f32_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
|
||||
%b0 : vector<2xf16>, %b1 : vector<2xf16>,
|
||||
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
|
||||
%c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) {
|
||||
// CHECK: nvvm.mma.sync {{.*}} {alayout = "row", blayout = "col"} : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
|
||||
%0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout = "row", blayout = "col"} : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
|
||||
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, %c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) {
|
||||
// CHECK: nvvm.mma.sync
|
||||
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7]
|
||||
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
|
||||
shape = {k = 4 : i32, m = 8 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
|
||||
llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
|
||||
}
|
||||
|
||||
func @nvvm_mma_m8n8k4_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
|
||||
%b0 : vector<2xf16>, %b1 : vector<2xf16>,
|
||||
%c0 : vector<2xf16>, %c1 : vector<2xf16>, %c2 : vector<2xf16>, %c3 : vector<2xf16>) {
|
||||
// CHECK: nvvm.mma.sync A[{{.*}}] B[{{.*}}] C[{{.*}}]
|
||||
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
|
||||
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
|
||||
shape = {k = 4 : i32, m = 8 : i32, n = 8 : i32}} : (vector<2xf16>,vector<2xf16>,vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
}
|
||||
|
||||
func @nvvm_mma_m8n8k16_s8_s8(%a0 : i32, %b0 : i32,
|
||||
%c0 : i32, %c1 : i32) {
|
||||
// CHECK: nvvm.mma.sync A[{{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>, shape = {k = 16 : i32, m = 8 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32)>
|
||||
%0 = nvvm.mma.sync A[%a0] B[%b0] C[%c0, %c1]
|
||||
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
|
||||
multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>,
|
||||
intOverflowBehavior=#nvvm.mma_int_overflow<wrapped>,
|
||||
shape = {k = 16 : i32, m = 8 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32)>
|
||||
llvm.return %0 : !llvm.struct<(i32, i32)>
|
||||
}
|
||||
|
||||
func @nvvm_mma_m16n8k8_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
|
||||
%b0 : vector<2xf16>,
|
||||
%c0 : vector<2xf16>, %c1 : vector<2xf16>) {
|
||||
// CHECK: nvvm.mma.sync A[%{{.*}}, %{{.*}}] B[%{{.*}}] C[%{{.*}}, %{{.*}}] {{{.*}}} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
|
||||
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1]
|
||||
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
|
||||
shape = {k = 8 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
|
||||
llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
|
||||
}
|
||||
|
||||
func @nvvm_mma_m16n8k16_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
|
||||
%a2 : vector<2xf16>, %a3 : vector<2xf16>,
|
||||
%b0 : vector<2xf16>, %b1 : vector<2xf16>,
|
||||
%c0 : vector<2xf16>, %c1 : vector<2xf16>) {
|
||||
// CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
|
||||
%0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1]
|
||||
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
|
||||
shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
|
||||
llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
|
||||
}
|
||||
|
||||
func @nvvm_mma_m16n8k16_f32_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
|
||||
%a2 : vector<2xf16>, %a3 : vector<2xf16>,
|
||||
%b0 : vector<2xf16>, %b1 : vector<2xf16>,
|
||||
%c0 : vector<2xf16>, %c1 : vector<2xf16>) {
|
||||
// CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(f32, f32, f32, f32)>
|
||||
%0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1]
|
||||
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
|
||||
shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>,vector<2xf16>,vector<2xf16>) -> !llvm.struct<(f32, f32, f32, f32)>
|
||||
llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
|
||||
}
|
||||
|
||||
func @nvvm_mma_m16n8k16_f16_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
|
||||
%a2 : vector<2xf16>, %a3 : vector<2xf16>,
|
||||
%b0 : vector<2xf16>, %b1 : vector<2xf16>,
|
||||
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) {
|
||||
// CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
|
||||
%0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
|
||||
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
|
||||
shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
|
||||
llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
|
||||
}
|
||||
|
||||
func @nvvm_mma_m16n8k16_f32_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
|
||||
%a2 : vector<2xf16>, %a3 : vector<2xf16>,
|
||||
%b0 : vector<2xf16>, %b1 : vector<2xf16>,
|
||||
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) {
|
||||
// CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)>
|
||||
%0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
|
||||
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
|
||||
shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)>
|
||||
llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
|
||||
}
|
||||
|
||||
func @nvvm_mma_m16n8k16_s8_s8(%a0 : i32, %a1 : i32, %b0 : i32,
|
||||
%c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) {
|
||||
// CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>, shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
|
||||
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
|
||||
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
|
||||
multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>,
|
||||
intOverflowBehavior=#nvvm.mma_int_overflow<wrapped>,
|
||||
shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
|
||||
llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
|
||||
}
|
||||
|
||||
func @nvvm_mma_m16n8k16_s8_u8(%a0 : i32, %a1 : i32,
|
||||
%b0 : i32,
|
||||
%c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) {
|
||||
// CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<satfinite>, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<u8>, shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
|
||||
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
|
||||
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
|
||||
multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<u8>,
|
||||
intOverflowBehavior=#nvvm.mma_int_overflow<satfinite>,
|
||||
shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
|
||||
llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
|
||||
}
|
||||
|
||||
func @nvvm_mma_m16n8k256_b1_b1(%a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
|
||||
%b0 : i32, %b1 : i32,
|
||||
%c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) {
|
||||
// CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {b1Op = #nvvm.mma_b1op<xor_popc>, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<b1>, multiplicandBPtxType = #nvvm.mma_type<b1>, shape = {k = 256 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
|
||||
%0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
|
||||
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
|
||||
multiplicandAPtxType = #nvvm.mma_type<b1>, multiplicandBPtxType = #nvvm.mma_type<b1>,
|
||||
b1Op = #nvvm.mma_b1op<xor_popc>, shape = {k = 256 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
|
||||
llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
|
||||
}
|
||||
|
||||
func @nvvm_mma_m16n8k128_b1_b1(%a0 : i32, %a1 : i32,
|
||||
%b0 : i32,
|
||||
%c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) {
|
||||
// CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {b1Op = #nvvm.mma_b1op<xor_popc>, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<b1>, multiplicandBPtxType = #nvvm.mma_type<b1>, shape = {k = 128 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
|
||||
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
|
||||
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
|
||||
multiplicandAPtxType = #nvvm.mma_type<b1>, multiplicandBPtxType = #nvvm.mma_type<b1>,
|
||||
b1Op = #nvvm.mma_b1op<xor_popc>,
|
||||
shape = {k = 128 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
|
||||
llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @nvvm_mma_m8n8k128_b1_b1
|
||||
func @nvvm_mma_m8n8k128_b1_b1(%a0 : i32,
|
||||
%b0 : i32,
|
||||
%c0 : i32, %c1 : i32) {
|
||||
// CHECK: nvvm.mma.sync A[{{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}] {b1Op = #nvvm.mma_b1op<xor_popc>, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<b1>, multiplicandBPtxType = #nvvm.mma_type<b1>, shape = {k = 128 : i32, m = 8 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32)>
|
||||
%0 = nvvm.mma.sync A[%a0] B[%b0] C[%c0, %c1]
|
||||
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
|
||||
multiplicandAPtxType = #nvvm.mma_type<b1>, multiplicandBPtxType = #nvvm.mma_type<b1>,
|
||||
b1Op = #nvvm.mma_b1op<xor_popc>, shape = {k = 128 : i32, m = 8 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32)>
|
||||
llvm.return %0 : !llvm.struct<(i32,i32)>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @nvvm_mma_m16n8k32_s4_s4
|
||||
func @nvvm_mma_m16n8k32_s4_s4(%a0 : i32, %a1 : i32,
|
||||
%b0 : i32,
|
||||
%c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) {
|
||||
// CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<s4>, multiplicandBPtxType = #nvvm.mma_type<s4>, shape = {k = 32 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
|
||||
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
|
||||
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
|
||||
multiplicandAPtxType = #nvvm.mma_type<s4>, multiplicandBPtxType = #nvvm.mma_type<s4>,
|
||||
intOverflowBehavior=#nvvm.mma_int_overflow<wrapped>,
|
||||
shape = {k = 32 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
|
||||
llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @nvvm_wmma_load_tf32
|
||||
func @nvvm_wmma_load_tf32(%arg0: !llvm.ptr<i32>, %arg1 : i32) -> !llvm.struct<(i32, i32, i32, i32)> {
|
||||
// CHECK: nvvm.wmma.load {{.*}} {eltype = #nvvm.mma_type<tf32>, frag = #nvvm.mma_frag<a>, k = 8 : i32, layout = #nvvm.mma_layout<row>, m = 16 : i32, n = 16 : i32}
|
||||
%0 = nvvm.wmma.load %arg0, %arg1
|
||||
|
|
|
@ -88,17 +88,124 @@ llvm.func @nvvm_vote(%0 : i32, %1 : i1) -> i32 {
|
|||
llvm.return %3 : i32
|
||||
}
|
||||
|
||||
llvm.func @nvvm_mma(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
|
||||
// CHECK-LABEL: @nvvm_mma_mn8n8k4_row_col_f32_f32
|
||||
llvm.func @nvvm_mma_mn8n8k4_row_col_f32_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
|
||||
%b0 : vector<2xf16>, %b1 : vector<2xf16>,
|
||||
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32,
|
||||
%c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> {
|
||||
// CHECK: call { float, float, float, float, float, float, float, float } @llvm.nvvm.mma.m8n8k4.row.col.f32.f32
|
||||
%0 = nvvm.mma.sync %a0, %a1, %b0, %b1, %c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7 {alayout="row", blayout="col"} : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
|
||||
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7]
|
||||
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, shape = {m = 8 : i32, n = 8 : i32, k = 4 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
|
||||
llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
|
||||
}
|
||||
|
||||
llvm.func @nvvm_mma_m16n8k16_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
|
||||
%a2 : vector<2xf16>, %a3 : vector<2xf16>,
|
||||
%b0 : vector<2xf16>, %b1 : vector<2xf16>,
|
||||
%c0 : vector<2xf16>, %c1 : vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> {
|
||||
// CHECK: call { <2 x half>, <2 x half> } @llvm.nvvm.mma.m16n8k16.row.col.f16.f16
|
||||
%0 = nvvm.mma.sync A[ %a0, %a1, %a2, %a3 ] B[ %b0, %b1 ] C[ %c0, %c1 ]
|
||||
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, shape = {m = 16 : i32, n = 8 : i32, k = 16 : i32}}
|
||||
: (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
|
||||
llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
|
||||
}
|
||||
|
||||
// f32 return type, f16 accumulate type
|
||||
llvm.func @nvvm_mma_m16n8k16_f32_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
|
||||
%a2 : vector<2xf16>, %a3 : vector<2xf16>,
|
||||
%b0 : vector<2xf16>, %b1 : vector<2xf16>,
|
||||
%c0 : vector<2xf16>, %c1 : vector<2xf16>) -> !llvm.struct<(f32, f32, f32, f32)> {
|
||||
// CHECK: call { float, float, float, float } @llvm.nvvm.mma.m16n8k16.row.col.f32.f16
|
||||
%0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1]
|
||||
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
|
||||
shape = {m = 16 : i32, n = 8 : i32, k = 16 : i32}} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(f32, f32, f32, f32)>
|
||||
llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
|
||||
}
|
||||
|
||||
// f16 return type, f32 accumulate type
|
||||
llvm.func @nvvm_mma_m16n8k16_f16_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
|
||||
%a2 : vector<2xf16>, %a3 : vector<2xf16>,
|
||||
%b0 : vector<2xf16>, %b1 : vector<2xf16>,
|
||||
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> {
|
||||
// CHECK: call { <2 x half>, <2 x half> } @llvm.nvvm.mma.m16n8k16.row.col.f16.f32
|
||||
%0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
|
||||
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
|
||||
shape = {m = 16 : i32, n = 8 : i32, k = 16 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
|
||||
llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
|
||||
}
|
||||
|
||||
// f32 return type, f32 accumulate type
|
||||
llvm.func @nvvm_mma_m16n8k16_f32_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
|
||||
%a2 : vector<2xf16>, %a3 : vector<2xf16>,
|
||||
%b0 : vector<2xf16>, %b1 : vector<2xf16>,
|
||||
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) -> !llvm.struct<(f32, f32, f32, f32)> {
|
||||
// CHECK: call { float, float, float, float } @llvm.nvvm.mma.m16n8k16.row.col.f32.f32
|
||||
%0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
|
||||
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
|
||||
shape = {m = 16 : i32, n = 8 : i32, k = 16 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)>
|
||||
llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
|
||||
}
|
||||
|
||||
llvm.func @nvvm_mma_m16n8k16_s8_s8(%a0 : i32, %a1 : i32,
|
||||
%b0 : i32,
|
||||
%c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) -> !llvm.struct<(i32, i32, i32, i32)> {
|
||||
// CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.mma.m16n8k16.row.col.s8
|
||||
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
|
||||
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
|
||||
multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>,
|
||||
intOverflowBehavior=#nvvm.mma_int_overflow<wrapped>,
|
||||
shape = {m = 16 : i32, n = 8 : i32, k = 16 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
|
||||
llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
|
||||
}
|
||||
|
||||
llvm.func @nvvm_mma_m16n8k16_s8_u8(%a0 : i32, %a1 : i32,
|
||||
%b0 : i32,
|
||||
%c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) -> !llvm.struct<(i32, i32, i32, i32)> {
|
||||
// CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.mma.m16n8k16.row.col.satfinite.s8.u8
|
||||
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
|
||||
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
|
||||
multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<u8>,
|
||||
intOverflowBehavior=#nvvm.mma_int_overflow<satfinite>,
|
||||
shape = {m = 16 : i32, n = 8 : i32, k = 16 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
|
||||
llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
|
||||
}
|
||||
|
||||
llvm.func @nvvm_mma_m16n8k128_b1_b1(%a0 : i32, %a1 : i32,
|
||||
%b0 : i32,
|
||||
%c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) -> !llvm.struct<(i32,i32,i32,i32)> {
|
||||
// CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.mma.xor.popc.m16n8k128.row.col.b1
|
||||
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
|
||||
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
|
||||
multiplicandAPtxType = #nvvm.mma_type<b1>, multiplicandBPtxType = #nvvm.mma_type<b1>,
|
||||
b1Op = #nvvm.mma_b1op<xor_popc>, shape = {k = 128 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
|
||||
llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
|
||||
}
|
||||
|
||||
llvm.func @nvvm_mma_m16n8k32_s4_s4(%a0 : i32, %a1 : i32,
|
||||
%b0 : i32,
|
||||
%c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) -> !llvm.struct<(i32,i32,i32,i32)> {
|
||||
// CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.mma.m16n8k32.row.col.satfinite.s4
|
||||
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
|
||||
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
|
||||
multiplicandAPtxType = #nvvm.mma_type<s4>, multiplicandBPtxType = #nvvm.mma_type<s4>,
|
||||
intOverflowBehavior=#nvvm.mma_int_overflow<satfinite>,
|
||||
shape = {k = 32 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
|
||||
llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
|
||||
}
|
||||
|
||||
llvm.func @nvvm_mma_m8n8k4_f64_f64(%a0 : f64,
|
||||
%b0 : f64,
|
||||
%c0 : f64, %c1 : f64) -> !llvm.struct<(f64, f64)> {
|
||||
// CHECK: call { double, double } @llvm.nvvm.mma.m8n8k4.row.col.f64
|
||||
%0 = nvvm.mma.sync A[%a0] B[%b0] C[%c0, %c1]
|
||||
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
|
||||
shape = {m = 8 : i32, n = 8 : i32, k = 4 : i32}} : (f64, f64, f64) -> !llvm.struct<(f64, f64)>
|
||||
llvm.return %0 : !llvm.struct<(f64, f64)>
|
||||
}
|
||||
|
||||
// The test below checks the correct mapping of the nvvm.wmma.*.load.* op to the correct intrinsic
|
||||
// in the LLVM NVPTX backend.
|
||||
// CHECK-LABEL: @gpu_wmma_load_op
|
||||
llvm.func @gpu_wmma_load_op(%arg0: !llvm.ptr<i32, 3>, %arg1: i32) {
|
||||
// CHECK: call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.load.a.row.stride.f16.p3i32(i32 addrspace(3)* %{{.*}}, i32 %{{.*}})
|
||||
%0 = nvvm.wmma.load %arg0, %arg1
|
||||
|
|
Loading…
Reference in a new issue