Add serialization and deserialization of FuncOps. To support this the

following SPIRV Instructions serializaiton/deserialization are added
as well

OpFunction
OpFunctionParameter
OpFunctionEnd
OpReturn

PiperOrigin-RevId: 257869806
This commit is contained in:
Mahesh Ravishankar 2019-07-12 14:32:20 -07:00 committed by Mehdi Amini
parent 63bc37c9c0
commit 9af156757d
6 changed files with 326 additions and 73 deletions

View file

@ -72,23 +72,29 @@ class SPV_OpCode<string name, int val> {
// Begin opcode section. Generated from SPIR-V spec; DO NOT MODIFY!
def SPV_OC_OpMemoryModel : I32EnumAttrCase<"OpMemoryModel", 14>;
def SPV_OC_OpEntryPoint : I32EnumAttrCase<"OpEntryPoint", 15>;
def SPV_OC_OpExecutionMode : I32EnumAttrCase<"OpExecutionMode", 16>;
def SPV_OC_OpTypeVoid : I32EnumAttrCase<"OpTypeVoid", 19>;
def SPV_OC_OpTypeFunction : I32EnumAttrCase<"OpTypeFunction", 33>;
def SPV_OC_OpVariable : I32EnumAttrCase<"OpVariable", 59>;
def SPV_OC_OpLoad : I32EnumAttrCase<"OpLoad", 61>;
def SPV_OC_OpStore : I32EnumAttrCase<"OpStore", 62>;
def SPV_OC_OpCompositeExtract : I32EnumAttrCase<"OpCompositeExtract", 81>;
def SPV_OC_OpFMul : I32EnumAttrCase<"OpFMul", 133>;
def SPV_OC_OpReturn : I32EnumAttrCase<"OpReturn", 253>;
def SPV_OC_OpMemoryModel : I32EnumAttrCase<"OpMemoryModel", 14>;
def SPV_OC_OpEntryPoint : I32EnumAttrCase<"OpEntryPoint", 15>;
def SPV_OC_OpExecutionMode : I32EnumAttrCase<"OpExecutionMode", 16>;
def SPV_OC_OpTypeVoid : I32EnumAttrCase<"OpTypeVoid", 19>;
def SPV_OC_OpTypeFunction : I32EnumAttrCase<"OpTypeFunction", 33>;
def SPV_OC_OpFunction : I32EnumAttrCase<"OpFunction", 54>;
def SPV_OC_OpFunctionParameter : I32EnumAttrCase<"OpFunctionParameter", 55>;
def SPV_OC_OpFunctionEnd : I32EnumAttrCase<"OpFunctionEnd", 56>;
def SPV_OC_OpVariable : I32EnumAttrCase<"OpVariable", 59>;
def SPV_OC_OpLoad : I32EnumAttrCase<"OpLoad", 61>;
def SPV_OC_OpStore : I32EnumAttrCase<"OpStore", 62>;
def SPV_OC_OpDecorate : I32EnumAttrCase<"OpDecorate", 71>;
def SPV_OC_OpCompositeExtract : I32EnumAttrCase<"OpCompositeExtract", 81>;
def SPV_OC_OpFMul : I32EnumAttrCase<"OpFMul", 133>;
def SPV_OC_OpReturn : I32EnumAttrCase<"OpReturn", 253>;
def SPV_OpcodeAttr :
I32EnumAttr<"Opcode", "valid SPIR-V instructions", [
SPV_OC_OpMemoryModel, SPV_OC_OpEntryPoint, SPV_OC_OpExecutionMode,
SPV_OC_OpTypeVoid, SPV_OC_OpTypeFunction, SPV_OC_OpVariable, SPV_OC_OpLoad,
SPV_OC_OpStore, SPV_OC_OpCompositeExtract, SPV_OC_OpFMul, SPV_OC_OpReturn
SPV_OC_OpTypeVoid, SPV_OC_OpTypeFunction, SPV_OC_OpFunction,
SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd, SPV_OC_OpVariable,
SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpDecorate, SPV_OC_OpCompositeExtract,
SPV_OC_OpFMul, SPV_OC_OpReturn
]> {
let returnType = "::mlir::spirv::Opcode";
let convertFromStorage = "static_cast<::mlir::spirv::Opcode>($_self.getInt())";
@ -294,6 +300,21 @@ def SPV_ExecutionModelAttr :
let cppNamespace = "::mlir::spirv";
}
def SPV_FC_None : I32EnumAttrCase<"None", 0x0000>;
def SPV_FC_Inline : I32EnumAttrCase<"Inline", 0x0001>;
def SPV_FC_DontInline : I32EnumAttrCase<"DontInline", 0x0002>;
def SPV_FC_Pure : I32EnumAttrCase<"Pure", 0x0004>;
def SPV_FC_Const : I32EnumAttrCase<"Const", 0x0008>;
def SPV_FunctionControlAttr :
I32EnumAttr<"FunctionControl", "valid SPIR-V FunctionControl", [
SPV_FC_None, SPV_FC_Inline, SPV_FC_DontInline, SPV_FC_Pure, SPV_FC_Const
]> {
let returnType = "::mlir::spirv::FunctionControl";
let convertFromStorage = "static_cast<::mlir::spirv::FunctionControl>($_self.getInt())";
let cppNamespace = "::mlir::spirv";
}
def SPV_IF_Unknown : I32EnumAttrCase<"Unknown", 0>;
def SPV_IF_Rgba32f : I32EnumAttrCase<"Rgba32f", 1>;
def SPV_IF_Rgba16f : I32EnumAttrCase<"Rgba16f", 2>;
@ -352,6 +373,18 @@ def SPV_ImageFormatAttr :
let cppNamespace = "::mlir::spirv";
}
def SPV_LT_Export : I32EnumAttrCase<"Export", 0>;
def SPV_LT_Import : I32EnumAttrCase<"Import", 1>;
def SPV_LinkageTypeAttr :
I32EnumAttr<"LinkageType", "valid SPIR-V LinkageType", [
SPV_LT_Export, SPV_LT_Import
]> {
let returnType = "::mlir::spirv::LinkageType";
let convertFromStorage = "static_cast<::mlir::spirv::LinkageType>($_self.getInt())";
let cppNamespace = "::mlir::spirv";
}
def SPV_MA_None : I32EnumAttrCase<"None", 0x0000>;
def SPV_MA_Volatile : I32EnumAttrCase<"Volatile", 0x0001>;
def SPV_MA_Aligned : I32EnumAttrCase<"Aligned", 0x0002>;

View file

@ -66,9 +66,20 @@ private:
/// Get type for a given result <id>
Type getType(uint32_t id) { return typeMap.lookup(id); }
/// Get Value associated with a result <id>
Value *getValue(uint32_t id) { return valueMap.lookup(id); }
// Check if a type is void
bool isVoidType(Type type) const { return type.isa<NoneType>(); }
/// Processes SPIR-V module header.
LogicalResult processHeader();
/// Deserialize a single instruction. The |opcode| and |operands| are returned
/// after deserialization to the caller.
LogicalResult deserializeInstruction(spirv::Opcode &opcode,
ArrayRef<uint32_t> &operands);
/// Processes a SPIR-V instruction with the given `opcode` and `operands`.
LogicalResult processInstruction(spirv::Opcode opcode,
ArrayRef<uint32_t> operands);
@ -77,6 +88,13 @@ private:
LogicalResult processType(spirv::Opcode opcode, ArrayRef<uint32_t> operands);
LogicalResult processFunctionType(ArrayRef<uint32_t> operands);
/// Process SPIR-V instructions that dont have any operands
template <typename OpTy>
LogicalResult processNullaryInstruction(ArrayRef<uint32_t> operands);
/// Process function objects in binary
LogicalResult processFunction(ArrayRef<uint32_t> operands);
LogicalResult processMemoryModel(ArrayRef<uint32_t> operands);
/// Initializes the `module` ModuleOp in this deserializer instance.
@ -102,6 +120,12 @@ private:
// result <id> to type mapping
DenseMap<uint32_t, Type> typeMap;
// result <id> to function mapping
DenseMap<uint32_t, Operation *> funcMap;
// result <id> to value mapping
DenseMap<uint32_t, Value *> valueMap;
};
} // namespace
@ -114,30 +138,11 @@ LogicalResult Deserializer::deserialize() {
if (failed(processHeader()))
return failure();
auto binarySize = binary.size();
curOffset = spirv::kHeaderWordCount;
while (curOffset < binarySize) {
// For each instruction, get its word count from the first word to slice it
// from the stream properly, and then dispatch to the instruction handler.
uint32_t wordCount = binary[curOffset] >> 16;
uint32_t opcode = binary[curOffset] & 0xffff;
if (wordCount == 0)
return emitError(unknownLoc, "word count cannot be zero");
uint32_t nextOffset = curOffset + wordCount;
if (nextOffset > binarySize)
return emitError(unknownLoc,
"insufficient words for the last instruction");
auto operands = binary.slice(curOffset + 1, wordCount - 1);
if (failed(
processInstruction(static_cast<spirv::Opcode>(opcode), operands)))
spirv::Opcode opcode;
ArrayRef<uint32_t> operands;
while (succeeded(deserializeInstruction(opcode, operands))) {
if (failed(processInstruction(opcode, operands)))
return failure();
curOffset = nextOffset;
}
return success();
@ -154,6 +159,32 @@ LogicalResult Deserializer::processHeader() {
return emitError(unknownLoc, "incorrect magic number");
// TODO(antiagainst): generator number, bound, schema
curOffset = spirv::kHeaderWordCount;
return success();
}
LogicalResult
Deserializer::deserializeInstruction(spirv::Opcode &opcode,
ArrayRef<uint32_t> &operands) {
auto binarySize = binary.size();
if (curOffset >= binarySize) {
return failure();
}
// For each instruction, get its word count from the first word to slice it
// from the stream properly, and then dispatch to the instruction handler.
uint32_t wordCount = binary[curOffset] >> 16;
opcode = static_cast<spirv::Opcode>(binary[curOffset] & 0xffff);
if (wordCount == 0)
return emitError(unknownLoc, "word count cannot be zero");
uint32_t nextOffset = curOffset + wordCount;
if (nextOffset > binarySize)
return emitError(unknownLoc, "insufficient words for the last instruction");
operands = binary.slice(curOffset + 1, wordCount - 1);
curOffset = nextOffset;
return success();
}
@ -174,7 +205,11 @@ LogicalResult Deserializer::processFunctionType(ArrayRef<uint32_t> operands) {
}
argTypes.push_back(ty);
}
typeMap[operands[0]] = FunctionType::get(argTypes, {returnType}, context);
ArrayRef<Type> returnTypes;
if (!isVoidType(returnType)) {
returnTypes = llvm::makeArrayRef(returnType);
}
typeMap[operands[0]] = FunctionType::get(argTypes, returnTypes, context);
return success();
}
@ -205,6 +240,118 @@ LogicalResult Deserializer::processType(spirv::Opcode opcode,
return success();
}
template <typename OpTy>
LogicalResult
Deserializer::processNullaryInstruction(ArrayRef<uint32_t> operands) {
if (!operands.empty()) {
return emitError(unknownLoc) << stringifyOpcode(spirv::getOpcode<OpTy>())
<< " must have no operands, but found "
<< operands.size() << " operands";
}
opBuilder.create<OpTy>(unknownLoc);
return success();
}
LogicalResult Deserializer::processFunction(ArrayRef<uint32_t> operands) {
// Get the result type
if (operands.size() != 4) {
return emitError(unknownLoc, "OpFunction must have 4 parameters");
}
Type resultType = getType(operands[0]);
if (!resultType) {
return emitError(unknownLoc, "unknown result type from <id> ")
<< operands[0];
}
if (funcMap.count(operands[1])) {
return emitError(unknownLoc, "duplicate function definition/declaration");
}
auto functionControl = spirv::symbolizeFunctionControl(operands[2]);
if (!functionControl) {
return emitError(unknownLoc, "unknown Function Control : ") << operands[2];
}
if (functionControl.getValue() != spirv::FunctionControl::None) {
/// TODO : Handle different function controls
return emitError(unknownLoc, "unhandled Function Control : '")
<< spirv::stringifyFunctionControl(functionControl.getValue())
<< "'";
}
Type fnType = getType(operands[3]);
if (!fnType || !fnType.isa<FunctionType>()) {
return emitError(unknownLoc, "unknown function type from <id> ")
<< operands[3];
}
auto functionType = fnType.cast<FunctionType>();
if ((isVoidType(resultType) && functionType.getNumResults() != 0) ||
(functionType.getNumResults() == 1 &&
functionType.getResult(0) != resultType)) {
return emitError(unknownLoc, "mismatch in function type ")
<< functionType << " and return type " << resultType << " specified";
}
/// TODO : The function name must be obtained from OpName eventually
std::string fnName = "spirv_fn_" + std::to_string(operands[2]);
auto funcOp = opBuilder.create<FuncOp>(unknownLoc, fnName, functionType,
ArrayRef<NamedAttribute>());
funcOp.addEntryBlock();
// Parse the op argument instructions
if (functionType.getNumInputs()) {
for (size_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
auto argType = functionType.getInput(i);
spirv::Opcode opcode;
ArrayRef<uint32_t> operands;
if (failed(deserializeInstruction(opcode, operands))) {
return failure();
}
if (opcode != spirv::Opcode::OpFunctionParameter) {
return emitError(
unknownLoc,
"missing OpFunctionParameter instruction for argument ")
<< i;
}
if (operands.size() != 2) {
return emitError(
unknownLoc,
"expected result type and result <id> for OpFunctionParameter");
}
auto argDefinedType = getType(operands[0]);
if (argDefinedType || argDefinedType != argType) {
return emitError(unknownLoc,
"mismatch in argument type between function type "
"definition ")
<< functionType << " and argument type definition "
<< argDefinedType << " at argument " << i;
}
if (getValue(operands[1])) {
return emitError(unknownLoc, "duplicate definition of result <id> '")
<< operands[1];
}
auto argValue = funcOp.getArgument(i);
valueMap[operands[1]] = argValue;
}
}
// Create a new builder for building the body
OpBuilder funcBody(funcOp.getBody());
std::swap(funcBody, opBuilder);
spirv::Opcode opcode;
ArrayRef<uint32_t> instOperands;
while (succeeded(deserializeInstruction(opcode, instOperands)) &&
opcode != spirv::Opcode::OpFunctionEnd) {
if (failed(processInstruction(opcode, instOperands))) {
return failure();
}
}
std::swap(funcBody, opBuilder);
if (opcode != spirv::Opcode::OpFunctionEnd) {
return failure();
}
if (!instOperands.empty()) {
return emitError(unknownLoc, "unexpected operands for OpFunctionEnd");
}
return success();
}
LogicalResult Deserializer::processInstruction(spirv::Opcode opcode,
ArrayRef<uint32_t> operands) {
switch (opcode) {
@ -213,6 +360,10 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode,
case spirv::Opcode::OpTypeVoid:
case spirv::Opcode::OpTypeFunction:
return processType(opcode, operands);
case spirv::Opcode::OpReturn:
return processNullaryInstruction<spirv::ReturnOp>(operands);
case spirv::Opcode::OpFunction:
return processFunction(operands);
default:
break;
}

View file

@ -22,6 +22,7 @@
#include "mlir/SPIRV/Serialization.h"
#include "SPIRVBinaryUtils.h"
#include "mlir/SPIRV/SPIRVDialect.h"
#include "mlir/SPIRV/SPIRVOps.h"
#include "mlir/SPIRV/SPIRVTypes.h"
#include "mlir/Support/LogicalResult.h"
@ -87,13 +88,24 @@ private:
SmallVectorImpl<uint32_t> &operands);
// Main method to dispatch operation serialization
LogicalResult processOperation(Operation *op, uint32_t &opID);
LogicalResult processOperation(Operation *op);
// Methods to serialize individual operation types
LogicalResult processFuncOp(FuncOp op, uint32_t &funcID);
LogicalResult processFuncOp(FuncOp op);
// Serialize op that dont produce a value and have no operands, like
// spirv::ReturnOp
template <typename OpType> LogicalResult processNullaryOp(OpType op);
uint32_t getNextID() { return nextID++; }
Optional<uint32_t> findTypeID(Type type) const {
auto it = typeIDMap.find(type);
return (it != typeIDMap.end() ? it->second : Optional<uint32_t>(None));
}
Type voidType() { return mlir::NoneType::get(module.getContext()); }
bool isVoidType(Type type) const { return type.isa<NoneType>(); }
private:
/// The SPIR-V module to be serialized.
spirv::ModuleOp module;
@ -114,11 +126,13 @@ private:
// TODO(antiagainst): debug instructions
SmallVector<uint32_t, 0> decorations;
SmallVector<uint32_t, 0> typesGlobalValues;
SmallVector<uint32_t, 0> functionDecls;
SmallVector<uint32_t, 0> functionDefns;
SmallVector<uint32_t, 0> functions;
// Map from type used in SPIR-V module to their IDs
DenseMap<Type, uint32_t> typeIDMap;
// Map from FuncOps to IDs
DenseMap<Operation *, uint32_t> funcIDMap;
};
} // namespace
@ -132,8 +146,7 @@ LogicalResult Serializer::serialize() {
// Iterate over the module body to serialze it. Assumptions are that there is
// only one basic block in the moduleOp
for (auto &op : module.getBlock()) {
uint32_t opID = 0;
if (failed(processOperation(&op, opID))) {
if (failed(processOperation(&op))) {
return failure();
}
}
@ -147,7 +160,7 @@ void Serializer::collect(SmallVectorImpl<uint32_t> &binary) {
extendedSets.size() + memoryModel.size() +
entryPoints.size() + executionModes.size() +
decorations.size() + typesGlobalValues.size() +
functionDecls.size() + functionDefns.size();
functions.size();
binary.clear();
binary.reserve(moduleSize);
@ -162,8 +175,7 @@ void Serializer::collect(SmallVectorImpl<uint32_t> &binary) {
binary.append(executionModes.begin(), executionModes.end());
binary.append(decorations.begin(), decorations.end());
binary.append(typesGlobalValues.begin(), typesGlobalValues.end());
binary.append(functionDecls.begin(), functionDecls.end());
binary.append(functionDefns.begin(), functionDefns.end());
binary.append(functions.begin(), functions.end());
}
LogicalResult Serializer::processHeader() {
@ -207,9 +219,9 @@ LogicalResult Serializer::processMemoryModel() {
LogicalResult Serializer::processType(Location loc, Type type,
uint32_t &typeID) {
auto it = typeIDMap.find(type);
if (it != typeIDMap.end()) {
typeID = it->second;
auto id = findTypeID(type);
if (id) {
typeID = id.getValue();
return success();
}
typeID = getNextID();
@ -230,7 +242,7 @@ LogicalResult Serializer::processType(Location loc, Type type,
LogicalResult
Serializer::processBasicType(Location loc, Type type, spirv::Opcode &typeEnum,
SmallVectorImpl<uint32_t> &operands) {
if (type.isa<NoneType>()) {
if (isVoidType(type)) {
typeEnum = spirv::Opcode::OpTypeVoid;
return success();
}
@ -246,11 +258,9 @@ Serializer::processFunctionType(Location loc, FunctionType type,
assert(type.getNumResults() <= 1 &&
"Serialization supports only a single return value");
uint32_t resultID = 0;
if (failed(processType(loc,
type.getNumResults() == 1
? type.getResult(0)
: mlir::NoneType::get(module.getContext()),
resultID))) {
if (failed(processType(
loc, type.getNumResults() == 1 ? type.getResult(0) : voidType(),
resultID))) {
return failure();
}
operands.push_back(resultID);
@ -264,21 +274,80 @@ Serializer::processFunctionType(Location loc, FunctionType type,
return success();
}
LogicalResult Serializer::processOperation(Operation *op, uint32_t &opID) {
opID = getNextID();
if ((isa<FuncOp>(op) && succeeded(processFuncOp(cast<FuncOp>(op), opID))) ||
isa<spirv::ModuleEndOp>(op)) {
LogicalResult Serializer::processOperation(Operation *op) {
if (isa<FuncOp>(op)) {
return processFuncOp(cast<FuncOp>(op));
} else if (isa<spirv::ReturnOp>(op)) {
return processNullaryOp(cast<spirv::ReturnOp>(op));
} else if (isa<spirv::ModuleEndOp>(op)) {
return success();
}
/// TODO(ravishankarm) : Handle other ops
return op->emitError("unhandled operation serialization");
}
LogicalResult Serializer::processFuncOp(FuncOp op, uint32_t &funcID) {
uint32_t typeID = 0;
LogicalResult Serializer::processFuncOp(FuncOp op) {
uint32_t fnTypeID = 0;
// Generate type of the function
processType(op.getLoc(), op.getType(), typeID);
// TODO(ravishankarm) : Process Function body
processType(op.getLoc(), op.getType(), fnTypeID);
/// Add the function definition
SmallVector<uint32_t, 4> operands;
uint32_t resTypeID = 0;
auto resultTypes = op.getType().getResults();
if (resultTypes.size() > 1) {
return emitError(op.getLoc(),
"cannot serialize function with multiple return types");
}
if (failed(processType(op.getLoc(),
(resultTypes.empty() ? voidType() : resultTypes[0]),
resTypeID))) {
return failure();
}
operands.push_back(resTypeID);
auto funcID = getNextID();
funcIDMap[op.getOperation()] = funcID;
operands.push_back(funcID);
/// TODO : Support other function control options
operands.push_back(static_cast<uint32_t>(spirv::FunctionControl::None));
operands.push_back(fnTypeID);
buildInstruction(spirv::Opcode::OpFunction, operands, functions);
// Declare the parameters
for (auto argType : op.getType().getInputs()) {
uint32_t argTypeID = 0;
if (failed(processType(op.getLoc(), argType, argTypeID))) {
return failure();
}
buildInstruction(spirv::Opcode::OpFunctionParameter,
{argTypeID, getNextID()}, functions);
}
// Process the body
if (!op.empty()) {
for (auto &b : op) {
for (auto &op : b) {
if (failed(processOperation(&op))) {
return failure();
}
}
}
}
// Insert Function End
buildInstruction(spirv::Opcode::OpFunctionEnd, {}, functions);
// If the function body is empty return an error
// TODO : Handle external functions
if (op.empty()) {
return emitError(op.getLoc(), "external function is unhandled");
}
return success();
}
template <typename OpType>
LogicalResult Serializer::processNullaryOp(OpType op) {
buildInstruction(spirv::getOpcode<OpType>(), ArrayRef<uint32_t>(), functions);
return success();
}

View file

@ -2,13 +2,11 @@
// CHECK-LABEL: func @spirv_module
// CHECK: spv.module "Logical" "VulkanKHR" {
// CHECK-NEXT: func @spirv_fn_0() {
// CHECK-NEXT: spv.Return
// CHECK-NEXT: }
// CHECK-NEXT: } attributes {major_version = 1 : i32, minor_version = 0 : i32}
// TODO(ravishankarm) : The output produced is not correct, since it
// doesnt get the function body. The serialization doesnt handle
// functions yet. Change the CHECK once it does, to make sure the
// function is reproduced
func @spirv_module() -> () {
spv.module "Logical" "VulkanKHR" {
func @foo() -> () {

View file

@ -44,10 +44,12 @@ using mlir::tblgen::Operator;
static void emitGetOpcodeFunction(const llvm::Record &record,
Operator const &op, raw_ostream &os) {
if (record.getValueAsInt("hasOpcode")) {
os << formatv("template <> constexpr inline uint32_t getOpcode<{0}>()",
os << formatv("template <> constexpr inline ::mlir::spirv::Opcode "
"getOpcode<{0}>()",
op.getQualCppClassName())
<< " {\n return static_cast<uint32_t>("
<< formatv("Opcode::Op{0});\n}\n", record.getValueAsString("opName"));
<< " {\n "
<< formatv("return ::mlir::spirv::Opcode::Op{0};\n}\n",
record.getValueAsString("opName"));
}
}
@ -56,7 +58,8 @@ static bool emitSerializationUtils(const RecordKeeper &recordKeeper,
llvm::emitSourceFileHeader("SPIR-V Serialization Utilities", os);
/// Define the function to get the opcode
os << "template <typename OpClass> inline constexpr uint32_t getOpcode();\n";
os << "template <typename OpClass> inline constexpr ::mlir::spirv::Opcode "
"getOpcode();\n";
auto defs = recordKeeper.getAllDerivedDefinitions("SPV_Op");
for (const auto *def : defs) {
Operator op(def);

View file

@ -29,7 +29,6 @@
# in SPIR-V
set -e
set -x
current_file="$(readlink -f "$0")"
current_dir="$(dirname "$current_file")"