llvm/mlir/tools/mlir-tblgen/PassGen.cpp
River Riddle 5e50dd048e [mlir] Rework the implementation of TypeID
This commit restructures how TypeID is implemented to ideally avoid
the current problems related to shared libraries. This is done by changing
the "implicit" fallback path to use the name of the type, instead of using
a static template variable (which breaks shared libraries). The major downside to this
is that it adds some additional initialization costs for the implicit path. Given the
use of type names for uniqueness in the fallback, we also no longer allow types
defined in anonymous namespaces to have an implicit TypeID. To simplify defining
an ID for these classes, a new `MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID` macro
was added to allow for explicitly defining a TypeID directly on an internal class.

To help identify when types are using the fallback, `-debug-only=typeid` can be
used to log which types are using implicit ids.

This change generally only requires changes to the test passes, which are all defined
in anonymous namespaces, and thus can't use the fallback any longer.

Differential Revision: https://reviews.llvm.org/D122775
2022-04-04 13:52:26 -07:00

218 lines
8 KiB
C++

//===- Pass.cpp - MLIR pass registration generator ------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// PassGen uses the description of passes to generate base classes for passes
// and command line registration.
//
//===----------------------------------------------------------------------===//
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Pass.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
using namespace mlir;
using namespace mlir::tblgen;
static llvm::cl::OptionCategory passGenCat("Options for -gen-pass-decls");
static llvm::cl::opt<std::string>
groupName("name", llvm::cl::desc("The name of this group of passes"),
llvm::cl::cat(passGenCat));
//===----------------------------------------------------------------------===//
// GEN: Pass base class generation
//===----------------------------------------------------------------------===//
/// The code snippet used to generate the start of a pass base class.
///
/// {0}: The def name of the pass record.
/// {1}: The base class for the pass.
/// {2): The command line argument for the pass.
/// {3}: The dependent dialects registration.
const char *const passDeclBegin = R"(
//===----------------------------------------------------------------------===//
// {0}
//===----------------------------------------------------------------------===//
template <typename DerivedT>
class {0}Base : public {1} {
public:
using Base = {0}Base;
{0}Base() : {1}(::mlir::TypeID::get<DerivedT>()) {{}
{0}Base(const {0}Base &other) : {1}(other) {{}
/// Returns the command-line argument attached to this pass.
static constexpr ::llvm::StringLiteral getArgumentName() {
return ::llvm::StringLiteral("{2}");
}
::llvm::StringRef getArgument() const override { return "{2}"; }
::llvm::StringRef getDescription() const override { return "{3}"; }
/// Returns the derived pass name.
static constexpr ::llvm::StringLiteral getPassName() {
return ::llvm::StringLiteral("{0}");
}
::llvm::StringRef getName() const override { return "{0}"; }
/// Support isa/dyn_cast functionality for the derived pass class.
static bool classof(const ::mlir::Pass *pass) {{
return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
}
/// A clone method to create a copy of this pass.
std::unique_ptr<::mlir::Pass> clonePass() const override {{
return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
}
/// Return the dialect that must be loaded in the context before this pass.
void getDependentDialects(::mlir::DialectRegistry &registry) const override {
{4}
}
/// Explicitly declare the TypeID for this class. We declare an explicit private
/// instantiation because Pass classes should only be visible by the current
/// library.
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID({0}Base<DerivedT>)
protected:
)";
/// Registration for a single dependent dialect, to be inserted for each
/// dependent dialect in the `getDependentDialects` above.
const char *const dialectRegistrationTemplate = R"(
registry.insert<{0}>();
)";
/// Emit the declarations for each of the pass options.
static void emitPassOptionDecls(const Pass &pass, raw_ostream &os) {
for (const PassOption &opt : pass.getOptions()) {
os.indent(2) << "::mlir::Pass::"
<< (opt.isListOption() ? "ListOption" : "Option");
os << llvm::formatv(R"(<{0}> {1}{{*this, "{2}", ::llvm::cl::desc("{3}"))",
opt.getType(), opt.getCppVariableName(),
opt.getArgument(), opt.getDescription());
if (Optional<StringRef> defaultVal = opt.getDefaultValue())
os << ", ::llvm::cl::init(" << defaultVal << ")";
if (Optional<StringRef> additionalFlags = opt.getAdditionalFlags())
os << ", " << *additionalFlags;
os << "};\n";
}
}
/// Emit the declarations for each of the pass statistics.
static void emitPassStatisticDecls(const Pass &pass, raw_ostream &os) {
for (const PassStatistic &stat : pass.getStatistics()) {
os << llvm::formatv(
" ::mlir::Pass::Statistic {0}{{this, \"{1}\", \"{2}\"};\n",
stat.getCppVariableName(), stat.getName(), stat.getDescription());
}
}
static void emitPassDecl(const Pass &pass, raw_ostream &os) {
StringRef defName = pass.getDef()->getName();
std::string dependentDialectRegistrations;
{
llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations);
for (StringRef dependentDialect : pass.getDependentDialects())
dialectsOs << llvm::formatv(dialectRegistrationTemplate,
dependentDialect);
}
os << llvm::formatv(passDeclBegin, defName, pass.getBaseClass(),
pass.getArgument(), pass.getSummary(),
dependentDialectRegistrations);
emitPassOptionDecls(pass, os);
emitPassStatisticDecls(pass, os);
os << "};\n";
}
/// Emit the code for registering each of the given passes with the global
/// PassRegistry.
static void emitPassDecls(ArrayRef<Pass> passes, raw_ostream &os) {
os << "#ifdef GEN_PASS_CLASSES\n";
for (const Pass &pass : passes)
emitPassDecl(pass, os);
os << "#undef GEN_PASS_CLASSES\n";
os << "#endif // GEN_PASS_CLASSES\n";
}
//===----------------------------------------------------------------------===//
// GEN: Pass registration generation
//===----------------------------------------------------------------------===//
/// The code snippet used to generate a pass registration.
///
/// {0}: The def name of the pass record.
/// {1}: The pass constructor call.
const char *const passRegistrationCode = R"(
//===----------------------------------------------------------------------===//
// {0} Registration
//===----------------------------------------------------------------------===//
inline void register{0}Pass() {{
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {{
return {1};
});
}
)";
/// The code snippet used to generate a function to register all passes in a
/// group.
///
/// {0}: The name of the pass group.
const char *const passGroupRegistrationCode = R"(
//===----------------------------------------------------------------------===//
// {0} Registration
//===----------------------------------------------------------------------===//
inline void register{0}Passes() {{
)";
/// Emit the code for registering each of the given passes with the global
/// PassRegistry.
static void emitRegistration(ArrayRef<Pass> passes, raw_ostream &os) {
os << "#ifdef GEN_PASS_REGISTRATION\n";
for (const Pass &pass : passes) {
os << llvm::formatv(passRegistrationCode, pass.getDef()->getName(),
pass.getConstructor());
}
os << llvm::formatv(passGroupRegistrationCode, groupName);
for (const Pass &pass : passes)
os << " register" << pass.getDef()->getName() << "Pass();\n";
os << "}\n";
os << "#undef GEN_PASS_REGISTRATION\n";
os << "#endif // GEN_PASS_REGISTRATION\n";
}
//===----------------------------------------------------------------------===//
// GEN: Registration hooks
//===----------------------------------------------------------------------===//
static void emitDecls(const llvm::RecordKeeper &recordKeeper, raw_ostream &os) {
os << "/* Autogenerated by mlir-tblgen; don't manually edit */\n";
std::vector<Pass> passes;
for (const auto *def : recordKeeper.getAllDerivedDefinitions("PassBase"))
passes.emplace_back(def);
emitPassDecls(passes, os);
emitRegistration(passes, os);
}
static mlir::GenRegistration
genRegister("gen-pass-decls", "Generate pass declarations",
[](const llvm::RecordKeeper &records, raw_ostream &os) {
emitDecls(records, os);
return false;
});