[SCF] Handle lowering of Execute region to Standard CFG
Lower SCF.executeregionop to llvm by essentially inlining the region and replacing the return Differential Revision: https://reviews.llvm.org/D105567
This commit is contained in:
parent
eaf22ba011
commit
9a11c70c18
|
@ -194,6 +194,13 @@ struct IfLowering : public OpRewritePattern<IfOp> {
|
||||||
PatternRewriter &rewriter) const override;
|
PatternRewriter &rewriter) const override;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct ExecuteRegionLowering : public OpRewritePattern<ExecuteRegionOp> {
|
||||||
|
using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(ExecuteRegionOp op,
|
||||||
|
PatternRewriter &rewriter) const override;
|
||||||
|
};
|
||||||
|
|
||||||
struct ParallelLowering : public OpRewritePattern<mlir::scf::ParallelOp> {
|
struct ParallelLowering : public OpRewritePattern<mlir::scf::ParallelOp> {
|
||||||
using OpRewritePattern<mlir::scf::ParallelOp>::OpRewritePattern;
|
using OpRewritePattern<mlir::scf::ParallelOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
@ -400,6 +407,38 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
ExecuteRegionLowering::matchAndRewrite(ExecuteRegionOp op,
|
||||||
|
PatternRewriter &rewriter) const {
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
|
||||||
|
auto *condBlock = rewriter.getInsertionBlock();
|
||||||
|
auto opPosition = rewriter.getInsertionPoint();
|
||||||
|
auto *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition);
|
||||||
|
|
||||||
|
auto ®ion = op.region();
|
||||||
|
rewriter.setInsertionPointToEnd(condBlock);
|
||||||
|
rewriter.create<BranchOp>(loc, ®ion.front());
|
||||||
|
|
||||||
|
for (Block &block : region) {
|
||||||
|
if (auto terminator = dyn_cast<scf::YieldOp>(block.getTerminator())) {
|
||||||
|
ValueRange terminatorOperands = terminator->getOperands();
|
||||||
|
rewriter.setInsertionPointToEnd(&block);
|
||||||
|
rewriter.create<BranchOp>(loc, remainingOpsBlock, terminatorOperands);
|
||||||
|
rewriter.eraseOp(terminator);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.inlineRegionBefore(region, remainingOpsBlock);
|
||||||
|
|
||||||
|
SmallVector<Value> vals;
|
||||||
|
for (auto arg : remainingOpsBlock->addArguments(op->getResultTypes())) {
|
||||||
|
vals.push_back(arg);
|
||||||
|
}
|
||||||
|
rewriter.replaceOp(op, vals);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult
|
||||||
ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
|
ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
|
||||||
PatternRewriter &rewriter) const {
|
PatternRewriter &rewriter) const {
|
||||||
|
@ -569,8 +608,8 @@ DoWhileLowering::matchAndRewrite(WhileOp whileOp,
|
||||||
}
|
}
|
||||||
|
|
||||||
void mlir::populateLoopToStdConversionPatterns(RewritePatternSet &patterns) {
|
void mlir::populateLoopToStdConversionPatterns(RewritePatternSet &patterns) {
|
||||||
patterns.add<ForLowering, IfLowering, ParallelLowering, WhileLowering>(
|
patterns.add<ForLowering, IfLowering, ParallelLowering, WhileLowering,
|
||||||
patterns.getContext());
|
ExecuteRegionLowering>(patterns.getContext());
|
||||||
patterns.add<DoWhileLowering>(patterns.getContext(), /*benefit=*/2);
|
patterns.add<DoWhileLowering>(patterns.getContext(), /*benefit=*/2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -580,7 +619,8 @@ void SCFToStandardPass::runOnOperation() {
|
||||||
// Configure conversion to lower out scf.for, scf.if, scf.parallel and
|
// Configure conversion to lower out scf.for, scf.if, scf.parallel and
|
||||||
// scf.while. Anything else is fine.
|
// scf.while. Anything else is fine.
|
||||||
ConversionTarget target(getContext());
|
ConversionTarget target(getContext());
|
||||||
target.addIllegalOp<scf::ForOp, scf::IfOp, scf::ParallelOp, scf::WhileOp>();
|
target.addIllegalOp<scf::ForOp, scf::IfOp, scf::ParallelOp, scf::WhileOp,
|
||||||
|
scf::ExecuteRegionOp>();
|
||||||
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
|
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
|
||||||
if (failed(
|
if (failed(
|
||||||
applyPartialConversion(getOperation(), target, std::move(patterns))))
|
applyPartialConversion(getOperation(), target, std::move(patterns))))
|
||||||
|
|
|
@ -587,3 +587,36 @@ func @ifs_in_parallel(%arg1: index, %arg2: index, %arg3: index, %arg4: i1, %arg5
|
||||||
// CHECK: return
|
// CHECK: return
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @func_execute_region_elim_multi_yield
|
||||||
|
func @func_execute_region_elim_multi_yield() {
|
||||||
|
"test.foo"() : () -> ()
|
||||||
|
%v = scf.execute_region -> i64 {
|
||||||
|
%c = "test.cmp"() : () -> i1
|
||||||
|
cond_br %c, ^bb2, ^bb3
|
||||||
|
^bb2:
|
||||||
|
%x = "test.val1"() : () -> i64
|
||||||
|
scf.yield %x : i64
|
||||||
|
^bb3:
|
||||||
|
%y = "test.val2"() : () -> i64
|
||||||
|
scf.yield %y : i64
|
||||||
|
}
|
||||||
|
"test.bar"(%v) : (i64) -> ()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-NOT: execute_region
|
||||||
|
// CHECK: "test.foo"
|
||||||
|
// CHECK: br ^[[rentry:.+]]
|
||||||
|
// CHECK: ^[[rentry]]
|
||||||
|
// CHECK: %[[cmp:.+]] = "test.cmp"
|
||||||
|
// CHECK: cond_br %[[cmp]], ^[[bb1:.+]], ^[[bb2:.+]]
|
||||||
|
// CHECK: ^[[bb1]]:
|
||||||
|
// CHECK: %[[x:.+]] = "test.val1"
|
||||||
|
// CHECK: br ^[[bb3:.+]](%[[x]] : i64)
|
||||||
|
// CHECK: ^[[bb2]]:
|
||||||
|
// CHECK: %[[y:.+]] = "test.val2"
|
||||||
|
// CHECK: br ^[[bb3]](%[[y:.+]] : i64)
|
||||||
|
// CHECK: ^[[bb3]](%[[z:.+]]: i64):
|
||||||
|
// CHECK: "test.bar"(%[[z]])
|
||||||
|
// CHECK: return
|
||||||
|
|
Loading…
Reference in a new issue