Add support for (de)serialization of SPIR-V Op Decorations

All non-argument attributes specified for an operation are treated as
decorations on the result value and (de)serialized using OpDecorate
instruction. An error is generated if an attribute is not an argument,
and the name doesn't correspond to a Decoration enum. Name of the
attributes that represent decoerations are to be the snake-case-ified
version of the Decoration name.
Add utility methods to convert to snake-case and camel-case.

PiperOrigin-RevId: 260792638
This commit is contained in:
Mahesh Ravishankar 2019-07-30 14:14:28 -07:00 committed by A. Unique TensorFlower
parent 3b207d3691
commit 1de519a753
10 changed files with 427 additions and 50 deletions

View file

@ -196,6 +196,97 @@ def SPV_AddressingModelAttr :
let cppNamespace = "::mlir::spirv";
}
def SPV_D_RelaxedPrecision : I32EnumAttrCase<"RelaxedPrecision", 0>;
def SPV_D_SpecId : I32EnumAttrCase<"SpecId", 1>;
def SPV_D_Block : I32EnumAttrCase<"Block", 2>;
def SPV_D_BufferBlock : I32EnumAttrCase<"BufferBlock", 3>;
def SPV_D_RowMajor : I32EnumAttrCase<"RowMajor", 4>;
def SPV_D_ColMajor : I32EnumAttrCase<"ColMajor", 5>;
def SPV_D_ArrayStride : I32EnumAttrCase<"ArrayStride", 6>;
def SPV_D_MatrixStride : I32EnumAttrCase<"MatrixStride", 7>;
def SPV_D_GLSLShared : I32EnumAttrCase<"GLSLShared", 8>;
def SPV_D_GLSLPacked : I32EnumAttrCase<"GLSLPacked", 9>;
def SPV_D_CPacked : I32EnumAttrCase<"CPacked", 10>;
def SPV_D_BuiltIn : I32EnumAttrCase<"BuiltIn", 11>;
def SPV_D_NoPerspective : I32EnumAttrCase<"NoPerspective", 13>;
def SPV_D_Flat : I32EnumAttrCase<"Flat", 14>;
def SPV_D_Patch : I32EnumAttrCase<"Patch", 15>;
def SPV_D_Centroid : I32EnumAttrCase<"Centroid", 16>;
def SPV_D_Sample : I32EnumAttrCase<"Sample", 17>;
def SPV_D_Invariant : I32EnumAttrCase<"Invariant", 18>;
def SPV_D_Restrict : I32EnumAttrCase<"Restrict", 19>;
def SPV_D_Aliased : I32EnumAttrCase<"Aliased", 20>;
def SPV_D_Volatile : I32EnumAttrCase<"Volatile", 21>;
def SPV_D_Constant : I32EnumAttrCase<"Constant", 22>;
def SPV_D_Coherent : I32EnumAttrCase<"Coherent", 23>;
def SPV_D_NonWritable : I32EnumAttrCase<"NonWritable", 24>;
def SPV_D_NonReadable : I32EnumAttrCase<"NonReadable", 25>;
def SPV_D_Uniform : I32EnumAttrCase<"Uniform", 26>;
def SPV_D_UniformId : I32EnumAttrCase<"UniformId", 27>;
def SPV_D_SaturatedConversion : I32EnumAttrCase<"SaturatedConversion", 28>;
def SPV_D_Stream : I32EnumAttrCase<"Stream", 29>;
def SPV_D_Location : I32EnumAttrCase<"Location", 30>;
def SPV_D_Component : I32EnumAttrCase<"Component", 31>;
def SPV_D_Index : I32EnumAttrCase<"Index", 32>;
def SPV_D_Binding : I32EnumAttrCase<"Binding", 33>;
def SPV_D_DescriptorSet : I32EnumAttrCase<"DescriptorSet", 34>;
def SPV_D_Offset : I32EnumAttrCase<"Offset", 35>;
def SPV_D_XfbBuffer : I32EnumAttrCase<"XfbBuffer", 36>;
def SPV_D_XfbStride : I32EnumAttrCase<"XfbStride", 37>;
def SPV_D_FuncParamAttr : I32EnumAttrCase<"FuncParamAttr", 38>;
def SPV_D_FPRoundingMode : I32EnumAttrCase<"FPRoundingMode", 39>;
def SPV_D_FPFastMathMode : I32EnumAttrCase<"FPFastMathMode", 40>;
def SPV_D_LinkageAttributes : I32EnumAttrCase<"LinkageAttributes", 41>;
def SPV_D_NoContraction : I32EnumAttrCase<"NoContraction", 42>;
def SPV_D_InputAttachmentIndex : I32EnumAttrCase<"InputAttachmentIndex", 43>;
def SPV_D_Alignment : I32EnumAttrCase<"Alignment", 44>;
def SPV_D_MaxByteOffset : I32EnumAttrCase<"MaxByteOffset", 45>;
def SPV_D_AlignmentId : I32EnumAttrCase<"AlignmentId", 46>;
def SPV_D_MaxByteOffsetId : I32EnumAttrCase<"MaxByteOffsetId", 47>;
def SPV_D_NoSignedWrap : I32EnumAttrCase<"NoSignedWrap", 4469>;
def SPV_D_NoUnsignedWrap : I32EnumAttrCase<"NoUnsignedWrap", 4470>;
def SPV_D_ExplicitInterpAMD : I32EnumAttrCase<"ExplicitInterpAMD", 4999>;
def SPV_D_OverrideCoverageNV : I32EnumAttrCase<"OverrideCoverageNV", 5248>;
def SPV_D_PassthroughNV : I32EnumAttrCase<"PassthroughNV", 5250>;
def SPV_D_ViewportRelativeNV : I32EnumAttrCase<"ViewportRelativeNV", 5252>;
def SPV_D_SecondaryViewportRelativeNV : I32EnumAttrCase<"SecondaryViewportRelativeNV", 5256>;
def SPV_D_PerPrimitiveNV : I32EnumAttrCase<"PerPrimitiveNV", 5271>;
def SPV_D_PerViewNV : I32EnumAttrCase<"PerViewNV", 5272>;
def SPV_D_PerTaskNV : I32EnumAttrCase<"PerTaskNV", 5273>;
def SPV_D_PerVertexNV : I32EnumAttrCase<"PerVertexNV", 5285>;
def SPV_D_NonUniformEXT : I32EnumAttrCase<"NonUniformEXT", 5300>;
def SPV_D_RestrictPointerEXT : I32EnumAttrCase<"RestrictPointerEXT", 5355>;
def SPV_D_AliasedPointerEXT : I32EnumAttrCase<"AliasedPointerEXT", 5356>;
def SPV_D_CounterBuffer : I32EnumAttrCase<"CounterBuffer", 5634>;
def SPV_D_UserSemantic : I32EnumAttrCase<"UserSemantic", 5635>;
def SPV_D_UserTypeGOOGLE : I32EnumAttrCase<"UserTypeGOOGLE", 5636>;
def SPV_DecorationAttr :
I32EnumAttr<"Decoration", "valid SPIR-V Decoration", [
SPV_D_RelaxedPrecision, SPV_D_SpecId, SPV_D_Block, SPV_D_BufferBlock,
SPV_D_RowMajor, SPV_D_ColMajor, SPV_D_ArrayStride, SPV_D_MatrixStride,
SPV_D_GLSLShared, SPV_D_GLSLPacked, SPV_D_CPacked, SPV_D_BuiltIn,
SPV_D_NoPerspective, SPV_D_Flat, SPV_D_Patch, SPV_D_Centroid, SPV_D_Sample,
SPV_D_Invariant, SPV_D_Restrict, SPV_D_Aliased, SPV_D_Volatile, SPV_D_Constant,
SPV_D_Coherent, SPV_D_NonWritable, SPV_D_NonReadable, SPV_D_Uniform,
SPV_D_UniformId, SPV_D_SaturatedConversion, SPV_D_Stream, SPV_D_Location,
SPV_D_Component, SPV_D_Index, SPV_D_Binding, SPV_D_DescriptorSet, SPV_D_Offset,
SPV_D_XfbBuffer, SPV_D_XfbStride, SPV_D_FuncParamAttr, SPV_D_FPRoundingMode,
SPV_D_FPFastMathMode, SPV_D_LinkageAttributes, SPV_D_NoContraction,
SPV_D_InputAttachmentIndex, SPV_D_Alignment, SPV_D_MaxByteOffset,
SPV_D_AlignmentId, SPV_D_MaxByteOffsetId, SPV_D_NoSignedWrap,
SPV_D_NoUnsignedWrap, SPV_D_ExplicitInterpAMD, SPV_D_OverrideCoverageNV,
SPV_D_PassthroughNV, SPV_D_ViewportRelativeNV,
SPV_D_SecondaryViewportRelativeNV, SPV_D_PerPrimitiveNV, SPV_D_PerViewNV,
SPV_D_PerTaskNV, SPV_D_PerVertexNV, SPV_D_NonUniformEXT,
SPV_D_RestrictPointerEXT, SPV_D_AliasedPointerEXT, SPV_D_CounterBuffer,
SPV_D_UserSemantic, SPV_D_UserTypeGOOGLE
]> {
let returnType = "::mlir::spirv::Decoration";
let convertFromStorage = "static_cast<::mlir::spirv::Decoration>($_self.getInt())";
let cppNamespace = "::mlir::spirv";
}
def SPV_D_1D : I32EnumAttrCase<"1D", 0>;
def SPV_D_2D : I32EnumAttrCase<"2D", 1>;
def SPV_D_3D : I32EnumAttrCase<"3D", 2>;

View file

@ -0,0 +1,81 @@
//===- StringExtras.h - String utilities used by MLIR -----------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file contains string utility functions used within MLIR.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_SUPPORT_STRINGEXTRAS_H
#define MLIR_SUPPORT_STRINGEXTRAS_H
#include "llvm/ADT/StringExtras.h"
namespace mlir {
/// Converts a string to snake-case from camel-case by replacing all uppercase
/// letters with '_' followed by the letter in lowercase, except if the
/// uppercase letter is the first character of the string.
inline std::string convertToSnakeCase(llvm::StringRef input) {
std::string snakeCase;
snakeCase.reserve(input.size());
for (auto c : input) {
if (std::isupper(c)) {
if (!snakeCase.empty() && snakeCase.back() != '_') {
snakeCase.push_back('_');
}
snakeCase.push_back(llvm::toLower(c));
} else {
snakeCase.push_back(c);
}
}
return snakeCase;
}
/// Converts a string from camel-case to snake_case by replacing all occurences
/// of '_' followed by a lowercase letter with the letter in
/// uppercase. Optionally allow capitalization of the first letter (if it is a
/// lowercase letter)
inline std::string convertToCamelCase(llvm::StringRef input,
bool capitalizeFirst = false) {
if (input.empty()) {
return "";
}
std::string output;
output.reserve(input.size());
size_t pos = 0;
if (capitalizeFirst && std::islower(input[pos])) {
output.push_back(llvm::toUpper(input[pos]));
pos++;
}
while (pos < input.size()) {
auto cur = input[pos];
if (cur == '_') {
if (pos && (pos + 1 < input.size())) {
if (std::islower(input[pos + 1])) {
output.push_back(llvm::toUpper(input[pos + 1]));
pos += 2;
continue;
}
}
}
output.push_back(cur);
pos++;
}
return output;
}
} // namespace mlir
#endif // MLIR_SUPPORT_STRINGEXTRAS_H

View file

@ -26,13 +26,12 @@
#include "mlir/IR/Function.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Support/StringExtras.h"
using namespace mlir;
// TODO(antiagainst): generate these strings using ODS.
static constexpr const char kAlignmentAttrName[] = "alignment";
static constexpr const char kBindingAttrName[] = "binding";
static constexpr const char kDescriptorSetAttrName[] = "descriptor_set";
static constexpr const char kIndicesAttrName[] = "indices";
static constexpr const char kValueAttrName[] = "value";
static constexpr const char kValuesAttrName[] = "values";
@ -67,8 +66,7 @@ static LogicalResult extractValueFromConstOp(Operation *op,
}
template <typename EnumClass>
static ParseResult parseEnumAttribute(EnumClass &value, OpAsmParser *parser,
OperationState *state) {
static ParseResult parseEnumAttribute(EnumClass &value, OpAsmParser *parser) {
Attribute attrVal;
SmallVector<NamedAttribute, 1> attr;
auto loc = parser->getCurrentLocation();
@ -89,6 +87,15 @@ static ParseResult parseEnumAttribute(EnumClass &value, OpAsmParser *parser,
<< " attribute specification: " << attrVal;
}
value = attrOptional.getValue();
return success();
}
template <typename EnumClass>
static ParseResult parseEnumAttribute(EnumClass &value, OpAsmParser *parser,
OperationState *state) {
if (parseEnumAttribute(value, parser)) {
return failure();
}
state->addAttribute(
spirv::attributeName<EnumClass>(),
parser->getBuilder().getI32IntegerAttr(bitwiseCast<int32_t>(value)));
@ -601,7 +608,7 @@ static ParseResult parseLoadOp(OpAsmParser *parser, OperationState *state) {
spirv::StorageClass storageClass;
OpAsmParser::OperandType ptrInfo;
Type elementType;
if (parseEnumAttribute(storageClass, parser, state) ||
if (parseEnumAttribute(storageClass, parser) ||
parser->parseOperand(ptrInfo) ||
parseMemoryAccessAttributes(parser, state) ||
parser->parseOptionalAttributeDict(state->attributes) ||
@ -813,7 +820,7 @@ static ParseResult parseStoreOp(OpAsmParser *parser, OperationState *state) {
SmallVector<OpAsmParser::OperandType, 2> operandInfo;
auto loc = parser->getCurrentLocation();
Type elementType;
if (parseEnumAttribute(storageClass, parser, state) ||
if (parseEnumAttribute(storageClass, parser) ||
parser->parseOperandList(operandInfo, 2) ||
parseMemoryAccessAttributes(parser, state) || parser->parseColon() ||
parser->parseType(elementType)) {
@ -873,13 +880,17 @@ static ParseResult parseVariableOp(OpAsmParser *parser, OperationState *state) {
// Parse optional descriptor binding
Attribute set, binding;
auto descriptorSetName =
convertToSnakeCase(stringifyDecoration(spirv::Decoration::DescriptorSet));
auto bindingName =
convertToSnakeCase(stringifyDecoration(spirv::Decoration::Binding));
if (succeeded(parser->parseOptionalKeyword("bind"))) {
Type i32Type = parser->getBuilder().getIntegerType(32);
if (parser->parseLParen() ||
parser->parseAttribute(set, i32Type, kDescriptorSetAttrName,
parser->parseAttribute(set, i32Type, descriptorSetName,
state->attributes) ||
parser->parseComma() ||
parser->parseAttribute(binding, i32Type, kBindingAttrName,
parser->parseAttribute(binding, i32Type, bindingName,
state->attributes) ||
parser->parseRParen())
return failure();
@ -931,12 +942,17 @@ static void print(spirv::VariableOp varOp, OpAsmPrinter *printer) {
}
// Print optional descriptor binding
auto set = varOp.getAttrOfType<IntegerAttr>(kDescriptorSetAttrName);
auto binding = varOp.getAttrOfType<IntegerAttr>(kBindingAttrName);
if (set && binding) {
elidedAttrs.push_back(kDescriptorSetAttrName);
elidedAttrs.push_back(kBindingAttrName);
*printer << " bind(" << set.getInt() << ", " << binding.getInt() << ")";
auto descriptorSetName =
convertToSnakeCase(stringifyDecoration(spirv::Decoration::DescriptorSet));
auto bindingName =
convertToSnakeCase(stringifyDecoration(spirv::Decoration::Binding));
auto descriptorSet = varOp.getAttrOfType<IntegerAttr>(descriptorSetName);
auto binding = varOp.getAttrOfType<IntegerAttr>(bindingName);
if (descriptorSet && binding) {
elidedAttrs.push_back(descriptorSetName);
elidedAttrs.push_back(bindingName);
*printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt()
<< ")";
}
printer->printOptionalAttrDict(op->getAttrs(), elidedAttrs);

View file

@ -27,6 +27,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/Location.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Support/StringExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/bit.h"
@ -80,6 +81,9 @@ private:
/// Process SPIR-V OpName with `operands`
LogicalResult processName(ArrayRef<uint32_t> operands);
/// Method to process an OpDecorate instruction.
LogicalResult processDecoration(ArrayRef<uint32_t> words);
/// Processes the SPIR-V function at the current `offset` into `binary`.
/// The operands to the OpFunction instruction is passed in as ``operands`.
/// This method processes each instruction inside the function and dispatches
@ -196,6 +200,9 @@ private:
// Result <id> to name mapping.
DenseMap<uint32_t, StringRef> nameMap;
// Result <id> to decorations mapping.
DenseMap<uint32_t, NamedAttributeList> decorations;
// List of instructions that are processed in a defered fashion (after an
// initial processing of the entire binary). Some operations like
// OpEntryPoint, and OpExecutionMode use forward references to function
@ -285,6 +292,37 @@ LogicalResult Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
return success();
}
LogicalResult Deserializer::processDecoration(ArrayRef<uint32_t> words) {
// TODO : This function should also be auto-generated. For now, since only a
// few decorations are processed/handled in a meaningful manner, going with a
// manual implementation.
if (words.size() < 2) {
return emitError(
unknownLoc, "OpDecorate must have at least result <id> and Decoration");
}
auto decorationName =
stringifyDecoration(static_cast<spirv::Decoration>(words[1]));
if (decorationName.empty()) {
return emitError(unknownLoc, "invalid Decoration code : ") << words[1];
}
auto attrName = convertToSnakeCase(decorationName);
switch (static_cast<spirv::Decoration>(words[1])) {
case spirv::Decoration::DescriptorSet:
case spirv::Decoration::Binding:
if (words.size() != 3) {
return emitError(unknownLoc, "OpDecorate with ")
<< decorationName << " needs a single integer literal";
}
decorations[words[0]].set(
opBuilder.getIdentifier(attrName),
opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2])));
break;
default:
return emitError(unknownLoc, "unhandled Decoration : '") << decorationName;
}
return success();
}
LogicalResult Deserializer::processFunction(ArrayRef<uint32_t> operands) {
// Get the result type
if (operands.size() != 4) {
@ -830,6 +868,8 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode,
return processConstantBool(false, operands);
case spirv::Opcode::OpConstantNull:
return processConstantNull(operands);
case spirv::Opcode::OpDecorate:
return processDecoration(operands);
case spirv::Opcode::OpFunction:
return processFunction(operands);
default:
@ -839,6 +879,7 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode,
}
namespace {
template <>
LogicalResult
Deserializer::processOp<spirv::EntryPointOp>(ArrayRef<uint32_t> words) {

View file

@ -27,6 +27,7 @@
#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
#include "mlir/IR/Builders.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Support/StringExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/bit.h"
#include "llvm/Support/raw_ostream.h"
@ -127,6 +128,10 @@ private:
/// Processes a SPIR-V function op.
LogicalResult processFuncOp(FuncOp op);
/// Process attributes that translate to decorations on the result <id>
LogicalResult processDecoration(Location loc, uint32_t resultID,
NamedAttribute attr);
//===--------------------------------------------------------------------===//
// Types
//===--------------------------------------------------------------------===//
@ -319,6 +324,34 @@ LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) {
return failure();
}
LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
NamedAttribute attr) {
auto attrName = attr.first.strref();
auto decorationName = mlir::convertToCamelCase(attrName, true);
auto decoration = spirv::symbolizeDecoration(decorationName);
if (!decoration) {
return emitError(
loc, "non-argument attributes expected to have snake-case-ified "
"decoration name, unhandled attribute with name : ")
<< attrName;
}
SmallVector<uint32_t, 1> args;
args.push_back(resultID);
args.push_back(static_cast<uint32_t>(decoration.getValue()));
switch (decoration.getValue()) {
case spirv::Decoration::DescriptorSet:
case spirv::Decoration::Binding:
if (auto intAttr = attr.second.dyn_cast<IntegerAttr>()) {
args.push_back(intAttr.getValue().getZExtValue());
break;
}
return emitError(loc, "expected integer attribute for ") << attrName;
default:
return emitError(loc, "unhandled decoration ") << decorationName;
}
return encodeInstructionInto(decorations, spirv::Opcode::OpDecorate, args);
}
LogicalResult Serializer::processFuncOp(FuncOp op) {
uint32_t fnTypeID = 0;
// Generate type of the function.

View file

@ -1,11 +1,11 @@
// RUN: mlir-translate -serialize-spirv %s | mlir-translate -deserialize-spirv | FileCheck %s
// CHECK: {{%.*}} = spv.Variable : !spv.ptr<f32, Input>
// CHECK-NEXT: {{%.*}} = spv.Variable : !spv.ptr<f32, Output>
// CHECK: {{%.*}} = spv.Variable bind(1, 0) : !spv.ptr<f32, Input>
// CHECK-NEXT: {{%.*}} = spv.Variable bind(0, 1) : !spv.ptr<f32, Output>
func @spirv_variables() -> () {
spv.module "Logical" "VulkanKHR" {
%2 = spv.Variable : !spv.ptr<f32, Input>
%3 = spv.Variable : !spv.ptr<f32, Output>
%2 = spv.Variable bind(1, 0) : !spv.ptr<f32, Input>
%3 = spv.Variable bind(0, 1): !spv.ptr<f32, Output>
}
return
}

View file

@ -20,9 +20,11 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/Support/StringExtras.h"
#include "mlir/TableGen/Attribute.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Operator.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
@ -39,6 +41,8 @@ using llvm::raw_string_ostream;
using llvm::Record;
using llvm::RecordKeeper;
using llvm::SMLoc;
using llvm::StringRef;
using llvm::Twine;
using mlir::tblgen::Attribute;
using mlir::tblgen::EnumAttr;
using mlir::tblgen::NamedAttribute;
@ -90,7 +94,8 @@ static void emitAttributeSerialization(const Attribute &attr,
os << " }\n";
}
static void emitSerializationFunction(const Record *record, const Operator &op,
static void emitSerializationFunction(const Record *attrClass,
const Record *record, const Operator &op,
raw_ostream &os) {
// If the record has 'autogenSerialization' set to 0, nothing to do
if (!record->getValueAsBit("autogenSerialization")) {
@ -101,21 +106,20 @@ static void emitSerializationFunction(const Record *record, const Operator &op,
op.getQualCppClassName())
<< " {\n";
os << " SmallVector<uint32_t, 4> operands;\n";
os << " SmallVector<StringRef, 2> elidedAttrs;\n";
// Serialize result information
if (op.getNumResults() == 1) {
os << " {\n";
os << " uint32_t typeID = 0;\n";
os << " if (failed(processType(op.getLoc(), "
"op.getResult()->getType(), typeID))) {\n";
os << " return failure();\n";
os << " }\n";
os << " operands.push_back(typeID);\n";
/// Create an SSA result <id> for the op
os << " auto resultID = getNextID();\n";
os << " valueIDMap[op.getResult()] = resultID;\n";
os << " operands.push_back(resultID);\n";
os << " uint32_t resultTypeID = 0;\n";
os << " if (failed(processType(op.getLoc(), op.getType(), resultTypeID))) "
"{\n";
os << " return failure();\n";
os << " }\n";
os << " operands.push_back(resultTypeID);\n";
// Create an SSA result <id> for the op
os << " auto resultID = getNextID();\n";
os << " valueIDMap[op.getResult()] = resultID;\n";
os << " operands.push_back(resultID);\n";
} else if (op.getNumResults() != 0) {
PrintFatalError(record->getLoc(), "SPIR-V ops can only zero or one result");
}
@ -140,6 +144,7 @@ static void emitSerializationFunction(const Record *record, const Operator &op,
emitAttributeSerialization(
(attr->attr.isOptional() ? attr->attr.getBaseAttr() : attr->attr),
record->getLoc(), "op", "operands", attr->name, os);
os << " elidedAttrs.push_back(\"" << attr->name << "\");\n";
}
os << " }\n";
}
@ -147,6 +152,20 @@ static void emitSerializationFunction(const Record *record, const Operator &op,
os << formatv(" encodeInstructionInto("
"functions, spirv::getOpcode<{0}>(), operands);\n",
op.getQualCppClassName());
if (op.getNumResults() == 1) {
// All non-argument attributes translated into OpDecorate instruction
os << " for (auto attr : op.getAttrs()) {\n";
os << " if (llvm::any_of(elidedAttrs, [&](StringRef elided) { return "
"attr.first.is(elided); })) {\n";
os << " continue;\n";
os << " }\n";
os << " if (failed(processDecoration(op.getLoc(), resultID, attr))) {\n";
os << " return failure();";
os << " }\n";
os << " }\n";
}
os << " return success();\n";
os << "}\n\n";
}
@ -196,7 +215,8 @@ static void emitAttributeDeserialization(
}
}
static void emitDeserializationFunction(const Record *record,
static void emitDeserializationFunction(const Record *attrClass,
const Record *record,
const Operator &op, raw_ostream &os) {
// If the record has 'autogenSerialization' set to 0, nothing to do
if (!record->getValueAsBit("autogenSerialization")) {
@ -292,8 +312,19 @@ static void emitDeserializationFunction(const Record *record,
"operands, attributes); (void)op;\n",
op.getQualCppClassName());
if (hasResult) {
os << " valueMap[valueID] = op.getResult();\n";
os << " valueMap[valueID] = op.getResult();\n\n";
}
// Import decorations parsed
if (op.getNumResults() == 1) {
os << " if (decorations.count(valueID)) {\n";
os << " auto decorationAttrs = decorations[valueID];\n";
os << " for (auto attr : decorationAttrs.getAttrs()) {\n";
os << " op.setAttr(attr.first, attr.second);\n";
os << " }\n";
os << " }\n";
}
os << " return success();\n";
os << "}\n\n";
}
@ -330,6 +361,7 @@ static bool emitSerializationFns(const RecordKeeper &recordKeeper,
utilsString;
raw_string_ostream dSerFn(dSerFnString), dDesFn(dDesFnString),
serFn(serFnString), deserFn(deserFnString), utils(utilsString);
auto attrClass = recordKeeper.getClass("Attr");
declareOpcodeFn(utils);
initDispatchSerializationFn(dSerFn);
@ -341,9 +373,9 @@ static bool emitSerializationFns(const RecordKeeper &recordKeeper,
}
Operator op(def);
emitGetOpcodeFunction(def, op, utils);
emitSerializationFunction(def, op, serFn);
emitSerializationFunction(attrClass, def, op, serFn);
emitSerializationDispatch(op, dSerFn);
emitDeserializationFunction(def, op, deserFn);
emitDeserializationFunction(attrClass, def, op, deserFn);
emitDeserializationDispatch(op, def, dDesFn);
}
finalizeDispatchSerializationFn(dSerFn);
@ -378,21 +410,6 @@ static void emitEnumGetSymbolizeFnDecl(raw_ostream &os) {
"SymbolizeFnTy<EnumClass> symbolizeEnum();\n";
}
std::string convertSnakeCase(llvm::StringRef inputString) {
std::string snakeCase;
for (auto c : inputString) {
if (c >= 'A' && c <= 'Z') {
if (!snakeCase.empty()) {
snakeCase.push_back('_');
}
snakeCase.push_back((c - 'A') + 'a');
} else {
snakeCase.push_back(c);
}
}
return snakeCase;
}
static void emitEnumGetAttrNameFnDefn(const EnumAttr &enumAttr,
raw_ostream &os) {
auto enumName = enumAttr.getEnumClassName();
@ -400,7 +417,7 @@ static void emitEnumGetAttrNameFnDefn(const EnumAttr &enumAttr,
<< " {\n";
os << " "
<< formatv("static constexpr const char attrName[] = \"{0}\";\n",
convertSnakeCase(enumName));
mlir::convertToSnakeCase(enumName));
os << " return attrName;\n";
os << "}\n";
}

View file

@ -2,6 +2,7 @@ add_mlir_unittest(MLIRIRTests
AttributeTest.cpp
DialectTest.cpp
OperationSupportTest.cpp
StringExtrasTest.cpp
)
target_link_libraries(MLIRIRTests
PRIVATE

View file

@ -0,0 +1,74 @@
//===- StringExtras.cpp - Tests for utility methods in StringExtras.h -----===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "mlir/Support/StringExtras.h"
#include "gtest/gtest.h"
using namespace mlir;
static void testConvertToSnakeCase(llvm::StringRef input,
llvm::StringRef expected) {
EXPECT_EQ(convertToSnakeCase(input), expected.str());
}
TEST(StringExtras, ConvertToSnakeCase) {
testConvertToSnakeCase("OpName", "op_name");
testConvertToSnakeCase("opName", "op_name");
testConvertToSnakeCase("_OpName", "_op_name");
testConvertToSnakeCase("Op_Name", "op_name");
testConvertToSnakeCase("", "");
testConvertToSnakeCase("A", "a");
testConvertToSnakeCase("_", "_");
testConvertToSnakeCase("a", "a");
testConvertToSnakeCase("op_name", "op_name");
testConvertToSnakeCase("_op_name", "_op_name");
testConvertToSnakeCase("__op_name", "__op_name");
testConvertToSnakeCase("op__name", "op__name");
}
template <bool capitalizeFirst>
static void testConvertToCamelCase(llvm::StringRef input,
llvm::StringRef expected) {
EXPECT_EQ(convertToCamelCase(input, capitalizeFirst), expected.str());
}
TEST(StringExtras, ConvertToCamelCase) {
testConvertToCamelCase<false>("op_name", "opName");
testConvertToCamelCase<false>("_op_name", "_opName");
testConvertToCamelCase<false>("__op_name", "_OpName");
testConvertToCamelCase<false>("op__name", "op_Name");
testConvertToCamelCase<false>("", "");
testConvertToCamelCase<false>("A", "A");
testConvertToCamelCase<false>("_", "_");
testConvertToCamelCase<false>("a", "a");
testConvertToCamelCase<false>("OpName", "OpName");
testConvertToCamelCase<false>("opName", "opName");
testConvertToCamelCase<false>("_OpName", "_OpName");
testConvertToCamelCase<false>("Op_Name", "Op_Name");
testConvertToCamelCase<true>("op_name", "OpName");
testConvertToCamelCase<true>("_op_name", "_opName");
testConvertToCamelCase<true>("__op_name", "_OpName");
testConvertToCamelCase<true>("op__name", "Op_Name");
testConvertToCamelCase<true>("", "");
testConvertToCamelCase<true>("A", "A");
testConvertToCamelCase<true>("_", "_");
testConvertToCamelCase<true>("a", "A");
testConvertToCamelCase<true>("OpName", "OpName");
testConvertToCamelCase<true>("_OpName", "_OpName");
testConvertToCamelCase<true>("Op_Name", "Op_Name");
testConvertToCamelCase<true>("opName", "OpName");
}

View file

@ -109,6 +109,28 @@ def split_list_into_sublists(items, offset):
return chuncks
def uniquify(lst, equality_fn):
"""Returns a list after pruning duplicate elements.
Arguments:
- lst: List whose elements are to be uniqued.
- equality_fn: Function used to compare equality between elements of the
list.
Returns:
- A list with all duplicated removed. The order of elements is same as the
original list, with only the first occurence of duplicates retained.
"""
keys = set()
unique_lst = []
for elem in lst:
key = equality_fn(elem)
if equality_fn(key) not in keys:
unique_lst.append(elem)
keys.add(key)
return unique_lst
def gen_operand_kind_enum_attr(operand_kind):
"""Generates the TableGen I32EnumAttr definition for the given operand kind.
@ -123,6 +145,7 @@ def gen_operand_kind_enum_attr(operand_kind):
kind_acronym = ''.join([c for c in kind_name if c >= 'A' and c <= 'Z'])
kind_cases = [(case['enumerant'], case['value'])
for case in operand_kind['enumerants']]
kind_cases = uniquify(kind_cases, lambda x: x[1])
max_len = max([len(symbol) for (symbol, _) in kind_cases])
# Generate the definition for each enum case