[mlir][Linalg] Pattern to fuse pad operation with elementwise operations.

Most convolution operations need explicit padding of the input to
ensure all accesses are inbounds. In such cases, having a pad
operation can be a significant overhead. One way to reduce that
overhead is to try to fuse the pad operation with the producer of its
source.

A sequence

```
linalg.generic -> linalg.pad_tensor
```

can be replaced with

```
linalg.fill -> tensor.extract_slice -> linalg.generic ->
tensor.insert_slice.
```

if the `linalg.generic` has all parallel iterator types.

Differential Revision: https://reviews.llvm.org/D116418
This commit is contained in:
MaheshRavishankar 2022-01-11 13:35:37 -08:00
parent 4372e629a9
commit e7cb716ef9
7 changed files with 273 additions and 0 deletions

View file

@ -87,6 +87,12 @@ void populateFoldReshapeOpsByLinearizationPatterns(RewritePatternSet &patterns);
void populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
RewritePatternSet &patterns);
/// Pattern to fuse a `linalg.pad_tensor` operation with the producer of its
/// source, if the producer is a `linalg` operation with all parallel iterator
/// types.
void populateFusePadTensorWithProducerLinalgOpPatterns(
RewritePatternSet &patterns);
/// Patterns to convert from one named op to another. These can be seen as
/// canonicalizations of named ops into another named op.
void populateLinalgNamedOpConversionPatterns(RewritePatternSet &patterns);

View file

@ -17,6 +17,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
Loops.cpp
LinalgStrategyPasses.cpp
NamedOpConversions.cpp
PadOpInterchange.cpp
Promotion.cpp
Tiling.cpp
Transforms.cpp

View file

@ -0,0 +1,122 @@
//===- PadOpInterchange.cpp - Interchange pad operation with Generic ops --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements patterns that intechanges a generic op -> pad_tensor
// pattern into extract_slice -> generic_op.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
using namespace mlir::linalg;
namespace {
/// A sequence of operations
///
/// ```mlir
/// %0 = linalg. ...
/// %1 = linalg.pad_tensor %0 ...
/// ```
///
/// can be replaced with
///
/// ```mlir
/// %0 = linalg.fill
/// %1 = tensor.extract_slice %0 ...
/// %2 = linalg. .... outs(..., %1, ....) ....
/// %3 = tensor.insert_slice %2 into %1 ...
/// ```
///
/// if the `linalg.generic` has all parallel iterator types.
struct FusePadTensorOp : OpRewritePattern<PadTensorOp> {
using OpRewritePattern<PadTensorOp>::OpRewritePattern;
LogicalResult matchAndRewrite(PadTensorOp padOp,
PatternRewriter &rewriter) const override {
// Only works on padding op that sets the padded value to a constant.
Value padValue = padOp.getConstantPaddingValue();
if (!padValue)
return rewriter.notifyMatchFailure(padOp, "non constant padding");
// This pattern could work for any Linalg op. For now restrict it to generic
// ops.
Value source = padOp.source();
auto linalgOp = source.getDefiningOp<GenericOp>();
if (!linalgOp) {
return rewriter.notifyMatchFailure(
padOp, "expected source to be linalg.generic op");
}
// All iterator types need to be parallel.
if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops()) {
return rewriter.notifyMatchFailure(
padOp, "only supported for ops with all parallel iterator types");
}
ReifiedRankedShapedTypeDims resultShape;
if (failed(padOp.reifyResultShapes(rewriter, resultShape)) ||
resultShape.size() != 1) {
return rewriter.notifyMatchFailure(
padOp, "failed to get shape of pad op result");
}
Location loc = padOp.getLoc();
// Create the tensor of same size as output of the pad op.
RankedTensorType padResultType = padOp.getResultType();
auto resultSizes = getAsOpFoldResult(resultShape[0]);
auto initTensor = rewriter.create<InitTensorOp>(
loc, resultSizes, padResultType.getElementType());
// Fill the tensor with the pad value.
// TODO: There is an option to fill only the boundaries. For now just
// filling the whole tensor.
auto fillTensor =
rewriter.create<FillOp>(loc, padValue, initTensor.getResult());
// Construct a slice of the fill result that is to be replaced with the
// result of the generic op. The low pad values are the offsets, the size of
// the source is the size of the slice.
// TODO: This insert/extract could be potentially made a utility method.
unsigned resultNumber = source.cast<OpResult>().getResultNumber();
SmallVector<OpFoldResult> offsets = padOp.getMixedLowPad();
SmallVector<OpFoldResult> sizes;
sizes.reserve(offsets.size());
for (auto shape : llvm::enumerate(
source.getType().cast<RankedTensorType>().getShape())) {
if (ShapedType::isDynamic(shape.value())) {
sizes.push_back(
rewriter.create<tensor::DimOp>(loc, source, shape.index())
.getResult());
} else {
sizes.push_back(rewriter.getIndexAttr(shape.value()));
}
}
SmallVector<OpFoldResult> strides(offsets.size(), rewriter.getIndexAttr(1));
auto slice = rewriter.create<tensor::ExtractSliceOp>(
loc, fillTensor.getResult(0), offsets, sizes, strides);
// Clone the generic op.
auto clonedOp = cast<GenericOp>(rewriter.clone(*linalgOp.getOperation()));
clonedOp.setOutputOperand(resultNumber, slice.getResult());
// Insert it back into the result of the fill.
rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
padOp, clonedOp.getResult(resultNumber), fillTensor.getResult(0),
offsets, sizes, strides);
return success();
}
};
} // namespace
void mlir::linalg::populateFusePadTensorWithProducerLinalgOpPatterns(
RewritePatternSet &patterns) {
patterns.add<FusePadTensorOp>(patterns.getContext());
}

View file

@ -0,0 +1,93 @@
// RUN: mlir-opt -test-linalg-pad-fusion -split-input-file %s | FileCheck %s
func @dynamic_pad_fusion(%arg0 : tensor<?x?xf32>, %arg1 : index, %arg2 : index,
%arg3 : index, %arg4 : index, %arg5 : f32) -> tensor<?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
%d1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
%init = linalg.init_tensor [%d0, %d1] : tensor<?x?xf32>
%0 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%arg0 : tensor<?x?xf32>) outs(%init : tensor<?x?xf32>) {
^bb0(%arg6 : f32, %arg7 : f32):
%1 = arith.mulf %arg6, %arg6 : f32
linalg.yield %1 : f32
} -> tensor<?x?xf32>
%1 = linalg.pad_tensor %0 low [%arg1, %arg2] high [%arg3, %arg4] {
^bb0(%arg6: index, %arg7 : index):
linalg.yield %arg5 : f32
} : tensor<?x?xf32> to tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s2 + s0 + s1)>
// CHECK: func @dynamic_pad_fusion
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: f32
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[SOURCE:.+]] = linalg.generic
// CHECK-DAG: %[[SOURCE_D0:.+]] = tensor.dim %[[SOURCE]], %[[C0]]
// CHECK-DAG: %[[TARGET_D0:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG3]], %[[SOURCE_D0]]]
// CHECK-DAG: %[[SOURCE_D1:.+]] = tensor.dim %[[SOURCE]], %[[C1]]
// CHECK-DAG: %[[TARGET_D1:.+]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[ARG4]], %[[SOURCE_D1]]]
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[TARGET_D0]], %[[TARGET_D1]]]
// CHECK: %[[FILL:.+]] = linalg.fill(%[[ARG5]], %[[INIT]])
// CHECK-DAG: %[[SIZE_D0:.+]] = tensor.dim %[[SOURCE]], %[[C0]]
// CHECK-DAG: %[[SIZE_D1:.+]] = tensor.dim %[[SOURCE]], %[[C1]]
// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[FILL]]
// CHECK-SAME: [%[[ARG1]], %[[ARG2]]] [%[[SIZE_D0]], %[[SIZE_D1]]] [1, 1]
// CHECK: %[[SOURCE:.+]] = linalg.generic
// CHECK-SAME: outs(%[[SLICE]] : tensor<?x?xf32>)
// CHECK: %[[RESULT:.+]] = tensor.insert_slice %[[SOURCE]] into %[[FILL]]
// CHECK-SAME: [%[[ARG1]], %[[ARG2]]] [%[[SIZE_D0]], %[[SIZE_D1]]] [1, 1]
// CHECK: return %[[RESULT]]
// -----
func @mixed_pad_fusion(%arg0 : tensor<?x42xf32>, %arg1 : index, %arg2 : index,
%arg3 : f32) -> tensor<49x?xf32> {
%c0 = arith.constant 0 : index
%d0 = tensor.dim %arg0, %c0 : tensor<?x42xf32>
%init = linalg.init_tensor [42, %d0] : tensor<42x?xf32>
%0 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>],
iterator_types = ["parallel", "parallel"]}
ins(%arg0 : tensor<?x42xf32>) outs(%init : tensor<42x?xf32>) {
^bb0(%arg4 : f32, %arg5 : f32):
%1 = arith.mulf %arg4, %arg4 : f32
linalg.yield %1 : f32
} -> tensor<42x?xf32>
%1 = linalg.pad_tensor %0 low [3, %arg1] high [4, %arg2] {
^bb0(%arg4: index, %arg5 : index):
linalg.yield %arg3 : f32
} : tensor<42x?xf32> to tensor<49x?xf32>
return %1 : tensor<49x?xf32>
}
// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s2 + s0 + s1)>
// CHECK: func @mixed_pad_fusion
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x42xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: f32
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[SOURCE:.+]] = linalg.generic
// CHECK-DAG: %[[SOURCE_D1:.+]] = tensor.dim %[[SOURCE]], %[[C1]]
// CHECK-DAG: %[[TARGET_D1:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]], %[[SOURCE_D1]]]
// CHECK: %[[INIT:.+]] = linalg.init_tensor [49, %[[TARGET_D1]]]
// CHECK: %[[FILL:.+]] = linalg.fill(%[[ARG3]], %[[INIT]])
// CHECK-DAG: %[[SIZE_D1:.+]] = tensor.dim %[[SOURCE]], %[[C1]]
// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[FILL]]
// CHECK-SAME: [3, %[[ARG1]]] [42, %[[SIZE_D1]]] [1, 1]
// CHECK: %[[SOURCE:.+]] = linalg.generic
// CHECK-SAME: outs(%[[SLICE]] : tensor<42x?xf32>)
// CHECK: %[[RESULT:.+]] = tensor.insert_slice %[[SOURCE]] into %[[FILL]]
// CHECK-SAME: [3, %[[ARG1]]] [42, %[[SIZE_D1]]] [1, 1]
// CHECK: return %[[RESULT]]

View file

@ -8,6 +8,7 @@ add_mlir_library(MLIRLinalgTestPasses
TestLinalgFusionTransforms.cpp
TestLinalgHoisting.cpp
TestLinalgTransforms.cpp
TestPadFusion.cpp
EXCLUDE_FROM_LIBMLIR

View file

@ -0,0 +1,48 @@
//===- TestPadFusion.cpp - Test fusion of pad op with Linalg ops ---------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a pass for testing fusion of pad ops with its producer
// Linalg op.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir {
namespace {
struct TestPadFusionPass : public PassWrapper<TestPadFusionPass, FunctionPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry
.insert<AffineDialect, linalg::LinalgDialect, tensor::TensorDialect>();
}
StringRef getArgument() const final { return "test-linalg-pad-fusion"; }
StringRef getDescription() const final { return "Test PadOp fusion"; }
void runOnFunction() override {
MLIRContext *context = &getContext();
FuncOp funcOp = getFunction();
RewritePatternSet patterns(context);
linalg::populateFusePadTensorWithProducerLinalgOpPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
std::move(patterns))))
return signalPassFailure();
}
};
} // namespace
namespace test {
void registerTestPadFusion() { PassRegistration<TestPadFusionPass>(); }
} // namespace test
} // namespace mlir

View file

@ -103,6 +103,7 @@ void registerTestMemRefStrideCalculation();
void registerTestNumberOfBlockExecutionsPass();
void registerTestNumberOfOperationExecutionsPass();
void registerTestOpaqueLoc();
void registerTestPadFusion();
void registerTestPDLByteCodePass();
void registerTestPreparationPassWithAllowedMemrefResults();
void registerTestRecursiveTypesPass();
@ -195,6 +196,7 @@ void registerTestPasses() {
mlir::test::registerTestNumberOfBlockExecutionsPass();
mlir::test::registerTestNumberOfOperationExecutionsPass();
mlir::test::registerTestOpaqueLoc();
mlir::test::registerTestPadFusion();
mlir::test::registerTestPDLByteCodePass();
mlir::test::registerTestRecursiveTypesPass();
mlir::test::registerTestSCFUtilsPass();