[mlir] Remove locking for dialect/operation registration.

Moving forward dialects should only be registered in a thread safe context. This matches the existing usage we have today, but it allows for removing quite a bit of expensive locking from the context.

This led to ~.5 a second compile time improvement when running one conversion pass on a very large .mlir file(hundreds of thousands of operations).

Differential Revision: https://reviews.llvm.org/D82595
This commit is contained in:
River Riddle 2020-06-30 15:43:03 -07:00
parent 2e2cdd0a52
commit 5d699d18b3
2 changed files with 17 additions and 48 deletions

View file

@ -258,10 +258,12 @@ private:
};
/// Registers all dialects and hooks from the global registries with the
/// specified MLIRContext.
/// Note: This method is not thread-safe.
void registerAllDialects(MLIRContext *context);
/// Utility to register a dialect. Client can register their dialect with the
/// global registry by calling registerDialect<MyDialect>();
/// Note: This method is not thread-safe.
template <typename ConcreteDialect> void registerDialect() {
Dialect::registerDialectAllocator(TypeID::get<ConcreteDialect>(),
[](MLIRContext *ctx) {

View file

@ -270,10 +270,6 @@ public:
// Other
//===--------------------------------------------------------------------===//
/// A general purpose mutex to lock access to parts of the context that do not
/// have a more specific mutex, e.g. registry operations.
llvm::sys::SmartRWMutex<true> contextMutex;
/// This is a list of dialects that are created referring to this context.
/// The MLIRContext owns the objects.
std::vector<std::unique_ptr<Dialect>> dialects;
@ -425,8 +421,6 @@ DiagnosticEngine &MLIRContext::getDiagEngine() { return getImpl().diagEngine; }
/// Return information about all registered IR dialects.
std::vector<Dialect *> MLIRContext::getRegisteredDialects() {
// Lock access to the context registry.
ScopedReaderLock registryLock(impl->contextMutex, impl->threadingIsEnabled);
std::vector<Dialect *> result;
result.reserve(impl->dialects.size());
for (auto &dialect : impl->dialects)
@ -437,9 +431,6 @@ std::vector<Dialect *> MLIRContext::getRegisteredDialects() {
/// Get a registered IR dialect with the given namespace. If none is found,
/// then return nullptr.
Dialect *MLIRContext::getRegisteredDialect(StringRef name) {
// Lock access to the context registry.
ScopedReaderLock registryLock(impl->contextMutex, impl->threadingIsEnabled);
// Dialects are sorted by name, so we can use binary search for lookup.
auto it = llvm::lower_bound(
impl->dialects, name,
@ -455,9 +446,6 @@ void Dialect::registerDialect(MLIRContext *context) {
auto &impl = context->getImpl();
std::unique_ptr<Dialect> dialect(this);
// Lock access to the context registry.
ScopedWriterLock registryLock(impl.contextMutex, impl.threadingIsEnabled);
// Get the correct insertion position sorted by namespace.
auto insertPt = llvm::lower_bound(
impl.dialects, dialect, [](const auto &lhs, const auto &rhs) {
@ -524,35 +512,26 @@ void MLIRContext::printStackTraceOnDiagnostic(bool enable) {
/// efficient, typically you should ask the operations about their properties
/// directly.
std::vector<AbstractOperation *> MLIRContext::getRegisteredOperations() {
std::vector<std::pair<StringRef, AbstractOperation *>> opsToSort;
{ // Lock access to the context registry.
ScopedReaderLock registryLock(impl->contextMutex, impl->threadingIsEnabled);
// We just have the operations in a non-deterministic hash table order. Dump
// into a temporary array, then sort it by operation name to get a stable
// ordering.
llvm::StringMap<AbstractOperation> &registeredOps =
impl->registeredOperations;
opsToSort.reserve(registeredOps.size());
for (auto &elt : registeredOps)
opsToSort.push_back({elt.first(), &elt.second});
}
llvm::array_pod_sort(opsToSort.begin(), opsToSort.end());
// We just have the operations in a non-deterministic hash table order. Dump
// into a temporary array, then sort it by operation name to get a stable
// ordering.
llvm::StringMap<AbstractOperation> &registeredOps =
impl->registeredOperations;
std::vector<AbstractOperation *> result;
result.reserve(opsToSort.size());
for (auto &elt : opsToSort)
result.push_back(elt.second);
result.reserve(registeredOps.size());
for (auto &elt : registeredOps)
result.push_back(&elt.second);
llvm::array_pod_sort(
result.begin(), result.end(),
[](AbstractOperation *const *lhs, AbstractOperation *const *rhs) {
return (*lhs)->name.compare((*rhs)->name);
});
return result;
}
bool MLIRContext::isOperationRegistered(StringRef name) {
// Lock access to the context registry.
ScopedReaderLock registryLock(impl->contextMutex, impl->threadingIsEnabled);
return impl->registeredOperations.count(name);
}
@ -561,12 +540,9 @@ void Dialect::addOperation(AbstractOperation opInfo) {
"op name doesn't start with dialect namespace");
assert(&opInfo.dialect == this && "Dialect object mismatch");
auto &impl = context->getImpl();
// Lock access to the context registry.
StringRef opName = opInfo.name;
ScopedWriterLock registryLock(impl.contextMutex, impl.threadingIsEnabled);
if (!impl.registeredOperations.insert({opName, std::move(opInfo)}).second) {
llvm::errs() << "error: operation named '" << opName
llvm::errs() << "error: operation named '" << opInfo.name
<< "' is already registered.\n";
abort();
}
@ -574,9 +550,6 @@ void Dialect::addOperation(AbstractOperation opInfo) {
void Dialect::addType(TypeID typeID, AbstractType &&typeInfo) {
auto &impl = context->getImpl();
// Lock access to the context registry.
ScopedWriterLock registryLock(impl.contextMutex, impl.threadingIsEnabled);
auto *newInfo =
new (impl.abstractDialectSymbolAllocator.Allocate<AbstractType>())
AbstractType(std::move(typeInfo));
@ -586,9 +559,6 @@ void Dialect::addType(TypeID typeID, AbstractType &&typeInfo) {
void Dialect::addAttribute(TypeID typeID, AbstractAttribute &&attrInfo) {
auto &impl = context->getImpl();
// Lock access to the context registry.
ScopedWriterLock registryLock(impl.contextMutex, impl.threadingIsEnabled);
auto *newInfo =
new (impl.abstractDialectSymbolAllocator.Allocate<AbstractAttribute>())
AbstractAttribute(std::move(attrInfo));
@ -612,9 +582,6 @@ const AbstractAttribute &AbstractAttribute::lookup(TypeID typeID,
const AbstractOperation *AbstractOperation::lookup(StringRef opName,
MLIRContext *context) {
auto &impl = context->getImpl();
// Lock access to the context registry.
ScopedReaderLock registryLock(impl.contextMutex, impl.threadingIsEnabled);
auto it = impl.registeredOperations.find(opName);
if (it != impl.registeredOperations.end())
return &it->second;