[spirv] Add support for BitEnumAttr

Certain enum classes in SPIR-V, like function/loop control and memory
access, are bitmasks. This CL introduces a BitEnumAttr to properly
model this and drive auto-generation of verification code and utility
functions. We still store the attribute using an 32-bit IntegerAttr
for minimal memory footprint and easy (de)serialization. But utility
conversion functions are adjusted to inspect each bit and generate
"|"-concatenated strings for the bits; vice versa.

Each such enum class has a "None" case that means no bit is set. We
need special handling for "None". Because of this, the logic is not
general anymore. So right now the definition is placed in the SPIR-V
dialect. If later this turns out to be useful for other dialects,
then we can see how to properly adjust it and move to OpBase.td.

Added tests for SPV_MemoryAccess to check and demonstrate.

PiperOrigin-RevId: 269350620
This commit is contained in:
Lei Zhang 2019-09-16 09:22:43 -07:00 committed by A. Unique TensorFlower
parent cb3ecb5291
commit 6934a337f0
14 changed files with 444 additions and 54 deletions

View file

@ -9,9 +9,14 @@ mlir_tablegen(SPIRVGLSLOps.cpp.inc -gen-op-defs)
add_public_tablegen_target(MLIRSPIRVGLSLOpsIncGen)
set(LLVM_TARGET_DEFINITIONS SPIRVBase.td)
mlir_tablegen(SPIRVEnums.h.inc -gen-enum-decls)
mlir_tablegen(SPIRVEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(MLIRSPIRVEnumsIncGen)
mlir_tablegen(SPIRVIntEnums.h.inc -gen-enum-decls)
mlir_tablegen(SPIRVIntEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(MLIRSPIRVIntEnumsIncGen)
set(LLVM_TARGET_DEFINITIONS SPIRVBase.td)
mlir_tablegen(SPIRVBitEnums.h.inc -gen-spirv-enum-decls)
mlir_tablegen(SPIRVBitEnums.cpp.inc -gen-spirv-enum-defs)
add_public_tablegen_target(MLIRSPIRVBitEnumsIncGen)
set(LLVM_TARGET_DEFINITIONS SPIRVOps.td)
mlir_tablegen(SPIRVSerialization.inc -gen-spirv-serialization)

View file

@ -185,7 +185,6 @@ def SPV_OpcodeAttr :
// End opcode section. Generated from SPIR-V spec; DO NOT MODIFY!
//===----------------------------------------------------------------------===//
// SPIR-V type definitions
//===----------------------------------------------------------------------===//
@ -231,6 +230,49 @@ class SPV_Optional<Type type> : Variadic<type>;
// TODO(ravishankarm): From 1.4, this should also include Composite type.
def SPV_SelectType : AnyTypeOf<[SPV_Scalar, SPV_Vector, SPV_AnyPtr]>;
//===----------------------------------------------------------------------===//
// SPIR-V BitEnum definition
//===----------------------------------------------------------------------===//
// A bit enum case stored with 32-bit IntegerAttr. `val` here is *not* the
// ordinal number of the bit that is set. It is the 32-bit integer with only
// one bit set.
class BitEnumAttrCase<string sym, int val> :
EnumAttrCaseInfo<sym, val>,
IntegerAttrBase<I32, "case " # sym> {
let predicate = CPred<
"$_self.cast<IntegerAttr>().getValue().getSExtValue() & " # val # "u">;
}
// A bit enum stored with 32-bit IntegerAttr.
//
// Op attributes of this kind are stored as IntegerAttr. Extra verification will
// be generated on the integer to make sure only allowed bit are set.
class BitEnumAttr<string name, string description,
list<BitEnumAttrCase> cases> :
EnumAttrInfo<name, cases>, IntegerAttrBase<I32, description> {
let predicate = And<[
IntegerAttrBase<I32, "">.predicate,
// Make sure we don't have unknown bit set.
CPred<"!($_self.cast<IntegerAttr>().getValue().getZExtValue() & (~(" #
StrJoin<!foreach(case, cases, case.value # "u"), "|">.result #
")))">
]>;
let underlyingType = "uint32_t";
// We need to return a string because we may concatenate symbols for multiple
// bits together.
let symbolToStringFnRetType = "std::string";
// The string used to separate bit enum cases in strings.
string separator = "|";
// Turn off the autogen with EnumsGen. SPIR-V needs custom logic here and
// we will use our own autogen logic.
let skipAutoGen = 1;
}
//===----------------------------------------------------------------------===//
// SPIR-V extension definitions
//===----------------------------------------------------------------------===//
@ -847,14 +889,14 @@ def SPV_ExecutionModelAttr :
let cppNamespace = "::mlir::spirv";
}
def SPV_FC_None : I32EnumAttrCase<"None", 0x0000>;
def SPV_FC_Inline : I32EnumAttrCase<"Inline", 0x0001>;
def SPV_FC_DontInline : I32EnumAttrCase<"DontInline", 0x0002>;
def SPV_FC_Pure : I32EnumAttrCase<"Pure", 0x0004>;
def SPV_FC_Const : I32EnumAttrCase<"Const", 0x0008>;
def SPV_FC_None : BitEnumAttrCase<"None", 0x0000>;
def SPV_FC_Inline : BitEnumAttrCase<"Inline", 0x0001>;
def SPV_FC_DontInline : BitEnumAttrCase<"DontInline", 0x0002>;
def SPV_FC_Pure : BitEnumAttrCase<"Pure", 0x0004>;
def SPV_FC_Const : BitEnumAttrCase<"Const", 0x0008>;
def SPV_FunctionControlAttr :
I32EnumAttr<"FunctionControl", "valid SPIR-V FunctionControl", [
BitEnumAttr<"FunctionControl", "valid SPIR-V FunctionControl", [
SPV_FC_None, SPV_FC_Inline, SPV_FC_DontInline, SPV_FC_Pure, SPV_FC_Const
]> {
let returnType = "::mlir::spirv::FunctionControl";
@ -932,19 +974,19 @@ def SPV_LinkageTypeAttr :
let cppNamespace = "::mlir::spirv";
}
def SPV_LC_None : I32EnumAttrCase<"None", 0x0000>;
def SPV_LC_Unroll : I32EnumAttrCase<"Unroll", 0x0001>;
def SPV_LC_DontUnroll : I32EnumAttrCase<"DontUnroll", 0x0002>;
def SPV_LC_DependencyInfinite : I32EnumAttrCase<"DependencyInfinite", 0x0004>;
def SPV_LC_DependencyLength : I32EnumAttrCase<"DependencyLength", 0x0008>;
def SPV_LC_MinIterations : I32EnumAttrCase<"MinIterations", 0x0010>;
def SPV_LC_MaxIterations : I32EnumAttrCase<"MaxIterations", 0x0020>;
def SPV_LC_IterationMultiple : I32EnumAttrCase<"IterationMultiple", 0x0040>;
def SPV_LC_PeelCount : I32EnumAttrCase<"PeelCount", 0x0080>;
def SPV_LC_PartialCount : I32EnumAttrCase<"PartialCount", 0x0100>;
def SPV_LC_None : BitEnumAttrCase<"None", 0x0000>;
def SPV_LC_Unroll : BitEnumAttrCase<"Unroll", 0x0001>;
def SPV_LC_DontUnroll : BitEnumAttrCase<"DontUnroll", 0x0002>;
def SPV_LC_DependencyInfinite : BitEnumAttrCase<"DependencyInfinite", 0x0004>;
def SPV_LC_DependencyLength : BitEnumAttrCase<"DependencyLength", 0x0008>;
def SPV_LC_MinIterations : BitEnumAttrCase<"MinIterations", 0x0010>;
def SPV_LC_MaxIterations : BitEnumAttrCase<"MaxIterations", 0x0020>;
def SPV_LC_IterationMultiple : BitEnumAttrCase<"IterationMultiple", 0x0040>;
def SPV_LC_PeelCount : BitEnumAttrCase<"PeelCount", 0x0080>;
def SPV_LC_PartialCount : BitEnumAttrCase<"PartialCount", 0x0100>;
def SPV_LoopControlAttr :
I32EnumAttr<"LoopControl", "valid SPIR-V LoopControl", [
BitEnumAttr<"LoopControl", "valid SPIR-V LoopControl", [
SPV_LC_None, SPV_LC_Unroll, SPV_LC_DontUnroll, SPV_LC_DependencyInfinite,
SPV_LC_DependencyLength, SPV_LC_MinIterations, SPV_LC_MaxIterations,
SPV_LC_IterationMultiple, SPV_LC_PeelCount, SPV_LC_PartialCount
@ -954,16 +996,16 @@ def SPV_LoopControlAttr :
let cppNamespace = "::mlir::spirv";
}
def SPV_MA_None : I32EnumAttrCase<"None", 0x0000>;
def SPV_MA_Volatile : I32EnumAttrCase<"Volatile", 0x0001>;
def SPV_MA_Aligned : I32EnumAttrCase<"Aligned", 0x0002>;
def SPV_MA_Nontemporal : I32EnumAttrCase<"Nontemporal", 0x0004>;
def SPV_MA_MakePointerAvailable : I32EnumAttrCase<"MakePointerAvailable", 0x0008>;
def SPV_MA_MakePointerVisible : I32EnumAttrCase<"MakePointerVisible", 0x0010>;
def SPV_MA_NonPrivatePointer : I32EnumAttrCase<"NonPrivatePointer", 0x0020>;
def SPV_MA_None : BitEnumAttrCase<"None", 0x0000>;
def SPV_MA_Volatile : BitEnumAttrCase<"Volatile", 0x0001>;
def SPV_MA_Aligned : BitEnumAttrCase<"Aligned", 0x0002>;
def SPV_MA_Nontemporal : BitEnumAttrCase<"Nontemporal", 0x0004>;
def SPV_MA_MakePointerAvailable : BitEnumAttrCase<"MakePointerAvailable", 0x0008>;
def SPV_MA_MakePointerVisible : BitEnumAttrCase<"MakePointerVisible", 0x0010>;
def SPV_MA_NonPrivatePointer : BitEnumAttrCase<"NonPrivatePointer", 0x0020>;
def SPV_MemoryAccessAttr :
I32EnumAttr<"MemoryAccess", "valid SPIR-V MemoryAccess", [
BitEnumAttr<"MemoryAccess", "valid SPIR-V MemoryAccess", [
SPV_MA_None, SPV_MA_Volatile, SPV_MA_Aligned, SPV_MA_Nontemporal,
SPV_MA_MakePointerAvailable, SPV_MA_MakePointerVisible,
SPV_MA_NonPrivatePointer

View file

@ -27,7 +27,8 @@
#include "mlir/IR/Types.h"
// Pull in all enum type definitions and utility function declarations
#include "mlir/Dialect/SPIRV/SPIRVEnums.h.inc"
#include "mlir/Dialect/SPIRV/SPIRVBitEnums.h.inc"
#include "mlir/Dialect/SPIRV/SPIRVIntEnums.h.inc"
#include <tuple>

View file

@ -787,6 +787,10 @@ class EnumAttrInfo<string name, list<EnumAttrCaseInfo> cases> {
// List of all accepted cases
list<EnumAttrCaseInfo> enumerants = cases;
// Whether to skip automatically generating C++ enum class and utility
// functions for this enum attribute with EnumsGen.
bit skipAutoGen = 0;
// The following fields are only used by the EnumsGen backend to generate
// an enum class definition and conversion utility functions.
@ -824,9 +828,10 @@ class EnumAttrInfo<string name, list<EnumAttrCaseInfo> cases> {
// corresponding string. It will have the following signature:
//
// ```c++
// llvm::StringRef <fn-name>(<qualified-enum-class-name>);
// <return-type> <fn-name>(<qualified-enum-class-name>);
// ```
string symbolToStringFnName = "stringify" # name;
string symbolToStringFnRetType = "llvm::StringRef";
// The name of the utility function that returns the max enum value used
// within the enum class. It will have the following signature:

View file

@ -151,6 +151,9 @@ public:
explicit EnumAttr(const llvm::Record &record);
explicit EnumAttr(const llvm::DefInit *init);
// Returns whether skipping auto-generation is requested.
bool skipAutoGen() const;
// Returns the enum class name.
StringRef getEnumClassName() const;
@ -172,6 +175,10 @@ public:
// corresponding string.
StringRef getSymbolToStringFnName() const;
// Returns the return type of the utility function that converts a symbol to
// the corresponding string.
StringRef getSymbolToStringFnRetType() const;
// Returns the name of the utilit function that returns the max enum value
// used within the enum class.
StringRef getMaxEnumValFnName() const;

View file

@ -11,7 +11,8 @@ add_llvm_library(MLIRSPIRV
add_dependencies(MLIRSPIRV
MLIRSPIRVOpsIncGen
MLIRSPIRVEnumsIncGen
MLIRSPIRVIntEnumsIncGen
MLIRSPIRVBitEnumsIncGen
MLIRSPIRVOpUtilsGen)
target_link_libraries(MLIRSPIRV

View file

@ -141,7 +141,7 @@ static ParseResult parseMemoryAccessAttributes(OpAsmParser *parser,
return failure();
}
if (memoryAccessAttr == spirv::MemoryAccess::Aligned) {
if (spirv::bitEnumContains(memoryAccessAttr, spirv::MemoryAccess::Aligned)) {
// Parse integer attribute for alignment.
Attribute alignmentAttr;
Type i32Type = parser->getBuilder().getIntegerType(32);
@ -212,7 +212,7 @@ static LogicalResult verifyMemoryAccessAttribute(LoadStoreOpTy loadStoreOp) {
<< memAccessVal;
}
if (*memAccess == spirv::MemoryAccess::Aligned) {
if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
if (!op->getAttr(kAlignmentAttrName)) {
return loadStoreOp.emitOpError("missing alignment value");
}

View file

@ -21,13 +21,15 @@
#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
#include "mlir/IR/StandardTypes.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSwitch.h"
using namespace mlir;
using namespace mlir::spirv;
// Pull in all enum utility function definitions
#include "mlir/Dialect/SPIRV/SPIRVEnums.cpp.inc"
#include "mlir/Dialect/SPIRV/SPIRVBitEnums.cpp.inc"
#include "mlir/Dialect/SPIRV/SPIRVIntEnums.cpp.inc"
//===----------------------------------------------------------------------===//
// ArrayType

View file

@ -170,6 +170,10 @@ tblgen::EnumAttr::EnumAttr(const llvm::Record &record) : Attribute(&record) {}
tblgen::EnumAttr::EnumAttr(const llvm::DefInit *init)
: EnumAttr(init->getDef()) {}
bool tblgen::EnumAttr::skipAutoGen() const {
return def->getValueAsBit("skipAutoGen");
}
StringRef tblgen::EnumAttr::getEnumClassName() const {
return def->getValueAsString("className");
}
@ -194,6 +198,10 @@ StringRef tblgen::EnumAttr::getSymbolToStringFnName() const {
return def->getValueAsString("symbolToStringFnName");
}
StringRef tblgen::EnumAttr::getSymbolToStringFnRetType() const {
return def->getValueAsString("symbolToStringFnRetType");
}
StringRef tblgen::EnumAttr::getMaxEnumValFnName() const {
return def->getValueAsString("maxEnumValFnName");
}

View file

@ -301,30 +301,84 @@ spv.module "Logical" "GLSL450" {
// spv.LoadOp
//===----------------------------------------------------------------------===//
// CHECK_LABEL: @simple_load
// CHECK-LABEL: @simple_load
func @simple_load() -> () {
%0 = spv.Variable : !spv.ptr<f32, Function>
// CHECK: spv.Load "Function" %0 : f32
// CHECK: spv.Load "Function" %{{.*}} : f32
%1 = spv.Load "Function" %0 : f32
return
}
// CHECK_LABEL: @volatile_load
// CHECK-LABEL: @load_none_access
func @load_none_access() -> () {
%0 = spv.Variable : !spv.ptr<f32, Function>
// CHECK: spv.Load "Function" %{{.*}} ["None"] : f32
%1 = spv.Load "Function" %0 ["None"] : f32
return
}
// CHECK-LABEL: @volatile_load
func @volatile_load() -> () {
%0 = spv.Variable : !spv.ptr<f32, Function>
// CHECK: spv.Load "Function" %0 ["Volatile"] : f32
// CHECK: spv.Load "Function" %{{.*}} ["Volatile"] : f32
%1 = spv.Load "Function" %0 ["Volatile"] : f32
return
}
// CHECK_LABEL: @aligned_load
// CHECK-LABEL: @aligned_load
func @aligned_load() -> () {
%0 = spv.Variable : !spv.ptr<f32, Function>
// CHECK: spv.Load "Function" %0 ["Aligned", 4] : f32
// CHECK: spv.Load "Function" %{{.*}} ["Aligned", 4] : f32
%1 = spv.Load "Function" %0 ["Aligned", 4] : f32
return
}
// CHECK-LABEL: @volatile_aligned_load
func @volatile_aligned_load() -> () {
%0 = spv.Variable : !spv.ptr<f32, Function>
// CHECK: spv.Load "Function" %{{.*}} ["Volatile|Aligned", 4] : f32
%1 = spv.Load "Function" %0 ["Volatile|Aligned", 4] : f32
return
}
// -----
// CHECK-LABEL: load_none_access
func @load_none_access() -> () {
%0 = spv.Variable : !spv.ptr<f32, Function>
// CHECK: spv.Load
// CHECK-SAME: ["None"]
%1 = "spv.Load"(%0) {memory_access = 0 : i32} : (!spv.ptr<f32, Function>) -> (f32)
return
}
// CHECK-LABEL: volatile_load
func @volatile_load() -> () {
%0 = spv.Variable : !spv.ptr<f32, Function>
// CHECK: spv.Load
// CHECK-SAME: ["Volatile"]
%1 = "spv.Load"(%0) {memory_access = 1 : i32} : (!spv.ptr<f32, Function>) -> (f32)
return
}
// CHECK-LABEL: aligned_load
func @aligned_load() -> () {
%0 = spv.Variable : !spv.ptr<f32, Function>
// CHECK: spv.Load
// CHECK-SAME: ["Aligned", 4]
%1 = "spv.Load"(%0) {memory_access = 2 : i32, alignment = 4 : i32} : (!spv.ptr<f32, Function>) -> (f32)
return
}
// CHECK-LABEL: volatile_aligned_load
func @volatile_aligned_load() -> () {
%0 = spv.Variable : !spv.ptr<f32, Function>
// CHECK: spv.Load
// CHECK-SAME: ["Volatile|Aligned", 4]
%1 = "spv.Load"(%0) {memory_access = 3 : i32, alignment = 4 : i32} : (!spv.ptr<f32, Function>) -> (f32)
return
}
// -----
func @simple_load_missing_storageclass() -> () {
@ -408,6 +462,24 @@ func @load_unknown_memory_access() -> () {
// -----
func @load_unknown_memory_access() -> () {
%0 = spv.Variable : !spv.ptr<f32, Function>
// expected-error @+1 {{custom op 'spv.Load' invalid memory_access attribute specification: "Volatile|Something"}}
%1 = spv.Load "Function" %0 ["Volatile|Something"] : f32
return
}
// -----
func @load_unknown_memory_access() -> () {
%0 = spv.Variable : !spv.ptr<f32, Function>
// expected-error @+1 {{failed to satisfy constraint: valid SPIR-V MemoryAccess}}
%1 = "spv.Load"(%0) {memory_access = 0x80000000 : i32} : (!spv.ptr<f32, Function>) -> (f32)
return
}
// -----
func @aligned_load_incorrect_attributes() -> () {
%0 = spv.Variable : !spv.ptr<f32, Function>
// expected-error @+1 {{expected ']'}}

View file

@ -19,6 +19,7 @@
//
//===----------------------------------------------------------------------===//
#include "EnumsGen.h"
#include "mlir/TableGen/Attribute.h"
#include "mlir/TableGen/GenInfo.h"
#include "llvm/ADT/SmallVector.h"
@ -127,9 +128,11 @@ static void emitSymToStrFn(const Record &enumDef, raw_ostream &os) {
EnumAttr enumAttr(enumDef);
StringRef enumName = enumAttr.getEnumClassName();
StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
StringRef symToStrFnRetType = enumAttr.getSymbolToStringFnRetType();
auto enumerants = enumAttr.getAllCases();
os << formatv("llvm::StringRef {1}({0} val) {{\n", enumName, symToStrFnName);
os << formatv("{2} {1}({0} val) {{\n", enumName, symToStrFnName,
symToStrFnRetType);
os << " switch (val) {\n";
for (const auto &enumerant : enumerants) {
auto symbol = enumerant.getSymbol();
@ -190,7 +193,8 @@ static void emitUnderlyingToSymFn(const Record &enumDef, raw_ostream &os) {
<< "}\n\n";
}
static void emitEnumDecl(const Record &enumDef, raw_ostream &os) {
void mlir::tblgen::emitEnumDecl(const Record &enumDef,
ExtraFnEmitter emitExtraFns, raw_ostream &os) {
EnumAttr enumAttr(enumDef);
StringRef enumName = enumAttr.getEnumClassName();
StringRef cppNamespace = enumAttr.getCppNamespace();
@ -198,6 +202,7 @@ static void emitEnumDecl(const Record &enumDef, raw_ostream &os) {
StringRef description = enumAttr.getDescription();
StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
StringRef symToStrFnRetType = enumAttr.getSymbolToStringFnRetType();
StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName();
auto enumerants = enumAttr.getAllCases();
@ -218,11 +223,11 @@ static void emitEnumDecl(const Record &enumDef, raw_ostream &os) {
"llvm::Optional<{0}> {1}({2});\n", enumName, underlyingToSymFnName,
underlyingType.empty() ? std::string("unsigned") : underlyingType);
}
os << formatv("llvm::StringRef {1}({0});\n", enumName, symToStrFnName);
os << formatv("{2} {1}({0});\n", enumName, symToStrFnName, symToStrFnRetType);
os << formatv("llvm::Optional<{0}> {1}(llvm::StringRef);\n", enumName,
strToSymFnName);
emitMaxValueFn(enumDef, os);
emitExtraFns(enumDef, os);
for (auto ns : llvm::reverse(namespaces))
os << "} // namespace " << ns << "\n";
@ -234,9 +239,14 @@ static void emitEnumDecl(const Record &enumDef, raw_ostream &os) {
static bool emitEnumDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
llvm::emitSourceFileHeader("Enum Utility Declarations", os);
auto extraFnEmitter = [](const Record &enumDef, raw_ostream &os) {
emitMaxValueFn(enumDef, os);
};
auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttrInfo");
for (const auto *def : defs)
emitEnumDecl(*def, os);
if (!EnumAttr(def).skipAutoGen())
mlir::tblgen::emitEnumDecl(*def, extraFnEmitter, os);
return false;
}
@ -265,7 +275,8 @@ static bool emitEnumDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttrInfo");
for (const auto *def : defs)
emitEnumDef(*def, os);
if (!EnumAttr(def).skipAutoGen())
emitEnumDef(*def, os);
return false;
}

View file

@ -0,0 +1,48 @@
//===- EnumsGen.h - MLIR enum utility generator -----------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file defines common utilities for enum generator.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TOOLS_MLIR_TBLGEN_ENUMSGEN_H_
#define MLIR_TOOLS_MLIR_TBLGEN_ENUMSGEN_H_
#include "mlir/Support/LLVM.h"
namespace llvm {
class Record;
}
namespace mlir {
namespace tblgen {
using ExtraFnEmitter = llvm::function_ref<void(const llvm::Record &enumDef,
llvm::raw_ostream &os)>;
// Emits declarations for the given EnumAttr `enumDef` into `os`.
//
// This will emit a C++ enum class and string to symbol and symbol to string
// conversion utility declarations. Additional functions can be emitted via
// the `emitExtraFns` function.
void emitEnumDecl(const llvm::Record &enumDef, ExtraFnEmitter emitExtraFns,
llvm::raw_ostream &os);
} // namespace tblgen
} // namespace mlir
#endif // MLIR_TOOLS_MLIR_TBLGEN_ENUMSGEN_H_

View file

@ -20,6 +20,7 @@
//
//===----------------------------------------------------------------------===//
#include "EnumsGen.h"
#include "mlir/Support/StringExtras.h"
#include "mlir/TableGen/Attribute.h"
#include "mlir/TableGen/GenInfo.h"
@ -49,6 +50,10 @@ using mlir::tblgen::NamedAttribute;
using mlir::tblgen::NamedTypeConstraint;
using mlir::tblgen::Operator;
//===----------------------------------------------------------------------===//
// Serialization AutoGen
//===----------------------------------------------------------------------===//
// Writes the following function to `os`:
// inline uint32_t getOpcode(<op-class-name>) { return <opcode>; }
static void emitGetOpcodeFunction(const Record *record, Operator const &op,
@ -397,6 +402,10 @@ static bool emitSerializationFns(const RecordKeeper &recordKeeper,
return false;
}
//===----------------------------------------------------------------------===//
// Op Utils AutoGen
//===----------------------------------------------------------------------===//
static void emitEnumGetAttrNameFnDecl(raw_ostream &os) {
os << formatv("template <typename EnumClass> inline constexpr StringRef "
"attributeName();\n");
@ -435,7 +444,7 @@ static void emitEnumGetSymbolizeFnDefn(const EnumAttr &enumAttr,
static bool emitOpUtils(const RecordKeeper &recordKeeper, raw_ostream &os) {
llvm::emitSourceFileHeader("SPIR-V Op Utilites", os);
auto defs = recordKeeper.getAllDerivedDefinitions("I32EnumAttr");
auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttrInfo");
os << "#ifndef SPIRV_OP_UTILS_H_\n";
os << "#define SPIRV_OP_UTILS_H_\n";
emitEnumGetAttrNameFnDecl(os);
@ -449,7 +458,168 @@ static bool emitOpUtils(const RecordKeeper &recordKeeper, raw_ostream &os) {
return false;
}
// Registers the enum utility generator to mlir-tblgen.
//===----------------------------------------------------------------------===//
// BitEnum AutoGen
//===----------------------------------------------------------------------===//
// Emits the following inline function for bit enums:
// inline <enum-type> operator|(<enum-type> a, <enum-type> b);
// inline <enum-type> bitEnumContains(<enum-type> a, <enum-type> b);
static void emitOperators(const Record &enumDef, raw_ostream &os) {
EnumAttr enumAttr(enumDef);
StringRef enumName = enumAttr.getEnumClassName();
std::string underlyingType = enumAttr.getUnderlyingType();
os << formatv("inline {0} operator|({0} lhs, {0} rhs) {{\n", enumName)
<< formatv(" return static_cast<{0}>("
"static_cast<{1}>(lhs) | static_cast<{1}>(rhs));\n",
enumName, underlyingType)
<< "}\n";
os << formatv(
"inline bool bitEnumContains({0} bits, {0} bit) {{\n"
" return (static_cast<{1}>(bits) & static_cast<{1}>(bit)) != 0;\n",
enumName, underlyingType)
<< "}\n";
}
static bool emitBitEnumDecls(const RecordKeeper &recordKeeper,
raw_ostream &os) {
llvm::emitSourceFileHeader("BitEnum Utility Declarations", os);
auto operatorsEmitter = [](const Record &enumDef, llvm::raw_ostream &os) {
return emitOperators(enumDef, os);
};
auto defs = recordKeeper.getAllDerivedDefinitions("BitEnumAttr");
for (const auto *def : defs)
mlir::tblgen::emitEnumDecl(*def, operatorsEmitter, os);
return false;
}
static void emitSymToStrFnForBitEnum(const Record &enumDef, raw_ostream &os) {
EnumAttr enumAttr(enumDef);
StringRef enumName = enumAttr.getEnumClassName();
StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
StringRef symToStrFnRetType = enumAttr.getSymbolToStringFnRetType();
StringRef separator = enumDef.getValueAsString("separator");
auto enumerants = enumAttr.getAllCases();
os << formatv("{2} {1}({0} symbol) {{\n", enumName, symToStrFnName,
symToStrFnRetType);
os << formatv(" auto val = static_cast<{0}>(symbol);\n",
enumAttr.getUnderlyingType());
os << " // Special case for all bits unset.\n";
os << " if (val == 0) return \"None\";\n\n";
os << " SmallVector<llvm::StringRef, 2> strs;\n";
for (const auto &enumerant : enumerants) {
// Skip the special enumerant for None.
if (auto val = enumerant.getValue())
os << formatv(" if ({0}u & val) {{ strs.push_back(\"{1}\"); "
"val &= ~{0}u; }\n",
val, enumerant.getSymbol());
}
// If we have unknown bit set, return an empty string to signal errors.
os << "\n if (val) return \"\";\n";
os << formatv(" return llvm::join(strs, \"{0}\");\n", separator);
os << "}\n\n";
}
static void emitStrToSymFnForBitEnum(const Record &enumDef, raw_ostream &os) {
EnumAttr enumAttr(enumDef);
StringRef enumName = enumAttr.getEnumClassName();
std::string underlyingType = enumAttr.getUnderlyingType();
StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
StringRef separator = enumDef.getValueAsString("separator");
auto enumerants = enumAttr.getAllCases();
os << formatv("llvm::Optional<{0}> {1}(llvm::StringRef str) {{\n", enumName,
strToSymFnName);
os << formatv(" if (str == \"None\") return {0}::None;\n\n", enumName);
// Split the string to get symbols for all the bits.
os << " SmallVector<llvm::StringRef, 2> symbols;\n";
os << formatv(" str.split(symbols, \"{0}\");\n\n", separator);
os << formatv(" {0} val = 0;\n", underlyingType);
os << " for (auto symbol : symbols) {\n";
// Convert each symbol to the bit ordinal and set the corresponding bit.
os << formatv(
" auto bit = llvm::StringSwitch<llvm::Optional<{0}>>(symbol)\n",
underlyingType);
for (const auto &enumerant : enumerants) {
// Skip the special enumerant for None.
if (auto val = enumerant.getValue())
os.indent(6) << formatv(".Case(\"{0}\", {1})\n", enumerant.getSymbol(),
enumerant.getValue());
}
os.indent(6) << ".Default(llvm::None);\n";
os << " if (bit) { val |= *bit; } else { return llvm::None; }\n";
os << " }\n";
os << formatv(" return static_cast<{0}>(val);\n", enumName);
os << "}\n\n";
}
static void emitUnderlyingToSymFnForBitEnum(const Record &enumDef,
raw_ostream &os) {
EnumAttr enumAttr(enumDef);
StringRef enumName = enumAttr.getEnumClassName();
std::string underlyingType = enumAttr.getUnderlyingType();
StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName();
auto enumerants = enumAttr.getAllCases();
os << formatv("llvm::Optional<{0}> {1}({2} value) {{\n", enumName,
underlyingToSymFnName, underlyingType);
os << formatv(" if (value == 0) return {0}::None;\n", enumName);
llvm::SmallVector<std::string, 8> values;
for (const auto &enumerant : enumerants) {
if (auto val = enumerant.getValue())
values.push_back(formatv("{0}u", val));
}
os << formatv(" if (value & ~({0})) return llvm::None;\n",
llvm::join(values, " | "));
os << formatv(" return static_cast<{0}>(value);\n", enumName);
os << "}\n";
}
static void emitBitEnumDef(const Record &enumDef, raw_ostream &os) {
EnumAttr enumAttr(enumDef);
StringRef cppNamespace = enumAttr.getCppNamespace();
llvm::SmallVector<StringRef, 2> namespaces;
llvm::SplitString(cppNamespace, namespaces, "::");
for (auto ns : namespaces)
os << "namespace " << ns << " {\n";
emitSymToStrFnForBitEnum(enumDef, os);
emitStrToSymFnForBitEnum(enumDef, os);
emitUnderlyingToSymFnForBitEnum(enumDef, os);
for (auto ns : llvm::reverse(namespaces))
os << "} // namespace " << ns << "\n";
os << "\n";
}
static bool emitBitEnumDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
llvm::emitSourceFileHeader("BitEnum Utility Definitions", os);
auto defs = recordKeeper.getAllDerivedDefinitions("BitEnumAttr");
for (const auto *def : defs)
emitBitEnumDef(*def, os);
return false;
}
//===----------------------------------------------------------------------===//
// Hook Registration
//===----------------------------------------------------------------------===//
static mlir::GenRegistration genSerialization(
"gen-spirv-serialization",
"Generate SPIR-V (de)serialization utilities and functions",
@ -463,3 +633,17 @@ static mlir::GenRegistration
[](const RecordKeeper &records, raw_ostream &os) {
return emitOpUtils(records, os);
});
static mlir::GenRegistration
genEnumDecls("gen-spirv-enum-decls",
"Generate SPIR-V bit enum utility declarations",
[](const RecordKeeper &records, raw_ostream &os) {
return emitBitEnumDecls(records, os);
});
static mlir::GenRegistration
genEnumDefs("gen-spirv-enum-defs",
"Generate SPIR-V bit enum utility definitions",
[](const RecordKeeper &records, raw_ostream &os) {
return emitBitEnumDefs(records, os);
});

View file

@ -132,16 +132,18 @@ def uniquify(lst, equality_fn):
def gen_operand_kind_enum_attr(operand_kind):
"""Generates the TableGen I32EnumAttr definition for the given operand kind.
"""Generates the TableGen EnumAttr definition for the given operand kind.
Returns:
- The operand kind's name
- A string containing the TableGen I32EnumAttr definition
- A string containing the TableGen EnumAttr definition
"""
if 'enumerants' not in operand_kind:
return '', ''
kind_name = operand_kind['kind']
is_bit_enum = operand_kind['category'] == 'BitEnum'
kind_category = 'Bit' if is_bit_enum else 'I32'
kind_acronym = ''.join([c for c in kind_name if c >= 'A' and c <= 'Z'])
kind_cases = [(case['enumerant'], case['value'])
for case in operand_kind['enumerants']]
@ -150,9 +152,10 @@ def gen_operand_kind_enum_attr(operand_kind):
# Generate the definition for each enum case
fmt_str = 'def SPV_{acronym}_{symbol} {colon:>{offset}} '\
'I32EnumAttrCase<"{symbol}", {value}>;'
'{category}EnumAttrCase<"{symbol}", {value}>;'
case_defs = [
fmt_str.format(
category=kind_category,
acronym=kind_acronym,
symbol=case[0],
value=case[1],
@ -174,12 +177,13 @@ def gen_operand_kind_enum_attr(operand_kind):
# Generate the enum attribute definition
enum_attr = 'def SPV_{name}Attr :\n '\
'I32EnumAttr<"{name}", "valid SPIR-V {name}", [\n{cases}\n ]> {{\n'\
'{category}EnumAttr<"{name}", "valid SPIR-V {name}", [\n{cases}\n'\
' ]> {{\n'\
' let returnType = "::mlir::spirv::{name}";\n'\
' let convertFromStorage = '\
'"static_cast<::mlir::spirv::{name}>($_self.getInt())";\n'\
' let cppNamespace = "::mlir::spirv";\n}}'.format(
name=kind_name, cases=case_names)
name=kind_name, category=kind_category, cases=case_names)
return kind_name, case_defs + '\n\n' + enum_attr