[mlir] Optimize OperationName construction and usage

When constructing an OperationName, the overwhelming majority of
cases are from registered operations. This revision adds a non-locked
lookup into the currently registered operations, which prevents locking
in the common case. This revision also optimizes several uses of
RegisteredOperationName that expect the operation to be registered,
e.g. such as in OpBuilder.

These changes provides a reasonable speedup (5-10%) in some
compilations, especially on platforms where locking is expensive.

Differential Revision: https://reviews.llvm.org/D117187
This commit is contained in:
River Riddle 2022-01-12 23:46:59 -08:00
parent a97e20a3a8
commit 11067d711b
6 changed files with 55 additions and 23 deletions

View file

@ -406,22 +406,27 @@ public:
private:
/// Helper for sanity checking preconditions for create* methods below.
void checkHasRegisteredInfo(const OperationName &name) {
if (LLVM_UNLIKELY(!name.isRegistered()))
template <typename OpT>
RegisteredOperationName getCheckRegisteredInfo(MLIRContext *ctx) {
Optional<RegisteredOperationName> opName =
RegisteredOperationName::lookup(OpT::getOperationName(), ctx);
if (LLVM_UNLIKELY(!opName)) {
llvm::report_fatal_error(
"Building op `" + name.getStringRef() +
"Building op `" + OpT::getOperationName() +
"` but it isn't registered in this MLIRContext: the dialect may not "
"be loaded or this operation isn't registered by the dialect. See "
"also https://mlir.llvm.org/getting_started/Faq/"
"#registered-loaded-dependent-whats-up-with-dialects-management");
}
return *opName;
}
public:
/// Create an operation of specific op type at the current insertion point.
template <typename OpTy, typename... Args>
OpTy create(Location location, Args &&...args) {
OperationState state(location, OpTy::getOperationName());
checkHasRegisteredInfo(state.name);
OperationState state(location,
getCheckRegisteredInfo<OpTy>(location.getContext()));
OpTy::build(*this, state, std::forward<Args>(args)...);
auto *op = createOperation(state);
auto result = dyn_cast<OpTy>(op);
@ -437,8 +442,8 @@ public:
Args &&...args) {
// Create the operation without using 'createOperation' as we don't want to
// insert it yet.
OperationState state(location, OpTy::getOperationName());
checkHasRegisteredInfo(state.name);
OperationState state(location,
getCheckRegisteredInfo<OpTy>(location.getContext()));
OpTy::build(*this, state, std::forward<Args>(args)...);
Operation *op = Operation::create(state);

View file

@ -231,9 +231,7 @@ public:
/// Lookup the registered operation information for the given operation.
/// Returns None if the operation isn't registered.
static Optional<RegisteredOperationName> lookup(StringRef name,
MLIRContext *ctx) {
return OperationName(name, ctx).getRegisteredInfo();
}
MLIRContext *ctx);
/// Register a new operation in a Dialect object.
/// This constructor is used by Dialect objects when they register the list of
@ -582,9 +580,12 @@ struct OperationState {
public:
OperationState(Location location, StringRef name);
OperationState(Location location, OperationName name);
OperationState(Location location, OperationName name, ValueRange operands,
TypeRange types, ArrayRef<NamedAttribute> attributes,
BlockRange successors = {},
MutableArrayRef<std::unique_ptr<Region>> regions = {});
OperationState(Location location, StringRef name, ValueRange operands,
TypeRange types, ArrayRef<NamedAttribute> attributes,
BlockRange successors = {},

View file

@ -1406,9 +1406,9 @@ static Operation *widenOp(Operation *op, VectorizationState &state) {
// name that works both in scalar mode and vector mode.
// TODO: Is it worth considering an Operation.clone operation which
// changes the type so we can promote an Operation with less boilerplate?
OperationState vecOpState(op->getLoc(), op->getName().getStringRef(),
vectorOperands, vectorTypes, op->getAttrs(),
/*successors=*/{}, /*regions=*/{});
OperationState vecOpState(op->getLoc(), op->getName(), vectorOperands,
vectorTypes, op->getAttrs(), /*successors=*/{},
/*regions=*/{});
Operation *vecOp = state.builder.createOperation(vecOpState);
state.registerOpVectorReplacement(op, vecOp);
return vecOp;

View file

@ -70,8 +70,7 @@ static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc,
Operation *op,
ArrayRef<Value> operands,
ArrayRef<Type> resultTypes) {
OperationState res(loc, op->getName().getStringRef(), operands, resultTypes,
op->getAttrs());
OperationState res(loc, op->getName(), operands, resultTypes, op->getAttrs());
return builder.createOperation(res);
}

View file

@ -182,7 +182,7 @@ public:
llvm::StringMap<OperationName::Impl> operations;
/// A vector of operation info specifically for registered operations.
SmallVector<RegisteredOperationName> registeredOperations;
llvm::StringMap<RegisteredOperationName> registeredOperations;
/// A mutex used when accessing operation information.
llvm::sys::SmartRWMutex<true> operationInfoMutex;
@ -576,8 +576,9 @@ std::vector<RegisteredOperationName> MLIRContext::getRegisteredOperations() {
// We just have the operations in a non-deterministic hash table order. Dump
// into a temporary array, then sort it by operation name to get a stable
// ordering.
std::vector<RegisteredOperationName> result(
impl->registeredOperations.begin(), impl->registeredOperations.end());
auto unwrappedNames = llvm::make_second_range(impl->registeredOperations);
std::vector<RegisteredOperationName> result(unwrappedNames.begin(),
unwrappedNames.end());
llvm::array_pod_sort(result.begin(), result.end(),
[](const RegisteredOperationName *lhs,
const RegisteredOperationName *rhs) {
@ -589,7 +590,7 @@ std::vector<RegisteredOperationName> MLIRContext::getRegisteredOperations() {
}
bool MLIRContext::isOperationRegistered(StringRef name) {
return OperationName(name, this).isRegistered();
return RegisteredOperationName::lookup(name, this).hasValue();
}
void Dialect::addType(TypeID typeID, AbstractType &&typeInfo) {
@ -649,6 +650,15 @@ OperationName::OperationName(StringRef name, MLIRContext *context) {
// Check for an existing name in read-only mode.
bool isMultithreadingEnabled = context->isMultithreadingEnabled();
if (isMultithreadingEnabled) {
// Check the registered info map first. In the overwhelmingly common case,
// the entry will be in here and it also removes the need to acquire any
// locks.
auto registeredIt = ctxImpl.registeredOperations.find(name);
if (LLVM_LIKELY(registeredIt != ctxImpl.registeredOperations.end())) {
impl = registeredIt->second.impl;
return;
}
llvm::sys::SmartScopedReader<true> contextLock(ctxImpl.operationInfoMutex);
auto it = ctxImpl.operations.find(name);
if (it != ctxImpl.operations.end()) {
@ -676,6 +686,15 @@ StringRef OperationName::getDialectNamespace() const {
// RegisteredOperationName
//===----------------------------------------------------------------------===//
Optional<RegisteredOperationName>
RegisteredOperationName::lookup(StringRef name, MLIRContext *ctx) {
auto &impl = ctx->getImpl();
auto it = impl.registeredOperations.find(name);
if (it != impl.registeredOperations.end())
return it->getValue();
return llvm::None;
}
ParseResult
RegisteredOperationName::parseAssembly(OpAsmParser &parser,
OperationState &result) const {
@ -717,7 +736,8 @@ void RegisteredOperationName::insert(
<< "' is already registered.\n";
abort();
}
ctxImpl.registeredOperations.push_back(RegisteredOperationName(&impl));
ctxImpl.registeredOperations.try_emplace(name,
RegisteredOperationName(&impl));
// Update the registered info for this operation.
impl.dialect = &dialect;

View file

@ -170,12 +170,12 @@ OperationState::OperationState(Location location, StringRef name)
OperationState::OperationState(Location location, OperationName name)
: location(location), name(name) {}
OperationState::OperationState(Location location, StringRef name,
OperationState::OperationState(Location location, OperationName name,
ValueRange operands, TypeRange types,
ArrayRef<NamedAttribute> attributes,
BlockRange successors,
MutableArrayRef<std::unique_ptr<Region>> regions)
: location(location), name(name, location->getContext()),
: location(location), name(name),
operands(operands.begin(), operands.end()),
types(types.begin(), types.end()),
attributes(attributes.begin(), attributes.end()),
@ -183,6 +183,13 @@ OperationState::OperationState(Location location, StringRef name,
for (std::unique_ptr<Region> &r : regions)
this->regions.push_back(std::move(r));
}
OperationState::OperationState(Location location, StringRef name,
ValueRange operands, TypeRange types,
ArrayRef<NamedAttribute> attributes,
BlockRange successors,
MutableArrayRef<std::unique_ptr<Region>> regions)
: OperationState(location, OperationName(name, location.getContext()),
operands, types, attributes, successors, regions) {}
void OperationState::addOperands(ValueRange newOperands) {
operands.append(newOperands.begin(), newOperands.end());