[mlir] Optimize OperationName construction and usage
When constructing an OperationName, the overwhelming majority of cases are from registered operations. This revision adds a non-locked lookup into the currently registered operations, which prevents locking in the common case. This revision also optimizes several uses of RegisteredOperationName that expect the operation to be registered, e.g. such as in OpBuilder. These changes provides a reasonable speedup (5-10%) in some compilations, especially on platforms where locking is expensive. Differential Revision: https://reviews.llvm.org/D117187
This commit is contained in:
parent
a97e20a3a8
commit
11067d711b
|
@ -406,22 +406,27 @@ public:
|
|||
|
||||
private:
|
||||
/// Helper for sanity checking preconditions for create* methods below.
|
||||
void checkHasRegisteredInfo(const OperationName &name) {
|
||||
if (LLVM_UNLIKELY(!name.isRegistered()))
|
||||
template <typename OpT>
|
||||
RegisteredOperationName getCheckRegisteredInfo(MLIRContext *ctx) {
|
||||
Optional<RegisteredOperationName> opName =
|
||||
RegisteredOperationName::lookup(OpT::getOperationName(), ctx);
|
||||
if (LLVM_UNLIKELY(!opName)) {
|
||||
llvm::report_fatal_error(
|
||||
"Building op `" + name.getStringRef() +
|
||||
"Building op `" + OpT::getOperationName() +
|
||||
"` but it isn't registered in this MLIRContext: the dialect may not "
|
||||
"be loaded or this operation isn't registered by the dialect. See "
|
||||
"also https://mlir.llvm.org/getting_started/Faq/"
|
||||
"#registered-loaded-dependent-whats-up-with-dialects-management");
|
||||
}
|
||||
return *opName;
|
||||
}
|
||||
|
||||
public:
|
||||
/// Create an operation of specific op type at the current insertion point.
|
||||
template <typename OpTy, typename... Args>
|
||||
OpTy create(Location location, Args &&...args) {
|
||||
OperationState state(location, OpTy::getOperationName());
|
||||
checkHasRegisteredInfo(state.name);
|
||||
OperationState state(location,
|
||||
getCheckRegisteredInfo<OpTy>(location.getContext()));
|
||||
OpTy::build(*this, state, std::forward<Args>(args)...);
|
||||
auto *op = createOperation(state);
|
||||
auto result = dyn_cast<OpTy>(op);
|
||||
|
@ -437,8 +442,8 @@ public:
|
|||
Args &&...args) {
|
||||
// Create the operation without using 'createOperation' as we don't want to
|
||||
// insert it yet.
|
||||
OperationState state(location, OpTy::getOperationName());
|
||||
checkHasRegisteredInfo(state.name);
|
||||
OperationState state(location,
|
||||
getCheckRegisteredInfo<OpTy>(location.getContext()));
|
||||
OpTy::build(*this, state, std::forward<Args>(args)...);
|
||||
Operation *op = Operation::create(state);
|
||||
|
||||
|
|
|
@ -231,9 +231,7 @@ public:
|
|||
/// Lookup the registered operation information for the given operation.
|
||||
/// Returns None if the operation isn't registered.
|
||||
static Optional<RegisteredOperationName> lookup(StringRef name,
|
||||
MLIRContext *ctx) {
|
||||
return OperationName(name, ctx).getRegisteredInfo();
|
||||
}
|
||||
MLIRContext *ctx);
|
||||
|
||||
/// Register a new operation in a Dialect object.
|
||||
/// This constructor is used by Dialect objects when they register the list of
|
||||
|
@ -582,9 +580,12 @@ struct OperationState {
|
|||
|
||||
public:
|
||||
OperationState(Location location, StringRef name);
|
||||
|
||||
OperationState(Location location, OperationName name);
|
||||
|
||||
OperationState(Location location, OperationName name, ValueRange operands,
|
||||
TypeRange types, ArrayRef<NamedAttribute> attributes,
|
||||
BlockRange successors = {},
|
||||
MutableArrayRef<std::unique_ptr<Region>> regions = {});
|
||||
OperationState(Location location, StringRef name, ValueRange operands,
|
||||
TypeRange types, ArrayRef<NamedAttribute> attributes,
|
||||
BlockRange successors = {},
|
||||
|
|
|
@ -1406,9 +1406,9 @@ static Operation *widenOp(Operation *op, VectorizationState &state) {
|
|||
// name that works both in scalar mode and vector mode.
|
||||
// TODO: Is it worth considering an Operation.clone operation which
|
||||
// changes the type so we can promote an Operation with less boilerplate?
|
||||
OperationState vecOpState(op->getLoc(), op->getName().getStringRef(),
|
||||
vectorOperands, vectorTypes, op->getAttrs(),
|
||||
/*successors=*/{}, /*regions=*/{});
|
||||
OperationState vecOpState(op->getLoc(), op->getName(), vectorOperands,
|
||||
vectorTypes, op->getAttrs(), /*successors=*/{},
|
||||
/*regions=*/{});
|
||||
Operation *vecOp = state.builder.createOperation(vecOpState);
|
||||
state.registerOpVectorReplacement(op, vecOp);
|
||||
return vecOp;
|
||||
|
|
|
@ -70,8 +70,7 @@ static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc,
|
|||
Operation *op,
|
||||
ArrayRef<Value> operands,
|
||||
ArrayRef<Type> resultTypes) {
|
||||
OperationState res(loc, op->getName().getStringRef(), operands, resultTypes,
|
||||
op->getAttrs());
|
||||
OperationState res(loc, op->getName(), operands, resultTypes, op->getAttrs());
|
||||
return builder.createOperation(res);
|
||||
}
|
||||
|
||||
|
|
|
@ -182,7 +182,7 @@ public:
|
|||
llvm::StringMap<OperationName::Impl> operations;
|
||||
|
||||
/// A vector of operation info specifically for registered operations.
|
||||
SmallVector<RegisteredOperationName> registeredOperations;
|
||||
llvm::StringMap<RegisteredOperationName> registeredOperations;
|
||||
|
||||
/// A mutex used when accessing operation information.
|
||||
llvm::sys::SmartRWMutex<true> operationInfoMutex;
|
||||
|
@ -576,8 +576,9 @@ std::vector<RegisteredOperationName> MLIRContext::getRegisteredOperations() {
|
|||
// We just have the operations in a non-deterministic hash table order. Dump
|
||||
// into a temporary array, then sort it by operation name to get a stable
|
||||
// ordering.
|
||||
std::vector<RegisteredOperationName> result(
|
||||
impl->registeredOperations.begin(), impl->registeredOperations.end());
|
||||
auto unwrappedNames = llvm::make_second_range(impl->registeredOperations);
|
||||
std::vector<RegisteredOperationName> result(unwrappedNames.begin(),
|
||||
unwrappedNames.end());
|
||||
llvm::array_pod_sort(result.begin(), result.end(),
|
||||
[](const RegisteredOperationName *lhs,
|
||||
const RegisteredOperationName *rhs) {
|
||||
|
@ -589,7 +590,7 @@ std::vector<RegisteredOperationName> MLIRContext::getRegisteredOperations() {
|
|||
}
|
||||
|
||||
bool MLIRContext::isOperationRegistered(StringRef name) {
|
||||
return OperationName(name, this).isRegistered();
|
||||
return RegisteredOperationName::lookup(name, this).hasValue();
|
||||
}
|
||||
|
||||
void Dialect::addType(TypeID typeID, AbstractType &&typeInfo) {
|
||||
|
@ -649,6 +650,15 @@ OperationName::OperationName(StringRef name, MLIRContext *context) {
|
|||
// Check for an existing name in read-only mode.
|
||||
bool isMultithreadingEnabled = context->isMultithreadingEnabled();
|
||||
if (isMultithreadingEnabled) {
|
||||
// Check the registered info map first. In the overwhelmingly common case,
|
||||
// the entry will be in here and it also removes the need to acquire any
|
||||
// locks.
|
||||
auto registeredIt = ctxImpl.registeredOperations.find(name);
|
||||
if (LLVM_LIKELY(registeredIt != ctxImpl.registeredOperations.end())) {
|
||||
impl = registeredIt->second.impl;
|
||||
return;
|
||||
}
|
||||
|
||||
llvm::sys::SmartScopedReader<true> contextLock(ctxImpl.operationInfoMutex);
|
||||
auto it = ctxImpl.operations.find(name);
|
||||
if (it != ctxImpl.operations.end()) {
|
||||
|
@ -676,6 +686,15 @@ StringRef OperationName::getDialectNamespace() const {
|
|||
// RegisteredOperationName
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Optional<RegisteredOperationName>
|
||||
RegisteredOperationName::lookup(StringRef name, MLIRContext *ctx) {
|
||||
auto &impl = ctx->getImpl();
|
||||
auto it = impl.registeredOperations.find(name);
|
||||
if (it != impl.registeredOperations.end())
|
||||
return it->getValue();
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
ParseResult
|
||||
RegisteredOperationName::parseAssembly(OpAsmParser &parser,
|
||||
OperationState &result) const {
|
||||
|
@ -717,7 +736,8 @@ void RegisteredOperationName::insert(
|
|||
<< "' is already registered.\n";
|
||||
abort();
|
||||
}
|
||||
ctxImpl.registeredOperations.push_back(RegisteredOperationName(&impl));
|
||||
ctxImpl.registeredOperations.try_emplace(name,
|
||||
RegisteredOperationName(&impl));
|
||||
|
||||
// Update the registered info for this operation.
|
||||
impl.dialect = &dialect;
|
||||
|
|
|
@ -170,12 +170,12 @@ OperationState::OperationState(Location location, StringRef name)
|
|||
OperationState::OperationState(Location location, OperationName name)
|
||||
: location(location), name(name) {}
|
||||
|
||||
OperationState::OperationState(Location location, StringRef name,
|
||||
OperationState::OperationState(Location location, OperationName name,
|
||||
ValueRange operands, TypeRange types,
|
||||
ArrayRef<NamedAttribute> attributes,
|
||||
BlockRange successors,
|
||||
MutableArrayRef<std::unique_ptr<Region>> regions)
|
||||
: location(location), name(name, location->getContext()),
|
||||
: location(location), name(name),
|
||||
operands(operands.begin(), operands.end()),
|
||||
types(types.begin(), types.end()),
|
||||
attributes(attributes.begin(), attributes.end()),
|
||||
|
@ -183,6 +183,13 @@ OperationState::OperationState(Location location, StringRef name,
|
|||
for (std::unique_ptr<Region> &r : regions)
|
||||
this->regions.push_back(std::move(r));
|
||||
}
|
||||
OperationState::OperationState(Location location, StringRef name,
|
||||
ValueRange operands, TypeRange types,
|
||||
ArrayRef<NamedAttribute> attributes,
|
||||
BlockRange successors,
|
||||
MutableArrayRef<std::unique_ptr<Region>> regions)
|
||||
: OperationState(location, OperationName(name, location.getContext()),
|
||||
operands, types, attributes, successors, regions) {}
|
||||
|
||||
void OperationState::addOperands(ValueRange newOperands) {
|
||||
operands.append(newOperands.begin(), newOperands.end());
|
||||
|
|
Loading…
Reference in a new issue