[mlir][Linalg] Allow all build methods of Structured ops to specify additional attributes.

Differential Revision: https://reviews.llvm.org/D108338
This commit is contained in:
MaheshRavishankar 2021-08-23 10:15:35 -07:00
parent 19dc02e99f
commit 4aeeb91a92
5 changed files with 37 additions and 17 deletions

View file

@ -620,18 +620,22 @@ def GenericOp : LinalgStructuredBase_Op<"generic", [
"ValueRange":$outputs, "ArrayRef<AffineMap>":$indexingMaps, "ValueRange":$outputs, "ArrayRef<AffineMap>":$indexingMaps,
"ArrayRef<StringRef>":$iteratorTypes, "StringRef":$doc, "ArrayRef<StringRef>":$iteratorTypes, "StringRef":$doc,
"StringRef":$libraryCall, "StringRef":$libraryCall,
CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>", "nullptr">)>, CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>", "nullptr">,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputBuffers, OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputBuffers,
"ArrayRef<AffineMap>":$indexingMaps, "ArrayRef<StringRef>":$iteratorTypes, "ArrayRef<AffineMap>":$indexingMaps, "ArrayRef<StringRef>":$iteratorTypes,
"StringRef":$doc, "StringRef":$libraryCall, "StringRef":$doc, "StringRef":$libraryCall,
CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>", "nullptr">)>, CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>", "nullptr">,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
OpBuilder<(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, OpBuilder<(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
"ValueRange":$outputs, "ArrayRef<AffineMap>":$indexingMaps, "ValueRange":$outputs, "ArrayRef<AffineMap>":$indexingMaps,
"ArrayRef<StringRef>":$iteratorTypes, "ArrayRef<StringRef>":$iteratorTypes,
CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>", "nullptr">)>, CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>", "nullptr">,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputBuffers, OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputBuffers,
"ArrayRef<AffineMap>":$indexingMaps, "ArrayRef<StringRef>":$iteratorTypes, "ArrayRef<AffineMap>":$indexingMaps, "ArrayRef<StringRef>":$iteratorTypes,
CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>", "nullptr">)> CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>", "nullptr">,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>
]; ];
let extraClassDeclaration = structuredOpsBaseDecls # [{ let extraClassDeclaration = structuredOpsBaseDecls # [{

View file

@ -502,13 +502,15 @@ void GenericOp::build(
OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes, OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps, ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
ArrayRef<StringRef> iteratorTypes, StringRef doc, StringRef libraryCall, ArrayRef<StringRef> iteratorTypes, StringRef doc, StringRef libraryCall,
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) { function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
ArrayRef<NamedAttribute> attributes) {
build(builder, result, resultTensorTypes, inputs, outputs, build(builder, result, resultTensorTypes, inputs, outputs,
builder.getAffineMapArrayAttr(indexingMaps), builder.getAffineMapArrayAttr(indexingMaps),
builder.getStrArrayAttr(iteratorTypes), builder.getStrArrayAttr(iteratorTypes),
doc.empty() ? StringAttr() : builder.getStringAttr(doc), doc.empty() ? StringAttr() : builder.getStringAttr(doc),
libraryCall.empty() ? StringAttr() libraryCall.empty() ? StringAttr()
: builder.getStringAttr(libraryCall)); : builder.getStringAttr(libraryCall));
result.addAttributes(attributes);
if (!bodyBuild) if (!bodyBuild)
return; return;
@ -527,30 +529,33 @@ void GenericOp::build(
OpBuilder &builder, OperationState &result, ValueRange inputs, OpBuilder &builder, OperationState &result, ValueRange inputs,
ValueRange outputs, ArrayRef<AffineMap> indexingMaps, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
ArrayRef<StringRef> iteratorTypes, StringRef doc, StringRef libraryCall, ArrayRef<StringRef> iteratorTypes, StringRef doc, StringRef libraryCall,
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) { function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
ArrayRef<NamedAttribute> attributes) {
build(builder, result, TypeRange{}, inputs, outputs, indexingMaps, build(builder, result, TypeRange{}, inputs, outputs, indexingMaps,
iteratorTypes, doc, libraryCall, bodyBuild); iteratorTypes, doc, libraryCall, bodyBuild, attributes);
} }
void GenericOp::build( void GenericOp::build(
OpBuilder &builder, OperationState &result, ValueRange inputs, OpBuilder &builder, OperationState &result, ValueRange inputs,
ValueRange outputs, ArrayRef<AffineMap> indexingMaps, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
ArrayRef<StringRef> iteratorTypes, ArrayRef<StringRef> iteratorTypes,
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) { function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
ArrayRef<NamedAttribute> attributes) {
build(builder, result, inputs, outputs, indexingMaps, iteratorTypes, build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
/*doc=*/"", /*doc=*/"",
/*libraryCall=*/"", bodyBuild); /*libraryCall=*/"", bodyBuild, attributes);
} }
void GenericOp::build( void GenericOp::build(
OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes, OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps, ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
ArrayRef<StringRef> iteratorTypes, ArrayRef<StringRef> iteratorTypes,
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) { function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
ArrayRef<NamedAttribute> attributes) {
build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps, build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
iteratorTypes, iteratorTypes,
/*doc=*/"", /*doc=*/"",
/*libraryCall=*/"", bodyBuild); /*libraryCall=*/"", bodyBuild, attributes);
} }
static void print(OpAsmPrinter &p, GenericOp op) { static void print(OpAsmPrinter &p, GenericOp op) {

View file

@ -169,7 +169,8 @@ It has one output.
// ODS-LABEL: def Test7Op // ODS-LABEL: def Test7Op
// ODS: OpBuilder< // ODS: OpBuilder<
// ODS: (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, // ODS: (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
// ODS: "ValueRange":$outputs, "Attribute":$attr_a, "Attribute":$attr_b) // ODS: "ValueRange":$outputs, "Attribute":$attr_a, "Attribute":$attr_b,
// ODS: CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)
// ODS: $_state.addAttribute("attr_a", attr_a); // ODS: $_state.addAttribute("attr_a", attr_a);
// ODS: $_state.addAttribute("attr_b", attr_b); // ODS: $_state.addAttribute("attr_b", attr_b);
// //

View file

@ -1910,7 +1910,8 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
let skipDefaultBuilders = 1; let skipDefaultBuilders = 1;
let builders = [ let builders = [
OpBuilder< OpBuilder<
(ins "ValueRange":$inputs, "ValueRange":$outputs), (ins "ValueRange":$inputs, "ValueRange":$outputs,
CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes),
[{{ [{{
$_state.addOperands(inputs); $_state.addOperands(inputs);
$_state.addOperands(outputs); $_state.addOperands(outputs);
@ -1919,6 +1920,7 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
$_builder.getI32VectorAttr({{ $_builder.getI32VectorAttr({{
static_cast<int32_t>(inputs.size()), static_cast<int32_t>(inputs.size()),
static_cast<int32_t>(outputs.size())})); static_cast<int32_t>(outputs.size())}));
$_state.addAttributes(attributes);
createAndFillStructuredOpRegion<{0}>( createAndFillStructuredOpRegion<{0}>(
$_builder, $_builder,
$_state, $_state,
@ -1927,7 +1929,8 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
}]>, }]>,
OpBuilder< OpBuilder<
(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
"ValueRange":$outputs), "ValueRange":$outputs,
CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes),
[{{ [{{
$_state.addOperands(inputs); $_state.addOperands(inputs);
$_state.addOperands(outputs); $_state.addOperands(outputs);
@ -1937,6 +1940,7 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
$_builder.getI32VectorAttr({{ $_builder.getI32VectorAttr({{
static_cast<int32_t>(inputs.size()), static_cast<int32_t>(inputs.size()),
static_cast<int32_t>(outputs.size())})); static_cast<int32_t>(outputs.size())}));
$_state.addAttributes(attributes);
createAndFillStructuredOpRegion<{0}>( createAndFillStructuredOpRegion<{0}>(
$_builder, $_builder,
$_state, $_state,
@ -2020,7 +2024,8 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
const char *builderFmt = R"FMT( const char *builderFmt = R"FMT(
, OpBuilder< , OpBuilder<
(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
"ValueRange":$outputs, {1}), "ValueRange":$outputs, {1},
CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes),
[{{ [{{
$_state.addOperands(inputs); $_state.addOperands(inputs);
$_state.addOperands(outputs); $_state.addOperands(outputs);
@ -2030,6 +2035,7 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
$_builder.getI32VectorAttr({{ $_builder.getI32VectorAttr({{
static_cast<int32_t>(inputs.size()), static_cast<int32_t>(inputs.size()),
static_cast<int32_t>(outputs.size())})); static_cast<int32_t>(outputs.size())}));
$_state.addAttributes(attributes);
createAndFillStructuredOpRegion<{0}>( createAndFillStructuredOpRegion<{0}>(
$_builder, $_builder,
$_state, $_state,

View file

@ -457,7 +457,8 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([
let skipDefaultBuilders = 1; let skipDefaultBuilders = 1;
let builders = [ let builders = [
OpBuilder< OpBuilder<
(ins "ValueRange":$inputs, "ValueRange":$outputs), (ins "ValueRange":$inputs, "ValueRange":$outputs,
CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes),
[{{ [{{
$_state.addOperands(inputs); $_state.addOperands(inputs);
$_state.addOperands(outputs); $_state.addOperands(outputs);
@ -471,6 +472,7 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([
$_builder.getI32VectorAttr({{ $_builder.getI32VectorAttr({{
static_cast<int32_t>(inputs.size()), static_cast<int32_t>(inputs.size()),
static_cast<int32_t>(outputs.size())})); static_cast<int32_t>(outputs.size())}));
$_state.addAttributes(attributes);
createAndFillStructuredOpRegion<{0}>( createAndFillStructuredOpRegion<{0}>(
$_builder, $_builder,
$_state, $_state,
@ -539,7 +541,8 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([
static const char structuredOpBuilderFormat[] = R"FMT( static const char structuredOpBuilderFormat[] = R"FMT(
, OpBuilder< , OpBuilder<
(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
"ValueRange":$outputs, {1}), "ValueRange":$outputs, {1},
CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes),
[{{ [{{
$_state.addOperands(inputs); $_state.addOperands(inputs);
$_state.addOperands(outputs); $_state.addOperands(outputs);
@ -555,6 +558,7 @@ static const char structuredOpBuilderFormat[] = R"FMT(
TypeRange(inputs), TypeRange(inputs),
TypeRange(outputs)); TypeRange(outputs));
{2} {2}
$_state.addAttributes(attributes);
}]> }]>
)FMT"; )FMT";