[mlir][Linalg] Pattern to fuse pad operation with elementwise operations.
Most convolution operations need explicit padding of the input to ensure all accesses are inbounds. In such cases, having a pad operation can be a significant overhead. One way to reduce that overhead is to try to fuse the pad operation with the producer of its source. A sequence ``` linalg.generic -> linalg.pad_tensor ``` can be replaced with ``` linalg.fill -> tensor.extract_slice -> linalg.generic -> tensor.insert_slice. ``` if the `linalg.generic` has all parallel iterator types. Differential Revision: https://reviews.llvm.org/D116418
This commit is contained in:
parent
4372e629a9
commit
e7cb716ef9
|
@ -87,6 +87,12 @@ void populateFoldReshapeOpsByLinearizationPatterns(RewritePatternSet &patterns);
|
|||
void populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
/// Pattern to fuse a `linalg.pad_tensor` operation with the producer of its
|
||||
/// source, if the producer is a `linalg` operation with all parallel iterator
|
||||
/// types.
|
||||
void populateFusePadTensorWithProducerLinalgOpPatterns(
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
/// Patterns to convert from one named op to another. These can be seen as
|
||||
/// canonicalizations of named ops into another named op.
|
||||
void populateLinalgNamedOpConversionPatterns(RewritePatternSet &patterns);
|
||||
|
|
|
@ -17,6 +17,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
|
|||
Loops.cpp
|
||||
LinalgStrategyPasses.cpp
|
||||
NamedOpConversions.cpp
|
||||
PadOpInterchange.cpp
|
||||
Promotion.cpp
|
||||
Tiling.cpp
|
||||
Transforms.cpp
|
||||
|
|
122
mlir/lib/Dialect/Linalg/Transforms/PadOpInterchange.cpp
Normal file
122
mlir/lib/Dialect/Linalg/Transforms/PadOpInterchange.cpp
Normal file
|
@ -0,0 +1,122 @@
|
|||
//===- PadOpInterchange.cpp - Interchange pad operation with Generic ops --===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements patterns that intechanges a generic op -> pad_tensor
|
||||
// pattern into extract_slice -> generic_op.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
||||
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::linalg;
|
||||
|
||||
namespace {
|
||||
|
||||
/// A sequence of operations
|
||||
///
|
||||
/// ```mlir
|
||||
/// %0 = linalg. ...
|
||||
/// %1 = linalg.pad_tensor %0 ...
|
||||
/// ```
|
||||
///
|
||||
/// can be replaced with
|
||||
///
|
||||
/// ```mlir
|
||||
/// %0 = linalg.fill
|
||||
/// %1 = tensor.extract_slice %0 ...
|
||||
/// %2 = linalg. .... outs(..., %1, ....) ....
|
||||
/// %3 = tensor.insert_slice %2 into %1 ...
|
||||
/// ```
|
||||
///
|
||||
/// if the `linalg.generic` has all parallel iterator types.
|
||||
struct FusePadTensorOp : OpRewritePattern<PadTensorOp> {
|
||||
using OpRewritePattern<PadTensorOp>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(PadTensorOp padOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Only works on padding op that sets the padded value to a constant.
|
||||
Value padValue = padOp.getConstantPaddingValue();
|
||||
if (!padValue)
|
||||
return rewriter.notifyMatchFailure(padOp, "non constant padding");
|
||||
|
||||
// This pattern could work for any Linalg op. For now restrict it to generic
|
||||
// ops.
|
||||
Value source = padOp.source();
|
||||
auto linalgOp = source.getDefiningOp<GenericOp>();
|
||||
if (!linalgOp) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
padOp, "expected source to be linalg.generic op");
|
||||
}
|
||||
// All iterator types need to be parallel.
|
||||
if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
padOp, "only supported for ops with all parallel iterator types");
|
||||
}
|
||||
ReifiedRankedShapedTypeDims resultShape;
|
||||
if (failed(padOp.reifyResultShapes(rewriter, resultShape)) ||
|
||||
resultShape.size() != 1) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
padOp, "failed to get shape of pad op result");
|
||||
}
|
||||
|
||||
Location loc = padOp.getLoc();
|
||||
|
||||
// Create the tensor of same size as output of the pad op.
|
||||
RankedTensorType padResultType = padOp.getResultType();
|
||||
auto resultSizes = getAsOpFoldResult(resultShape[0]);
|
||||
auto initTensor = rewriter.create<InitTensorOp>(
|
||||
loc, resultSizes, padResultType.getElementType());
|
||||
|
||||
// Fill the tensor with the pad value.
|
||||
// TODO: There is an option to fill only the boundaries. For now just
|
||||
// filling the whole tensor.
|
||||
auto fillTensor =
|
||||
rewriter.create<FillOp>(loc, padValue, initTensor.getResult());
|
||||
|
||||
// Construct a slice of the fill result that is to be replaced with the
|
||||
// result of the generic op. The low pad values are the offsets, the size of
|
||||
// the source is the size of the slice.
|
||||
// TODO: This insert/extract could be potentially made a utility method.
|
||||
unsigned resultNumber = source.cast<OpResult>().getResultNumber();
|
||||
SmallVector<OpFoldResult> offsets = padOp.getMixedLowPad();
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
sizes.reserve(offsets.size());
|
||||
for (auto shape : llvm::enumerate(
|
||||
source.getType().cast<RankedTensorType>().getShape())) {
|
||||
if (ShapedType::isDynamic(shape.value())) {
|
||||
sizes.push_back(
|
||||
rewriter.create<tensor::DimOp>(loc, source, shape.index())
|
||||
.getResult());
|
||||
} else {
|
||||
sizes.push_back(rewriter.getIndexAttr(shape.value()));
|
||||
}
|
||||
}
|
||||
SmallVector<OpFoldResult> strides(offsets.size(), rewriter.getIndexAttr(1));
|
||||
auto slice = rewriter.create<tensor::ExtractSliceOp>(
|
||||
loc, fillTensor.getResult(0), offsets, sizes, strides);
|
||||
|
||||
// Clone the generic op.
|
||||
auto clonedOp = cast<GenericOp>(rewriter.clone(*linalgOp.getOperation()));
|
||||
clonedOp.setOutputOperand(resultNumber, slice.getResult());
|
||||
|
||||
// Insert it back into the result of the fill.
|
||||
rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
|
||||
padOp, clonedOp.getResult(resultNumber), fillTensor.getResult(0),
|
||||
offsets, sizes, strides);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void mlir::linalg::populateFusePadTensorWithProducerLinalgOpPatterns(
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.add<FusePadTensorOp>(patterns.getContext());
|
||||
}
|
93
mlir/test/Dialect/Linalg/pad_fusion.mlir
Normal file
93
mlir/test/Dialect/Linalg/pad_fusion.mlir
Normal file
|
@ -0,0 +1,93 @@
|
|||
// RUN: mlir-opt -test-linalg-pad-fusion -split-input-file %s | FileCheck %s
|
||||
|
||||
func @dynamic_pad_fusion(%arg0 : tensor<?x?xf32>, %arg1 : index, %arg2 : index,
|
||||
%arg3 : index, %arg4 : index, %arg5 : f32) -> tensor<?x?xf32> {
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
%d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
|
||||
%d1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
|
||||
%init = linalg.init_tensor [%d0, %d1] : tensor<?x?xf32>
|
||||
%0 = linalg.generic {
|
||||
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
|
||||
iterator_types = ["parallel", "parallel"]}
|
||||
ins(%arg0 : tensor<?x?xf32>) outs(%init : tensor<?x?xf32>) {
|
||||
^bb0(%arg6 : f32, %arg7 : f32):
|
||||
%1 = arith.mulf %arg6, %arg6 : f32
|
||||
linalg.yield %1 : f32
|
||||
} -> tensor<?x?xf32>
|
||||
%1 = linalg.pad_tensor %0 low [%arg1, %arg2] high [%arg3, %arg4] {
|
||||
^bb0(%arg6: index, %arg7 : index):
|
||||
linalg.yield %arg5 : f32
|
||||
} : tensor<?x?xf32> to tensor<?x?xf32>
|
||||
return %1 : tensor<?x?xf32>
|
||||
}
|
||||
|
||||
// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s2 + s0 + s1)>
|
||||
// CHECK: func @dynamic_pad_fusion
|
||||
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
|
||||
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
|
||||
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
|
||||
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
|
||||
// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: f32
|
||||
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
|
||||
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
|
||||
// CHECK-DAG: %[[SOURCE:.+]] = linalg.generic
|
||||
// CHECK-DAG: %[[SOURCE_D0:.+]] = tensor.dim %[[SOURCE]], %[[C0]]
|
||||
// CHECK-DAG: %[[TARGET_D0:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG3]], %[[SOURCE_D0]]]
|
||||
// CHECK-DAG: %[[SOURCE_D1:.+]] = tensor.dim %[[SOURCE]], %[[C1]]
|
||||
// CHECK-DAG: %[[TARGET_D1:.+]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[ARG4]], %[[SOURCE_D1]]]
|
||||
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[TARGET_D0]], %[[TARGET_D1]]]
|
||||
// CHECK: %[[FILL:.+]] = linalg.fill(%[[ARG5]], %[[INIT]])
|
||||
// CHECK-DAG: %[[SIZE_D0:.+]] = tensor.dim %[[SOURCE]], %[[C0]]
|
||||
// CHECK-DAG: %[[SIZE_D1:.+]] = tensor.dim %[[SOURCE]], %[[C1]]
|
||||
// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[FILL]]
|
||||
// CHECK-SAME: [%[[ARG1]], %[[ARG2]]] [%[[SIZE_D0]], %[[SIZE_D1]]] [1, 1]
|
||||
// CHECK: %[[SOURCE:.+]] = linalg.generic
|
||||
// CHECK-SAME: outs(%[[SLICE]] : tensor<?x?xf32>)
|
||||
// CHECK: %[[RESULT:.+]] = tensor.insert_slice %[[SOURCE]] into %[[FILL]]
|
||||
// CHECK-SAME: [%[[ARG1]], %[[ARG2]]] [%[[SIZE_D0]], %[[SIZE_D1]]] [1, 1]
|
||||
// CHECK: return %[[RESULT]]
|
||||
|
||||
// -----
|
||||
|
||||
func @mixed_pad_fusion(%arg0 : tensor<?x42xf32>, %arg1 : index, %arg2 : index,
|
||||
%arg3 : f32) -> tensor<49x?xf32> {
|
||||
%c0 = arith.constant 0 : index
|
||||
%d0 = tensor.dim %arg0, %c0 : tensor<?x42xf32>
|
||||
%init = linalg.init_tensor [42, %d0] : tensor<42x?xf32>
|
||||
%0 = linalg.generic {
|
||||
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>],
|
||||
iterator_types = ["parallel", "parallel"]}
|
||||
ins(%arg0 : tensor<?x42xf32>) outs(%init : tensor<42x?xf32>) {
|
||||
^bb0(%arg4 : f32, %arg5 : f32):
|
||||
%1 = arith.mulf %arg4, %arg4 : f32
|
||||
linalg.yield %1 : f32
|
||||
} -> tensor<42x?xf32>
|
||||
%1 = linalg.pad_tensor %0 low [3, %arg1] high [4, %arg2] {
|
||||
^bb0(%arg4: index, %arg5 : index):
|
||||
linalg.yield %arg3 : f32
|
||||
} : tensor<42x?xf32> to tensor<49x?xf32>
|
||||
return %1 : tensor<49x?xf32>
|
||||
}
|
||||
// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s2 + s0 + s1)>
|
||||
// CHECK: func @mixed_pad_fusion
|
||||
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x42xf32>
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
|
||||
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
|
||||
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: f32
|
||||
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
|
||||
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
|
||||
// CHECK-DAG: %[[SOURCE:.+]] = linalg.generic
|
||||
// CHECK-DAG: %[[SOURCE_D1:.+]] = tensor.dim %[[SOURCE]], %[[C1]]
|
||||
// CHECK-DAG: %[[TARGET_D1:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]], %[[SOURCE_D1]]]
|
||||
// CHECK: %[[INIT:.+]] = linalg.init_tensor [49, %[[TARGET_D1]]]
|
||||
// CHECK: %[[FILL:.+]] = linalg.fill(%[[ARG3]], %[[INIT]])
|
||||
// CHECK-DAG: %[[SIZE_D1:.+]] = tensor.dim %[[SOURCE]], %[[C1]]
|
||||
// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[FILL]]
|
||||
// CHECK-SAME: [3, %[[ARG1]]] [42, %[[SIZE_D1]]] [1, 1]
|
||||
// CHECK: %[[SOURCE:.+]] = linalg.generic
|
||||
// CHECK-SAME: outs(%[[SLICE]] : tensor<42x?xf32>)
|
||||
// CHECK: %[[RESULT:.+]] = tensor.insert_slice %[[SOURCE]] into %[[FILL]]
|
||||
// CHECK-SAME: [3, %[[ARG1]]] [42, %[[SIZE_D1]]] [1, 1]
|
||||
// CHECK: return %[[RESULT]]
|
|
@ -8,6 +8,7 @@ add_mlir_library(MLIRLinalgTestPasses
|
|||
TestLinalgFusionTransforms.cpp
|
||||
TestLinalgHoisting.cpp
|
||||
TestLinalgTransforms.cpp
|
||||
TestPadFusion.cpp
|
||||
|
||||
EXCLUDE_FROM_LIBMLIR
|
||||
|
||||
|
|
48
mlir/test/lib/Dialect/Linalg/TestPadFusion.cpp
Normal file
48
mlir/test/lib/Dialect/Linalg/TestPadFusion.cpp
Normal file
|
@ -0,0 +1,48 @@
|
|||
//===- TestPadFusion.cpp - Test fusion of pad op with Linalg ops ---------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements a pass for testing fusion of pad ops with its producer
|
||||
// Linalg op.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
namespace {
|
||||
struct TestPadFusionPass : public PassWrapper<TestPadFusionPass, FunctionPass> {
|
||||
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry
|
||||
.insert<AffineDialect, linalg::LinalgDialect, tensor::TensorDialect>();
|
||||
}
|
||||
|
||||
StringRef getArgument() const final { return "test-linalg-pad-fusion"; }
|
||||
StringRef getDescription() const final { return "Test PadOp fusion"; }
|
||||
|
||||
void runOnFunction() override {
|
||||
MLIRContext *context = &getContext();
|
||||
FuncOp funcOp = getFunction();
|
||||
RewritePatternSet patterns(context);
|
||||
linalg::populateFusePadTensorWithProducerLinalgOpPatterns(patterns);
|
||||
if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
|
||||
std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace test {
|
||||
void registerTestPadFusion() { PassRegistration<TestPadFusionPass>(); }
|
||||
} // namespace test
|
||||
|
||||
} // namespace mlir
|
|
@ -103,6 +103,7 @@ void registerTestMemRefStrideCalculation();
|
|||
void registerTestNumberOfBlockExecutionsPass();
|
||||
void registerTestNumberOfOperationExecutionsPass();
|
||||
void registerTestOpaqueLoc();
|
||||
void registerTestPadFusion();
|
||||
void registerTestPDLByteCodePass();
|
||||
void registerTestPreparationPassWithAllowedMemrefResults();
|
||||
void registerTestRecursiveTypesPass();
|
||||
|
@ -195,6 +196,7 @@ void registerTestPasses() {
|
|||
mlir::test::registerTestNumberOfBlockExecutionsPass();
|
||||
mlir::test::registerTestNumberOfOperationExecutionsPass();
|
||||
mlir::test::registerTestOpaqueLoc();
|
||||
mlir::test::registerTestPadFusion();
|
||||
mlir::test::registerTestPDLByteCodePass();
|
||||
mlir::test::registerTestRecursiveTypesPass();
|
||||
mlir::test::registerTestSCFUtilsPass();
|
||||
|
|
Loading…
Reference in a new issue