[mlir] Allow to override type/attr aliases from various hooks

Use new return type for `OpAsmDialectInterface::getAlias`:

* `AliasResult::NoAlias` if an alias was not provided.
* `AliasResult::OverridableAlias` if an alias was provided, but it might be overriden by other hook.
* `AliasResult::FinalAlias` if an alias was provided and it should be used (no other hooks will be checked).

In that case `AsmPrinter` will use either the first alias with `FinalAlias` result or
the last alias with `OverridableAlias` result (it depends on dialect array order).

Used `OverridableAlias` result for `BuiltinOpAsmDialectInterface`.

Use case: provide more informative alias for built-in attributes like `AffineMapAttr`
instead of generic "map<N>".

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D107437
This commit is contained in:
Vladislav Vinogradov 2021-08-03 17:23:31 +03:00
parent 0fd03feb4b
commit 59f59d1c62
5 changed files with 64 additions and 30 deletions

View file

@ -928,18 +928,29 @@ using OpAsmSetValueNameFn = function_ref<void(Value, StringRef)>;
class OpAsmDialectInterface
: public DialectInterface::Base<OpAsmDialectInterface> {
public:
/// Holds the result of `getAlias` hook call.
enum class AliasResult {
/// The object (type or attribute) is not supported by the hook
/// and an alias was not provided.
NoAlias,
/// An alias was provided, but it might be overriden by other hook.
OverridableAlias,
/// An alias was provided and it should be used
/// (no other hooks will be checked).
FinalAlias
};
OpAsmDialectInterface(Dialect *dialect) : Base(dialect) {}
/// Hooks for getting an alias identifier alias for a given symbol, that is
/// not necessarily a part of this dialect. The identifier is used in place of
/// the symbol when printing textual IR. These aliases must not contain `.` or
/// end with a numeric digit([0-9]+). Returns success if an alias was
/// provided, failure otherwise.
virtual LogicalResult getAlias(Attribute attr, raw_ostream &os) const {
return failure();
/// end with a numeric digit([0-9]+).
virtual AliasResult getAlias(Attribute attr, raw_ostream &os) const {
return AliasResult::NoAlias;
}
virtual LogicalResult getAlias(Type type, raw_ostream &os) const {
return failure();
virtual AliasResult getAlias(Type type, raw_ostream &os) const {
return AliasResult::NoAlias;
}
/// Get a special name to use when printing the given operation. See

View file

@ -652,21 +652,28 @@ void AliasInitializer::visit(Type type) {
template <typename T>
LogicalResult AliasInitializer::generateAlias(
T symbol, llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol) {
SmallString<16> tempBuffer;
SmallString<32> nameBuffer;
for (const auto &interface : interfaces) {
if (failed(interface.getAlias(symbol, aliasOS)))
OpAsmDialectInterface::AliasResult result =
interface.getAlias(symbol, aliasOS);
if (result == OpAsmDialectInterface::AliasResult::NoAlias)
continue;
StringRef name = aliasOS.str();
assert(!name.empty() && "expected valid alias name");
name = sanitizeIdentifier(name, tempBuffer, /*allowedPunctChars=*/"$_-",
/*allowTrailingDigit=*/false);
name = name.copy(aliasAllocator);
aliasToSymbol[name].push_back(symbol);
aliasBuffer.clear();
return success();
nameBuffer = std::move(aliasBuffer);
assert(!nameBuffer.empty() && "expected valid alias name");
if (result == OpAsmDialectInterface::AliasResult::FinalAlias)
break;
}
return failure();
if (nameBuffer.empty())
return failure();
SmallString<16> tempBuffer;
StringRef name =
sanitizeIdentifier(nameBuffer, tempBuffer, /*allowedPunctChars=*/"$_-",
/*allowTrailingDigit=*/false);
name = name.copy(aliasAllocator);
aliasToSymbol[name].push_back(symbol);
return success();
}
//===----------------------------------------------------------------------===//

View file

@ -33,30 +33,30 @@ namespace {
struct BuiltinOpAsmDialectInterface : public OpAsmDialectInterface {
using OpAsmDialectInterface::OpAsmDialectInterface;
LogicalResult getAlias(Attribute attr, raw_ostream &os) const override {
AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
if (attr.isa<AffineMapAttr>()) {
os << "map";
return success();
return AliasResult::OverridableAlias;
}
if (attr.isa<IntegerSetAttr>()) {
os << "set";
return success();
return AliasResult::OverridableAlias;
}
if (attr.isa<LocationAttr>()) {
os << "loc";
return success();
return AliasResult::OverridableAlias;
}
return failure();
return AliasResult::NoAlias;
}
LogicalResult getAlias(Type type, raw_ostream &os) const final {
AliasResult getAlias(Type type, raw_ostream &os) const final {
if (auto tupleType = type.dyn_cast<TupleType>()) {
if (tupleType.size() > 16) {
os << "tuple";
return success();
return AliasResult::OverridableAlias;
}
}
return failure();
return AliasResult::NoAlias;
}
};
} // end anonymous namespace.

View file

@ -18,6 +18,9 @@
// CHECK-DAG: !tuple = type tuple<i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32>
"test.op"() {alias_test = "alias_test:large_tuple"} : () -> (tuple<i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32>)
// CHECK-DAG: !test_tuple = type tuple<!test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla>
"test.op"() {alias_test = "alias_test:large_tuple"} : () -> (tuple<!test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla>)
// CHECK-DAG: #test_encoding = "alias_test:tensor_encoding"
// CHECK-DAG: tensor<32xf32, #test_encoding>
"test.op"() : () -> tensor<32xf32, "alias_test:tensor_encoding">

View file

@ -51,10 +51,10 @@ static_assert(OpTrait::hasSingleBlockImplicitTerminator<
struct TestOpAsmInterface : public OpAsmDialectInterface {
using OpAsmDialectInterface::OpAsmDialectInterface;
LogicalResult getAlias(Attribute attr, raw_ostream &os) const final {
AliasResult getAlias(Attribute attr, raw_ostream &os) const final {
StringAttr strAttr = attr.dyn_cast<StringAttr>();
if (!strAttr)
return failure();
return AliasResult::NoAlias;
// Check the contents of the string attribute to see what the test alias
// should be named.
@ -70,10 +70,23 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
.Case("alias_test:tensor_encoding", StringRef("test_encoding"))
.Default(llvm::None);
if (!aliasName)
return failure();
return AliasResult::NoAlias;
os << *aliasName;
return success();
return AliasResult::FinalAlias;
}
AliasResult getAlias(Type type, raw_ostream &os) const final {
if (auto tupleType = type.dyn_cast<TupleType>()) {
if (tupleType.size() > 0 &&
llvm::all_of(tupleType.getTypes(), [](Type elemType) {
return elemType.isa<SimpleAType>();
})) {
os << "test_tuple";
return AliasResult::FinalAlias;
}
}
return AliasResult::NoAlias;
}
void getAsmResultNames(Operation *op,