[SCF] Handle lowering of Execute region to Standard CFG
Lower SCF.executeregionop to llvm by essentially inlining the region and replacing the return Differential Revision: https://reviews.llvm.org/D105567
This commit is contained in:
parent
eaf22ba011
commit
9a11c70c18
|
@ -194,6 +194,13 @@ struct IfLowering : public OpRewritePattern<IfOp> {
|
|||
PatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
struct ExecuteRegionLowering : public OpRewritePattern<ExecuteRegionOp> {
|
||||
using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ExecuteRegionOp op,
|
||||
PatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
struct ParallelLowering : public OpRewritePattern<mlir::scf::ParallelOp> {
|
||||
using OpRewritePattern<mlir::scf::ParallelOp>::OpRewritePattern;
|
||||
|
||||
|
@ -400,6 +407,38 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
|
|||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
ExecuteRegionLowering::matchAndRewrite(ExecuteRegionOp op,
|
||||
PatternRewriter &rewriter) const {
|
||||
auto loc = op.getLoc();
|
||||
|
||||
auto *condBlock = rewriter.getInsertionBlock();
|
||||
auto opPosition = rewriter.getInsertionPoint();
|
||||
auto *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition);
|
||||
|
||||
auto ®ion = op.region();
|
||||
rewriter.setInsertionPointToEnd(condBlock);
|
||||
rewriter.create<BranchOp>(loc, ®ion.front());
|
||||
|
||||
for (Block &block : region) {
|
||||
if (auto terminator = dyn_cast<scf::YieldOp>(block.getTerminator())) {
|
||||
ValueRange terminatorOperands = terminator->getOperands();
|
||||
rewriter.setInsertionPointToEnd(&block);
|
||||
rewriter.create<BranchOp>(loc, remainingOpsBlock, terminatorOperands);
|
||||
rewriter.eraseOp(terminator);
|
||||
}
|
||||
}
|
||||
|
||||
rewriter.inlineRegionBefore(region, remainingOpsBlock);
|
||||
|
||||
SmallVector<Value> vals;
|
||||
for (auto arg : remainingOpsBlock->addArguments(op->getResultTypes())) {
|
||||
vals.push_back(arg);
|
||||
}
|
||||
rewriter.replaceOp(op, vals);
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
|
||||
PatternRewriter &rewriter) const {
|
||||
|
@ -569,8 +608,8 @@ DoWhileLowering::matchAndRewrite(WhileOp whileOp,
|
|||
}
|
||||
|
||||
void mlir::populateLoopToStdConversionPatterns(RewritePatternSet &patterns) {
|
||||
patterns.add<ForLowering, IfLowering, ParallelLowering, WhileLowering>(
|
||||
patterns.getContext());
|
||||
patterns.add<ForLowering, IfLowering, ParallelLowering, WhileLowering,
|
||||
ExecuteRegionLowering>(patterns.getContext());
|
||||
patterns.add<DoWhileLowering>(patterns.getContext(), /*benefit=*/2);
|
||||
}
|
||||
|
||||
|
@ -580,7 +619,8 @@ void SCFToStandardPass::runOnOperation() {
|
|||
// Configure conversion to lower out scf.for, scf.if, scf.parallel and
|
||||
// scf.while. Anything else is fine.
|
||||
ConversionTarget target(getContext());
|
||||
target.addIllegalOp<scf::ForOp, scf::IfOp, scf::ParallelOp, scf::WhileOp>();
|
||||
target.addIllegalOp<scf::ForOp, scf::IfOp, scf::ParallelOp, scf::WhileOp,
|
||||
scf::ExecuteRegionOp>();
|
||||
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
|
||||
if (failed(
|
||||
applyPartialConversion(getOperation(), target, std::move(patterns))))
|
||||
|
|
|
@ -587,3 +587,36 @@ func @ifs_in_parallel(%arg1: index, %arg2: index, %arg3: index, %arg4: i1, %arg5
|
|||
// CHECK: return
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @func_execute_region_elim_multi_yield
|
||||
func @func_execute_region_elim_multi_yield() {
|
||||
"test.foo"() : () -> ()
|
||||
%v = scf.execute_region -> i64 {
|
||||
%c = "test.cmp"() : () -> i1
|
||||
cond_br %c, ^bb2, ^bb3
|
||||
^bb2:
|
||||
%x = "test.val1"() : () -> i64
|
||||
scf.yield %x : i64
|
||||
^bb3:
|
||||
%y = "test.val2"() : () -> i64
|
||||
scf.yield %y : i64
|
||||
}
|
||||
"test.bar"(%v) : (i64) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-NOT: execute_region
|
||||
// CHECK: "test.foo"
|
||||
// CHECK: br ^[[rentry:.+]]
|
||||
// CHECK: ^[[rentry]]
|
||||
// CHECK: %[[cmp:.+]] = "test.cmp"
|
||||
// CHECK: cond_br %[[cmp]], ^[[bb1:.+]], ^[[bb2:.+]]
|
||||
// CHECK: ^[[bb1]]:
|
||||
// CHECK: %[[x:.+]] = "test.val1"
|
||||
// CHECK: br ^[[bb3:.+]](%[[x]] : i64)
|
||||
// CHECK: ^[[bb2]]:
|
||||
// CHECK: %[[y:.+]] = "test.val2"
|
||||
// CHECK: br ^[[bb3]](%[[y:.+]] : i64)
|
||||
// CHECK: ^[[bb3]](%[[z:.+]]: i64):
|
||||
// CHECK: "test.bar"(%[[z]])
|
||||
// CHECK: return
|
||||
|
|
Loading…
Reference in a new issue