[TableGen] Model variadic operands using Variadic<Type>

Previously, we were using the trait mechanism to specify that an op has variadic operands.
That led a discrepancy between how we handle ops with deterministic number of operands.
Besides, we have no way to specify the constraints and match against the variadic operands.

This CL introduced Variadic<Type> as a way to solve the above issues.

PiperOrigin-RevId: 232656104
This commit is contained in:
Lei Zhang 2019-02-06 05:06:11 -08:00 committed by jpienaar
parent 0c65cf283c
commit 1df6ca5053
7 changed files with 84 additions and 44 deletions

View file

@ -100,10 +100,20 @@ class TypeConstraint<Pred condition, string descr> {
string description = descr;
}
// A type, carries type constraints, but accepts any type by default.
// A type, carries type constraints.
class Type<Pred condition, string descr = "">
: TypeConstraint<condition, descr>;
// A variadic type. It expands to zero or more of the base type.
// This class is used for supporting variadic operands/results. An op can
// declare no more than one variadic operand/result, and that operand/result
// must be the last one in the operand/result list.
class Variadic<Type type, string descr = "">
// TODO: support variadic type conditions
: Type<CPred<"true">, descr> {
Type baseType = type;
}
// A type that can be constructed using MLIR::Builder.
// Note that this does not "inherit" from Type because it would require
// duplicating Type subclasses for buildable and non-buildable cases to avoid
@ -352,11 +362,6 @@ class OpTrait<string prop> {
string trait = prop;
}
// Note: Ideally, we should be able to automatically deduce most of these traits
// from other bits of op definitions, especially those regarding the number of
// operands and results.
class AtLeastNOperands<int c> : OpTrait<"AtLeastNOperands<" # c # ">::Impl">;
// op supports operand broadcast behavior
def Broadcastable : OpTrait<"BroadcastableTwoOperandsOneResult">;
// X op Y == Y op X
@ -367,8 +372,6 @@ def NoSideEffect : OpTrait<"HasNoSideEffect">;
def SameValueType : OpTrait<"SameOperandsAndResultType">;
// op is a terminator
def Terminator : OpTrait<"IsTerminator">;
// op has an unknown number of operands
def VariadicOperands : OpTrait<"VariadicOperands">;
//===----------------------------------------------------------------------===//
// Ops

View file

@ -89,8 +89,8 @@ class LLVM_ZeroResultOp<string mnemonic, list<OpTrait> traits = []> :
// Base class for LLVM terminator operations. All terminator operations have
// zero results and an optional list of successors.
class LLVM_TerminatorOp<string mnemonic, list<OpTrait> traits = []> :
LLVM_Op<mnemonic, !listconcat(traits, [Terminator, VariadicOperands])>,
Results<(outs)> {
LLVM_Op<mnemonic, !listconcat(traits, [Terminator])>,
Arguments<(ins Variadic<LLVM_Type>)>, Results<(outs)> {
let builder = [{
static void build(Builder *builder, OperationState *result,
ArrayRef<Value *> properOperands,
@ -117,7 +117,7 @@ class LLVM_ArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
// Class for variadic instructions.
class LLVM_VariadicOneResultOp<string mnemonic, list<OpTrait> traits = []> :
LLVM_OneResultOp<mnemonic, !listconcat(traits, [VariadicOperands])>;
LLVM_OneResultOp<mnemonic, traits>, Arguments<(ins Variadic<LLVM_Type>)>;
// Integer binary instructions.
def LLVM_AddOp : LLVM_ArithmeticOp<"add", [Commutative]>;
@ -151,7 +151,8 @@ def LLVM_BitcastOp : LLVM_OneResultOp<"bitcast", [NoSideEffect]>,
// Call-related instructions.
def LLVM_CallOp : LLVM_VariadicOneResultOp<"call">;
def LLVM_Call0Op : LLVM_ZeroResultOp<"call0", [VariadicOperands]>;
def LLVM_Call0Op : LLVM_ZeroResultOp<"call0", []>,
Arguments<(ins Variadic<LLVM_Type>)>;
def LLVM_ExtractValueOp : LLVM_OneResultOp<"extractvalue", [NoSideEffect]>,
Arguments<(ins LLVM_Type)>;
def LLVM_InsertValueOp : LLVM_OneResultOp<"insertvalue", [NoSideEffect]>,

View file

@ -97,6 +97,9 @@ public:
Operand &getOperand(int index) { return operands[index]; }
const Operand &getOperand(int index) const { return operands[index]; }
// Returns true if this operation has a variadic operand.
bool hasVariadicOperand() const;
// Op argument (attribute or operand) accessors.
Argument getArg(int index);
StringRef getArgName(int index) const;

View file

@ -71,6 +71,13 @@ public:
explicit Type(const llvm::Record *record) : Type(*record) {}
explicit Type(const llvm::DefInit *init);
// Returns true if this is a variadic type.
bool isVariadic() const;
// Gets the base type of this variadic type.
// Precondition: This type is a variadic type.
Type getVariadicBaseType() const;
// Returns the TableGen def name for this type.
StringRef getTableGenDefName() const;
};

View file

@ -84,6 +84,10 @@ const tblgen::NamedAttribute &tblgen::Operator::getAttribute(int index) const {
return attributes[index];
}
bool tblgen::Operator::hasVariadicOperand() const {
return !operands.empty() && operands.back().type.isVariadic();
}
StringRef tblgen::Operator::getArgName(int index) const {
DagInit *argumentValues = def.getValueAsDag("arguments");
return argumentValues->getArgName(index)->getValue();
@ -179,6 +183,12 @@ void tblgen::Operator::populateOperandsAndAttributes() {
Attribute(cast<DefInit>(val.getValue()))});
}
}
for (int i = 0, e = operands.size() - 1; i < e; ++i) {
if (operands[i].type.isVariadic())
PrintFatalError(def.getLoc(),
"only the last operand allowed to be variadic");
}
}
ArrayRef<llvm::SMLoc> tblgen::Operator::getLoc() const { return def.getLoc(); }

View file

@ -61,3 +61,10 @@ tblgen::Type::Type(const llvm::Record &record) : TypeConstraint(record) {
tblgen::Type::Type(const llvm::DefInit *init) : Type(*init->getDef()) {}
StringRef tblgen::Type::getTableGenDefName() const { return def->getName(); }
bool tblgen::Type::isVariadic() const { return def->isSubClassOf("Variadic"); }
tblgen::Type tblgen::Type::getVariadicBaseType() const {
assert(isVariadic() && "must be variadic type");
return Type(def->getValueAsDef("baseType"));
}

View file

@ -228,7 +228,7 @@ void OpEmitter::emitNamedOperands() {
)";
for (int i = 0, e = op.getNumOperands(); i != e; ++i) {
const auto &operand = op.getOperand(i);
if (!operand.name.empty())
if (!operand.type.isVariadic() && !operand.name.empty())
os << formatv(operandMethods, operand.name, i);
}
}
@ -260,8 +260,11 @@ void OpEmitter::emitBuilder() {
os << ", Type returnType" << i;
// Emit parameters for all operands
for (int i = 0, e = op.getNumOperands(); i != e; ++i)
os << ", Value* " << getArgumentName(op, i);
for (int i = 0, e = op.getNumOperands(); i != e; ++i) {
auto &operand = op.getOperand(i);
os << (operand.type.isVariadic() ? ", ArrayRef<Value*> " : ", Value* ")
<< getArgumentName(op, i);
}
// Emit parameters for all attributes
// TODO(antiagainst): Support default initializer for attributes
@ -283,12 +286,20 @@ void OpEmitter::emitBuilder() {
}
// Push all operands to the result
if (op.getNumOperands() > 0) {
auto numOperands = op.getNumOperands();
bool hasVariadicOperand = op.hasVariadicOperand();
int numNonVariadicOperands = numOperands - int(hasVariadicOperand);
if (numNonVariadicOperands > 0) {
OUT(4) << "result->addOperands({" << getArgumentName(op, 0);
for (int i = 1, e = op.getNumOperands(); i != e; ++i)
for (int i = 1, e = numNonVariadicOperands; i < e; ++i) {
os << ", " << getArgumentName(op, i);
}
os << "});\n";
}
if (hasVariadicOperand) {
OUT(4) << "result->addOperands(" << getArgumentName(op, numOperands - 1)
<< ");\n";
}
// Push all attributes to the result
for (const auto &namedAttr : op.getAttributes())
@ -310,7 +321,7 @@ void OpEmitter::emitBuilder() {
<< " result->addTypes(resultTypes);\n";
// Operands
OUT(4) << "assert(args.size() == " << op.getNumOperands()
OUT(4) << "assert(args.size() == " << numNonVariadicOperands
<< "u && \"mismatched number of parameters\");\n"
<< " result->addOperands(args);\n\n";
@ -422,9 +433,12 @@ void OpEmitter::emitVerifier() {
OUT(4) << "}\n";
}
// TODO: Handle variadic.
int opIndex = 0;
for (const auto &operand : op.getOperands()) {
// TODO: Handle variadic operand verification.
if (operand.type.isVariadic())
continue;
// TODO: Commonality between matchers could be extracted to have a more
// concise code.
if (operand.hasMatcher()) {
@ -466,38 +480,33 @@ void OpEmitter::emitTraits() {
break;
}
// Track explicitly added operand size traits. Note that some ops might
// implicitly defines the number of operands via the Argument dag.
bool hasVariadicOperands = false;
bool hasAtLeastNOperands = false;
// Add variadic size trait and normal op traits.
for (StringRef trait : def.getValueAsListOfStrings("traits")) {
if (trait == "VariadicOperands") {
hasVariadicOperands = true;
} else if (trait.startswith("AtLeastNOperands")) {
hasAtLeastNOperands = true;
}
os << ", OpTrait::" << trait;
}
if ((hasVariadicOperands || hasAtLeastNOperands) && op.getNumOperands() > 0) {
PrintFatalError(def.getLoc(),
"Operands number definition is not consistent.");
}
auto numOperands = op.getNumOperands();
bool hasVariadicOperand = op.hasVariadicOperand();
// Add operand size trait.
switch (op.getNumOperands()) {
case 0:
if (!hasVariadicOperands && !hasAtLeastNOperands)
os << ", OpTrait::ZeroOperands";
break;
case 1:
os << ", OpTrait::OneOperand";
break;
default:
os << ", OpTrait::NOperands<" << op.getNumOperands() << ">::Impl";
break;
os << ", OpTrait::";
if (hasVariadicOperand) {
if (numOperands == 1)
os << "VariadicOperands";
else
os << "AtLeastNOperands<" << (numOperands - 1) << ">::Impl";
} else {
switch (op.getNumOperands()) {
case 0:
os << "ZeroOperands";
break;
case 1:
os << "OneOperand";
break;
default:
os << "NOperands<" << numOperands << ">::Impl";
break;
}
}
}