Use an Identifier instead of an OperationName internally for OpPassManager identification (NFC)

This allows to defers the check for traits to the execution instead of forcing it on the pipeline creation.
In particular, this is making our pipeline creation tolerant to dialects not being loaded in the context yet.

Reviewed By: rriddle, GMNGeoffrey

Differential Revision: https://reviews.llvm.org/D86915
This commit is contained in:
Mehdi Amini 2020-09-02 20:09:07 +00:00
parent 553bfc8fa1
commit 1284dc34ab
6 changed files with 97 additions and 49 deletions

View file

@ -9,12 +9,12 @@
#ifndef MLIR_PASS_PASSINSTRUMENTATION_H_
#define MLIR_PASS_PASSINSTRUMENTATION_H_
#include "mlir/IR/Identifier.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/TypeID.h"
namespace mlir {
class Operation;
class OperationName;
class Pass;
namespace detail {
@ -43,13 +43,13 @@ public:
/// A callback to run before a pass pipeline is executed. This function takes
/// the name of the operation type being operated on, and information related
/// to the parent that spawned this pipeline.
virtual void runBeforePipeline(const OperationName &name,
virtual void runBeforePipeline(Identifier name,
const PipelineParentInfo &parentInfo) {}
/// A callback to run after a pass pipeline has executed. This function takes
/// the name of the operation type being operated on, and information related
/// to the parent that spawned this pipeline.
virtual void runAfterPipeline(const OperationName &name,
virtual void runAfterPipeline(Identifier name,
const PipelineParentInfo &parentInfo) {}
/// A callback to run before a pass is executed. This function takes a pointer
@ -90,12 +90,12 @@ public:
/// See PassInstrumentation::runBeforePipeline for details.
void
runBeforePipeline(const OperationName &name,
runBeforePipeline(Identifier name,
const PassInstrumentation::PipelineParentInfo &parentInfo);
/// See PassInstrumentation::runAfterPipeline for details.
void
runAfterPipeline(const OperationName &name,
runAfterPipeline(Identifier name,
const PassInstrumentation::PipelineParentInfo &parentInfo);
/// See PassInstrumentation::runBeforePass for details.

View file

@ -26,9 +26,9 @@ class Any;
namespace mlir {
class AnalysisManager;
class Identifier;
class MLIRContext;
class ModuleOp;
class OperationName;
class Operation;
class Pass;
class PassInstrumentation;
@ -47,7 +47,7 @@ struct OpPassManagerImpl;
/// other OpPassManagers or the top-level PassManager.
class OpPassManager {
public:
OpPassManager(OperationName name, bool verifyPasses);
OpPassManager(Identifier name, MLIRContext *context, bool verifyPasses);
OpPassManager(OpPassManager &&rhs);
OpPassManager(const OpPassManager &rhs);
~OpPassManager();
@ -70,10 +70,10 @@ public:
/// Nest a new operation pass manager for the given operation kind under this
/// pass manager.
OpPassManager &nest(const OperationName &nestedName);
OpPassManager &nest(Identifier nestedName);
OpPassManager &nest(StringRef nestedName);
template <typename OpT> OpPassManager &nest() {
return nest(OpT::getOperationName());
return nest(Identifier::get(OpT::getOperationName(), getContext()));
}
/// Add the given pass to this pass manager. If this pass has a concrete
@ -93,7 +93,7 @@ public:
MLIRContext *getContext() const;
/// Return the operation name that this pass manager operates on.
const OperationName &getOpName() const;
Identifier getOpName() const;
/// Returns the internal implementation instance.
detail::OpPassManagerImpl &getImpl();

View file

@ -92,17 +92,17 @@ void VerifierPass::runOnOperation() {
namespace mlir {
namespace detail {
struct OpPassManagerImpl {
OpPassManagerImpl(OperationName name, bool verifyPasses)
: name(name), verifyPasses(verifyPasses) {}
OpPassManagerImpl(Identifier name, MLIRContext *ctx, bool verifyPasses)
: name(name), context(ctx), verifyPasses(verifyPasses) {}
/// Merge the passes of this pass manager into the one provided.
void mergeInto(OpPassManagerImpl &rhs);
/// Nest a new operation pass manager for the given operation kind under this
/// pass manager.
OpPassManager &nest(const OperationName &nestedName);
OpPassManager &nest(Identifier nestedName);
OpPassManager &nest(StringRef nestedName) {
return nest(OperationName(nestedName, getContext()));
return nest(Identifier::get(nestedName, getContext()));
}
/// Add the given pass to this pass manager. If this pass has a concrete
@ -118,12 +118,13 @@ struct OpPassManagerImpl {
void splitAdaptorPasses();
/// Return an instance of the context.
MLIRContext *getContext() const {
return name.getAbstractOperation()->dialect.getContext();
}
MLIRContext *getContext() const { return context; }
/// The name of the operation that passes of this pass manager operate on.
OperationName name;
Identifier name;
/// The current context for this pass manager
MLIRContext *context;
/// Flag that specifies if the IR should be verified after each pass has run.
bool verifyPasses : 1;
@ -141,8 +142,8 @@ void OpPassManagerImpl::mergeInto(OpPassManagerImpl &rhs) {
passes.clear();
}
OpPassManager &OpPassManagerImpl::nest(const OperationName &nestedName) {
OpPassManager nested(nestedName, verifyPasses);
OpPassManager &OpPassManagerImpl::nest(Identifier nestedName) {
OpPassManager nested(nestedName, getContext(), verifyPasses);
auto *adaptor = new OpToOpPassAdaptor(std::move(nested));
addPass(std::unique_ptr<Pass>(adaptor));
return adaptor->getPassManagers().front();
@ -152,7 +153,7 @@ void OpPassManagerImpl::addPass(std::unique_ptr<Pass> pass) {
// If this pass runs on a different operation than this pass manager, then
// implicitly nest a pass manager for this operation.
auto passOpName = pass->getOpName();
if (passOpName && passOpName != name.getStringRef())
if (passOpName && passOpName != name.strref())
return nest(*passOpName).addPass(std::move(pass));
passes.emplace_back(std::move(pass));
@ -239,19 +240,14 @@ void OpPassManagerImpl::splitAdaptorPasses() {
// OpPassManager
//===----------------------------------------------------------------------===//
OpPassManager::OpPassManager(OperationName name, bool verifyPasses)
: impl(new OpPassManagerImpl(name, verifyPasses)) {
assert(name.getAbstractOperation() &&
"OpPassManager can only operate on registered operations");
assert(name.getAbstractOperation()->hasProperty(
OperationProperty::IsolatedFromAbove) &&
"OpPassManager only supports operating on operations marked as "
"'IsolatedFromAbove'");
}
OpPassManager::OpPassManager(Identifier name, MLIRContext *context,
bool verifyPasses)
: impl(new OpPassManagerImpl(name, context, verifyPasses)) {}
OpPassManager::OpPassManager(OpPassManager &&rhs) : impl(std::move(rhs.impl)) {}
OpPassManager::OpPassManager(const OpPassManager &rhs) { *this = rhs; }
OpPassManager &OpPassManager::operator=(const OpPassManager &rhs) {
impl.reset(new OpPassManagerImpl(rhs.impl->name, rhs.impl->verifyPasses));
impl.reset(new OpPassManagerImpl(rhs.impl->name, rhs.impl->getContext(),
rhs.impl->verifyPasses));
for (auto &pass : rhs.impl->passes)
impl->passes.emplace_back(pass->clone());
return *this;
@ -275,7 +271,7 @@ OpPassManager::const_pass_iterator OpPassManager::end() const {
/// Nest a new operation pass manager for the given operation kind under this
/// pass manager.
OpPassManager &OpPassManager::nest(const OperationName &nestedName) {
OpPassManager &OpPassManager::nest(Identifier nestedName) {
return impl->nest(nestedName);
}
OpPassManager &OpPassManager::nest(StringRef nestedName) {
@ -298,7 +294,7 @@ OpPassManagerImpl &OpPassManager::getImpl() { return *impl; }
MLIRContext *OpPassManager::getContext() const { return impl->getContext(); }
/// Return the operation name that this pass manager operates on.
const OperationName &OpPassManager::getOpName() const { return impl->name; }
Identifier OpPassManager::getOpName() const { return impl->name; }
/// Prints out the given passes as the textual representation of a pipeline.
static void printAsTextualPipeline(ArrayRef<std::unique_ptr<Pass>> passes,
@ -336,6 +332,14 @@ void OpPassManager::getDependentDialects(DialectRegistry &dialects) const {
LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
AnalysisManager am) {
if (!op->getName().getAbstractOperation())
return op->emitOpError()
<< "trying to schedule a pass on an unregistered operation";
if (!op->getName().getAbstractOperation()->hasProperty(
OperationProperty::IsolatedFromAbove))
return op->emitOpError() << "trying to schedule a pass on an operation not "
"marked as 'IsolatedFromAbove'";
pass->passState.emplace(op, am);
// Instrument before the pass has run.
@ -385,7 +389,7 @@ LogicalResult OpToOpPassAdaptor::runPipeline(
/// Find an operation pass manager that can operate on an operation of the given
/// type, or nullptr if one does not exist.
static OpPassManager *findPassManagerFor(MutableArrayRef<OpPassManager> mgrs,
const OperationName &name) {
Identifier name) {
auto it = llvm::find_if(
mgrs, [&](OpPassManager &mgr) { return mgr.getOpName() == name; });
return it == mgrs.end() ? nullptr : &*it;
@ -417,8 +421,8 @@ void OpToOpPassAdaptor::mergeInto(OpToOpPassAdaptor &rhs) {
// After coalescing, sort the pass managers within rhs by name.
llvm::array_pod_sort(rhs.mgrs.begin(), rhs.mgrs.end(),
[](const OpPassManager *lhs, const OpPassManager *rhs) {
return lhs->getOpName().getStringRef().compare(
rhs->getOpName().getStringRef());
return lhs->getOpName().strref().compare(
rhs->getOpName().strref());
});
}
@ -450,7 +454,7 @@ void OpToOpPassAdaptor::runOnOperationImpl() {
for (auto &region : getOperation()->getRegions()) {
for (auto &block : region) {
for (auto &op : block) {
auto *mgr = findPassManagerFor(mgrs, op.getName());
auto *mgr = findPassManagerFor(mgrs, op.getName().getIdentifier());
if (!mgr)
continue;
@ -494,8 +498,8 @@ void OpToOpPassAdaptor::runOnOperationAsyncImpl() {
for (auto &region : getOperation()->getRegions()) {
for (auto &block : region) {
for (auto &op : block) {
// Add this operation iff the name matches the any of the pass managers.
if (findPassManagerFor(mgrs, op.getName()))
// Add this operation iff the name matches any of the pass managers.
if (findPassManagerFor(mgrs, op.getName().getIdentifier()))
opAMPairs.emplace_back(&op, am.nest(&op));
}
}
@ -531,7 +535,8 @@ void OpToOpPassAdaptor::runOnOperationAsyncImpl() {
// Get the pass manager for this operation and execute it.
auto &it = opAMPairs[nextID];
auto *pm = findPassManagerFor(pms, it.first->getName());
auto *pm =
findPassManagerFor(pms, it.first->getName().getIdentifier());
assert(pm && "expected valid pass manager for operation");
if (instrumentor)
@ -732,7 +737,7 @@ PassManager::runWithCrashRecovery(MutableArrayRef<std::unique_ptr<Pass>> passes,
//===----------------------------------------------------------------------===//
PassManager::PassManager(MLIRContext *ctx, bool verifyPasses)
: OpPassManager(OperationName(ModuleOp::getOperationName(), ctx),
: OpPassManager(Identifier::get(ModuleOp::getOperationName(), ctx), ctx,
verifyPasses),
passTiming(false), localReproducer(false) {}
@ -870,7 +875,7 @@ PassInstrumentor::~PassInstrumentor() {}
/// See PassInstrumentation::runBeforePipeline for details.
void PassInstrumentor::runBeforePipeline(
const OperationName &name,
Identifier name,
const PassInstrumentation::PipelineParentInfo &parentInfo) {
llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
for (auto &instr : impl->instrumentations)
@ -879,7 +884,7 @@ void PassInstrumentor::runBeforePipeline(
/// See PassInstrumentation::runAfterPipeline for details.
void PassInstrumentor::runAfterPipeline(
const OperationName &name,
Identifier name,
const PassInstrumentation::PipelineParentInfo &parentInfo) {
llvm::sys::SmartScopedLock<true> instrumentationLock(impl->mutex);
for (auto &instr : llvm::reverse(impl->instrumentations))

View file

@ -116,7 +116,7 @@ static void printResultsAsPipeline(raw_ostream &os, OpPassManager &pm) {
// Print each of the children passes.
for (OpPassManager &mgr : mgrs) {
auto name = ("'" + mgr.getOpName().getStringRef() + "' Pipeline").str();
auto name = ("'" + mgr.getOpName().strref() + "' Pipeline").str();
printPassEntry(os, indent, name);
for (Pass &pass : mgr.getPasses())
printPass(indent + 2, &pass);

View file

@ -165,9 +165,9 @@ struct PassTiming : public PassInstrumentation {
~PassTiming() override { print(); }
/// Setup the instrumentation hooks.
void runBeforePipeline(const OperationName &name,
void runBeforePipeline(Identifier name,
const PipelineParentInfo &parentInfo) override;
void runAfterPipeline(const OperationName &name,
void runAfterPipeline(Identifier name,
const PipelineParentInfo &parentInfo) override;
void runBeforePass(Pass *pass, Operation *) override { startPassTimer(pass); }
void runAfterPass(Pass *pass, Operation *) override;
@ -242,15 +242,15 @@ struct PassTiming : public PassInstrumentation {
};
} // end anonymous namespace
void PassTiming::runBeforePipeline(const OperationName &name,
void PassTiming::runBeforePipeline(Identifier name,
const PipelineParentInfo &parentInfo) {
// We don't actually want to time the pipelines, they gather their total
// from their held passes.
getTimer(name.getAsOpaquePointer(), TimerKind::Pipeline,
[&] { return ("'" + name.getStringRef() + "' Pipeline").str(); });
[&] { return ("'" + name.strref() + "' Pipeline").str(); });
}
void PassTiming::runAfterPipeline(const OperationName &name,
void PassTiming::runAfterPipeline(Identifier name,
const PipelineParentInfo &parentInfo) {
// Pop the timer for the pipeline.
auto tid = llvm::get_threadid();

View file

@ -74,4 +74,47 @@ TEST(PassManagerTest, OpSpecificAnalysis) {
}
}
namespace {
struct InvalidPass : Pass {
InvalidPass() : Pass(TypeID::get<InvalidPass>(), StringRef("invalid_op")) {}
StringRef getName() const override { return "Invalid Pass"; }
void runOnOperation() override {}
/// A clone method to create a copy of this pass.
std::unique_ptr<Pass> clonePass() const override {
return std::make_unique<InvalidPass>(
*static_cast<const InvalidPass *>(this));
}
};
} // anonymous namespace
TEST(PassManagerTest, InvalidPass) {
MLIRContext context;
// Create a module
OwningModuleRef module(ModuleOp::create(UnknownLoc::get(&context)));
// Add a single "invalid_op" operation
OpBuilder builder(&module->getBodyRegion());
OperationState state(UnknownLoc::get(&context), "invalid_op");
builder.insert(Operation::create(state));
// Register a diagnostic handler to capture the diagnostic so that we can
// check it later.
std::unique_ptr<Diagnostic> diagnostic;
context.getDiagEngine().registerHandler([&](Diagnostic &diag) {
diagnostic.reset(new Diagnostic(std::move(diag)));
});
// Instantiate and run our pass.
PassManager pm(&context);
pm.addPass(std::make_unique<InvalidPass>());
LogicalResult result = pm.run(module.get());
EXPECT_TRUE(failed(result));
ASSERT_TRUE(diagnostic.get() != nullptr);
EXPECT_EQ(
diagnostic->str(),
"'invalid_op' op trying to schedule a pass on an unregistered operation");
}
} // end namespace