[mlir][Linalg] Allow all build methods of Structured ops to specify additional attributes.
Differential Revision: https://reviews.llvm.org/D108338
This commit is contained in:
parent
19dc02e99f
commit
4aeeb91a92
|
@ -620,18 +620,22 @@ def GenericOp : LinalgStructuredBase_Op<"generic", [
|
||||||
"ValueRange":$outputs, "ArrayRef<AffineMap>":$indexingMaps,
|
"ValueRange":$outputs, "ArrayRef<AffineMap>":$indexingMaps,
|
||||||
"ArrayRef<StringRef>":$iteratorTypes, "StringRef":$doc,
|
"ArrayRef<StringRef>":$iteratorTypes, "StringRef":$doc,
|
||||||
"StringRef":$libraryCall,
|
"StringRef":$libraryCall,
|
||||||
CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>", "nullptr">)>,
|
CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>", "nullptr">,
|
||||||
|
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
|
||||||
OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputBuffers,
|
OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputBuffers,
|
||||||
"ArrayRef<AffineMap>":$indexingMaps, "ArrayRef<StringRef>":$iteratorTypes,
|
"ArrayRef<AffineMap>":$indexingMaps, "ArrayRef<StringRef>":$iteratorTypes,
|
||||||
"StringRef":$doc, "StringRef":$libraryCall,
|
"StringRef":$doc, "StringRef":$libraryCall,
|
||||||
CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>", "nullptr">)>,
|
CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>", "nullptr">,
|
||||||
|
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
|
||||||
OpBuilder<(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
|
OpBuilder<(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
|
||||||
"ValueRange":$outputs, "ArrayRef<AffineMap>":$indexingMaps,
|
"ValueRange":$outputs, "ArrayRef<AffineMap>":$indexingMaps,
|
||||||
"ArrayRef<StringRef>":$iteratorTypes,
|
"ArrayRef<StringRef>":$iteratorTypes,
|
||||||
CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>", "nullptr">)>,
|
CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>", "nullptr">,
|
||||||
|
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
|
||||||
OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputBuffers,
|
OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputBuffers,
|
||||||
"ArrayRef<AffineMap>":$indexingMaps, "ArrayRef<StringRef>":$iteratorTypes,
|
"ArrayRef<AffineMap>":$indexingMaps, "ArrayRef<StringRef>":$iteratorTypes,
|
||||||
CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>", "nullptr">)>
|
CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>", "nullptr">,
|
||||||
|
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>
|
||||||
];
|
];
|
||||||
|
|
||||||
let extraClassDeclaration = structuredOpsBaseDecls # [{
|
let extraClassDeclaration = structuredOpsBaseDecls # [{
|
||||||
|
|
|
@ -502,13 +502,15 @@ void GenericOp::build(
|
||||||
OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
|
OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
|
||||||
ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
|
ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
|
||||||
ArrayRef<StringRef> iteratorTypes, StringRef doc, StringRef libraryCall,
|
ArrayRef<StringRef> iteratorTypes, StringRef doc, StringRef libraryCall,
|
||||||
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
|
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
|
||||||
|
ArrayRef<NamedAttribute> attributes) {
|
||||||
build(builder, result, resultTensorTypes, inputs, outputs,
|
build(builder, result, resultTensorTypes, inputs, outputs,
|
||||||
builder.getAffineMapArrayAttr(indexingMaps),
|
builder.getAffineMapArrayAttr(indexingMaps),
|
||||||
builder.getStrArrayAttr(iteratorTypes),
|
builder.getStrArrayAttr(iteratorTypes),
|
||||||
doc.empty() ? StringAttr() : builder.getStringAttr(doc),
|
doc.empty() ? StringAttr() : builder.getStringAttr(doc),
|
||||||
libraryCall.empty() ? StringAttr()
|
libraryCall.empty() ? StringAttr()
|
||||||
: builder.getStringAttr(libraryCall));
|
: builder.getStringAttr(libraryCall));
|
||||||
|
result.addAttributes(attributes);
|
||||||
if (!bodyBuild)
|
if (!bodyBuild)
|
||||||
return;
|
return;
|
||||||
|
|
||||||
|
@ -527,30 +529,33 @@ void GenericOp::build(
|
||||||
OpBuilder &builder, OperationState &result, ValueRange inputs,
|
OpBuilder &builder, OperationState &result, ValueRange inputs,
|
||||||
ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
|
ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
|
||||||
ArrayRef<StringRef> iteratorTypes, StringRef doc, StringRef libraryCall,
|
ArrayRef<StringRef> iteratorTypes, StringRef doc, StringRef libraryCall,
|
||||||
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
|
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
|
||||||
|
ArrayRef<NamedAttribute> attributes) {
|
||||||
build(builder, result, TypeRange{}, inputs, outputs, indexingMaps,
|
build(builder, result, TypeRange{}, inputs, outputs, indexingMaps,
|
||||||
iteratorTypes, doc, libraryCall, bodyBuild);
|
iteratorTypes, doc, libraryCall, bodyBuild, attributes);
|
||||||
}
|
}
|
||||||
|
|
||||||
void GenericOp::build(
|
void GenericOp::build(
|
||||||
OpBuilder &builder, OperationState &result, ValueRange inputs,
|
OpBuilder &builder, OperationState &result, ValueRange inputs,
|
||||||
ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
|
ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
|
||||||
ArrayRef<StringRef> iteratorTypes,
|
ArrayRef<StringRef> iteratorTypes,
|
||||||
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
|
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
|
||||||
|
ArrayRef<NamedAttribute> attributes) {
|
||||||
build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
|
build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
|
||||||
/*doc=*/"",
|
/*doc=*/"",
|
||||||
/*libraryCall=*/"", bodyBuild);
|
/*libraryCall=*/"", bodyBuild, attributes);
|
||||||
}
|
}
|
||||||
|
|
||||||
void GenericOp::build(
|
void GenericOp::build(
|
||||||
OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
|
OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
|
||||||
ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
|
ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
|
||||||
ArrayRef<StringRef> iteratorTypes,
|
ArrayRef<StringRef> iteratorTypes,
|
||||||
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
|
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
|
||||||
|
ArrayRef<NamedAttribute> attributes) {
|
||||||
build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
|
build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
|
||||||
iteratorTypes,
|
iteratorTypes,
|
||||||
/*doc=*/"",
|
/*doc=*/"",
|
||||||
/*libraryCall=*/"", bodyBuild);
|
/*libraryCall=*/"", bodyBuild, attributes);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void print(OpAsmPrinter &p, GenericOp op) {
|
static void print(OpAsmPrinter &p, GenericOp op) {
|
||||||
|
|
|
@ -169,7 +169,8 @@ It has one output.
|
||||||
// ODS-LABEL: def Test7Op
|
// ODS-LABEL: def Test7Op
|
||||||
// ODS: OpBuilder<
|
// ODS: OpBuilder<
|
||||||
// ODS: (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
|
// ODS: (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
|
||||||
// ODS: "ValueRange":$outputs, "Attribute":$attr_a, "Attribute":$attr_b)
|
// ODS: "ValueRange":$outputs, "Attribute":$attr_a, "Attribute":$attr_b,
|
||||||
|
// ODS: CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)
|
||||||
// ODS: $_state.addAttribute("attr_a", attr_a);
|
// ODS: $_state.addAttribute("attr_a", attr_a);
|
||||||
// ODS: $_state.addAttribute("attr_b", attr_b);
|
// ODS: $_state.addAttribute("attr_b", attr_b);
|
||||||
//
|
//
|
||||||
|
|
|
@ -1910,7 +1910,8 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
|
||||||
let skipDefaultBuilders = 1;
|
let skipDefaultBuilders = 1;
|
||||||
let builders = [
|
let builders = [
|
||||||
OpBuilder<
|
OpBuilder<
|
||||||
(ins "ValueRange":$inputs, "ValueRange":$outputs),
|
(ins "ValueRange":$inputs, "ValueRange":$outputs,
|
||||||
|
CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes),
|
||||||
[{{
|
[{{
|
||||||
$_state.addOperands(inputs);
|
$_state.addOperands(inputs);
|
||||||
$_state.addOperands(outputs);
|
$_state.addOperands(outputs);
|
||||||
|
@ -1919,6 +1920,7 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
|
||||||
$_builder.getI32VectorAttr({{
|
$_builder.getI32VectorAttr({{
|
||||||
static_cast<int32_t>(inputs.size()),
|
static_cast<int32_t>(inputs.size()),
|
||||||
static_cast<int32_t>(outputs.size())}));
|
static_cast<int32_t>(outputs.size())}));
|
||||||
|
$_state.addAttributes(attributes);
|
||||||
createAndFillStructuredOpRegion<{0}>(
|
createAndFillStructuredOpRegion<{0}>(
|
||||||
$_builder,
|
$_builder,
|
||||||
$_state,
|
$_state,
|
||||||
|
@ -1927,7 +1929,8 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
|
||||||
}]>,
|
}]>,
|
||||||
OpBuilder<
|
OpBuilder<
|
||||||
(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
|
(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
|
||||||
"ValueRange":$outputs),
|
"ValueRange":$outputs,
|
||||||
|
CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes),
|
||||||
[{{
|
[{{
|
||||||
$_state.addOperands(inputs);
|
$_state.addOperands(inputs);
|
||||||
$_state.addOperands(outputs);
|
$_state.addOperands(outputs);
|
||||||
|
@ -1937,6 +1940,7 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
|
||||||
$_builder.getI32VectorAttr({{
|
$_builder.getI32VectorAttr({{
|
||||||
static_cast<int32_t>(inputs.size()),
|
static_cast<int32_t>(inputs.size()),
|
||||||
static_cast<int32_t>(outputs.size())}));
|
static_cast<int32_t>(outputs.size())}));
|
||||||
|
$_state.addAttributes(attributes);
|
||||||
createAndFillStructuredOpRegion<{0}>(
|
createAndFillStructuredOpRegion<{0}>(
|
||||||
$_builder,
|
$_builder,
|
||||||
$_state,
|
$_state,
|
||||||
|
@ -2020,7 +2024,8 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
|
||||||
const char *builderFmt = R"FMT(
|
const char *builderFmt = R"FMT(
|
||||||
, OpBuilder<
|
, OpBuilder<
|
||||||
(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
|
(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
|
||||||
"ValueRange":$outputs, {1}),
|
"ValueRange":$outputs, {1},
|
||||||
|
CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes),
|
||||||
[{{
|
[{{
|
||||||
$_state.addOperands(inputs);
|
$_state.addOperands(inputs);
|
||||||
$_state.addOperands(outputs);
|
$_state.addOperands(outputs);
|
||||||
|
@ -2030,6 +2035,7 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
|
||||||
$_builder.getI32VectorAttr({{
|
$_builder.getI32VectorAttr({{
|
||||||
static_cast<int32_t>(inputs.size()),
|
static_cast<int32_t>(inputs.size()),
|
||||||
static_cast<int32_t>(outputs.size())}));
|
static_cast<int32_t>(outputs.size())}));
|
||||||
|
$_state.addAttributes(attributes);
|
||||||
createAndFillStructuredOpRegion<{0}>(
|
createAndFillStructuredOpRegion<{0}>(
|
||||||
$_builder,
|
$_builder,
|
||||||
$_state,
|
$_state,
|
||||||
|
|
|
@ -457,7 +457,8 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([
|
||||||
let skipDefaultBuilders = 1;
|
let skipDefaultBuilders = 1;
|
||||||
let builders = [
|
let builders = [
|
||||||
OpBuilder<
|
OpBuilder<
|
||||||
(ins "ValueRange":$inputs, "ValueRange":$outputs),
|
(ins "ValueRange":$inputs, "ValueRange":$outputs,
|
||||||
|
CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes),
|
||||||
[{{
|
[{{
|
||||||
$_state.addOperands(inputs);
|
$_state.addOperands(inputs);
|
||||||
$_state.addOperands(outputs);
|
$_state.addOperands(outputs);
|
||||||
|
@ -471,6 +472,7 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([
|
||||||
$_builder.getI32VectorAttr({{
|
$_builder.getI32VectorAttr({{
|
||||||
static_cast<int32_t>(inputs.size()),
|
static_cast<int32_t>(inputs.size()),
|
||||||
static_cast<int32_t>(outputs.size())}));
|
static_cast<int32_t>(outputs.size())}));
|
||||||
|
$_state.addAttributes(attributes);
|
||||||
createAndFillStructuredOpRegion<{0}>(
|
createAndFillStructuredOpRegion<{0}>(
|
||||||
$_builder,
|
$_builder,
|
||||||
$_state,
|
$_state,
|
||||||
|
@ -539,7 +541,8 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([
|
||||||
static const char structuredOpBuilderFormat[] = R"FMT(
|
static const char structuredOpBuilderFormat[] = R"FMT(
|
||||||
, OpBuilder<
|
, OpBuilder<
|
||||||
(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
|
(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
|
||||||
"ValueRange":$outputs, {1}),
|
"ValueRange":$outputs, {1},
|
||||||
|
CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes),
|
||||||
[{{
|
[{{
|
||||||
$_state.addOperands(inputs);
|
$_state.addOperands(inputs);
|
||||||
$_state.addOperands(outputs);
|
$_state.addOperands(outputs);
|
||||||
|
@ -555,6 +558,7 @@ static const char structuredOpBuilderFormat[] = R"FMT(
|
||||||
TypeRange(inputs),
|
TypeRange(inputs),
|
||||||
TypeRange(outputs));
|
TypeRange(outputs));
|
||||||
{2}
|
{2}
|
||||||
|
$_state.addAttributes(attributes);
|
||||||
}]>
|
}]>
|
||||||
)FMT";
|
)FMT";
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue