[mlir] Resolve TODO and use the pass argument instead of the TypeID for registration

This simplifies various pieces of code that interact with the pass registry, e.g. this removes the need to register passes to get accurate pass pipelines descriptions when generating crash reproducers.

Differential Revision: https://reviews.llvm.org/D101880
This commit is contained in:
River Riddle 2021-06-02 12:06:32 -07:00
parent 8beaca8c14
commit fa51c5af5d
4 changed files with 34 additions and 30 deletions

View file

@ -56,13 +56,12 @@ public:
TypeID getTypeID() const { return passID; }
/// Returns the pass info for the specified pass class or null if unknown.
static const PassInfo *lookupPassInfo(TypeID passID);
template <typename PassT> static const PassInfo *lookupPassInfo() {
return lookupPassInfo(TypeID::get<PassT>());
}
static const PassInfo *lookupPassInfo(StringRef passArg);
/// Returns the pass info for this pass.
const PassInfo *lookupPassInfo() const { return lookupPassInfo(getTypeID()); }
/// Returns the pass info for this pass, or null if unknown.
const PassInfo *lookupPassInfo() const {
return lookupPassInfo(getArgument());
}
/// Returns the derived pass name.
virtual StringRef getName() const = 0;
@ -76,11 +75,7 @@ public:
/// Returns the command line argument used when registering this pass. Return
/// an empty string if one does not exist.
virtual StringRef getArgument() const {
if (const PassInfo *passInfo = lookupPassInfo())
return passInfo->getPassArgument();
return "";
}
virtual StringRef getArgument() const { return ""; }
/// Returns the name of the operation that this pass operates on, or None if
/// this is a generic OperationPass.

View file

@ -108,7 +108,7 @@ class PassInfo : public PassRegistryEntry {
public:
/// PassInfo constructor should not be invoked directly, instead use
/// PassRegistration or registerPass.
PassInfo(StringRef arg, StringRef description, TypeID passID,
PassInfo(StringRef arg, StringRef description,
const PassAllocatorFunction &allocator);
};

View file

@ -19,7 +19,11 @@ using namespace mlir;
using namespace detail;
/// Static mapping of all of the registered passes.
static llvm::ManagedStatic<DenseMap<TypeID, PassInfo>> passRegistry;
static llvm::ManagedStatic<llvm::StringMap<PassInfo>> passRegistry;
/// A mapping of the above pass registry entries to the corresponding TypeID
/// of the pass that they generate.
static llvm::ManagedStatic<llvm::StringMap<TypeID>> passRegistryTypeIDs;
/// Static mapping of all of the registered pass pipelines.
static llvm::ManagedStatic<llvm::StringMap<PassPipelineInfo>>
@ -94,7 +98,7 @@ void mlir::registerPassPipeline(
// PassInfo
//===----------------------------------------------------------------------===//
PassInfo::PassInfo(StringRef arg, StringRef description, TypeID passID,
PassInfo::PassInfo(StringRef arg, StringRef description,
const PassAllocatorFunction &allocator)
: PassRegistryEntry(
arg, description, buildDefaultRegistryFn(allocator),
@ -105,18 +109,23 @@ PassInfo::PassInfo(StringRef arg, StringRef description, TypeID passID,
void mlir::registerPass(StringRef arg, StringRef description,
const PassAllocatorFunction &function) {
// TODO: We should use the 'arg' as the lookup key instead of the pass id.
TypeID passID = function()->getTypeID();
PassInfo passInfo(arg, description, passID, function);
passRegistry->try_emplace(passID, passInfo);
PassInfo passInfo(arg, description, function);
passRegistry->try_emplace(arg, passInfo);
// Verify that the registered pass has the same ID as any registered to this
// arg before it.
TypeID entryTypeID = function()->getTypeID();
auto it = passRegistryTypeIDs->try_emplace(arg, entryTypeID).first;
if (it->second != entryTypeID) {
llvm_unreachable("pass allocator creates a different pass than previously "
"registered");
}
}
/// Returns the pass info for the specified pass class or null if unknown.
const PassInfo *mlir::Pass::lookupPassInfo(TypeID passID) {
auto it = passRegistry->find(passID);
if (it == passRegistry->end())
return nullptr;
return &it->getSecond();
/// Returns the pass info for the specified pass argument or null if unknown.
const PassInfo *mlir::Pass::lookupPassInfo(StringRef passArg) {
auto it = passRegistry->find(passArg);
return it == passRegistry->end() ? nullptr : &it->second;
}
//===----------------------------------------------------------------------===//
@ -433,12 +442,8 @@ TextualPipeline::resolvePipelineElement(PipelineElement &element,
}
// If not, then this must be a specific pass name.
for (auto &passIt : *passRegistry) {
if (passIt.second.getPassArgument() == element.name) {
element.registryEntry = &passIt.second;
return success();
}
}
if ((element.registryEntry = Pass::lookupPassInfo(element.name)))
return success();
// Emit an error for the unknown pass.
auto *rawLoc = element.name.data();

View file

@ -16,9 +16,11 @@ namespace {
struct TestModulePass
: public PassWrapper<TestModulePass, OperationPass<ModuleOp>> {
void runOnOperation() final {}
StringRef getArgument() const final { return "test-module-pass"; }
};
struct TestFunctionPass : public PassWrapper<TestFunctionPass, FunctionPass> {
void runOnFunction() final {}
StringRef getArgument() const final { return "test-function-pass"; }
};
class TestOptionsPass : public PassWrapper<TestOptionsPass, FunctionPass> {
public:
@ -41,6 +43,7 @@ public:
}
void runOnFunction() final {}
StringRef getArgument() const final { return "test-options-pass"; }
ListOption<int> listOption{*this, "list", llvm::cl::MiscFlags::CommaSeparated,
llvm::cl::desc("Example list option")};
@ -56,6 +59,7 @@ public:
class TestCrashRecoveryPass
: public PassWrapper<TestCrashRecoveryPass, OperationPass<>> {
void runOnOperation() final { abort(); }
StringRef getArgument() const final { return "test-pass-crash"; }
};
/// A test pass that always fails to enable testing the failure recovery