diff --git a/mlir/include/mlir/IR/op_base.td b/mlir/include/mlir/IR/op_base.td index 45f18b53154c..185a391105a1 100644 --- a/mlir/include/mlir/IR/op_base.td +++ b/mlir/include/mlir/IR/op_base.td @@ -100,10 +100,20 @@ class TypeConstraint { string description = descr; } -// A type, carries type constraints, but accepts any type by default. +// A type, carries type constraints. class Type : TypeConstraint; +// A variadic type. It expands to zero or more of the base type. +// This class is used for supporting variadic operands/results. An op can +// declare no more than one variadic operand/result, and that operand/result +// must be the last one in the operand/result list. +class Variadic + // TODO: support variadic type conditions + : Type, descr> { + Type baseType = type; +} + // A type that can be constructed using MLIR::Builder. // Note that this does not "inherit" from Type because it would require // duplicating Type subclasses for buildable and non-buildable cases to avoid @@ -352,11 +362,6 @@ class OpTrait { string trait = prop; } -// Note: Ideally, we should be able to automatically deduce most of these traits -// from other bits of op definitions, especially those regarding the number of -// operands and results. -class AtLeastNOperands : OpTrait<"AtLeastNOperands<" # c # ">::Impl">; - // op supports operand broadcast behavior def Broadcastable : OpTrait<"BroadcastableTwoOperandsOneResult">; // X op Y == Y op X @@ -367,8 +372,6 @@ def NoSideEffect : OpTrait<"HasNoSideEffect">; def SameValueType : OpTrait<"SameOperandsAndResultType">; // op is a terminator def Terminator : OpTrait<"IsTerminator">; -// op has an unknown number of operands -def VariadicOperands : OpTrait<"VariadicOperands">; //===----------------------------------------------------------------------===// // Ops diff --git a/mlir/include/mlir/LLVMIR/llvm_ops.td b/mlir/include/mlir/LLVMIR/llvm_ops.td index 378e3ea01c72..bb039f245c80 100644 --- a/mlir/include/mlir/LLVMIR/llvm_ops.td +++ b/mlir/include/mlir/LLVMIR/llvm_ops.td @@ -89,8 +89,8 @@ class LLVM_ZeroResultOp traits = []> : // Base class for LLVM terminator operations. All terminator operations have // zero results and an optional list of successors. class LLVM_TerminatorOp traits = []> : - LLVM_Op, - Results<(outs)> { + LLVM_Op, + Arguments<(ins Variadic)>, Results<(outs)> { let builder = [{ static void build(Builder *builder, OperationState *result, ArrayRef properOperands, @@ -117,7 +117,7 @@ class LLVM_ArithmeticOp traits = []> : // Class for variadic instructions. class LLVM_VariadicOneResultOp traits = []> : - LLVM_OneResultOp; + LLVM_OneResultOp, Arguments<(ins Variadic)>; // Integer binary instructions. def LLVM_AddOp : LLVM_ArithmeticOp<"add", [Commutative]>; @@ -151,7 +151,8 @@ def LLVM_BitcastOp : LLVM_OneResultOp<"bitcast", [NoSideEffect]>, // Call-related instructions. def LLVM_CallOp : LLVM_VariadicOneResultOp<"call">; -def LLVM_Call0Op : LLVM_ZeroResultOp<"call0", [VariadicOperands]>; +def LLVM_Call0Op : LLVM_ZeroResultOp<"call0", []>, + Arguments<(ins Variadic)>; def LLVM_ExtractValueOp : LLVM_OneResultOp<"extractvalue", [NoSideEffect]>, Arguments<(ins LLVM_Type)>; def LLVM_InsertValueOp : LLVM_OneResultOp<"insertvalue", [NoSideEffect]>, diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h index a75b909a9d50..bc34bca5946e 100644 --- a/mlir/include/mlir/TableGen/Operator.h +++ b/mlir/include/mlir/TableGen/Operator.h @@ -97,6 +97,9 @@ public: Operand &getOperand(int index) { return operands[index]; } const Operand &getOperand(int index) const { return operands[index]; } + // Returns true if this operation has a variadic operand. + bool hasVariadicOperand() const; + // Op argument (attribute or operand) accessors. Argument getArg(int index); StringRef getArgName(int index) const; diff --git a/mlir/include/mlir/TableGen/Type.h b/mlir/include/mlir/TableGen/Type.h index 247e0fc8e4b8..9b6ba99ebd57 100644 --- a/mlir/include/mlir/TableGen/Type.h +++ b/mlir/include/mlir/TableGen/Type.h @@ -71,6 +71,13 @@ public: explicit Type(const llvm::Record *record) : Type(*record) {} explicit Type(const llvm::DefInit *init); + // Returns true if this is a variadic type. + bool isVariadic() const; + + // Gets the base type of this variadic type. + // Precondition: This type is a variadic type. + Type getVariadicBaseType() const; + // Returns the TableGen def name for this type. StringRef getTableGenDefName() const; }; diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp index 21d855a4b182..390b4870599b 100644 --- a/mlir/lib/TableGen/Operator.cpp +++ b/mlir/lib/TableGen/Operator.cpp @@ -84,6 +84,10 @@ const tblgen::NamedAttribute &tblgen::Operator::getAttribute(int index) const { return attributes[index]; } +bool tblgen::Operator::hasVariadicOperand() const { + return !operands.empty() && operands.back().type.isVariadic(); +} + StringRef tblgen::Operator::getArgName(int index) const { DagInit *argumentValues = def.getValueAsDag("arguments"); return argumentValues->getArgName(index)->getValue(); @@ -179,6 +183,12 @@ void tblgen::Operator::populateOperandsAndAttributes() { Attribute(cast(val.getValue()))}); } } + + for (int i = 0, e = operands.size() - 1; i < e; ++i) { + if (operands[i].type.isVariadic()) + PrintFatalError(def.getLoc(), + "only the last operand allowed to be variadic"); + } } ArrayRef tblgen::Operator::getLoc() const { return def.getLoc(); } diff --git a/mlir/lib/TableGen/Type.cpp b/mlir/lib/TableGen/Type.cpp index 9fb7124a01a0..9067eedc9752 100644 --- a/mlir/lib/TableGen/Type.cpp +++ b/mlir/lib/TableGen/Type.cpp @@ -61,3 +61,10 @@ tblgen::Type::Type(const llvm::Record &record) : TypeConstraint(record) { tblgen::Type::Type(const llvm::DefInit *init) : Type(*init->getDef()) {} StringRef tblgen::Type::getTableGenDefName() const { return def->getName(); } + +bool tblgen::Type::isVariadic() const { return def->isSubClassOf("Variadic"); } + +tblgen::Type tblgen::Type::getVariadicBaseType() const { + assert(isVariadic() && "must be variadic type"); + return Type(def->getValueAsDef("baseType")); +} diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index 5264a9e15f53..d282dfc5058d 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -228,7 +228,7 @@ void OpEmitter::emitNamedOperands() { )"; for (int i = 0, e = op.getNumOperands(); i != e; ++i) { const auto &operand = op.getOperand(i); - if (!operand.name.empty()) + if (!operand.type.isVariadic() && !operand.name.empty()) os << formatv(operandMethods, operand.name, i); } } @@ -260,8 +260,11 @@ void OpEmitter::emitBuilder() { os << ", Type returnType" << i; // Emit parameters for all operands - for (int i = 0, e = op.getNumOperands(); i != e; ++i) - os << ", Value* " << getArgumentName(op, i); + for (int i = 0, e = op.getNumOperands(); i != e; ++i) { + auto &operand = op.getOperand(i); + os << (operand.type.isVariadic() ? ", ArrayRef " : ", Value* ") + << getArgumentName(op, i); + } // Emit parameters for all attributes // TODO(antiagainst): Support default initializer for attributes @@ -283,12 +286,20 @@ void OpEmitter::emitBuilder() { } // Push all operands to the result - if (op.getNumOperands() > 0) { + auto numOperands = op.getNumOperands(); + bool hasVariadicOperand = op.hasVariadicOperand(); + int numNonVariadicOperands = numOperands - int(hasVariadicOperand); + if (numNonVariadicOperands > 0) { OUT(4) << "result->addOperands({" << getArgumentName(op, 0); - for (int i = 1, e = op.getNumOperands(); i != e; ++i) + for (int i = 1, e = numNonVariadicOperands; i < e; ++i) { os << ", " << getArgumentName(op, i); + } os << "});\n"; } + if (hasVariadicOperand) { + OUT(4) << "result->addOperands(" << getArgumentName(op, numOperands - 1) + << ");\n"; + } // Push all attributes to the result for (const auto &namedAttr : op.getAttributes()) @@ -310,7 +321,7 @@ void OpEmitter::emitBuilder() { << " result->addTypes(resultTypes);\n"; // Operands - OUT(4) << "assert(args.size() == " << op.getNumOperands() + OUT(4) << "assert(args.size() == " << numNonVariadicOperands << "u && \"mismatched number of parameters\");\n" << " result->addOperands(args);\n\n"; @@ -422,9 +433,12 @@ void OpEmitter::emitVerifier() { OUT(4) << "}\n"; } - // TODO: Handle variadic. int opIndex = 0; for (const auto &operand : op.getOperands()) { + // TODO: Handle variadic operand verification. + if (operand.type.isVariadic()) + continue; + // TODO: Commonality between matchers could be extracted to have a more // concise code. if (operand.hasMatcher()) { @@ -466,38 +480,33 @@ void OpEmitter::emitTraits() { break; } - // Track explicitly added operand size traits. Note that some ops might - // implicitly defines the number of operands via the Argument dag. - bool hasVariadicOperands = false; - bool hasAtLeastNOperands = false; - // Add variadic size trait and normal op traits. for (StringRef trait : def.getValueAsListOfStrings("traits")) { - if (trait == "VariadicOperands") { - hasVariadicOperands = true; - } else if (trait.startswith("AtLeastNOperands")) { - hasAtLeastNOperands = true; - } os << ", OpTrait::" << trait; } - if ((hasVariadicOperands || hasAtLeastNOperands) && op.getNumOperands() > 0) { - PrintFatalError(def.getLoc(), - "Operands number definition is not consistent."); - } + auto numOperands = op.getNumOperands(); + bool hasVariadicOperand = op.hasVariadicOperand(); // Add operand size trait. - switch (op.getNumOperands()) { - case 0: - if (!hasVariadicOperands && !hasAtLeastNOperands) - os << ", OpTrait::ZeroOperands"; - break; - case 1: - os << ", OpTrait::OneOperand"; - break; - default: - os << ", OpTrait::NOperands<" << op.getNumOperands() << ">::Impl"; - break; + os << ", OpTrait::"; + if (hasVariadicOperand) { + if (numOperands == 1) + os << "VariadicOperands"; + else + os << "AtLeastNOperands<" << (numOperands - 1) << ">::Impl"; + } else { + switch (op.getNumOperands()) { + case 0: + os << "ZeroOperands"; + break; + case 1: + os << "OneOperand"; + break; + default: + os << "NOperands<" << numOperands << ">::Impl"; + break; + } } }