[mlir] Allow to override type/attr aliases from various hooks
Use new return type for `OpAsmDialectInterface::getAlias`: * `AliasResult::NoAlias` if an alias was not provided. * `AliasResult::OverridableAlias` if an alias was provided, but it might be overriden by other hook. * `AliasResult::FinalAlias` if an alias was provided and it should be used (no other hooks will be checked). In that case `AsmPrinter` will use either the first alias with `FinalAlias` result or the last alias with `OverridableAlias` result (it depends on dialect array order). Used `OverridableAlias` result for `BuiltinOpAsmDialectInterface`. Use case: provide more informative alias for built-in attributes like `AffineMapAttr` instead of generic "map<N>". Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D107437
This commit is contained in:
parent
0fd03feb4b
commit
59f59d1c62
|
@ -928,18 +928,29 @@ using OpAsmSetValueNameFn = function_ref<void(Value, StringRef)>;
|
|||
class OpAsmDialectInterface
|
||||
: public DialectInterface::Base<OpAsmDialectInterface> {
|
||||
public:
|
||||
/// Holds the result of `getAlias` hook call.
|
||||
enum class AliasResult {
|
||||
/// The object (type or attribute) is not supported by the hook
|
||||
/// and an alias was not provided.
|
||||
NoAlias,
|
||||
/// An alias was provided, but it might be overriden by other hook.
|
||||
OverridableAlias,
|
||||
/// An alias was provided and it should be used
|
||||
/// (no other hooks will be checked).
|
||||
FinalAlias
|
||||
};
|
||||
|
||||
OpAsmDialectInterface(Dialect *dialect) : Base(dialect) {}
|
||||
|
||||
/// Hooks for getting an alias identifier alias for a given symbol, that is
|
||||
/// not necessarily a part of this dialect. The identifier is used in place of
|
||||
/// the symbol when printing textual IR. These aliases must not contain `.` or
|
||||
/// end with a numeric digit([0-9]+). Returns success if an alias was
|
||||
/// provided, failure otherwise.
|
||||
virtual LogicalResult getAlias(Attribute attr, raw_ostream &os) const {
|
||||
return failure();
|
||||
/// end with a numeric digit([0-9]+).
|
||||
virtual AliasResult getAlias(Attribute attr, raw_ostream &os) const {
|
||||
return AliasResult::NoAlias;
|
||||
}
|
||||
virtual LogicalResult getAlias(Type type, raw_ostream &os) const {
|
||||
return failure();
|
||||
virtual AliasResult getAlias(Type type, raw_ostream &os) const {
|
||||
return AliasResult::NoAlias;
|
||||
}
|
||||
|
||||
/// Get a special name to use when printing the given operation. See
|
||||
|
|
|
@ -652,21 +652,28 @@ void AliasInitializer::visit(Type type) {
|
|||
template <typename T>
|
||||
LogicalResult AliasInitializer::generateAlias(
|
||||
T symbol, llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol) {
|
||||
SmallString<16> tempBuffer;
|
||||
SmallString<32> nameBuffer;
|
||||
for (const auto &interface : interfaces) {
|
||||
if (failed(interface.getAlias(symbol, aliasOS)))
|
||||
OpAsmDialectInterface::AliasResult result =
|
||||
interface.getAlias(symbol, aliasOS);
|
||||
if (result == OpAsmDialectInterface::AliasResult::NoAlias)
|
||||
continue;
|
||||
StringRef name = aliasOS.str();
|
||||
assert(!name.empty() && "expected valid alias name");
|
||||
name = sanitizeIdentifier(name, tempBuffer, /*allowedPunctChars=*/"$_-",
|
||||
/*allowTrailingDigit=*/false);
|
||||
name = name.copy(aliasAllocator);
|
||||
|
||||
aliasToSymbol[name].push_back(symbol);
|
||||
aliasBuffer.clear();
|
||||
return success();
|
||||
nameBuffer = std::move(aliasBuffer);
|
||||
assert(!nameBuffer.empty() && "expected valid alias name");
|
||||
if (result == OpAsmDialectInterface::AliasResult::FinalAlias)
|
||||
break;
|
||||
}
|
||||
return failure();
|
||||
|
||||
if (nameBuffer.empty())
|
||||
return failure();
|
||||
|
||||
SmallString<16> tempBuffer;
|
||||
StringRef name =
|
||||
sanitizeIdentifier(nameBuffer, tempBuffer, /*allowedPunctChars=*/"$_-",
|
||||
/*allowTrailingDigit=*/false);
|
||||
name = name.copy(aliasAllocator);
|
||||
aliasToSymbol[name].push_back(symbol);
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -33,30 +33,30 @@ namespace {
|
|||
struct BuiltinOpAsmDialectInterface : public OpAsmDialectInterface {
|
||||
using OpAsmDialectInterface::OpAsmDialectInterface;
|
||||
|
||||
LogicalResult getAlias(Attribute attr, raw_ostream &os) const override {
|
||||
AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
|
||||
if (attr.isa<AffineMapAttr>()) {
|
||||
os << "map";
|
||||
return success();
|
||||
return AliasResult::OverridableAlias;
|
||||
}
|
||||
if (attr.isa<IntegerSetAttr>()) {
|
||||
os << "set";
|
||||
return success();
|
||||
return AliasResult::OverridableAlias;
|
||||
}
|
||||
if (attr.isa<LocationAttr>()) {
|
||||
os << "loc";
|
||||
return success();
|
||||
return AliasResult::OverridableAlias;
|
||||
}
|
||||
return failure();
|
||||
return AliasResult::NoAlias;
|
||||
}
|
||||
|
||||
LogicalResult getAlias(Type type, raw_ostream &os) const final {
|
||||
AliasResult getAlias(Type type, raw_ostream &os) const final {
|
||||
if (auto tupleType = type.dyn_cast<TupleType>()) {
|
||||
if (tupleType.size() > 16) {
|
||||
os << "tuple";
|
||||
return success();
|
||||
return AliasResult::OverridableAlias;
|
||||
}
|
||||
}
|
||||
return failure();
|
||||
return AliasResult::NoAlias;
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace.
|
||||
|
|
|
@ -18,6 +18,9 @@
|
|||
// CHECK-DAG: !tuple = type tuple<i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32>
|
||||
"test.op"() {alias_test = "alias_test:large_tuple"} : () -> (tuple<i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32>)
|
||||
|
||||
// CHECK-DAG: !test_tuple = type tuple<!test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla>
|
||||
"test.op"() {alias_test = "alias_test:large_tuple"} : () -> (tuple<!test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla>)
|
||||
|
||||
// CHECK-DAG: #test_encoding = "alias_test:tensor_encoding"
|
||||
// CHECK-DAG: tensor<32xf32, #test_encoding>
|
||||
"test.op"() : () -> tensor<32xf32, "alias_test:tensor_encoding">
|
||||
|
|
|
@ -51,10 +51,10 @@ static_assert(OpTrait::hasSingleBlockImplicitTerminator<
|
|||
struct TestOpAsmInterface : public OpAsmDialectInterface {
|
||||
using OpAsmDialectInterface::OpAsmDialectInterface;
|
||||
|
||||
LogicalResult getAlias(Attribute attr, raw_ostream &os) const final {
|
||||
AliasResult getAlias(Attribute attr, raw_ostream &os) const final {
|
||||
StringAttr strAttr = attr.dyn_cast<StringAttr>();
|
||||
if (!strAttr)
|
||||
return failure();
|
||||
return AliasResult::NoAlias;
|
||||
|
||||
// Check the contents of the string attribute to see what the test alias
|
||||
// should be named.
|
||||
|
@ -70,10 +70,23 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
|
|||
.Case("alias_test:tensor_encoding", StringRef("test_encoding"))
|
||||
.Default(llvm::None);
|
||||
if (!aliasName)
|
||||
return failure();
|
||||
return AliasResult::NoAlias;
|
||||
|
||||
os << *aliasName;
|
||||
return success();
|
||||
return AliasResult::FinalAlias;
|
||||
}
|
||||
|
||||
AliasResult getAlias(Type type, raw_ostream &os) const final {
|
||||
if (auto tupleType = type.dyn_cast<TupleType>()) {
|
||||
if (tupleType.size() > 0 &&
|
||||
llvm::all_of(tupleType.getTypes(), [](Type elemType) {
|
||||
return elemType.isa<SimpleAType>();
|
||||
})) {
|
||||
os << "test_tuple";
|
||||
return AliasResult::FinalAlias;
|
||||
}
|
||||
}
|
||||
return AliasResult::NoAlias;
|
||||
}
|
||||
|
||||
void getAsmResultNames(Operation *op,
|
||||
|
|
Loading…
Reference in a new issue