[mlir] Support for mutable types
Introduce support for mutable storage in the StorageUniquer infrastructure. This makes MLIR have key-value storage instead of just uniqued key storage. A storage instance now contains a unique immutable key and a mutable value, both stored in the arena allocator that belongs to the context. This is a preconditio for supporting recursive types that require delayed initialization, in particular LLVM structure types. The functionality is exercised in the test pass with trivial self-recursive type. So far, recursive types can only be printed in parsed in a closed type system. Removing this restriction is left for future work. Differential Revision: https://reviews.llvm.org/D84171
This commit is contained in:
parent
1956cf1042
commit
a51829913d
|
@ -47,7 +47,8 @@ namespace MyTypes {
|
|||
enum Kinds {
|
||||
// These kinds will be used in the examples below.
|
||||
Simple = Type::Kind::FIRST_PRIVATE_EXPERIMENTAL_0_TYPE,
|
||||
Complex
|
||||
Complex,
|
||||
Recursive
|
||||
};
|
||||
}
|
||||
```
|
||||
|
@ -58,13 +59,17 @@ As described above, `Type` objects in MLIR are value-typed and rely on having an
|
|||
implicitly internal storage object that holds the actual data for the type. When
|
||||
defining a new `Type` it isn't always necessary to define a new storage class.
|
||||
So before defining the derived `Type`, it's important to know which of the two
|
||||
classes of `Type` we are defining. Some types are `primitives` meaning they do
|
||||
classes of `Type` we are defining. Some types are _primitives_ meaning they do
|
||||
not have any parameters and are singletons uniqued by kind, like the
|
||||
[`index` type](LangRef.md#index-type). Parametric types on the other hand, have
|
||||
additional information that differentiates different instances of the same
|
||||
`Type` kind. For example the [`integer` type](LangRef.md#integer-type) has a
|
||||
bitwidth, making `i8` and `i16` be different instances of
|
||||
[`integer` type](LangRef.md#integer-type).
|
||||
[`integer` type](LangRef.md#integer-type). Types can also have a mutable
|
||||
component, which can be used, for example, to construct self-referring recursive
|
||||
types. The mutable component _cannot_ be used to differentiate types within the
|
||||
same kind, so usually such types are also parametric where the parameters serve
|
||||
to identify them.
|
||||
|
||||
#### Simple non-parametric types
|
||||
|
||||
|
@ -240,6 +245,126 @@ public:
|
|||
};
|
||||
```
|
||||
|
||||
#### Types with a mutable component
|
||||
|
||||
Types with a mutable component require defining a type storage class regardless
|
||||
of being parametric. The storage contains both the parameters and the mutable
|
||||
component and is accessed in a thread-safe way by the type support
|
||||
infrastructure.
|
||||
|
||||
##### Defining a type storage
|
||||
|
||||
In addition to the requirements for the type storage class for parametric types,
|
||||
the storage class for types with a mutable component must additionally obey the
|
||||
following.
|
||||
|
||||
* The mutable component must not participate in the storage key.
|
||||
* Provide a mutation method that is used to modify an existing instance of the
|
||||
storage. This method modifies the mutable component based on arguments,
|
||||
using `allocator` for any new dynamically-allocated storage, and indicates
|
||||
whether the modification was successful.
|
||||
- `LogicalResult mutate(StorageAllocator &allocator, Args ...&& args)`
|
||||
|
||||
Let's define a simple storage for recursive types, where a type is identified by
|
||||
its name and can contain another type including itself.
|
||||
|
||||
```c++
|
||||
/// Here we define a storage class for a RecursiveType that is identified by its
|
||||
/// name and contains another type.
|
||||
struct RecursiveTypeStorage : public TypeStorage {
|
||||
/// The type is uniquely identified by its name. Note that the contained type
|
||||
/// is _not_ a part of the key.
|
||||
using KeyTy = StringRef;
|
||||
|
||||
/// Construct the storage from the type name. Explicitly initialize the
|
||||
/// containedType to nullptr, which is used as marker for the mutable
|
||||
/// component being not yet initialized.
|
||||
RecursiveTypeStorage(StringRef name) : name(name), containedType(nullptr) {}
|
||||
|
||||
/// Define the comparison function.
|
||||
bool operator==(const KeyTy &key) const { return key == name; }
|
||||
|
||||
/// Define a construction method for creating a new instance of the storage.
|
||||
static RecursiveTypeStorage *construct(StorageAllocator &allocator,
|
||||
const KeyTy &key) {
|
||||
// Note that the key string is copied into the allocator to ensure it
|
||||
// remains live as long as the storage itself.
|
||||
return new (allocator.allocate<RecursiveTypeStorage>())
|
||||
RecursiveTypeStorage(allocator.copyInto(key));
|
||||
}
|
||||
|
||||
/// Define a mutation method for changing the type after it is created. In
|
||||
/// many cases, we only want to set the mutable component once and reject
|
||||
/// any further modification, which can be achieved by returning failure from
|
||||
/// this function.
|
||||
LogicalResult mutate(StorageAllocator &, Type body) {
|
||||
// If the contained type has been initialized already, and the call tries
|
||||
// to change it, reject the change.
|
||||
if (containedType && containedType != body)
|
||||
return failure();
|
||||
|
||||
// Change the body successfully.
|
||||
containedType = body;
|
||||
return success();
|
||||
}
|
||||
|
||||
StringRef name;
|
||||
Type containedType;
|
||||
};
|
||||
```
|
||||
|
||||
##### Type class definition
|
||||
|
||||
Having defined the storage class, we can define the type class itself. This is
|
||||
similar to parametric types. `Type::TypeBase` provides a `mutate` method that
|
||||
forwards its arguments to the `mutate` method of the storage and ensures the
|
||||
modification happens under lock.
|
||||
|
||||
```c++
|
||||
class RecursiveType : public Type::TypeBase<RecursiveType, Type,
|
||||
RecursiveTypeStorage> {
|
||||
public:
|
||||
/// Inherit parent constructors.
|
||||
using Base::Base;
|
||||
|
||||
/// This static method is used to support type inquiry through isa, cast,
|
||||
/// and dyn_cast.
|
||||
static bool kindof(unsigned kind) { return kind == MyTypes::Recursive; }
|
||||
|
||||
/// Creates an instance of the Recursive type. This only takes the type name
|
||||
/// and returns the type with uninitialized body.
|
||||
static RecursiveType get(MLIRContext *ctx, StringRef name) {
|
||||
// Call into the base to get a uniqued instance of this type. The parameter
|
||||
// (name) is passed after the kind.
|
||||
return Base::get(ctx, MyTypes::Recursive, name);
|
||||
}
|
||||
|
||||
/// Now we can change the mutable component of the type. This is an instance
|
||||
/// method callable on an already existing RecursiveType.
|
||||
void setBody(Type body) {
|
||||
// Call into the base to mutate the type.
|
||||
LogicalResult result = Base::mutate(body);
|
||||
// Most types expect mutation to always succeed, but types can implement
|
||||
// custom logic for handling mutation failures.
|
||||
assert(succeeded(result) &&
|
||||
"attempting to change the body of an already-initialized type");
|
||||
// Avoid unused-variable warning when building without assertions.
|
||||
(void) result;
|
||||
}
|
||||
|
||||
/// Returns the contained type, which may be null if it has not been
|
||||
/// initialized yet.
|
||||
Type getBody() {
|
||||
return getImpl()->containedType;
|
||||
}
|
||||
|
||||
/// Returns the name.
|
||||
StringRef getName() {
|
||||
return getImpl()->name;
|
||||
}
|
||||
};
|
||||
```
|
||||
|
||||
### Registering types with a Dialect
|
||||
|
||||
Once the dialect types have been defined, they must then be registered with a
|
||||
|
|
|
@ -139,6 +139,13 @@ public:
|
|||
kind, std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
template <typename ImplType, typename... Args>
|
||||
static LogicalResult mutate(MLIRContext *ctx, ImplType *impl,
|
||||
Args &&...args) {
|
||||
assert(impl && "cannot mutate null attribute");
|
||||
return ctx->getAttributeUniquer().mutate(impl, std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
private:
|
||||
/// Initialize the given attribute storage instance.
|
||||
static void initializeAttributeStorage(AttributeStorage *storage,
|
||||
|
|
|
@ -48,10 +48,10 @@ struct SparseElementsAttributeStorage;
|
|||
|
||||
/// Attributes are known-constant values of operations and functions.
|
||||
///
|
||||
/// Instances of the Attribute class are references to immutable, uniqued,
|
||||
/// and immortal values owned by MLIRContext. As such, an Attribute is a thin
|
||||
/// wrapper around an underlying storage pointer. Attributes are usually passed
|
||||
/// by value.
|
||||
/// Instances of the Attribute class are references to immortal key-value pairs
|
||||
/// with immutable, uniqued key owned by MLIRContext. As such, an Attribute is a
|
||||
/// thin wrapper around an underlying storage pointer. Attributes are usually
|
||||
/// passed by value.
|
||||
class Attribute {
|
||||
public:
|
||||
/// Integer identifier for all the concrete attribute kinds.
|
||||
|
|
|
@ -105,6 +105,14 @@ protected:
|
|||
return UniquerT::template get<ConcreteT>(loc.getContext(), kind, args...);
|
||||
}
|
||||
|
||||
/// Mutate the current storage instance. This will not change the unique key.
|
||||
/// The arguments are forwarded to 'ConcreteT::mutate'.
|
||||
template <typename... Args>
|
||||
LogicalResult mutate(Args &&...args) {
|
||||
return UniquerT::mutate(this->getContext(), getImpl(),
|
||||
std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
/// Default implementation that just returns success.
|
||||
template <typename... Args>
|
||||
static LogicalResult verifyConstructionInvariants(Args... args) {
|
||||
|
|
|
@ -132,6 +132,15 @@ struct TypeUniquer {
|
|||
},
|
||||
kind, std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
/// Change the mutable component of the given type instance in the provided
|
||||
/// context.
|
||||
template <typename ImplType, typename... Args>
|
||||
static LogicalResult mutate(MLIRContext *ctx, ImplType *impl,
|
||||
Args &&...args) {
|
||||
assert(impl && "cannot mutate null type");
|
||||
return ctx->getTypeUniquer().mutate(impl, std::forward<Args>(args)...);
|
||||
}
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
|
|
|
@ -27,15 +27,17 @@ struct FunctionTypeStorage;
|
|||
struct OpaqueTypeStorage;
|
||||
} // namespace detail
|
||||
|
||||
/// Instances of the Type class are immutable and uniqued. They wrap a pointer
|
||||
/// to the storage object owned by MLIRContext. Therefore, instances of Type
|
||||
/// are passed around by value.
|
||||
/// Instances of the Type class are uniqued, have an immutable identifier and an
|
||||
/// optional mutable component. They wrap a pointer to the storage object owned
|
||||
/// by MLIRContext. Therefore, instances of Type are passed around by value.
|
||||
///
|
||||
/// Some types are "primitives" meaning they do not have any parameters, for
|
||||
/// example the Index type. Parametric types have additional information that
|
||||
/// differentiates the types of the same kind between them, for example the
|
||||
/// Integer type has bitwidth, making i8 and i16 belong to the same kind by be
|
||||
/// different instances of the IntegerType.
|
||||
/// different instances of the IntegerType. Type parameters are part of the
|
||||
/// unique immutable key. The mutable component of the type can be modified
|
||||
/// after the type is created, but cannot affect the identity of the type.
|
||||
///
|
||||
/// Types are constructed and uniqued via the 'detail::TypeUniquer' class.
|
||||
///
|
||||
|
@ -62,6 +64,7 @@ struct OpaqueTypeStorage;
|
|||
/// - The type kind (for LLVM-style RTTI).
|
||||
/// - The dialect that defined the type.
|
||||
/// - Any parameters of the type.
|
||||
/// - An optional mutable component.
|
||||
/// For non-parametric types, a convenience DefaultTypeStorage is provided.
|
||||
/// Parametric storage types must derive TypeStorage and respect the following:
|
||||
/// - Define a type alias, KeyTy, to a type that uniquely identifies the
|
||||
|
@ -75,11 +78,14 @@ struct OpaqueTypeStorage;
|
|||
/// - Provide a method, 'bool operator==(const KeyTy &) const', to
|
||||
/// compare the storage instance against an instance of the key type.
|
||||
///
|
||||
/// - Provide a construction method:
|
||||
/// - Provide a static construction method:
|
||||
/// 'DerivedStorage *construct(TypeStorageAllocator &, const KeyTy &key)'
|
||||
/// that builds a unique instance of the derived storage. The arguments to
|
||||
/// this function are an allocator to store any uniqued data within the
|
||||
/// context and the key type for this storage.
|
||||
///
|
||||
/// - If they have a mutable component, this component must not be a part of
|
||||
// the key.
|
||||
class Type {
|
||||
public:
|
||||
/// Integer identifier for all the concrete type kinds.
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
#define MLIR_SUPPORT_STORAGEUNIQUER_H
|
||||
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/Support/Allocator.h"
|
||||
|
||||
|
@ -60,6 +61,20 @@ using has_impltype_hash_t = decltype(ImplTy::hashKey(std::declval<T>()));
|
|||
/// that is called when erasing a storage instance. This should cleanup any
|
||||
/// fields of the storage as necessary and not attempt to free the memory
|
||||
/// of the storage itself.
|
||||
///
|
||||
/// Storage classes may have an optional mutable component, which must not take
|
||||
/// part in the unique immutable key. In this case, storage classes may be
|
||||
/// mutated with `mutate` and must additionally respect the following:
|
||||
/// - Provide a mutation method:
|
||||
/// 'LogicalResult mutate(StorageAllocator &, <...>)'
|
||||
/// that is called when mutating a storage instance. The first argument is
|
||||
/// an allocator to store any mutable data, and the remaining arguments are
|
||||
/// forwarded from the call site. The storage can be mutated at any time
|
||||
/// after creation. Care must be taken to avoid excessive mutation since
|
||||
/// the allocated storage can keep containing previous states. The return
|
||||
/// value of the function is used to indicate whether the mutation was
|
||||
/// successful, e.g., to limit the number of mutations or enable deferred
|
||||
/// one-time assignment of the mutable component.
|
||||
class StorageUniquer {
|
||||
public:
|
||||
StorageUniquer();
|
||||
|
@ -166,6 +181,17 @@ public:
|
|||
return static_cast<Storage *>(getImpl(kind, ctorFn));
|
||||
}
|
||||
|
||||
/// Changes the mutable component of 'storage' by forwarding the trailing
|
||||
/// arguments to the 'mutate' function of the derived class.
|
||||
template <typename Storage, typename... Args>
|
||||
LogicalResult mutate(Storage *storage, Args &&...args) {
|
||||
auto mutationFn = [&](StorageAllocator &allocator) -> LogicalResult {
|
||||
return static_cast<Storage &>(*storage).mutate(
|
||||
allocator, std::forward<Args>(args)...);
|
||||
};
|
||||
return mutateImpl(mutationFn);
|
||||
}
|
||||
|
||||
/// Erases a uniqued instance of 'Storage'. This function is used for derived
|
||||
/// types that have complex storage or uniquing constraints.
|
||||
template <typename Storage, typename Arg, typename... Args>
|
||||
|
@ -206,6 +232,10 @@ private:
|
|||
function_ref<bool(const BaseStorage *)> isEqual,
|
||||
function_ref<void(BaseStorage *)> cleanupFn);
|
||||
|
||||
/// Implementation for mutating an instance of a derived storage.
|
||||
LogicalResult
|
||||
mutateImpl(function_ref<LogicalResult(StorageAllocator &)> mutationFn);
|
||||
|
||||
/// The internal implementation class.
|
||||
std::unique_ptr<detail::StorageUniquerImpl> impl;
|
||||
|
||||
|
|
|
@ -124,6 +124,16 @@ struct StorageUniquerImpl {
|
|||
storageTypes.erase(existing);
|
||||
}
|
||||
|
||||
/// Mutates an instance of a derived storage in a thread-safe way.
|
||||
LogicalResult
|
||||
mutate(function_ref<LogicalResult(StorageAllocator &)> mutationFn) {
|
||||
if (!threadingIsEnabled)
|
||||
return mutationFn(allocator);
|
||||
|
||||
llvm::sys::SmartScopedWriter<true> lock(mutex);
|
||||
return mutationFn(allocator);
|
||||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Instance Storage
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
@ -214,3 +224,9 @@ void StorageUniquer::eraseImpl(unsigned kind, unsigned hashValue,
|
|||
function_ref<void(BaseStorage *)> cleanupFn) {
|
||||
impl->erase(kind, hashValue, isEqual, cleanupFn);
|
||||
}
|
||||
|
||||
/// Implementation for mutating an instance of a derived storage.
|
||||
LogicalResult StorageUniquer::mutateImpl(
|
||||
function_ref<LogicalResult(StorageAllocator &)> mutationFn) {
|
||||
return impl->mutate(mutationFn);
|
||||
}
|
||||
|
|
16
mlir/test/IR/recursive-type.mlir
Normal file
16
mlir/test/IR/recursive-type.mlir
Normal file
|
@ -0,0 +1,16 @@
|
|||
// RUN: mlir-opt %s -test-recursive-types | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @roundtrip
|
||||
func @roundtrip() {
|
||||
// CHECK: !test.test_rec<a, test_rec<b, test_type>>
|
||||
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec<a, test_rec<b, test_type>>
|
||||
// CHECK: !test.test_rec<c, test_rec<c>>
|
||||
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec<c, test_rec<c>>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @create
|
||||
func @create() {
|
||||
// CHECK: !test.test_rec<some_long_and_unique_name, test_rec<some_long_and_unique_name>>
|
||||
return
|
||||
}
|
|
@ -16,6 +16,7 @@
|
|||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Transforms/FoldUtils.h"
|
||||
#include "mlir/Transforms/InliningUtils.h"
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
@ -137,19 +138,73 @@ TestDialect::TestDialect(MLIRContext *context)
|
|||
>();
|
||||
addInterfaces<TestOpAsmInterface, TestOpFolderDialectInterface,
|
||||
TestInlinerInterface>();
|
||||
addTypes<TestType>();
|
||||
addTypes<TestType, TestRecursiveType>();
|
||||
allowUnknownOperations();
|
||||
}
|
||||
|
||||
Type TestDialect::parseType(DialectAsmParser &parser) const {
|
||||
if (failed(parser.parseKeyword("test_type")))
|
||||
static Type parseTestType(DialectAsmParser &parser,
|
||||
llvm::SetVector<Type> &stack) {
|
||||
StringRef typeTag;
|
||||
if (failed(parser.parseKeyword(&typeTag)))
|
||||
return Type();
|
||||
return TestType::get(getContext());
|
||||
|
||||
if (typeTag == "test_type")
|
||||
return TestType::get(parser.getBuilder().getContext());
|
||||
|
||||
if (typeTag != "test_rec")
|
||||
return Type();
|
||||
|
||||
StringRef name;
|
||||
if (parser.parseLess() || parser.parseKeyword(&name))
|
||||
return Type();
|
||||
auto rec = TestRecursiveType::create(parser.getBuilder().getContext(), name);
|
||||
|
||||
// If this type already has been parsed above in the stack, expect just the
|
||||
// name.
|
||||
if (stack.contains(rec)) {
|
||||
if (failed(parser.parseGreater()))
|
||||
return Type();
|
||||
return rec;
|
||||
}
|
||||
|
||||
// Otherwise, parse the body and update the type.
|
||||
if (failed(parser.parseComma()))
|
||||
return Type();
|
||||
stack.insert(rec);
|
||||
Type subtype = parseTestType(parser, stack);
|
||||
stack.pop_back();
|
||||
if (!subtype || failed(parser.parseGreater()) || failed(rec.setBody(subtype)))
|
||||
return Type();
|
||||
|
||||
return rec;
|
||||
}
|
||||
|
||||
Type TestDialect::parseType(DialectAsmParser &parser) const {
|
||||
llvm::SetVector<Type> stack;
|
||||
return parseTestType(parser, stack);
|
||||
}
|
||||
|
||||
static void printTestType(Type type, DialectAsmPrinter &printer,
|
||||
llvm::SetVector<Type> &stack) {
|
||||
if (type.isa<TestType>()) {
|
||||
printer << "test_type";
|
||||
return;
|
||||
}
|
||||
|
||||
auto rec = type.cast<TestRecursiveType>();
|
||||
printer << "test_rec<" << rec.getName();
|
||||
if (!stack.contains(rec)) {
|
||||
printer << ", ";
|
||||
stack.insert(rec);
|
||||
printTestType(rec.getBody(), printer, stack);
|
||||
stack.pop_back();
|
||||
}
|
||||
printer << ">";
|
||||
}
|
||||
|
||||
void TestDialect::printType(Type type, DialectAsmPrinter &printer) const {
|
||||
assert(type.isa<TestType>() && "unexpected type");
|
||||
printer << "test_type";
|
||||
llvm::SetVector<Type> stack;
|
||||
printTestType(type, printer, stack);
|
||||
}
|
||||
|
||||
LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
|
||||
|
|
|
@ -39,6 +39,60 @@ struct TestType : public Type::TypeBase<TestType, Type, TypeStorage,
|
|||
emitRemark(loc) << *this << " - TestC";
|
||||
}
|
||||
};
|
||||
|
||||
/// Storage for simple named recursive types, where the type is identified by
|
||||
/// its name and can "contain" another type, including itself.
|
||||
struct TestRecursiveTypeStorage : public TypeStorage {
|
||||
using KeyTy = StringRef;
|
||||
|
||||
explicit TestRecursiveTypeStorage(StringRef key) : name(key), body(Type()) {}
|
||||
|
||||
bool operator==(const KeyTy &other) const { return name == other; }
|
||||
|
||||
static TestRecursiveTypeStorage *construct(TypeStorageAllocator &allocator,
|
||||
const KeyTy &key) {
|
||||
return new (allocator.allocate<TestRecursiveTypeStorage>())
|
||||
TestRecursiveTypeStorage(allocator.copyInto(key));
|
||||
}
|
||||
|
||||
LogicalResult mutate(TypeStorageAllocator &allocator, Type newBody) {
|
||||
// Cannot set a different body than before.
|
||||
if (body && body != newBody)
|
||||
return failure();
|
||||
|
||||
body = newBody;
|
||||
return success();
|
||||
}
|
||||
|
||||
StringRef name;
|
||||
Type body;
|
||||
};
|
||||
|
||||
/// Simple recursive type identified by its name and pointing to another named
|
||||
/// type, potentially itself. This requires the body to be mutated separately
|
||||
/// from type creation.
|
||||
class TestRecursiveType
|
||||
: public Type::TypeBase<TestRecursiveType, Type, TestRecursiveTypeStorage> {
|
||||
public:
|
||||
using Base::Base;
|
||||
|
||||
static bool kindof(unsigned kind) {
|
||||
return kind == Type::Kind::FIRST_PRIVATE_EXPERIMENTAL_9_TYPE + 1;
|
||||
}
|
||||
|
||||
static TestRecursiveType create(MLIRContext *ctx, StringRef name) {
|
||||
return Base::get(ctx, Type::Kind::FIRST_PRIVATE_EXPERIMENTAL_9_TYPE + 1,
|
||||
name);
|
||||
}
|
||||
|
||||
/// Body getter and setter.
|
||||
LogicalResult setBody(Type body) { return Base::mutate(body); }
|
||||
Type getBody() { return getImpl()->body; }
|
||||
|
||||
/// Name/key getter.
|
||||
StringRef getName() { return getImpl()->name; }
|
||||
};
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_TESTTYPES_H
|
||||
|
|
|
@ -5,6 +5,7 @@ add_mlir_library(MLIRTestIR
|
|||
TestMatchers.cpp
|
||||
TestSideEffects.cpp
|
||||
TestSymbolUses.cpp
|
||||
TestTypes.cpp
|
||||
|
||||
EXCLUDE_FROM_LIBMLIR
|
||||
|
||||
|
|
78
mlir/test/lib/IR/TestTypes.cpp
Normal file
78
mlir/test/lib/IR/TestTypes.cpp
Normal file
|
@ -0,0 +1,78 @@
|
|||
//===- TestTypes.cpp - Test passes for MLIR types -------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "TestTypes.h"
|
||||
#include "TestDialect.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
struct TestRecursiveTypesPass
|
||||
: public PassWrapper<TestRecursiveTypesPass, FunctionPass> {
|
||||
LogicalResult createIRWithTypes();
|
||||
|
||||
void runOnFunction() override {
|
||||
FuncOp func = getFunction();
|
||||
|
||||
// Just make sure recurisve types are printed and parsed.
|
||||
if (func.getName() == "roundtrip")
|
||||
return;
|
||||
|
||||
// Create a recursive type and print it as a part of a dummy op.
|
||||
if (func.getName() == "create") {
|
||||
if (failed(createIRWithTypes()))
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
// Unknown key.
|
||||
func.emitOpError() << "unexpected function name";
|
||||
signalPassFailure();
|
||||
}
|
||||
};
|
||||
} // end namespace
|
||||
|
||||
LogicalResult TestRecursiveTypesPass::createIRWithTypes() {
|
||||
MLIRContext *ctx = &getContext();
|
||||
FuncOp func = getFunction();
|
||||
auto type = TestRecursiveType::create(ctx, "some_long_and_unique_name");
|
||||
if (failed(type.setBody(type)))
|
||||
return func.emitError("expected to be able to set the type body");
|
||||
|
||||
// Setting the same body is fine.
|
||||
if (failed(type.setBody(type)))
|
||||
return func.emitError(
|
||||
"expected to be able to set the type body to the same value");
|
||||
|
||||
// Setting a different body is not.
|
||||
if (succeeded(type.setBody(IndexType::get(ctx))))
|
||||
return func.emitError(
|
||||
"not expected to be able to change function body more than once");
|
||||
|
||||
// Expecting to get the same type for the same name.
|
||||
auto other = TestRecursiveType::create(ctx, "some_long_and_unique_name");
|
||||
if (type != other)
|
||||
return func.emitError("expected type name to be the uniquing key");
|
||||
|
||||
// Create the op to check how the type is printed.
|
||||
OperationState state(func.getLoc(), "test.dummy_type_test_op");
|
||||
state.addTypes(type);
|
||||
func.getBody().front().push_front(Operation::create(state));
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
|
||||
void registerTestRecursiveTypesPass() {
|
||||
PassRegistration<TestRecursiveTypesPass> reg(
|
||||
"test-recursive-types", "Test support for recursive types");
|
||||
}
|
||||
|
||||
} // end namespace mlir
|
|
@ -63,6 +63,7 @@ void registerTestMemRefDependenceCheck();
|
|||
void registerTestMemRefStrideCalculation();
|
||||
void registerTestOpaqueLoc();
|
||||
void registerTestPreparationPassWithAllowedMemrefResults();
|
||||
void registerTestRecursiveTypesPass();
|
||||
void registerTestReducer();
|
||||
void registerTestGpuParallelLoopMappingPass();
|
||||
void registerTestSCFUtilsPass();
|
||||
|
@ -138,6 +139,7 @@ void registerTestPasses() {
|
|||
registerTestMemRefStrideCalculation();
|
||||
registerTestOpaqueLoc();
|
||||
registerTestPreparationPassWithAllowedMemrefResults();
|
||||
registerTestRecursiveTypesPass();
|
||||
registerTestReducer();
|
||||
registerTestGpuParallelLoopMappingPass();
|
||||
registerTestSCFUtilsPass();
|
||||
|
|
Loading…
Reference in a new issue