[SCF] Handle lowering of Execute region to Standard CFG

Lower SCF.executeregionop to llvm by essentially inlining the region and replacing the return

Differential Revision: https://reviews.llvm.org/D105567
This commit is contained in:
William S. Moses 2021-07-07 14:27:35 -04:00
parent eaf22ba011
commit 9a11c70c18
2 changed files with 76 additions and 3 deletions

View file

@ -194,6 +194,13 @@ struct IfLowering : public OpRewritePattern<IfOp> {
PatternRewriter &rewriter) const override; PatternRewriter &rewriter) const override;
}; };
struct ExecuteRegionLowering : public OpRewritePattern<ExecuteRegionOp> {
using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ExecuteRegionOp op,
PatternRewriter &rewriter) const override;
};
struct ParallelLowering : public OpRewritePattern<mlir::scf::ParallelOp> { struct ParallelLowering : public OpRewritePattern<mlir::scf::ParallelOp> {
using OpRewritePattern<mlir::scf::ParallelOp>::OpRewritePattern; using OpRewritePattern<mlir::scf::ParallelOp>::OpRewritePattern;
@ -400,6 +407,38 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
return success(); return success();
} }
LogicalResult
ExecuteRegionLowering::matchAndRewrite(ExecuteRegionOp op,
PatternRewriter &rewriter) const {
auto loc = op.getLoc();
auto *condBlock = rewriter.getInsertionBlock();
auto opPosition = rewriter.getInsertionPoint();
auto *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition);
auto &region = op.region();
rewriter.setInsertionPointToEnd(condBlock);
rewriter.create<BranchOp>(loc, &region.front());
for (Block &block : region) {
if (auto terminator = dyn_cast<scf::YieldOp>(block.getTerminator())) {
ValueRange terminatorOperands = terminator->getOperands();
rewriter.setInsertionPointToEnd(&block);
rewriter.create<BranchOp>(loc, remainingOpsBlock, terminatorOperands);
rewriter.eraseOp(terminator);
}
}
rewriter.inlineRegionBefore(region, remainingOpsBlock);
SmallVector<Value> vals;
for (auto arg : remainingOpsBlock->addArguments(op->getResultTypes())) {
vals.push_back(arg);
}
rewriter.replaceOp(op, vals);
return success();
}
LogicalResult LogicalResult
ParallelLowering::matchAndRewrite(ParallelOp parallelOp, ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
PatternRewriter &rewriter) const { PatternRewriter &rewriter) const {
@ -569,8 +608,8 @@ DoWhileLowering::matchAndRewrite(WhileOp whileOp,
} }
void mlir::populateLoopToStdConversionPatterns(RewritePatternSet &patterns) { void mlir::populateLoopToStdConversionPatterns(RewritePatternSet &patterns) {
patterns.add<ForLowering, IfLowering, ParallelLowering, WhileLowering>( patterns.add<ForLowering, IfLowering, ParallelLowering, WhileLowering,
patterns.getContext()); ExecuteRegionLowering>(patterns.getContext());
patterns.add<DoWhileLowering>(patterns.getContext(), /*benefit=*/2); patterns.add<DoWhileLowering>(patterns.getContext(), /*benefit=*/2);
} }
@ -580,7 +619,8 @@ void SCFToStandardPass::runOnOperation() {
// Configure conversion to lower out scf.for, scf.if, scf.parallel and // Configure conversion to lower out scf.for, scf.if, scf.parallel and
// scf.while. Anything else is fine. // scf.while. Anything else is fine.
ConversionTarget target(getContext()); ConversionTarget target(getContext());
target.addIllegalOp<scf::ForOp, scf::IfOp, scf::ParallelOp, scf::WhileOp>(); target.addIllegalOp<scf::ForOp, scf::IfOp, scf::ParallelOp, scf::WhileOp,
scf::ExecuteRegionOp>();
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
if (failed( if (failed(
applyPartialConversion(getOperation(), target, std::move(patterns)))) applyPartialConversion(getOperation(), target, std::move(patterns))))

View file

@ -587,3 +587,36 @@ func @ifs_in_parallel(%arg1: index, %arg2: index, %arg3: index, %arg4: i1, %arg5
// CHECK: return // CHECK: return
return return
} }
// CHECK-LABEL: func @func_execute_region_elim_multi_yield
func @func_execute_region_elim_multi_yield() {
"test.foo"() : () -> ()
%v = scf.execute_region -> i64 {
%c = "test.cmp"() : () -> i1
cond_br %c, ^bb2, ^bb3
^bb2:
%x = "test.val1"() : () -> i64
scf.yield %x : i64
^bb3:
%y = "test.val2"() : () -> i64
scf.yield %y : i64
}
"test.bar"(%v) : (i64) -> ()
return
}
// CHECK-NOT: execute_region
// CHECK: "test.foo"
// CHECK: br ^[[rentry:.+]]
// CHECK: ^[[rentry]]
// CHECK: %[[cmp:.+]] = "test.cmp"
// CHECK: cond_br %[[cmp]], ^[[bb1:.+]], ^[[bb2:.+]]
// CHECK: ^[[bb1]]:
// CHECK: %[[x:.+]] = "test.val1"
// CHECK: br ^[[bb3:.+]](%[[x]] : i64)
// CHECK: ^[[bb2]]:
// CHECK: %[[y:.+]] = "test.val2"
// CHECK: br ^[[bb3]](%[[y:.+]] : i64)
// CHECK: ^[[bb3]](%[[z:.+]]: i64):
// CHECK: "test.bar"(%[[z]])
// CHECK: return