[mlir][interfaces] Add helpers for detecting recursive regions

Add helper functions to check if an op may be executed multiple times based on RegionBranchOpInterface.

Differential Revision: https://reviews.llvm.org/D123789
This commit is contained in:
Matthias Springer 2022-04-19 16:12:40 +09:00
parent c5cac48549
commit 0f4ba02db3
4 changed files with 140 additions and 1 deletions

View file

@ -216,6 +216,16 @@ private:
/// RegionBranchOpInterface.
bool insideMutuallyExclusiveRegions(Operation *a, Operation *b);
/// Return the first enclosing region of the given op that may be executed
/// repetitively as per RegionBranchOpInterface or `nullptr` if no such region
/// exists.
Region *getEnclosingRepetitiveRegion(Operation *op);
/// Return the first enclosing region of the given Value that may be executed
/// repetitively as per RegionBranchOpInterface or `nullptr` if no such region
/// exists.
Region *getEnclosingRepetitiveRegion(Value value);
//===----------------------------------------------------------------------===//
// RegionBranchTerminatorOpInterface
//===----------------------------------------------------------------------===//

View file

@ -211,6 +211,11 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
SmallVector<Attribute, 2> nullAttrs(getOperation()->getNumOperands());
getSuccessorRegions(index, nullAttrs, regions);
}
/// Return `true` if control flow originating from the given region may
/// eventually branch back to the same region. (Maybe after passing through
/// other regions.)
bool isRepetitiveRegion(unsigned index);
}];
}

View file

@ -309,6 +309,57 @@ bool mlir::insideMutuallyExclusiveRegions(Operation *a, Operation *b) {
return false;
}
bool RegionBranchOpInterface::isRepetitiveRegion(unsigned index) {
SmallVector<bool> visited(getOperation()->getNumRegions(), false);
visited[index] = true;
// Retrieve all successors of the region and enqueue them in the worklist.
SmallVector<unsigned> worklist;
auto enqueueAllSuccessors = [&](unsigned index) {
SmallVector<RegionSuccessor> successors;
this->getSuccessorRegions(index, successors);
for (RegionSuccessor successor : successors)
if (!successor.isParent())
worklist.push_back(successor.getSuccessor()->getRegionNumber());
};
enqueueAllSuccessors(index);
// Process all regions in the worklist via DFS.
while (!worklist.empty()) {
unsigned nextRegion = worklist.pop_back_val();
if (nextRegion == index)
return true;
if (visited[nextRegion])
continue;
visited[nextRegion] = true;
enqueueAllSuccessors(nextRegion);
}
return false;
}
Region *mlir::getEnclosingRepetitiveRegion(Operation *op) {
while (Region *region = op->getParentRegion()) {
op = region->getParentOp();
if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op))
if (branchOp.isRepetitiveRegion(region->getRegionNumber()))
return region;
}
return nullptr;
}
Region *mlir::getEnclosingRepetitiveRegion(Value value) {
Region *region = value.getParentRegion();
while (region) {
Operation *op = region->getParentOp();
if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op))
if (branchOp.isRepetitiveRegion(region->getRegionNumber()))
return region;
region = op->getParentRegion();
}
return nullptr;
}
//===----------------------------------------------------------------------===//
// RegionBranchTerminatorOpInterface
//===----------------------------------------------------------------------===//

View file

@ -42,6 +42,29 @@ struct MutuallyExclusiveRegionsOp
SmallVectorImpl<RegionSuccessor> &regions) {}
};
/// All regions of this op call each other in a large circle.
struct LoopRegionsOp
: public Op<LoopRegionsOp, RegionBranchOpInterface::Trait> {
using Op::Op;
static const unsigned kNumRegions = 3;
static ArrayRef<StringRef> getAttributeNames() { return {}; }
static StringRef getOperationName() { return "cftest.loop_regions_op"; }
void getSuccessorRegions(Optional<unsigned> index,
ArrayRef<Attribute> operands,
SmallVectorImpl<RegionSuccessor> &regions) {
if (index) {
if (*index == 1)
// This region also branches back to the parent.
regions.push_back(RegionSuccessor());
regions.push_back(
RegionSuccessor(&getOperation()->getRegion(*index % kNumRegions)));
}
}
};
/// Regions are executed sequentially.
struct SequentialRegionsOp
: public Op<SequentialRegionsOp, RegionBranchOpInterface::Trait> {
@ -65,7 +88,8 @@ struct SequentialRegionsOp
struct CFTestDialect : Dialect {
explicit CFTestDialect(MLIRContext *ctx)
: Dialect(getDialectNamespace(), ctx, TypeID::get<CFTestDialect>()) {
addOperations<DummyOp, MutuallyExclusiveRegionsOp, SequentialRegionsOp>();
addOperations<DummyOp, MutuallyExclusiveRegionsOp, LoopRegionsOp,
SequentialRegionsOp>();
}
static StringRef getDialectNamespace() { return "cftest"; }
};
@ -142,3 +166,52 @@ TEST(RegionBranchOpInterface, NestedMutuallyExclusiveOps) {
EXPECT_TRUE(insideMutuallyExclusiveRegions(op3, op2));
EXPECT_FALSE(insideMutuallyExclusiveRegions(op1, op3));
}
TEST(RegionBranchOpInterface, RecursiveRegions) {
const char *ir = R"MLIR(
"cftest.loop_regions_op"() (
{"cftest.dummy_op"() : () -> ()}, // op1
{"cftest.dummy_op"() : () -> ()}, // op2
{"cftest.dummy_op"() : () -> ()} // op3
) : () -> ()
)MLIR";
DialectRegistry registry;
registry.insert<CFTestDialect>();
MLIRContext ctx(registry);
OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
Operation *testOp = &module->getBody()->getOperations().front();
auto regionOp = cast<RegionBranchOpInterface>(testOp);
Operation *op1 = &testOp->getRegion(0).front().front();
Operation *op2 = &testOp->getRegion(1).front().front();
Operation *op3 = &testOp->getRegion(2).front().front();
EXPECT_TRUE(regionOp.isRepetitiveRegion(0));
EXPECT_TRUE(regionOp.isRepetitiveRegion(1));
EXPECT_TRUE(regionOp.isRepetitiveRegion(2));
EXPECT_NE(getEnclosingRepetitiveRegion(op1), nullptr);
EXPECT_NE(getEnclosingRepetitiveRegion(op2), nullptr);
EXPECT_NE(getEnclosingRepetitiveRegion(op3), nullptr);
}
TEST(RegionBranchOpInterface, NotRecursiveRegions) {
const char *ir = R"MLIR(
"cftest.sequential_regions_op"() (
{"cftest.dummy_op"() : () -> ()}, // op1
{"cftest.dummy_op"() : () -> ()} // op2
) : () -> ()
)MLIR";
DialectRegistry registry;
registry.insert<CFTestDialect>();
MLIRContext ctx(registry);
OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
Operation *testOp = &module->getBody()->getOperations().front();
Operation *op1 = &testOp->getRegion(0).front().front();
Operation *op2 = &testOp->getRegion(1).front().front();
EXPECT_EQ(getEnclosingRepetitiveRegion(op1), nullptr);
EXPECT_EQ(getEnclosingRepetitiveRegion(op2), nullptr);
}