[mlir][ods] Add support for custom directive in attr/type formats

This patch adds support for custom directives in attribute and type formats. Custom directives dispatch calls to user-defined parser and printer functions.

For example, the assembly format "custom<Foo>($foo, ref($bar))" expects a function with the signature

```
LogicalResult parseFoo(AsmParser &parser, FailureOr<FooT> &foo, BarT bar);
void printFoo(AsmPrinter &printer, FooT foo, BarT bar);
```

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D120944
This commit is contained in:
Mogball 2022-03-15 07:12:37 +00:00
parent 6143ec2961
commit 4767e26775
9 changed files with 308 additions and 32 deletions

View file

@ -558,6 +558,8 @@ Attribute and type assembly formats have the following directives:
mnemonic. mnemonic.
* `struct`: generate a "struct-like" parser and printer for a list of * `struct`: generate a "struct-like" parser and printer for a list of
key-value pairs. key-value pairs.
* `custom`: dispatch a call to user-define parser and printer functions
* `ref`: in a custom directive, references a previously bound variable
#### `params` Directive #### `params` Directive
@ -649,3 +651,44 @@ assembly format of `` `<` struct(params) `>` `` will result in:
The order in which the parameters are printed is the order in which they are The order in which the parameters are printed is the order in which they are
declared in the attribute's or type's `parameter` list. declared in the attribute's or type's `parameter` list.
#### `custom` and `ref` directive
The `custom` directive is used to dispatch calls to user-defined printer and
parser functions. For example, suppose we had the following type:
```tablegen
let parameters = (ins "int":$foo, "int":$bar);
let assemblyFormat = "custom<Foo>($foo) custom<Bar>($bar, ref($foo))";
```
The `custom` directive `custom<Foo>($foo)` will in the parser and printer
respectively generate calls to:
```c++
LogicalResult parseFoo(AsmParser &parser, FailureOr<int> &foo);
void printFoo(AsmPrinter &printer, int foo);
```
A previously bound variable can be passed as a parameter to a `custom` directive
by wrapping it in a `ref` directive. In the previous example, `$foo` is bound by
the first directive. The second directive references it and expects the
following printer and parser signatures:
```c++
LogicalResult parseBar(AsmParser &parser, FailureOr<int> &bar, int foo);
void printBar(AsmPrinter &printer, int bar, int foo);
```
More complex C++ types can be used with the `custom` directive. The only caveat
is that the parameter for the parser must use the storage type of the parameter.
For example, `StringRefParameter` expects the parser and printer signatures as:
```c++
LogicalResult parseStringParam(AsmParser &parser,
FailureOr<std::string> &value);
void printStringParam(AsmPrinter &printer, StringRef value);
```
The custom parser is considered to have failed if it returns failure or if any
bound parameters have failure values afterwards.

View file

@ -363,4 +363,18 @@ def TestTypeDefaultValuedType : Test_Type<"TestTypeDefaultValuedType"> {
let assemblyFormat = "`<` (`(` $type^ `)`)? `>`"; let assemblyFormat = "`<` (`(` $type^ `)`)? `>`";
} }
def TestTypeCustom : Test_Type<"TestTypeCustom"> {
let parameters = (ins "int":$a, OptionalParameter<"mlir::Optional<int>">:$b);
let mnemonic = "custom_type";
let assemblyFormat = [{ `<` custom<CustomTypeA>($a)
custom<CustomTypeB>(ref($a), $b) `>` }];
}
def TestTypeCustomString : Test_Type<"TestTypeCustomString"> {
let parameters = (ins StringRefParameter<>:$foo);
let mnemonic = "custom_type_string";
let assemblyFormat = [{ `<` custom<FooString>($foo)
custom<BarString>(ref($foo)) `>` }];
}
#endif // TEST_TYPEDEFS #endif // TEST_TYPEDEFS

View file

@ -208,6 +208,59 @@ unsigned TestTypeWithLayoutType::extractKind(DataLayoutEntryListRef params,
return 1; return 1;
} }
//===----------------------------------------------------------------------===//
// TestCustomType
//===----------------------------------------------------------------------===//
static LogicalResult parseCustomTypeA(AsmParser &parser,
FailureOr<int> &a_result) {
a_result.emplace();
return parser.parseInteger(*a_result);
}
static void printCustomTypeA(AsmPrinter &printer, int a) { printer << a; }
static LogicalResult parseCustomTypeB(AsmParser &parser, int a,
FailureOr<Optional<int>> &b_result) {
if (a < 0)
return success();
for (int i : llvm::seq(0, a))
if (failed(parser.parseInteger(i)))
return failure();
b_result.emplace(0);
return parser.parseInteger(**b_result);
}
static void printCustomTypeB(AsmPrinter &printer, int a, Optional<int> b) {
if (a < 0)
return;
printer << ' ';
for (int i : llvm::seq(0, a))
printer << i << ' ';
printer << *b;
}
static LogicalResult parseFooString(AsmParser &parser,
FailureOr<std::string> &foo) {
std::string result;
if (parser.parseString(&result))
return failure();
foo = std::move(result);
return success();
}
static void printFooString(AsmPrinter &printer, StringRef foo) {
printer << '"' << foo << '"';
}
static LogicalResult parseBarString(AsmParser &parser, StringRef foo) {
return parser.parseKeyword(foo);
}
static void printBarString(AsmPrinter &printer, StringRef foo) {
printer << ' ' << foo;
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Tablegen Generated Definitions // Tablegen Generated Definitions
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View file

@ -107,3 +107,27 @@ def InvalidTypeN : InvalidType<"InvalidTypeN", "invalid_n"> {
// CHECK: optional group anchor must be a parameter or directive // CHECK: optional group anchor must be a parameter or directive
let assemblyFormat = "(`(` $a `)`^)?"; let assemblyFormat = "(`(` $a `)`^)?";
} }
def InvalidTypeO : InvalidType<"InvalidTypeO", "invalid_o"> {
let parameters = (ins "int":$a);
// CHECK: `ref` is only allowed inside custom directives
let assemblyFormat = "$a ref($a)";
}
def InvalidTypeP : InvalidType<"InvalidTypeP", "invalid_p"> {
let parameters = (ins "int":$a);
// CHECK: parameter 'a' must be bound before it is referenced
let assemblyFormat = "custom<Foo>(ref($a)) $a";
}
def InvalidTypeQ : InvalidType<"InvalidTypeQ", "invalid_q"> {
let parameters = (ins "int":$a);
// CHECK: `params` can only be used at the top-level context or within a `struct` directive
let assemblyFormat = "custom<Foo>(params)";
}
def InvalidTypeR : InvalidType<"InvalidTypeR", "invalid_r"> {
let parameters = (ins "int":$a);
// CHECK: `struct` can only be used at the top-level context
let assemblyFormat = "custom<Foo>(struct(params))";
}

View file

@ -48,6 +48,10 @@ attributes {
// CHECK: !test.ap_float<> // CHECK: !test.ap_float<>
// CHECK: !test.default_valued_type<(i64)> // CHECK: !test.default_valued_type<(i64)>
// CHECK: !test.default_valued_type<> // CHECK: !test.default_valued_type<>
// CHECK: !test.custom_type<-5>
// CHECK: !test.custom_type<2 0 1 5>
// CHECK: !test.custom_type_string<"foo" foo>
// CHECK: !test.custom_type_string<"bar" bar>
func private @test_roundtrip_default_parsers_struct( func private @test_roundtrip_default_parsers_struct(
!test.no_parser<255, [1, 2, 3, 4, 5], "foobar", 4> !test.no_parser<255, [1, 2, 3, 4, 5], "foobar", 4>
@ -79,5 +83,9 @@ func private @test_roundtrip_default_parsers_struct(
!test.ap_float<5.0>, !test.ap_float<5.0>,
!test.ap_float<>, !test.ap_float<>,
!test.default_valued_type<(i64)>, !test.default_valued_type<(i64)>,
!test.default_valued_type<> !test.default_valued_type<>,
!test.custom_type<-5>,
!test.custom_type<2 9 9 5>,
!test.custom_type_string<"foo" foo>,
!test.custom_type_string<"bar" bar>
) )

View file

@ -499,3 +499,27 @@ def TypeI : TestType<"TestK"> {
let mnemonic = "type_k"; let mnemonic = "type_k";
let assemblyFormat = "$a"; let assemblyFormat = "$a";
} }
// TYPE: ::mlir::Type TestLType::parse
// TYPE: auto odsCustomLoc = odsParser.getCurrentLocation()
// TYPE: auto odsCustomResult = parseA(odsParser,
// TYPE-NEXT: _result_a
// TYPE: if (::mlir::failed(odsCustomResult)) return {}
// TYPE: if (::mlir::failed(_result_a))
// TYPE-NEXT: odsParser.emitError(odsCustomLoc,
// TYPE: auto odsCustomResult = parseB(odsParser,
// TYPE-NEXT: _result_b
// TYPE-NEXT: *_result_a
// TYPE: void TestLType::print
// TYPE: printA(odsPrinter
// TYPE-NEXT: getA()
// TYPE: printB(odsPrinter
// TYPE-NEXT: getB()
// TYPE-NEXT: getA()
def TypeJ : TestType<"TestL"> {
let parameters = (ins "int":$a, OptionalParameter<"Attribute">:$b);
let mnemonic = "type_j";
let assemblyFormat = "custom<A>($a) custom<B>($b, ref($a))";
}

View file

@ -199,6 +199,8 @@ private:
void genParamsParser(ParamsDirective *el, FmtContext &ctx, MethodBody &os); void genParamsParser(ParamsDirective *el, FmtContext &ctx, MethodBody &os);
/// Generate the parser code for a `struct` directive. /// Generate the parser code for a `struct` directive.
void genStructParser(StructDirective *el, FmtContext &ctx, MethodBody &os); void genStructParser(StructDirective *el, FmtContext &ctx, MethodBody &os);
/// Generate the parser code for a `custom` directive.
void genCustomParser(CustomDirective *el, FmtContext &ctx, MethodBody &os);
/// Generate the parser code for an optional group. /// Generate the parser code for an optional group.
void genOptionalGroupParser(OptionalElement *el, FmtContext &ctx, void genOptionalGroupParser(OptionalElement *el, FmtContext &ctx,
MethodBody &os); MethodBody &os);
@ -218,6 +220,8 @@ private:
void genParamsPrinter(ParamsDirective *el, FmtContext &ctx, MethodBody &os); void genParamsPrinter(ParamsDirective *el, FmtContext &ctx, MethodBody &os);
/// Generate the printer code for a `struct` directive. /// Generate the printer code for a `struct` directive.
void genStructPrinter(StructDirective *el, FmtContext &ctx, MethodBody &os); void genStructPrinter(StructDirective *el, FmtContext &ctx, MethodBody &os);
/// Generate the printer code for a `custom` directive.
void genCustomPrinter(CustomDirective *el, FmtContext &ctx, MethodBody &os);
/// Generate the printer code for an optional group. /// Generate the printer code for an optional group.
void genOptionalGroupPrinter(OptionalElement *el, FmtContext &ctx, void genOptionalGroupPrinter(OptionalElement *el, FmtContext &ctx,
MethodBody &os); MethodBody &os);
@ -313,6 +317,8 @@ void DefFormat::genElementParser(FormatElement *el, FmtContext &ctx,
return genParamsParser(params, ctx, os); return genParamsParser(params, ctx, os);
if (auto *strct = dyn_cast<StructDirective>(el)) if (auto *strct = dyn_cast<StructDirective>(el))
return genStructParser(strct, ctx, os); return genStructParser(strct, ctx, os);
if (auto *custom = dyn_cast<CustomDirective>(el))
return genCustomParser(custom, ctx, os);
if (auto *optional = dyn_cast<OptionalElement>(el)) if (auto *optional = dyn_cast<OptionalElement>(el))
return genOptionalGroupParser(optional, ctx, os); return genOptionalGroupParser(optional, ctx, os);
if (isa<WhitespaceElement>(el)) if (isa<WhitespaceElement>(el))
@ -566,6 +572,47 @@ void DefFormat::genStructParser(StructDirective *el, FmtContext &ctx,
os.unindent() << "}\n"; os.unindent() << "}\n";
} }
void DefFormat::genCustomParser(CustomDirective *el, FmtContext &ctx,
MethodBody &os) {
os << "{\n";
os.indent();
// Bound variables are passed directly to the parser as `FailureOr<T> &`.
// Referenced variables are passed as `T`. The custom parser fails if it
// returns failure or if any of the required parameters failed.
os << tgfmt("auto odsCustomLoc = $_parser.getCurrentLocation();\n", &ctx);
os << "(void)odsCustomLoc;\n";
os << tgfmt("auto odsCustomResult = parse$0($_parser", &ctx, el->getName());
os.indent();
for (FormatElement *arg : el->getArguments()) {
os << ",\n";
FormatElement *param;
if (auto *ref = dyn_cast<RefDirective>(arg)) {
os << "*";
param = ref->getArg();
} else {
param = arg;
}
os << "_result_" << cast<ParameterElement>(param)->getName();
}
os.unindent() << ");\n";
os << "if (::mlir::failed(odsCustomResult)) return {};\n";
for (FormatElement *arg : el->getArguments()) {
if (auto *param = dyn_cast<ParameterElement>(arg)) {
if (param->isOptional())
continue;
os << formatv("if (::mlir::failed(_result_{0})) {{\n", param->getName());
os.indent() << tgfmt("$_parser.emitError(odsCustomLoc, ", &ctx)
<< "\"custom parser failed to parse parameter '"
<< param->getName() << "'\");\n";
os << "return {};\n";
os.unindent() << "}\n";
}
}
os.unindent() << "}\n";
}
void DefFormat::genOptionalGroupParser(OptionalElement *el, FmtContext &ctx, void DefFormat::genOptionalGroupParser(OptionalElement *el, FmtContext &ctx,
MethodBody &os) { MethodBody &os) {
ArrayRef<FormatElement *> elements = ArrayRef<FormatElement *> elements =
@ -634,6 +681,8 @@ void DefFormat::genElementPrinter(FormatElement *el, FmtContext &ctx,
return genParamsPrinter(params, ctx, os); return genParamsPrinter(params, ctx, os);
if (auto *strct = dyn_cast<StructDirective>(el)) if (auto *strct = dyn_cast<StructDirective>(el))
return genStructPrinter(strct, ctx, os); return genStructPrinter(strct, ctx, os);
if (auto *custom = dyn_cast<CustomDirective>(el))
return genCustomPrinter(custom, ctx, os);
if (auto *var = dyn_cast<ParameterElement>(el)) if (auto *var = dyn_cast<ParameterElement>(el))
return genVariablePrinter(var, ctx, os); return genVariablePrinter(var, ctx, os);
if (auto *optional = dyn_cast<OptionalElement>(el)) if (auto *optional = dyn_cast<OptionalElement>(el))
@ -746,6 +795,21 @@ void DefFormat::genStructPrinter(StructDirective *el, FmtContext &ctx,
}); });
} }
void DefFormat::genCustomPrinter(CustomDirective *el, FmtContext &ctx,
MethodBody &os) {
os << tgfmt("print$0($_printer", &ctx, el->getName());
os.indent();
for (FormatElement *arg : el->getArguments()) {
FormatElement *param = arg;
if (auto *ref = dyn_cast<RefDirective>(arg))
param = ref->getArg();
os << ",\n"
<< getParameterAccessorName(cast<ParameterElement>(param)->getName())
<< "()";
}
os.unindent() << ");\n";
}
void DefFormat::genOptionalGroupPrinter(OptionalElement *el, FmtContext &ctx, void DefFormat::genOptionalGroupPrinter(OptionalElement *el, FmtContext &ctx,
MethodBody &os) { MethodBody &os) {
FormatElement *anchor = el->getAnchor(); FormatElement *anchor = el->getAnchor();
@ -805,9 +869,7 @@ protected:
/// Verify the elements of a custom directive. /// Verify the elements of a custom directive.
LogicalResult LogicalResult
verifyCustomDirectiveArguments(SMLoc loc, verifyCustomDirectiveArguments(SMLoc loc,
ArrayRef<FormatElement *> arguments) override { ArrayRef<FormatElement *> arguments) override;
return emitError(loc, "'custom' not supported (yet)");
}
/// Verify the elements of an optional group. /// Verify the elements of an optional group.
LogicalResult LogicalResult
verifyOptionalGroupElements(SMLoc loc, ArrayRef<FormatElement *> elements, verifyOptionalGroupElements(SMLoc loc, ArrayRef<FormatElement *> elements,
@ -822,11 +884,13 @@ protected:
private: private:
/// Parse a `params` directive. /// Parse a `params` directive.
FailureOr<FormatElement *> parseParamsDirective(SMLoc loc); FailureOr<FormatElement *> parseParamsDirective(SMLoc loc, Context ctx);
/// Parse a `qualified` directive. /// Parse a `qualified` directive.
FailureOr<FormatElement *> parseQualifiedDirective(SMLoc loc, Context ctx); FailureOr<FormatElement *> parseQualifiedDirective(SMLoc loc, Context ctx);
/// Parse a `struct` directive. /// Parse a `struct` directive.
FailureOr<FormatElement *> parseStructDirective(SMLoc loc); FailureOr<FormatElement *> parseStructDirective(SMLoc loc, Context ctx);
/// Parse a `ref` directive.
FailureOr<FormatElement *> parseRefDirective(SMLoc loc, Context ctx);
/// Attribute or type tablegen def. /// Attribute or type tablegen def.
const AttrOrTypeDef &def; const AttrOrTypeDef &def;
@ -862,6 +926,12 @@ LogicalResult DefFormatParser::verify(SMLoc loc,
return success(); return success();
} }
LogicalResult DefFormatParser::verifyCustomDirectiveArguments(
SMLoc loc, ArrayRef<FormatElement *> arguments) {
// Arguments are fully verified by the parser context.
return success();
}
LogicalResult LogicalResult
DefFormatParser::verifyOptionalGroupElements(llvm::SMLoc loc, DefFormatParser::verifyOptionalGroupElements(llvm::SMLoc loc,
ArrayRef<FormatElement *> elements, ArrayRef<FormatElement *> elements,
@ -915,10 +985,19 @@ DefFormatParser::parseVariableImpl(SMLoc loc, StringRef name, Context ctx) {
def.getName() + " has no parameter named '" + name + "'"); def.getName() + " has no parameter named '" + name + "'");
} }
auto idx = std::distance(params.begin(), it); auto idx = std::distance(params.begin(), it);
if (ctx != RefDirectiveContext) {
// Check that the variable has not already been bound.
if (seenParams.test(idx)) if (seenParams.test(idx))
return emitError(loc, "duplicate parameter '" + name + "'"); return emitError(loc, "duplicate parameter '" + name + "'");
seenParams.set(idx); seenParams.set(idx);
// Otherwise, to be referenced, a variable must have been bound.
} else if (!seenParams.test(idx)) {
return emitError(loc, "parameter '" + name +
"' must be bound before it is referenced");
}
return create<ParameterElement>(*it); return create<ParameterElement>(*it);
} }
@ -930,14 +1009,13 @@ DefFormatParser::parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind,
case FormatToken::kw_qualified: case FormatToken::kw_qualified:
return parseQualifiedDirective(loc, ctx); return parseQualifiedDirective(loc, ctx);
case FormatToken::kw_params: case FormatToken::kw_params:
return parseParamsDirective(loc); return parseParamsDirective(loc, ctx);
case FormatToken::kw_struct: case FormatToken::kw_struct:
if (ctx != TopLevelContext) { return parseStructDirective(loc, ctx);
return emitError( case FormatToken::kw_ref:
loc, return parseRefDirective(loc, ctx);
"`struct` may only be used in the top-level section of the format"); case FormatToken::kw_custom:
} return parseCustomDirective(loc, ctx);
return parseStructDirective(loc);
default: default:
return emitError(loc, "unsupported directive kind"); return emitError(loc, "unsupported directive kind");
@ -961,10 +1039,18 @@ DefFormatParser::parseQualifiedDirective(SMLoc loc, Context ctx) {
return var; return var;
} }
FailureOr<FormatElement *> DefFormatParser::parseParamsDirective(SMLoc loc) { FailureOr<FormatElement *> DefFormatParser::parseParamsDirective(SMLoc loc,
// Collect all of the attribute's or type's parameters. Context ctx) {
// It doesn't make sense to allow references to all parameters in a custom
// directive because parameters are the only things that can be bound.
if (ctx != TopLevelContext && ctx != StructDirectiveContext) {
return emitError(loc, "`params` can only be used at the top-level context "
"or within a `struct` directive");
}
// Collect all of the attribute's or type's parameters and ensure that none of
// the parameters have already been captured.
std::vector<ParameterElement *> vars; std::vector<ParameterElement *> vars;
// Ensure that none of the parameters have already been captured.
for (const auto &it : llvm::enumerate(def.getParameters())) { for (const auto &it : llvm::enumerate(def.getParameters())) {
if (seenParams.test(it.index())) { if (seenParams.test(it.index())) {
return emitError(loc, "`params` captures duplicate parameter: " + return emitError(loc, "`params` captures duplicate parameter: " +
@ -976,7 +1062,11 @@ FailureOr<FormatElement *> DefFormatParser::parseParamsDirective(SMLoc loc) {
return create<ParamsDirective>(std::move(vars)); return create<ParamsDirective>(std::move(vars));
} }
FailureOr<FormatElement *> DefFormatParser::parseStructDirective(SMLoc loc) { FailureOr<FormatElement *> DefFormatParser::parseStructDirective(SMLoc loc,
Context ctx) {
if (ctx != TopLevelContext)
return emitError(loc, "`struct` can only be used at the top-level context");
if (failed(parseToken(FormatToken::l_paren, if (failed(parseToken(FormatToken::l_paren,
"expected '(' before `struct` argument list"))) "expected '(' before `struct` argument list")))
return failure(); return failure();
@ -1012,6 +1102,22 @@ FailureOr<FormatElement *> DefFormatParser::parseStructDirective(SMLoc loc) {
return create<StructDirective>(std::move(vars)); return create<StructDirective>(std::move(vars));
} }
FailureOr<FormatElement *> DefFormatParser::parseRefDirective(SMLoc loc,
Context ctx) {
if (ctx != CustomDirectiveContext)
return emitError(loc, "`ref` is only allowed inside custom directives");
// Parse the child parameter element.
FailureOr<FormatElement *> child;
if (failed(parseToken(FormatToken::l_paren, "expected '('")) ||
failed(child = parseElement(RefDirectiveContext)) ||
failed(parseToken(FormatToken::r_paren, "expeced ')'")))
return failure();
// Only parameter elements are allowed to be parsed under a `ref` directive.
return create<RefDirective>(*child);
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Interface // Interface
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View file

@ -338,6 +338,22 @@ private:
std::vector<FormatElement *> arguments; std::vector<FormatElement *> arguments;
}; };
/// This class represents a reference directive. This directive can be used to
/// reference but not bind a previously bound variable or format object. Its
/// current only use is to pass variables as arguments to the custom directive.
class RefDirective : public DirectiveElementBase<DirectiveElement::Ref> {
public:
/// Create a reference directive with the single referenced child.
RefDirective(FormatElement *arg) : arg(arg) {}
/// Get the reference argument.
FormatElement *getArg() const { return arg; }
private:
/// The referenced argument.
FormatElement *arg;
};
/// This class represents a group of elements that are optionally emitted based /// This class represents a group of elements that are optionally emitted based
/// on an optional variable "anchor" and a group of elements that are emitted /// on an optional variable "anchor" and a group of elements that are emitted
/// when the anchor element is not present. /// when the anchor element is not present.

View file

@ -153,18 +153,6 @@ private:
FormatElement *inputs, *results; FormatElement *inputs, *results;
}; };
/// This class represents the `ref` directive.
class RefDirective : public DirectiveElementBase<DirectiveElement::Ref> {
public:
RefDirective(FormatElement *arg) : arg(arg) {}
FormatElement *getArg() const { return arg; }
private:
/// The argument that is used to format the directive.
FormatElement *arg;
};
/// This class represents the `type` directive. /// This class represents the `type` directive.
class TypeDirective : public DirectiveElementBase<DirectiveElement::Type> { class TypeDirective : public DirectiveElementBase<DirectiveElement::Type> {
public: public: