[mlir][SCF] Canonicalize dim(x) where x is an iter_arg

* Add `DimOfIterArgFolder`.
* Move existing cross-dialect canonicalization patterns to `LoopCanonicalization.cpp`.
* Rename `SCFAffineOpCanonicalization` pass to `SCFForLoopCanonicalization`.
* Expand documentaton of scf.for: The type of loop-carried variables may not change with iterations. (Not even the dynamic type.)

Differential Revision: https://reviews.llvm.org/D108806
This commit is contained in:
Matthias Springer 2021-08-30 01:12:14 +00:00
parent 71b170ccf3
commit d18ffd61d4
14 changed files with 179 additions and 82 deletions

View file

@ -30,7 +30,7 @@ std::unique_ptr<Pass> createForLoopPeelingPass();
/// Creates a pass that canonicalizes affine.min and affine.max operations
/// inside of scf.for loops with known lower and upper bounds.
std::unique_ptr<Pass> createSCFAffineOpCanonicalizationPass();
std::unique_ptr<Pass> createSCFForLoopCanonicalizationPass();
/// Creates a loop fusion pass which fuses parallel loops.
std::unique_ptr<Pass> createParallelLoopFusionPass();

View file

@ -17,14 +17,14 @@ def SCFBufferize : FunctionPass<"scf-bufferize"> {
let dependentDialects = ["memref::MemRefDialect"];
}
// Note: Making this a canonicalization pattern would require a dependency
// of the SCF dialect on the Affine dialect or vice versa.
def SCFAffineOpCanonicalization
: FunctionPass<"canonicalize-scf-affine-op"> {
let summary = "Canonicalize affine.min and affine.max ops in the context of "
"SCF loops with known bounds";
let constructor = "mlir::createSCFAffineOpCanonicalizationPass()";
let dependentDialects = ["AffineDialect"];
// Note: Making these canonicalization patterns would require a dependency
// of the SCF dialect on the Affine/Tensor/MemRef dialects or vice versa.
def SCFForLoopCanonicalization
: FunctionPass<"for-loop-canonicalization"> {
let summary = "Canonicalize operations within scf.for loop bodies";
let constructor = "mlir::createSCFForLoopCanonicalizationPass()";
let dependentDialects = ["AffineDialect", "tensor::TensorDialect",
"memref::MemRefDialect"];
}
def SCFForLoopPeeling

View file

@ -122,7 +122,7 @@ def ForOp : SCF_Op<"for",
let summary = "for operation";
let description = [{
The "scf.for" operation represents a loop taking 3 SSA value as operands
that represent the lower bound, upper bound and step respectively. The
that represent the lower bound, upper bound and step respectively. The
operation defines an SSA value for its induction variable. It has one
region capturing the loop body. The induction variable is represented as an
argument of this region. This SSA value always has type index, which is the
@ -146,14 +146,18 @@ def ForOp : SCF_Op<"for",
values after loop termination. The initial values of the variables are
passed as additional SSA operands to the "scf.for" following the 3 loop
control SSA values mentioned above (lower bound, upper bound and step). The
operation region has equivalent arguments for each variable representing
the value of the variable at the current iteration.
operation region has an argument for the induction variable, followed by
one argument for each loop-carried variable, representing he value of the
variable at the current iteration.
The region must terminate with a "scf.yield" that passes all the current
iteration variables to the next iteration, or to the "scf.for" result, if
at the last iteration. Note, that when the loop-carried variables are
present, calling ForOp::build will not insert the terminator implicitly.
The caller must insert "scf.yield" in that case.
The region must terminate with a "scf.yield" that passes the current
values of loop-carried variables to the next iteration, or to the "scf.for"
result, if at the last iteration. The type (static or dynamic) of a
loop-carried variable may not change with iterations. E.g., it is illegal
to pass a tensor of larger size to the next iteration; even if the tensor's
dimensions are dynamic (i.e., same static type). Note, that when the
loop-carried variables are present, calling ForOp::build will not insert the
terminator implicitly. The caller must insert "scf.yield" in that case.
"scf.for" results hold the final values after the last iteration.
For example, to sum-reduce a memref:

View file

@ -179,7 +179,7 @@ void populateSCFLoopPipeliningPatterns(RewritePatternSet &patterns,
/// Populate patterns for canonicalizing operations inside SCF loop bodies.
/// At the moment, only affine.min/max computations with iteration variables,
/// loop bounds and loop steps are canonicalized.
void populateSCFLoopBodyCanonicalizationPatterns(RewritePatternSet &patterns);
void populateSCFForLoopCanonicalizationPatterns(RewritePatternSet &patterns);
} // namespace scf
} // namespace mlir

View file

@ -48,7 +48,7 @@ void mlir::linalg::CodegenStrategy::transform(FuncOp func) const {
RewritePatternSet stage2Patterns =
linalg::getLinalgTilingCanonicalizationPatterns(context);
scf::populateSCFLoopBodyCanonicalizationPatterns(stage2Patterns);
scf::populateSCFForLoopCanonicalizationPatterns(stage2Patterns);
auto stage3Transforms = [&](Operation *op) {
// Some of these may be too aggressive as a stage 3 that is applied on each

View file

@ -537,7 +537,7 @@ applyTilingToLoopPatterns(LinalgTilingLoopType loopType, FuncOp funcOp,
MLIRContext *ctx = funcOp.getContext();
RewritePatternSet patterns(ctx);
insertTilingPatterns(patterns, options);
scf::populateSCFLoopBodyCanonicalizationPatterns(patterns);
scf::populateSCFForLoopCanonicalizationPatterns(patterns);
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
(void)applyPatternsAndFoldGreedily(
funcOp, getLinalgTilingCanonicalizationPatterns(ctx));

View file

@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIRSCFTransforms
Bufferize.cpp
LoopCanonicalization.cpp
LoopPipelining.cpp
LoopRangeFolding.cpp
LoopSpecialization.cpp
@ -22,6 +23,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
MLIRSCF
MLIRStandard
MLIRSupport
MLIRTensor
MLIRTransforms
MLIRTransformUtils
)

View file

@ -0,0 +1,127 @@
//===- LoopCanonicalization.cpp - Cross-dialect canonicalization patterns -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains cross-dialect canonicalization patterns that cannot be
// actual canonicalization patterns due to undesired additional dependencies.
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/Passes.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/SCF/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
using namespace mlir::scf;
namespace {
/// Fold dim ops of iter_args to dim ops of their respective init args. E.g.:
///
/// ```
/// %0 = ... : tensor<?x?xf32>
/// scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) {
/// %1 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
/// ...
/// }
/// ```
///
/// is folded to:
///
/// ```
/// %0 = ... : tensor<?x?xf32>
/// scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) {
/// %1 = tensor.dim %0, %c0 : tensor<?x?xf32>
/// ...
/// }
/// ```
template <typename OpTy>
struct DimOfIterArgFolder : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy dimOp,
PatternRewriter &rewriter) const override {
auto blockArg = dimOp.source().template dyn_cast<BlockArgument>();
if (!blockArg)
return failure();
auto forOp = dyn_cast<ForOp>(blockArg.getParentBlock()->getParentOp());
if (!forOp)
return failure();
Value initArg = forOp.getOpOperandForRegionIterArg(blockArg).get();
rewriter.updateRootInPlace(
dimOp, [&]() { dimOp.sourceMutable().assign(initArg); });
return success();
};
};
/// Canonicalize AffineMinOp/AffineMaxOp operations in the context of scf.for
/// and scf.parallel loops with a known range.
template <typename OpTy, bool IsMin>
struct AffineOpSCFCanonicalizationPattern : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
auto loopMatcher = [](Value iv, Value &lb, Value &ub, Value &step) {
if (scf::ForOp forOp = scf::getForInductionVarOwner(iv)) {
lb = forOp.lowerBound();
ub = forOp.upperBound();
step = forOp.step();
return success();
}
if (scf::ParallelOp parOp = scf::getParallelForInductionVarOwner(iv)) {
for (unsigned idx = 0; idx < parOp.getNumLoops(); ++idx) {
if (parOp.getInductionVars()[idx] == iv) {
lb = parOp.lowerBound()[idx];
ub = parOp.upperBound()[idx];
step = parOp.step()[idx];
return success();
}
}
return failure();
}
return failure();
};
return scf::canonicalizeMinMaxOpInLoop(rewriter, op, op.getAffineMap(),
op.operands(), IsMin, loopMatcher);
}
};
struct SCFForLoopCanonicalization
: public SCFForLoopCanonicalizationBase<SCFForLoopCanonicalization> {
void runOnFunction() override {
FuncOp funcOp = getFunction();
MLIRContext *ctx = funcOp.getContext();
RewritePatternSet patterns(ctx);
scf::populateSCFForLoopCanonicalizationPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns))))
signalPassFailure();
}
};
} // namespace
void mlir::scf::populateSCFForLoopCanonicalizationPatterns(
RewritePatternSet &patterns) {
MLIRContext *ctx = patterns.getContext();
patterns
.insert<AffineOpSCFCanonicalizationPattern<AffineMinOp, /*IsMin=*/true>,
AffineOpSCFCanonicalizationPattern<AffineMaxOp, /*IsMin=*/false>,
DimOfIterArgFolder<tensor::DimOp>,
DimOfIterArgFolder<memref::DimOp>>(ctx);
}
std::unique_ptr<Pass> mlir::createSCFForLoopCanonicalizationPass() {
return std::make_unique<SCFForLoopCanonicalization>();
}

View file

@ -516,40 +516,6 @@ struct ForLoopPeelingPattern : public OpRewritePattern<ForOp> {
/// the direct parent.
bool skipPartial;
};
/// Canonicalize AffineMinOp/AffineMaxOp operations in the context of scf.for
/// and scf.parallel loops with a known range.
template <typename OpTy, bool IsMin>
struct AffineOpSCFCanonicalizationPattern : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
auto loopMatcher = [](Value iv, Value &lb, Value &ub, Value &step) {
if (scf::ForOp forOp = scf::getForInductionVarOwner(iv)) {
lb = forOp.lowerBound();
ub = forOp.upperBound();
step = forOp.step();
return success();
}
if (scf::ParallelOp parOp = scf::getParallelForInductionVarOwner(iv)) {
for (unsigned idx = 0; idx < parOp.getNumLoops(); ++idx) {
if (parOp.getInductionVars()[idx] == iv) {
lb = parOp.lowerBound()[idx];
ub = parOp.upperBound()[idx];
step = parOp.step()[idx];
return success();
}
}
return failure();
}
return failure();
};
return scf::canonicalizeMinMaxOpInLoop(rewriter, op, op.getAffineMap(),
op.operands(), IsMin, loopMatcher);
}
};
} // namespace
namespace {
@ -583,24 +549,8 @@ struct ForLoopPeeling : public SCFForLoopPeelingBase<ForLoopPeeling> {
});
}
};
struct SCFAffineOpCanonicalization
: public SCFAffineOpCanonicalizationBase<SCFAffineOpCanonicalization> {
void runOnFunction() override {
FuncOp funcOp = getFunction();
MLIRContext *ctx = funcOp.getContext();
RewritePatternSet patterns(ctx);
scf::populateSCFLoopBodyCanonicalizationPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns))))
signalPassFailure();
}
};
} // namespace
std::unique_ptr<Pass> mlir::createSCFAffineOpCanonicalizationPass() {
return std::make_unique<SCFAffineOpCanonicalization>();
}
std::unique_ptr<Pass> mlir::createParallelLoopSpecializationPass() {
return std::make_unique<ParallelLoopSpecialization>();
}
@ -612,12 +562,3 @@ std::unique_ptr<Pass> mlir::createForLoopSpecializationPass() {
std::unique_ptr<Pass> mlir::createForLoopPeelingPass() {
return std::make_unique<ForLoopPeeling>();
}
void mlir::scf::populateSCFLoopBodyCanonicalizationPatterns(
RewritePatternSet &patterns) {
MLIRContext *ctx = patterns.getContext();
patterns
.insert<AffineOpSCFCanonicalizationPattern<AffineMinOp, /*IsMin=*/true>,
AffineOpSCFCanonicalizationPattern<AffineMaxOp, /*IsMin=*/false>>(
ctx);
}

View file

@ -22,6 +22,10 @@ namespace memref {
class MemRefDialect;
} // end namespace memref
namespace tensor {
class TensorDialect;
} // end namespace tensor
#define GEN_PASS_CLASSES
#include "mlir/Dialect/SCF/Passes.h.inc"

View file

@ -1,4 +1,4 @@
// RUN: mlir-opt %s -canonicalize-scf-affine-op -split-input-file | FileCheck %s
// RUN: mlir-opt %s -for-loop-canonicalization -split-input-file | FileCheck %s
// CHECK-LABEL: func @scf_for_canonicalize_min
// CHECK: %[[C2:.*]] = constant 2 : i64
@ -224,3 +224,21 @@ func @scf_parallel_canonicalize_min_2(%A : memref<i64>) {
}
return
}
// -----
// CHECK-LABEL: func @tensor_dim_of_iter_arg(
// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>
// CHECK: scf.for
// CHECK: tensor.dim %[[t]]
func @tensor_dim_of_iter_arg(%t : tensor<?x?xf32>) -> index {
%c0 = constant 0 : index
%c1 = constant 1 : index
%c10 = constant 10 : index
%0, %1 = scf.for %i = %c0 to %c10 step %c1 iter_args(%arg0 = %t, %arg1 = %c0)
-> (tensor<?x?xf32>, index) {
%dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
scf.yield %arg0, %dim : tensor<?x?xf32>, index
}
return %1 : index
}

View file

@ -71,7 +71,7 @@ void TestConvVectorization::runOnOperation() {
RewritePatternSet stage2Patterns =
linalg::getLinalgTilingCanonicalizationPatterns(context);
scf::populateSCFLoopBodyCanonicalizationPatterns(stage2Patterns);
scf::populateSCFForLoopCanonicalizationPatterns(stage2Patterns);
auto stage3Transforms = [](Operation *op) {
PassManager pm(op->getContext());

View file

@ -237,7 +237,7 @@ struct TestLinalgGreedyFusion
RewritePatternSet patterns =
linalg::getLinalgTilingCanonicalizationPatterns(context);
patterns.add<ExtractSliceOfPadTensorSwapPattern>(context);
scf::populateSCFLoopBodyCanonicalizationPatterns(patterns);
scf::populateSCFForLoopCanonicalizationPatterns(patterns);
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
do {
(void)applyPatternsAndFoldGreedily(getFunction(), frozenPatterns);

View file

@ -1491,6 +1491,7 @@ cc_library(
":SCFPassIncGen",
":StandardOps",
":Support",
":TensorDialect",
":Transforms",
"//llvm:Support",
],