[mlir] support recursive types in type conversion infra

MLIR supports recursive types but they could not be handled by the conversion
infrastructure directly as it would result in infinite recursion in
`convertType` for elemental types. Support this case by keeping the "call
stack" of nested type conversions in the TypeConverter class and by passing it
as an optional argument to the individual conversion callback. The callback can
then check if a specific type is present on the stack more than once to detect
and handle the recursive case.

This approach is preferred to the alternative approach of having a separate
callback dedicated to handling only the recursive case as the latter was
observed to introduce ~3% time overhead on a 50MB IR file even if it did not
contain recursive types.

This approach is also preferred to keeping a local stack in type converters
that need to handle recursive types as that would compose poorly in case of
out-of-tree or cross-project extensions.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D113579
This commit is contained in:
Alex Zinenko 2021-11-22 13:19:36 +01:00
parent 774f7832fb
commit 9c5982ef8e
5 changed files with 108 additions and 19 deletions

View file

@ -307,6 +307,14 @@ class TypeConverter {
/// existing value are expected to be removed during conversion. If /// existing value are expected to be removed during conversion. If
/// `llvm::None` is returned, the converter is allowed to try another /// `llvm::None` is returned, the converter is allowed to try another
/// conversion function to perform the conversion. /// conversion function to perform the conversion.
/// * Optional<LogicalResult>(T, SmallVectorImpl<Type> &, ArrayRef<Type>)
/// - This form represents a 1-N type conversion supporting recursive
/// types. The first two arguments and the return value are the same as
/// for the regular 1-N form. The third argument is contains is the
/// "call stack" of the recursive conversion: it contains the list of
/// types currently being converted, with the current type being the
/// last one. If it is present more than once in the list, the
/// conversion concerns a recursive type.
/// Note: When attempting to convert a type, e.g. via 'convertType', the /// Note: When attempting to convert a type, e.g. via 'convertType', the
/// mostly recently added conversions will be invoked first. /// mostly recently added conversions will be invoked first.
template <typename FnT, template <typename FnT,

View file

@ -101,6 +101,14 @@ public:
/// existing value are expected to be removed during conversion. If /// existing value are expected to be removed during conversion. If
/// `llvm::None` is returned, the converter is allowed to try another /// `llvm::None` is returned, the converter is allowed to try another
/// conversion function to perform the conversion. /// conversion function to perform the conversion.
/// * Optional<LogicalResult>(T, SmallVectorImpl<Type> &, ArrayRef<Type>)
/// - This form represents a 1-N type conversion supporting recursive
/// types. The first two arguments and the return value are the same as
/// for the regular 1-N form. The third argument is contains is the
/// "call stack" of the recursive conversion: it contains the list of
/// types currently being converted, with the current type being the
/// last one. If it is present more than once in the list, the
/// conversion concerns a recursive type.
/// Note: When attempting to convert a type, e.g. via 'convertType', the /// Note: When attempting to convert a type, e.g. via 'convertType', the
/// mostly recently added conversions will be invoked first. /// mostly recently added conversions will be invoked first.
template <typename FnT, typename T = typename llvm::function_traits< template <typename FnT, typename T = typename llvm::function_traits<
@ -221,8 +229,8 @@ private:
/// The signature of the callback used to convert a type. If the new set of /// The signature of the callback used to convert a type. If the new set of
/// types is empty, the type is removed and any usages of the existing value /// types is empty, the type is removed and any usages of the existing value
/// are expected to be removed during conversion. /// are expected to be removed during conversion.
using ConversionCallbackFn = using ConversionCallbackFn = std::function<Optional<LogicalResult>(
std::function<Optional<LogicalResult>(Type, SmallVectorImpl<Type> &)>; Type, SmallVectorImpl<Type> &, ArrayRef<Type>)>;
/// The signature of the callback used to materialize a conversion. /// The signature of the callback used to materialize a conversion.
using MaterializationCallbackFn = using MaterializationCallbackFn =
@ -240,28 +248,44 @@ private:
template <typename T, typename FnT> template <typename T, typename FnT>
std::enable_if_t<llvm::is_invocable<FnT, T>::value, ConversionCallbackFn> std::enable_if_t<llvm::is_invocable<FnT, T>::value, ConversionCallbackFn>
wrapCallback(FnT &&callback) { wrapCallback(FnT &&callback) {
return wrapCallback<T>([callback = std::forward<FnT>(callback)]( return wrapCallback<T>(
T type, SmallVectorImpl<Type> &results) { [callback = std::forward<FnT>(callback)](
if (Optional<Type> resultOpt = callback(type)) { T type, SmallVectorImpl<Type> &results, ArrayRef<Type>) {
bool wasSuccess = static_cast<bool>(resultOpt.getValue()); if (Optional<Type> resultOpt = callback(type)) {
if (wasSuccess) bool wasSuccess = static_cast<bool>(resultOpt.getValue());
results.push_back(resultOpt.getValue()); if (wasSuccess)
return Optional<LogicalResult>(success(wasSuccess)); results.push_back(resultOpt.getValue());
} return Optional<LogicalResult>(success(wasSuccess));
return Optional<LogicalResult>(); }
}); return Optional<LogicalResult>();
});
} }
/// With callback of form: `Optional<LogicalResult>(T, SmallVectorImpl<> &)` /// With callback of form: `Optional<LogicalResult>(T, SmallVectorImpl<Type>
/// &)`
template <typename T, typename FnT> template <typename T, typename FnT>
std::enable_if_t<!llvm::is_invocable<FnT, T>::value, ConversionCallbackFn> std::enable_if_t<llvm::is_invocable<FnT, T, SmallVectorImpl<Type> &>::value,
ConversionCallbackFn>
wrapCallback(FnT &&callback) {
return wrapCallback<T>(
[callback = std::forward<FnT>(callback)](
T type, SmallVectorImpl<Type> &results, ArrayRef<Type>) {
return callback(type, results);
});
}
/// With callback of form: `Optional<LogicalResult>(T, SmallVectorImpl<Type>
/// &, ArrayRef<Type>)`.
template <typename T, typename FnT>
std::enable_if_t<llvm::is_invocable<FnT, T, SmallVectorImpl<Type> &,
ArrayRef<Type>>::value,
ConversionCallbackFn>
wrapCallback(FnT &&callback) { wrapCallback(FnT &&callback) {
return [callback = std::forward<FnT>(callback)]( return [callback = std::forward<FnT>(callback)](
Type type, Type type, SmallVectorImpl<Type> &results,
SmallVectorImpl<Type> &results) -> Optional<LogicalResult> { ArrayRef<Type> callStack) -> Optional<LogicalResult> {
T derivedType = type.dyn_cast<T>(); T derivedType = type.dyn_cast<T>();
if (!derivedType) if (!derivedType)
return llvm::None; return llvm::None;
return callback(derivedType, results); return callback(derivedType, results, callStack);
}; };
} }
@ -300,6 +324,10 @@ private:
DenseMap<Type, Type> cachedDirectConversions; DenseMap<Type, Type> cachedDirectConversions;
/// This cache stores the successful 1->N conversions, where N != 1. /// This cache stores the successful 1->N conversions, where N != 1.
DenseMap<Type, SmallVector<Type, 2>> cachedMultiConversions; DenseMap<Type, SmallVector<Type, 2>> cachedMultiConversions;
/// Stores the types that are being converted in the case when convertType
/// is being called recursively to convert nested types.
SmallVector<Type, 2> conversionCallStack;
}; };
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View file

@ -14,6 +14,7 @@
#include "mlir/IR/FunctionSupport.h" #include "mlir/IR/FunctionSupport.h"
#include "mlir/Rewrite/PatternApplicator.h" #include "mlir/Rewrite/PatternApplicator.h"
#include "mlir/Transforms/Utils.h" #include "mlir/Transforms/Utils.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/SetVector.h" #include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/Debug.h" #include "llvm/Support/Debug.h"
@ -2931,8 +2932,12 @@ LogicalResult TypeConverter::convertType(Type t,
// Walk the added converters in reverse order to apply the most recently // Walk the added converters in reverse order to apply the most recently
// registered first. // registered first.
size_t currentCount = results.size(); size_t currentCount = results.size();
conversionCallStack.push_back(t);
auto popConversionCallStack =
llvm::make_scope_exit([this]() { conversionCallStack.pop_back(); });
for (ConversionCallbackFn &converter : llvm::reverse(conversions)) { for (ConversionCallbackFn &converter : llvm::reverse(conversions)) {
if (Optional<LogicalResult> result = converter(t, results)) { if (Optional<LogicalResult> result =
converter(t, results, conversionCallStack)) {
if (!succeeded(*result)) { if (!succeeded(*result)) {
cachedDirectConversions.try_emplace(t, nullptr); cachedDirectConversions.try_emplace(t, nullptr);
return failure(); return failure();

View file

@ -112,3 +112,12 @@ func @test_signature_conversion_no_converter() {
}) : () -> () }) : () -> ()
return return
} }
// -----
// CHECK-LABEL: @recursive_type_conversion
func @recursive_type_conversion() {
// CHECK: !test.test_rec<outer_converted_type, smpla>
"test.type_producer"() : () -> !test.test_rec<something, test_rec<something>>
return
}

View file

@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "TestDialect.h" #include "TestDialect.h"
#include "TestTypes.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
@ -924,10 +925,16 @@ struct TestTypeConversionProducer
matchAndRewrite(TestTypeProducerOp op, OpAdaptor adaptor, matchAndRewrite(TestTypeProducerOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final { ConversionPatternRewriter &rewriter) const final {
Type resultType = op.getType(); Type resultType = op.getType();
Type convertedType = getTypeConverter()
? getTypeConverter()->convertType(resultType)
: resultType;
if (resultType.isa<FloatType>()) if (resultType.isa<FloatType>())
resultType = rewriter.getF64Type(); resultType = rewriter.getF64Type();
else if (resultType.isInteger(16)) else if (resultType.isInteger(16))
resultType = rewriter.getIntegerType(64); resultType = rewriter.getIntegerType(64);
else if (resultType.isa<test::TestRecursiveType>() &&
convertedType != resultType)
resultType = convertedType;
else else
return failure(); return failure();
@ -1035,6 +1042,35 @@ struct TestTypeConversionDriver
// Drop all integer types. // Drop all integer types.
return success(); return success();
}); });
converter.addConversion(
// Convert a recursive self-referring type into a non-self-referring
// type named "outer_converted_type" that contains a SimpleAType.
[&](test::TestRecursiveType type, SmallVectorImpl<Type> &results,
ArrayRef<Type> callStack) -> Optional<LogicalResult> {
// If the type is already converted, return it to indicate that it is
// legal.
if (type.getName() == "outer_converted_type") {
results.push_back(type);
return success();
}
// If the type is on the call stack more than once (it is there at
// least once because of the _current_ call, which is always the last
// element on the stack), we've hit the recursive case. Just return
// SimpleAType here to create a non-recursive type as a result.
if (llvm::is_contained(callStack.drop_back(), type)) {
results.push_back(test::SimpleAType::get(type.getContext()));
return success();
}
// Convert the body recursively.
auto result = test::TestRecursiveType::get(type.getContext(),
"outer_converted_type");
if (failed(result.setBody(converter.convertType(type.getBody()))))
return failure();
results.push_back(result);
return success();
});
/// Add the legal set of type materializations. /// Add the legal set of type materializations.
converter.addSourceMaterialization([](OpBuilder &builder, Type resultType, converter.addSourceMaterialization([](OpBuilder &builder, Type resultType,
@ -1059,7 +1095,10 @@ struct TestTypeConversionDriver
// Initialize the conversion target. // Initialize the conversion target.
mlir::ConversionTarget target(getContext()); mlir::ConversionTarget target(getContext());
target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) { target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) {
return op.getType().isF64() || op.getType().isInteger(64); auto recursiveType = op.getType().dyn_cast<test::TestRecursiveType>();
return op.getType().isF64() || op.getType().isInteger(64) ||
(recursiveType &&
recursiveType.getName() == "outer_converted_type");
}); });
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
return converter.isSignatureLegal(op.getType()) && return converter.isSignatureLegal(op.getType()) &&