[MLIR][Linalg] Retire C++ DotOp in favor of a linalg-ods-gen'd op
- replace DotOp, now that DRR rules have been dropped. - Capture arguments mismatch in the parser. The number of parsed arguments must equal the number of expected arguments. Reviewed By: ftynse, nicolasvasilache Differential Revision: https://reviews.llvm.org/D82952
This commit is contained in:
parent
7e8d5a90f2
commit
946be75b9e
|
@ -8,6 +8,11 @@ def matvec(A: f32(M, N), y: f32(N)) -> (x: f32(M)) {
|
|||
x(m) = std_addf<n>(std_mulf(A(m, n), y(n)));
|
||||
}
|
||||
|
||||
ods_def<DotOp>:
|
||||
def dot(A: f32(M), B: f32(M)) -> (C: f32()) {
|
||||
C() = std_addf<m>(std_mulf(A(m), B(m)));
|
||||
}
|
||||
|
||||
ods_def<BatchMatmulOp>:
|
||||
def batch_matmul(A: f32(Batch, M, K), B: f32(Batch, K, N)) -> (C: f32(Batch, M, N)) {
|
||||
C(b, m, n) = std_addf<k>(std_mulf(A(b, m, k), B(b, k, n)));
|
||||
|
|
|
@ -51,9 +51,9 @@ using ReassociationExprs = SmallVector<AffineExpr, 2>;
|
|||
/// 1. linalg.fill(%A, %f) : memref<f32>, f32
|
||||
/// name mangles into `linalg_fill_viewf32_f32_impl`
|
||||
///
|
||||
/// 2. linalg.dot(%A, %B, %C) :
|
||||
/// memref<?xf32, stride_specification>,
|
||||
/// memref<?xf32, stride_specification>, memref<f32>
|
||||
/// 2. linalg.dot %A, %B, %C :
|
||||
/// (memref<?xf32, stride_specification>,
|
||||
/// memref<?xf32, stride_specification>, memref<f32>)
|
||||
/// name mangles into `linalg_dot_viewxf32_viewxf32_viewf32_impl`
|
||||
///
|
||||
/// 3. linalg.matmul(...) :
|
||||
|
|
|
@ -180,31 +180,6 @@ def FillOp : LinalgStructured_Op<"fill", [NInputs<0>, NOutputs<1>]> {
|
|||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def DotOp : LinalgStructured_Op<"dot", [NInputs<2>, NOutputs<1>]> {
|
||||
|
||||
let arguments = (ins AnyStridedMemRefOfRank<1>,
|
||||
AnyStridedMemRefOfRank<1>,
|
||||
AnyStridedMemRefOfRank<0>);
|
||||
|
||||
let extraClassDeclaration = libraryCallName # [{
|
||||
llvm::Optional<SmallVector<StringRef, 8>> referenceIterators() {
|
||||
return SmallVector<StringRef, 8>{getReductionIteratorTypeName()};
|
||||
}
|
||||
|
||||
// A(r_i) * B(r_i) -> C()
|
||||
llvm::Optional<SmallVector<AffineMap, 8>> referenceIndexingMaps() {
|
||||
MLIRContext *context = getContext();
|
||||
auto r_i = getAffineDimExpr(0, context);
|
||||
return SmallVector<AffineMap, 8>{
|
||||
AffineMap::get(1, 0, {r_i}, context),
|
||||
AffineMap::get(1, 0, {r_i}, context),
|
||||
AffineMap::get(1, 0, {}, context)};
|
||||
}
|
||||
}];
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
/// A base class for pooling operation such as conv. The arguments must contain
|
||||
/// optional arguments `strides`, `dilations` and `padding` with following type:
|
||||
/// OptionalAttr<I64ArrayAttr>:$strides
|
||||
|
|
|
@ -235,13 +235,13 @@ void mlir::populateLinalgToStandardConversionPatterns(
|
|||
LinalgOpConversion<PoolingMaxOp>,
|
||||
LinalgOpConversion<PoolingMinOp>,
|
||||
LinalgOpConversion<PoolingSumOp>,
|
||||
LinalgOpConversion<CopyOp>,
|
||||
LinalgOpConversion<DotOp>,
|
||||
LinalgOpConversion<CopyOp>,
|
||||
LinalgOpConversion<FillOp>,
|
||||
LinalgOpConversion<GenericOp>,
|
||||
LinalgOpConversion<IndexedGenericOp>>(ctx);
|
||||
// TODO: collect all auto-generated named ops with a tblgen directive.
|
||||
patterns.insert<
|
||||
LinalgOpConversion<DotOp>,
|
||||
LinalgOpConversion<BatchMatmulOp>,
|
||||
LinalgOpConversion<MatvecOp>,
|
||||
LinalgOpConversion<MatmulOp>>(ctx);
|
||||
|
|
|
@ -1173,10 +1173,6 @@ LogicalResult CopyOp::fold(ArrayRef<Attribute>,
|
|||
SmallVectorImpl<OpFoldResult> &) {
|
||||
return foldMemRefCast(*this);
|
||||
}
|
||||
LogicalResult DotOp::fold(ArrayRef<Attribute>,
|
||||
SmallVectorImpl<OpFoldResult> &) {
|
||||
return foldMemRefCast(*this);
|
||||
}
|
||||
LogicalResult FillOp::fold(ArrayRef<Attribute>,
|
||||
SmallVectorImpl<OpFoldResult> &) {
|
||||
return foldMemRefCast(*this);
|
||||
|
@ -1280,6 +1276,17 @@ static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
|
|||
if (!tensorResultTypes.empty())
|
||||
result.addTypes(tensorResultTypes);
|
||||
|
||||
// The number of parsed arguments must equal
|
||||
// the number of expected arguments for the current operation.
|
||||
auto parsedArgs = operandsInfo.size();
|
||||
auto expectedArgs = NamedStructuredOpType::getNumInputs() +
|
||||
NamedStructuredOpType::getNumOutputs();
|
||||
if (parsedArgs != expectedArgs)
|
||||
return parser.emitError(parser.getNameLoc(),
|
||||
"expects " + std::to_string(expectedArgs) +
|
||||
" operands, but found " +
|
||||
std::to_string(parsedArgs));
|
||||
|
||||
buildNamedStructuredOpRegionAndAttributes<NamedStructuredOpType>(
|
||||
parser.getBuilder(), result, operandTypes, tensorResultTypes);
|
||||
|
||||
|
@ -1299,6 +1306,10 @@ LogicalResult BatchMatmulOp::fold(ArrayRef<Attribute>,
|
|||
SmallVectorImpl<OpFoldResult> &) {
|
||||
return foldMemRefCast(*this);
|
||||
}
|
||||
LogicalResult DotOp::fold(ArrayRef<Attribute>,
|
||||
SmallVectorImpl<OpFoldResult> &) {
|
||||
return foldMemRefCast(*this);
|
||||
}
|
||||
LogicalResult MatmulOp::fold(ArrayRef<Attribute>,
|
||||
SmallVectorImpl<OpFoldResult> &) {
|
||||
return foldMemRefCast(*this);
|
||||
|
|
|
@ -295,18 +295,6 @@ void emitScalarImplementation(ArrayRef<Value> allIvs, FillOp fillOp) {
|
|||
nPar > 0 ? O(ivs) = fillOp.value() : O() = fillOp.value();
|
||||
}
|
||||
|
||||
template <typename IndexedValueType>
|
||||
void emitScalarImplementation(ArrayRef<Value> allIvs, DotOp dotOp) {
|
||||
assert(dotOp.hasBufferSemantics() &&
|
||||
"expected linalg op with buffer semantics");
|
||||
assert(allIvs.size() == 1);
|
||||
Value r_i(allIvs[0]);
|
||||
IndexedValueType A(dotOp.getInput(0)), B(dotOp.getInput(1)),
|
||||
C(dotOp.getOutputBuffer(0));
|
||||
// Emit scalar form.
|
||||
C() = C() + A(r_i) * B(r_i);
|
||||
}
|
||||
|
||||
template <typename IndexedValueType>
|
||||
Value getConvOpInput(ConvOp convOp, StdIndexedValue im,
|
||||
MutableArrayRef<Value> imIdx) {
|
||||
|
@ -673,8 +661,6 @@ static Optional<LinalgLoops> linalgOpToLoopsImplSwitch(Operation *op,
|
|||
return linalgOpToLoopsImpl<LoopTy, CopyOp>(op, builder);
|
||||
if (isa<FillOp>(op))
|
||||
return linalgOpToLoopsImpl<LoopTy, FillOp>(op, builder);
|
||||
if (isa<DotOp>(op))
|
||||
return linalgOpToLoopsImpl<LoopTy, DotOp>(op, builder);
|
||||
if (isa<ConvOp>(op))
|
||||
return linalgOpToLoopsImpl<LoopTy, ConvOp>(op, builder);
|
||||
if (isa<PoolingMaxOp>(op))
|
||||
|
@ -693,6 +679,8 @@ static Optional<LinalgLoops> linalgOpToLoopsImplSwitch(Operation *op,
|
|||
return linalgOpToLoopsImpl<LoopTy, MatmulOp>(op, builder);
|
||||
if (isa<MatvecOp>(op))
|
||||
return linalgOpToLoopsImpl<LoopTy, MatvecOp>(op, builder);
|
||||
if (isa<DotOp>(op))
|
||||
return linalgOpToLoopsImpl<LoopTy, DotOp>(op, builder);
|
||||
if (isa<BatchMatmulOp>(op))
|
||||
return linalgOpToLoopsImpl<LoopTy, BatchMatmulOp>(op, builder);
|
||||
llvm_unreachable("Unexpected op in linalgOpToLoopsImpl");
|
||||
|
|
|
@ -422,8 +422,8 @@ func @generic(%arg0: memref<?x?xi4>) {
|
|||
// -----
|
||||
|
||||
func @generic_result_0_element_type(%arg0: memref<?xf32>) {
|
||||
// expected-error @+1 {{'linalg.dot' op expected 3 operands, but found 2}}
|
||||
linalg.dot(%arg0, %arg0): memref<?xf32>, memref<?xf32>
|
||||
// expected-error @+1 {{'linalg.dot' expects 3 operands, but found 2}}
|
||||
linalg.dot %arg0, %arg0 : (memref<?xf32>, memref<?xf32>)
|
||||
}
|
||||
|
||||
// -----
|
||||
|
|
|
@ -123,7 +123,7 @@ func @dot(%arg0: memref<?xi8>, %M: index) {
|
|||
%1 = view %arg0[%c0][%M] : memref<?xi8> to memref<?xf32>
|
||||
%2 = view %arg0[%c0][%M] : memref<?xi8> to memref<?xf32>
|
||||
%3 = view %arg0[%c0][] : memref<?xi8> to memref<f32>
|
||||
linalg.dot(%1, %2, %3) : memref<?xf32>, memref<?xf32>, memref<f32>
|
||||
linalg.dot %1, %2, %3 : (memref<?xf32>, memref<?xf32>, memref<f32>)
|
||||
return
|
||||
}
|
||||
// CHECKLOOP-LABEL: func @dot(%{{.*}}: memref<?xi8>,
|
||||
|
@ -154,7 +154,9 @@ func @dot(%arg0: memref<?xi8>, %M: index) {
|
|||
|
||||
|
||||
func @dot_view(%arg0: memref<?xf32, offset: ?, strides: [1]>, %arg1: memref<?xf32, offset: ?, strides: [1]>, %arg2: memref<f32>) {
|
||||
linalg.dot(%arg0, %arg1, %arg2) : memref<?xf32, offset: ?, strides: [1]>, memref<?xf32, offset: ?, strides: [1]>, memref<f32>
|
||||
linalg.dot %arg0, %arg1, %arg2 : (memref<?xf32, offset: ?, strides: [1]>,
|
||||
memref<?xf32, offset: ?, strides: [1]>,
|
||||
memref<f32>)
|
||||
return
|
||||
}
|
||||
// CHECKLOOP-LABEL: func @dot_view(
|
||||
|
|
|
@ -88,10 +88,10 @@ func @ops(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
|||
memref<?x?xf32, offset: ?, strides: [?, 1]>)
|
||||
linalg.matvec %arg0, %arg1, %arg2 : (memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
||||
memref<?xf32, offset: ?, strides: [1]>,
|
||||
memref<?xf32, offset: ?, strides: [1]>)
|
||||
linalg.dot(%arg1, %arg2, %arg3) : memref<?xf32, offset: ?, strides: [1]>,
|
||||
memref<?xf32, offset: ?, strides: [1]>,
|
||||
memref<f32>
|
||||
memref<?xf32, offset: ?, strides: [1]>)
|
||||
linalg.dot %arg1, %arg2, %arg3 : (memref<?xf32, offset: ?, strides: [1]>,
|
||||
memref<?xf32, offset: ?, strides: [1]>,
|
||||
memref<f32>)
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @ops(%
|
||||
|
@ -103,10 +103,10 @@ func @ops(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
|
|||
// CHECK-SAME: (memref<?x?xf32, #[[$strided2D]]>,
|
||||
// CHECK-SAME: memref<?xf32, #[[$strided1D]]>,
|
||||
// CHECK-SAME: memref<?xf32, #[[$strided1D]]>)
|
||||
// CHECK-NEXT: linalg.dot(%{{.*}}, %{{.*}}, %{{.*}}) :
|
||||
// CHECK-SAME: memref<?xf32, #[[$strided1D]]>,
|
||||
// CHECK-SAME: memref<?xf32, #[[$strided1D]]>,
|
||||
// CHECK-SAME: memref<f32>
|
||||
// CHECK-NEXT: linalg.dot %{{.*}}, %{{.*}}, %{{.*}} :
|
||||
// CHECK-SAME: (memref<?xf32, #[[$strided1D]]>,
|
||||
// CHECK-SAME: memref<?xf32, #[[$strided1D]]>,
|
||||
// CHECK-SAME: memref<f32>)
|
||||
|
||||
// -----
|
||||
|
||||
|
|
|
@ -13,9 +13,9 @@
|
|||
func @dot(%arg0: memref<?xf32, offset: ?, strides: [1]>,
|
||||
%arg1: memref<?xf32, offset: ?, strides: [1]>,
|
||||
%arg2: memref<f32>) {
|
||||
linalg.dot(%arg0, %arg1, %arg2) : memref<?xf32, offset: ?, strides: [1]>,
|
||||
memref<?xf32, offset: ?, strides: [1]>,
|
||||
memref<f32>
|
||||
linalg.dot %arg0, %arg1, %arg2 : (memref<?xf32, offset: ?, strides: [1]>,
|
||||
memref<?xf32, offset: ?, strides: [1]>,
|
||||
memref<f32>)
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @dot(
|
||||
|
|
|
@ -271,7 +271,9 @@ func @matvec(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1: memref<?
|
|||
// TILE-234: linalg.matvec %[[sAij]], %[[sBj]], %[[sCi]] : (memref<?x?xf32, #[[$strided2D]]>, memref<?xf32, #[[$strided1D]]>, memref<?xf32, #[[$strided1D]]>)
|
||||
|
||||
func @dot(%arg0: memref<?xf32, offset: ?, strides: [1]>, %arg1: memref<?xf32, offset: ?, strides: [1]>, %arg2: memref<f32>) {
|
||||
linalg.dot(%arg0, %arg1, %arg2) : memref<?xf32, offset: ?, strides: [1]>, memref<?xf32, offset: ?, strides: [1]>, memref<f32>
|
||||
linalg.dot %arg0, %arg1, %arg2 : (memref<?xf32, offset: ?, strides: [1]>,
|
||||
memref<?xf32, offset: ?, strides: [1]>,
|
||||
memref<f32>)
|
||||
return
|
||||
}
|
||||
// TILE-2-LABEL: func @dot(
|
||||
|
@ -285,7 +287,7 @@ func @dot(%arg0: memref<?xf32, offset: ?, strides: [1]>, %arg1: memref<?xf32, of
|
|||
// TILE-2: %[[localM:.*]] = dim %{{.*}}, %c0
|
||||
// TILE-2: %[[szM:.*]] = affine.min #[[$bound_map]](%[[I]])[%[[localM]]]
|
||||
// TILE-2: %[[sBi:.*]] = subview %{{.*}}[%[[I]]] [%[[szM]]] [1] : memref<?xf32, #[[$strided1D]]> to memref<?xf32, #[[$strided1D]]>
|
||||
// TILE-2: linalg.dot(%[[sAi]], %[[sBi]], {{.*}}) : memref<?xf32, #[[$strided1D]]>, memref<?xf32, #[[$strided1D]]>, memref<f32>
|
||||
// TILE-2: linalg.dot %[[sAi]], %[[sBi]], {{.*}} : (memref<?xf32, #[[$strided1D]]>, memref<?xf32, #[[$strided1D]]>, memref<f32>)
|
||||
|
||||
// TILE-02-LABEL: func @dot(
|
||||
// TILE-02-NOT: scf.for
|
||||
|
@ -304,7 +306,7 @@ func @dot(%arg0: memref<?xf32, offset: ?, strides: [1]>, %arg1: memref<?xf32, of
|
|||
// TILE-234: %[[localM:.*]] = dim %{{.*}}, %c0
|
||||
// TILE-234: %[[szM:.*]] = affine.min #[[$bound_map_2]](%[[I]])[%[[localM]]]
|
||||
// TILE-234: %[[sBi:.*]] = subview %{{.*}}[%[[I]]] [%[[szM]]] [1] : memref<?xf32, #[[$strided1D]]> to memref<?xf32, #[[$strided1D]]>
|
||||
// TILE-234: linalg.dot(%[[sAi]], %[[sBi]], %{{.*}}) : memref<?xf32, #[[$strided1D]]>, memref<?xf32, #[[$strided1D]]>, memref<f32>
|
||||
// TILE-234: linalg.dot %[[sAi]], %[[sBi]], %{{.*}} : (memref<?xf32, #[[$strided1D]]>, memref<?xf32, #[[$strided1D]]>, memref<f32>)
|
||||
|
||||
func @fill_static(%arg0: memref<127x99xf32>, %arg1: f32) {
|
||||
linalg.fill(%arg0, %arg1) : memref<127x99xf32>, f32
|
||||
|
|
|
@ -36,7 +36,7 @@ func @matmul(%A: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
|
|||
func @contraction_dot(%A: memref<1584xf32>, %B: memref<1584xf32>, %C: memref<f32>) {
|
||||
// VECTOR-CONTRACTION: vector.contract
|
||||
// VECTOR-CONTRACTION-SAME: vector<1584xf32>, vector<1584xf32> into f32
|
||||
linalg.dot(%A, %B, %C) : memref<1584xf32>, memref<1584xf32>, memref<f32>
|
||||
linalg.dot %A, %B, %C : (memref<1584xf32>, memref<1584xf32>, memref<f32>)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -14,10 +14,10 @@
|
|||
func @dot(%x: memref<?xf32, offset: ?, strides: [1]>,
|
||||
%y: memref<?xf32, offset: ?, strides: [1]>,
|
||||
%v: memref<f32>) {
|
||||
linalg.dot(%x, %y, %v) { __internal_linalg_transform__ = "MEM" } :
|
||||
memref<?xf32, offset: ?, strides: [1]>,
|
||||
memref<?xf32, offset: ?, strides: [1]>,
|
||||
memref<f32>
|
||||
linalg.dot %x, %y, %v { __internal_linalg_transform__ = "MEM" } :
|
||||
(memref<?xf32, offset: ?, strides: [1]>,
|
||||
memref<?xf32, offset: ?, strides: [1]>,
|
||||
memref<f32>)
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @dot
|
||||
|
@ -28,8 +28,8 @@ func @dot(%x: memref<?xf32, offset: ?, strides: [1]>,
|
|||
// CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c1]] {
|
||||
// CHECK: load
|
||||
// CHECK: load
|
||||
// CHECK: mulf
|
||||
// CHECK: load
|
||||
// CHECK: mulf
|
||||
// CHECK: addf
|
||||
// CHECK: store
|
||||
|
||||
|
|
|
@ -51,7 +51,7 @@ func @dot() -> f32 {
|
|||
%B = view %bB[%c0][%c16] : memref<?xi8> to memref<?xf32>
|
||||
%C = view %bC[%c0][] : memref<?xi8> to memref<f32>
|
||||
|
||||
linalg.dot(%A, %B, %C) : memref<?xf32>, memref<?xf32>, memref<f32>
|
||||
linalg.dot %A, %B, %C : (memref<?xf32>, memref<?xf32>, memref<f32>)
|
||||
%res = load %C[] : memref<f32>
|
||||
|
||||
dealloc %bC : memref<?xi8>
|
||||
|
|
Loading…
Reference in a new issue