[ODS] Support numRegions in Op definition

--

PiperOrigin-RevId: 250282024
This commit is contained in:
Lei Zhang 2019-05-28 08:03:46 -07:00 committed by Mehdi Amini
parent c2d105811a
commit d4c8c8de42
6 changed files with 76 additions and 3 deletions

View file

@ -916,6 +916,10 @@ class Op<Dialect dialect, string mnemonic, list<OpTrait> props = []> {
// The list of results of the op. Default to 0 results.
dag results = (outs);
// How many regions this op has.
// TODO(b/133479568): Enhance to support advanced region usage cases
int numRegions = 0;
// Attribute getters can be added to the op by adding an Attr member
// with the name and type of the attribute. E.g., adding int attribute
// with name "value" and type "i32":

View file

@ -129,6 +129,9 @@ public:
// requiring the raw MLIR trait here.
bool hasTrait(llvm::StringRef trait) const;
// Returns the number of regions.
int getNumRegions() const;
// Trait.
using const_trait_iterator = const OpTrait *;
const_trait_iterator trait_begin() const;
@ -174,6 +177,9 @@ private:
// The traits of the op.
SmallVector<OpTrait, 4> traits;
// The number of regions of this op.
int numRegions = 0;
// The number of native attributes stored in the leading positions of
// `attributes`.
int numNativeAttributes;

View file

@ -146,6 +146,8 @@ bool tblgen::Operator::hasTrait(StringRef trait) const {
return false;
}
int tblgen::Operator::getNumRegions() const { return numRegions; }
auto tblgen::Operator::trait_begin() const -> const_trait_iterator {
return traits.begin();
}
@ -265,6 +267,11 @@ void tblgen::Operator::populateOpStructure() {
traits.reserve(traitListInit->size());
for (auto traitInit : *traitListInit)
traits.push_back(OpTrait::create(traitInit));
// Handle regions
numRegions = def.getValueAsInt("numRegions");
if (numRegions < 0)
PrintFatalError(def.getLoc(), "numRegions cannot be negative");
}
ArrayRef<llvm::SMLoc> tblgen::Operator::getLoc() const { return def.getLoc(); }

32
mlir/test/IR/region.mlir Normal file
View file

@ -0,0 +1,32 @@
// RUN: mlir-test-opt %s -split-input-file -verify | FileCheck %s
func @correct_number_of_regions() {
// CHECK: test.two_region_op
"test.two_region_op"()(
{"work"() : () -> ()},
{"work"() : () -> ()}
) : () -> ()
return
}
// -----
func @missingk_regions() {
// expected-error@+1 {{op has incorrect number of regions: expected 2 but found 1}}
"test.two_region_op"()(
{"work"() : () -> ()}
) : () -> ()
return
}
// -----
func @extra_regions() {
// expected-error@+1 {{op has incorrect number of regions: expected 2 but found 3}}
"test.two_region_op"()(
{"work"() : () -> ()},
{"work"() : () -> ()},
{"work"() : () -> ()}
) : () -> ()
return
}

View file

@ -113,4 +113,12 @@ def : Pat<(OpD $input), (OpF $input), [], (addBenefit 10)>;
def : Pat<(OpG $input), (OpB $input, ConstantAttr<I32Attr, "20">:$attr)>;
def : Pat<(OpG (OpG $input)), (OpB $input, ConstantAttr<I32Attr, "34">:$attr)>;
//===----------------------------------------------------------------------===//
// Test op regions
//===----------------------------------------------------------------------===//
def TwoRegionOp : TEST_Op<"two_region_op", []> {
let numRegions = 2;
}
#endif // TEST_OPS

View file

@ -742,6 +742,12 @@ void OpEmitter::genStandaloneParamBuilder(bool useOperandType,
}
}
}
// Create the correct number of regions
if (int numRegions = op.getNumRegions()) {
for (int i = 0; i < numRegions; ++i)
m.body() << " (void)" << builderOpState << "->addRegion();\n";
}
}
void OpEmitter::genBuilder() {
@ -820,6 +826,12 @@ void OpEmitter::genBuilder() {
<< " " << builderOpState
<< "->addAttribute(pair.first, pair.second);\n";
// Create the correct number of regions
if (int numRegions = op.getNumRegions()) {
for (int i = 0; i < numRegions; ++i)
m.body() << " (void)" << builderOpState << "->addRegion();\n";
}
// 3. Deduced result types
bool useOperandType = op.hasTrait("SameOperandsAndResultType");
@ -883,9 +895,6 @@ void OpEmitter::genVerifier() {
auto valueInit = def.getValueInit("verifier");
CodeInit *codeInit = dyn_cast<CodeInit>(valueInit);
bool hasCustomVerify = codeInit && !codeInit->getValue().empty();
if (!hasCustomVerify && op.getNumArgs() == 0 && op.getNumResults() == 0 &&
op.getNumPredOpTraits() == 0)
return;
auto &method = opClass.newMethod("LogicalResult", "verify", /*params=*/"");
auto &body = method.body();
@ -972,6 +981,13 @@ void OpEmitter::genVerifier() {
}
}
// Verify this op has the correct number of regions
body << formatv(
" if (this->getOperation()->getNumRegions() != {0}) \n return "
"emitOpError(\"has incorrect number of regions: expected {0} but found "
"\") << this->getOperation()->getNumRegions();\n",
op.getNumRegions());
if (hasCustomVerify)
body << codeInit->getValue() << "\n";
else