[mlir][ods] Add support for custom directive in attr/type formats

This patch adds support for custom directives in attribute and type formats. Custom directives dispatch calls to user-defined parser and printer functions.

For example, the assembly format "custom<Foo>($foo, ref($bar))" expects a function with the signature

```
LogicalResult parseFoo(AsmParser &parser, FailureOr<FooT> &foo, BarT bar);
void printFoo(AsmPrinter &printer, FooT foo, BarT bar);
```

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D120944
This commit is contained in:
Mogball 2022-03-15 07:12:37 +00:00
parent 6143ec2961
commit 4767e26775
9 changed files with 308 additions and 32 deletions

View file

@ -558,6 +558,8 @@ Attribute and type assembly formats have the following directives:
mnemonic.
* `struct`: generate a "struct-like" parser and printer for a list of
key-value pairs.
* `custom`: dispatch a call to user-define parser and printer functions
* `ref`: in a custom directive, references a previously bound variable
#### `params` Directive
@ -649,3 +651,44 @@ assembly format of `` `<` struct(params) `>` `` will result in:
The order in which the parameters are printed is the order in which they are
declared in the attribute's or type's `parameter` list.
#### `custom` and `ref` directive
The `custom` directive is used to dispatch calls to user-defined printer and
parser functions. For example, suppose we had the following type:
```tablegen
let parameters = (ins "int":$foo, "int":$bar);
let assemblyFormat = "custom<Foo>($foo) custom<Bar>($bar, ref($foo))";
```
The `custom` directive `custom<Foo>($foo)` will in the parser and printer
respectively generate calls to:
```c++
LogicalResult parseFoo(AsmParser &parser, FailureOr<int> &foo);
void printFoo(AsmPrinter &printer, int foo);
```
A previously bound variable can be passed as a parameter to a `custom` directive
by wrapping it in a `ref` directive. In the previous example, `$foo` is bound by
the first directive. The second directive references it and expects the
following printer and parser signatures:
```c++
LogicalResult parseBar(AsmParser &parser, FailureOr<int> &bar, int foo);
void printBar(AsmPrinter &printer, int bar, int foo);
```
More complex C++ types can be used with the `custom` directive. The only caveat
is that the parameter for the parser must use the storage type of the parameter.
For example, `StringRefParameter` expects the parser and printer signatures as:
```c++
LogicalResult parseStringParam(AsmParser &parser,
FailureOr<std::string> &value);
void printStringParam(AsmPrinter &printer, StringRef value);
```
The custom parser is considered to have failed if it returns failure or if any
bound parameters have failure values afterwards.

View file

@ -363,4 +363,18 @@ def TestTypeDefaultValuedType : Test_Type<"TestTypeDefaultValuedType"> {
let assemblyFormat = "`<` (`(` $type^ `)`)? `>`";
}
def TestTypeCustom : Test_Type<"TestTypeCustom"> {
let parameters = (ins "int":$a, OptionalParameter<"mlir::Optional<int>">:$b);
let mnemonic = "custom_type";
let assemblyFormat = [{ `<` custom<CustomTypeA>($a)
custom<CustomTypeB>(ref($a), $b) `>` }];
}
def TestTypeCustomString : Test_Type<"TestTypeCustomString"> {
let parameters = (ins StringRefParameter<>:$foo);
let mnemonic = "custom_type_string";
let assemblyFormat = [{ `<` custom<FooString>($foo)
custom<BarString>(ref($foo)) `>` }];
}
#endif // TEST_TYPEDEFS

View file

@ -208,6 +208,59 @@ unsigned TestTypeWithLayoutType::extractKind(DataLayoutEntryListRef params,
return 1;
}
//===----------------------------------------------------------------------===//
// TestCustomType
//===----------------------------------------------------------------------===//
static LogicalResult parseCustomTypeA(AsmParser &parser,
FailureOr<int> &a_result) {
a_result.emplace();
return parser.parseInteger(*a_result);
}
static void printCustomTypeA(AsmPrinter &printer, int a) { printer << a; }
static LogicalResult parseCustomTypeB(AsmParser &parser, int a,
FailureOr<Optional<int>> &b_result) {
if (a < 0)
return success();
for (int i : llvm::seq(0, a))
if (failed(parser.parseInteger(i)))
return failure();
b_result.emplace(0);
return parser.parseInteger(**b_result);
}
static void printCustomTypeB(AsmPrinter &printer, int a, Optional<int> b) {
if (a < 0)
return;
printer << ' ';
for (int i : llvm::seq(0, a))
printer << i << ' ';
printer << *b;
}
static LogicalResult parseFooString(AsmParser &parser,
FailureOr<std::string> &foo) {
std::string result;
if (parser.parseString(&result))
return failure();
foo = std::move(result);
return success();
}
static void printFooString(AsmPrinter &printer, StringRef foo) {
printer << '"' << foo << '"';
}
static LogicalResult parseBarString(AsmParser &parser, StringRef foo) {
return parser.parseKeyword(foo);
}
static void printBarString(AsmPrinter &printer, StringRef foo) {
printer << ' ' << foo;
}
//===----------------------------------------------------------------------===//
// Tablegen Generated Definitions
//===----------------------------------------------------------------------===//

View file

@ -107,3 +107,27 @@ def InvalidTypeN : InvalidType<"InvalidTypeN", "invalid_n"> {
// CHECK: optional group anchor must be a parameter or directive
let assemblyFormat = "(`(` $a `)`^)?";
}
def InvalidTypeO : InvalidType<"InvalidTypeO", "invalid_o"> {
let parameters = (ins "int":$a);
// CHECK: `ref` is only allowed inside custom directives
let assemblyFormat = "$a ref($a)";
}
def InvalidTypeP : InvalidType<"InvalidTypeP", "invalid_p"> {
let parameters = (ins "int":$a);
// CHECK: parameter 'a' must be bound before it is referenced
let assemblyFormat = "custom<Foo>(ref($a)) $a";
}
def InvalidTypeQ : InvalidType<"InvalidTypeQ", "invalid_q"> {
let parameters = (ins "int":$a);
// CHECK: `params` can only be used at the top-level context or within a `struct` directive
let assemblyFormat = "custom<Foo>(params)";
}
def InvalidTypeR : InvalidType<"InvalidTypeR", "invalid_r"> {
let parameters = (ins "int":$a);
// CHECK: `struct` can only be used at the top-level context
let assemblyFormat = "custom<Foo>(struct(params))";
}

View file

@ -48,6 +48,10 @@ attributes {
// CHECK: !test.ap_float<>
// CHECK: !test.default_valued_type<(i64)>
// CHECK: !test.default_valued_type<>
// CHECK: !test.custom_type<-5>
// CHECK: !test.custom_type<2 0 1 5>
// CHECK: !test.custom_type_string<"foo" foo>
// CHECK: !test.custom_type_string<"bar" bar>
func private @test_roundtrip_default_parsers_struct(
!test.no_parser<255, [1, 2, 3, 4, 5], "foobar", 4>
@ -79,5 +83,9 @@ func private @test_roundtrip_default_parsers_struct(
!test.ap_float<5.0>,
!test.ap_float<>,
!test.default_valued_type<(i64)>,
!test.default_valued_type<>
!test.default_valued_type<>,
!test.custom_type<-5>,
!test.custom_type<2 9 9 5>,
!test.custom_type_string<"foo" foo>,
!test.custom_type_string<"bar" bar>
)

View file

@ -499,3 +499,27 @@ def TypeI : TestType<"TestK"> {
let mnemonic = "type_k";
let assemblyFormat = "$a";
}
// TYPE: ::mlir::Type TestLType::parse
// TYPE: auto odsCustomLoc = odsParser.getCurrentLocation()
// TYPE: auto odsCustomResult = parseA(odsParser,
// TYPE-NEXT: _result_a
// TYPE: if (::mlir::failed(odsCustomResult)) return {}
// TYPE: if (::mlir::failed(_result_a))
// TYPE-NEXT: odsParser.emitError(odsCustomLoc,
// TYPE: auto odsCustomResult = parseB(odsParser,
// TYPE-NEXT: _result_b
// TYPE-NEXT: *_result_a
// TYPE: void TestLType::print
// TYPE: printA(odsPrinter
// TYPE-NEXT: getA()
// TYPE: printB(odsPrinter
// TYPE-NEXT: getB()
// TYPE-NEXT: getA()
def TypeJ : TestType<"TestL"> {
let parameters = (ins "int":$a, OptionalParameter<"Attribute">:$b);
let mnemonic = "type_j";
let assemblyFormat = "custom<A>($a) custom<B>($b, ref($a))";
}

View file

@ -199,6 +199,8 @@ private:
void genParamsParser(ParamsDirective *el, FmtContext &ctx, MethodBody &os);
/// Generate the parser code for a `struct` directive.
void genStructParser(StructDirective *el, FmtContext &ctx, MethodBody &os);
/// Generate the parser code for a `custom` directive.
void genCustomParser(CustomDirective *el, FmtContext &ctx, MethodBody &os);
/// Generate the parser code for an optional group.
void genOptionalGroupParser(OptionalElement *el, FmtContext &ctx,
MethodBody &os);
@ -218,6 +220,8 @@ private:
void genParamsPrinter(ParamsDirective *el, FmtContext &ctx, MethodBody &os);
/// Generate the printer code for a `struct` directive.
void genStructPrinter(StructDirective *el, FmtContext &ctx, MethodBody &os);
/// Generate the printer code for a `custom` directive.
void genCustomPrinter(CustomDirective *el, FmtContext &ctx, MethodBody &os);
/// Generate the printer code for an optional group.
void genOptionalGroupPrinter(OptionalElement *el, FmtContext &ctx,
MethodBody &os);
@ -313,6 +317,8 @@ void DefFormat::genElementParser(FormatElement *el, FmtContext &ctx,
return genParamsParser(params, ctx, os);
if (auto *strct = dyn_cast<StructDirective>(el))
return genStructParser(strct, ctx, os);
if (auto *custom = dyn_cast<CustomDirective>(el))
return genCustomParser(custom, ctx, os);
if (auto *optional = dyn_cast<OptionalElement>(el))
return genOptionalGroupParser(optional, ctx, os);
if (isa<WhitespaceElement>(el))
@ -566,6 +572,47 @@ void DefFormat::genStructParser(StructDirective *el, FmtContext &ctx,
os.unindent() << "}\n";
}
void DefFormat::genCustomParser(CustomDirective *el, FmtContext &ctx,
MethodBody &os) {
os << "{\n";
os.indent();
// Bound variables are passed directly to the parser as `FailureOr<T> &`.
// Referenced variables are passed as `T`. The custom parser fails if it
// returns failure or if any of the required parameters failed.
os << tgfmt("auto odsCustomLoc = $_parser.getCurrentLocation();\n", &ctx);
os << "(void)odsCustomLoc;\n";
os << tgfmt("auto odsCustomResult = parse$0($_parser", &ctx, el->getName());
os.indent();
for (FormatElement *arg : el->getArguments()) {
os << ",\n";
FormatElement *param;
if (auto *ref = dyn_cast<RefDirective>(arg)) {
os << "*";
param = ref->getArg();
} else {
param = arg;
}
os << "_result_" << cast<ParameterElement>(param)->getName();
}
os.unindent() << ");\n";
os << "if (::mlir::failed(odsCustomResult)) return {};\n";
for (FormatElement *arg : el->getArguments()) {
if (auto *param = dyn_cast<ParameterElement>(arg)) {
if (param->isOptional())
continue;
os << formatv("if (::mlir::failed(_result_{0})) {{\n", param->getName());
os.indent() << tgfmt("$_parser.emitError(odsCustomLoc, ", &ctx)
<< "\"custom parser failed to parse parameter '"
<< param->getName() << "'\");\n";
os << "return {};\n";
os.unindent() << "}\n";
}
}
os.unindent() << "}\n";
}
void DefFormat::genOptionalGroupParser(OptionalElement *el, FmtContext &ctx,
MethodBody &os) {
ArrayRef<FormatElement *> elements =
@ -634,6 +681,8 @@ void DefFormat::genElementPrinter(FormatElement *el, FmtContext &ctx,
return genParamsPrinter(params, ctx, os);
if (auto *strct = dyn_cast<StructDirective>(el))
return genStructPrinter(strct, ctx, os);
if (auto *custom = dyn_cast<CustomDirective>(el))
return genCustomPrinter(custom, ctx, os);
if (auto *var = dyn_cast<ParameterElement>(el))
return genVariablePrinter(var, ctx, os);
if (auto *optional = dyn_cast<OptionalElement>(el))
@ -746,6 +795,21 @@ void DefFormat::genStructPrinter(StructDirective *el, FmtContext &ctx,
});
}
void DefFormat::genCustomPrinter(CustomDirective *el, FmtContext &ctx,
MethodBody &os) {
os << tgfmt("print$0($_printer", &ctx, el->getName());
os.indent();
for (FormatElement *arg : el->getArguments()) {
FormatElement *param = arg;
if (auto *ref = dyn_cast<RefDirective>(arg))
param = ref->getArg();
os << ",\n"
<< getParameterAccessorName(cast<ParameterElement>(param)->getName())
<< "()";
}
os.unindent() << ");\n";
}
void DefFormat::genOptionalGroupPrinter(OptionalElement *el, FmtContext &ctx,
MethodBody &os) {
FormatElement *anchor = el->getAnchor();
@ -805,9 +869,7 @@ protected:
/// Verify the elements of a custom directive.
LogicalResult
verifyCustomDirectiveArguments(SMLoc loc,
ArrayRef<FormatElement *> arguments) override {
return emitError(loc, "'custom' not supported (yet)");
}
ArrayRef<FormatElement *> arguments) override;
/// Verify the elements of an optional group.
LogicalResult
verifyOptionalGroupElements(SMLoc loc, ArrayRef<FormatElement *> elements,
@ -822,11 +884,13 @@ protected:
private:
/// Parse a `params` directive.
FailureOr<FormatElement *> parseParamsDirective(SMLoc loc);
FailureOr<FormatElement *> parseParamsDirective(SMLoc loc, Context ctx);
/// Parse a `qualified` directive.
FailureOr<FormatElement *> parseQualifiedDirective(SMLoc loc, Context ctx);
/// Parse a `struct` directive.
FailureOr<FormatElement *> parseStructDirective(SMLoc loc);
FailureOr<FormatElement *> parseStructDirective(SMLoc loc, Context ctx);
/// Parse a `ref` directive.
FailureOr<FormatElement *> parseRefDirective(SMLoc loc, Context ctx);
/// Attribute or type tablegen def.
const AttrOrTypeDef &def;
@ -862,6 +926,12 @@ LogicalResult DefFormatParser::verify(SMLoc loc,
return success();
}
LogicalResult DefFormatParser::verifyCustomDirectiveArguments(
SMLoc loc, ArrayRef<FormatElement *> arguments) {
// Arguments are fully verified by the parser context.
return success();
}
LogicalResult
DefFormatParser::verifyOptionalGroupElements(llvm::SMLoc loc,
ArrayRef<FormatElement *> elements,
@ -915,10 +985,19 @@ DefFormatParser::parseVariableImpl(SMLoc loc, StringRef name, Context ctx) {
def.getName() + " has no parameter named '" + name + "'");
}
auto idx = std::distance(params.begin(), it);
if (ctx != RefDirectiveContext) {
// Check that the variable has not already been bound.
if (seenParams.test(idx))
return emitError(loc, "duplicate parameter '" + name + "'");
seenParams.set(idx);
// Otherwise, to be referenced, a variable must have been bound.
} else if (!seenParams.test(idx)) {
return emitError(loc, "parameter '" + name +
"' must be bound before it is referenced");
}
return create<ParameterElement>(*it);
}
@ -930,14 +1009,13 @@ DefFormatParser::parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind,
case FormatToken::kw_qualified:
return parseQualifiedDirective(loc, ctx);
case FormatToken::kw_params:
return parseParamsDirective(loc);
return parseParamsDirective(loc, ctx);
case FormatToken::kw_struct:
if (ctx != TopLevelContext) {
return emitError(
loc,
"`struct` may only be used in the top-level section of the format");
}
return parseStructDirective(loc);
return parseStructDirective(loc, ctx);
case FormatToken::kw_ref:
return parseRefDirective(loc, ctx);
case FormatToken::kw_custom:
return parseCustomDirective(loc, ctx);
default:
return emitError(loc, "unsupported directive kind");
@ -961,10 +1039,18 @@ DefFormatParser::parseQualifiedDirective(SMLoc loc, Context ctx) {
return var;
}
FailureOr<FormatElement *> DefFormatParser::parseParamsDirective(SMLoc loc) {
// Collect all of the attribute's or type's parameters.
FailureOr<FormatElement *> DefFormatParser::parseParamsDirective(SMLoc loc,
Context ctx) {
// It doesn't make sense to allow references to all parameters in a custom
// directive because parameters are the only things that can be bound.
if (ctx != TopLevelContext && ctx != StructDirectiveContext) {
return emitError(loc, "`params` can only be used at the top-level context "
"or within a `struct` directive");
}
// Collect all of the attribute's or type's parameters and ensure that none of
// the parameters have already been captured.
std::vector<ParameterElement *> vars;
// Ensure that none of the parameters have already been captured.
for (const auto &it : llvm::enumerate(def.getParameters())) {
if (seenParams.test(it.index())) {
return emitError(loc, "`params` captures duplicate parameter: " +
@ -976,7 +1062,11 @@ FailureOr<FormatElement *> DefFormatParser::parseParamsDirective(SMLoc loc) {
return create<ParamsDirective>(std::move(vars));
}
FailureOr<FormatElement *> DefFormatParser::parseStructDirective(SMLoc loc) {
FailureOr<FormatElement *> DefFormatParser::parseStructDirective(SMLoc loc,
Context ctx) {
if (ctx != TopLevelContext)
return emitError(loc, "`struct` can only be used at the top-level context");
if (failed(parseToken(FormatToken::l_paren,
"expected '(' before `struct` argument list")))
return failure();
@ -1012,6 +1102,22 @@ FailureOr<FormatElement *> DefFormatParser::parseStructDirective(SMLoc loc) {
return create<StructDirective>(std::move(vars));
}
FailureOr<FormatElement *> DefFormatParser::parseRefDirective(SMLoc loc,
Context ctx) {
if (ctx != CustomDirectiveContext)
return emitError(loc, "`ref` is only allowed inside custom directives");
// Parse the child parameter element.
FailureOr<FormatElement *> child;
if (failed(parseToken(FormatToken::l_paren, "expected '('")) ||
failed(child = parseElement(RefDirectiveContext)) ||
failed(parseToken(FormatToken::r_paren, "expeced ')'")))
return failure();
// Only parameter elements are allowed to be parsed under a `ref` directive.
return create<RefDirective>(*child);
}
//===----------------------------------------------------------------------===//
// Interface
//===----------------------------------------------------------------------===//

View file

@ -338,6 +338,22 @@ private:
std::vector<FormatElement *> arguments;
};
/// This class represents a reference directive. This directive can be used to
/// reference but not bind a previously bound variable or format object. Its
/// current only use is to pass variables as arguments to the custom directive.
class RefDirective : public DirectiveElementBase<DirectiveElement::Ref> {
public:
/// Create a reference directive with the single referenced child.
RefDirective(FormatElement *arg) : arg(arg) {}
/// Get the reference argument.
FormatElement *getArg() const { return arg; }
private:
/// The referenced argument.
FormatElement *arg;
};
/// This class represents a group of elements that are optionally emitted based
/// on an optional variable "anchor" and a group of elements that are emitted
/// when the anchor element is not present.

View file

@ -153,18 +153,6 @@ private:
FormatElement *inputs, *results;
};
/// This class represents the `ref` directive.
class RefDirective : public DirectiveElementBase<DirectiveElement::Ref> {
public:
RefDirective(FormatElement *arg) : arg(arg) {}
FormatElement *getArg() const { return arg; }
private:
/// The argument that is used to format the directive.
FormatElement *arg;
};
/// This class represents the `type` directive.
class TypeDirective : public DirectiveElementBase<DirectiveElement::Type> {
public: