[spirv] Fix the entry block to start with OpLabel

Each basic block in SPIR-V must start with an OpLabel instruction.
We don't support control flow yet, so this CL just makes sure that
the entry block follows this rule and is valid.

PiperOrigin-RevId: 265718841
This commit is contained in:
Lei Zhang 2019-08-27 10:50:58 -07:00 committed by A. Unique TensorFlower
parent 4ced99c085
commit 3af6b53381
4 changed files with 82 additions and 8 deletions

View file

@ -131,6 +131,7 @@ def SPV_OC_OpULessThan : I32EnumAttrCase<"OpULessThan", 176>;
def SPV_OC_OpSLessThan : I32EnumAttrCase<"OpSLessThan", 177>;
def SPV_OC_OpULessThanEqual : I32EnumAttrCase<"OpULessThanEqual", 178>;
def SPV_OC_OpSLessThanEqual : I32EnumAttrCase<"OpSLessThanEqual", 179>;
def SPV_OC_OpLabel : I32EnumAttrCase<"OpLabel", 248>;
def SPV_OC_OpReturn : I32EnumAttrCase<"OpReturn", 253>;
def SPV_OC_OpReturnValue : I32EnumAttrCase<"OpReturnValue", 254>;
@ -153,7 +154,7 @@ def SPV_OpcodeAttr :
SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan,
SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan,
SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual,
SPV_OC_OpReturn, SPV_OC_OpReturnValue
SPV_OC_OpLabel, SPV_OC_OpReturn, SPV_OC_OpReturnValue
]> {
let returnType = "::mlir::spirv::Opcode";
let convertFromStorage = "static_cast<::mlir::spirv::Opcode>($_self.getInt())";

View file

@ -1,3 +1,4 @@
//===- Deserializer.cpp - MLIR SPIR-V Deserialization ---------------------===//
//
// Copyright 2019 The MLIR Authors.
//
@ -43,6 +44,11 @@ static inline StringRef decodeStringLiteral(ArrayRef<uint32_t> words,
return str;
}
// Extracts the opcode from the given first word of a SPIR-V instruction.
static inline spirv::Opcode extractOpcode(uint32_t word) {
return static_cast<spirv::Opcode>(word & 0xffff);
}
namespace {
/// A SPIR-V module serializer.
///
@ -176,6 +182,13 @@ private:
/// Processes a SPIR-V OpConstantNull instruction with the given `operands`.
LogicalResult processConstantNull(ArrayRef<uint32_t> operands);
//===--------------------------------------------------------------------===//
// Control flow
//===--------------------------------------------------------------------===//
/// Processes a SPIR-V OpLabel instruction with the given `operands`.
LogicalResult processLabel(ArrayRef<uint32_t> operands);
//===--------------------------------------------------------------------===//
// Instruction
//===--------------------------------------------------------------------===//
@ -195,6 +208,9 @@ private:
sliceInstruction(spirv::Opcode &opcode, ArrayRef<uint32_t> &operands,
Optional<spirv::Opcode> expectedOpcode = llvm::None);
/// Returns the next instruction's opcode if exists.
Optional<spirv::Opcode> peekOpcode();
/// Processes a SPIR-V instruction with the given `opcode` and `operands`.
/// This method is the main entrance for handling SPIR-V instruction; it
/// checks the instruction opcode and dispatches to the corresponding handler.
@ -581,10 +597,18 @@ LogicalResult Deserializer::processFunction(ArrayRef<uint32_t> operands) {
}
}
// Create a new builder for building the body
// Create a new builder for building the body.
OpBuilder funcBody(funcOp.getBody());
std::swap(funcBody, opBuilder);
// Make sure the first basic block, if exists, starts with an OpLabel
// instruction.
if (auto nextOpcode = peekOpcode()) {
if (*nextOpcode != spirv::Opcode::OpFunctionEnd &&
*nextOpcode != spirv::Opcode::OpLabel)
return emitError(unknownLoc, "a basic block must start with OpLabel");
}
spirv::Opcode opcode = spirv::Opcode::OpNop;
ArrayRef<uint32_t> instOperands;
while (succeeded(sliceInstruction(opcode, instOperands,
@ -597,9 +621,12 @@ LogicalResult Deserializer::processFunction(ArrayRef<uint32_t> operands) {
if (opcode != spirv::Opcode::OpFunctionEnd) {
return failure();
}
// Process OpFunctionEnd.
if (!instOperands.empty()) {
return emitError(unknownLoc, "unexpected operands for OpFunctionEnd");
}
std::swap(funcBody, opBuilder);
return success();
}
@ -1124,6 +1151,18 @@ LogicalResult Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
<< resultType;
}
//===----------------------------------------------------------------------===//
// Control flow
//===----------------------------------------------------------------------===//
LogicalResult Deserializer::processLabel(ArrayRef<uint32_t> operands) {
if (operands.size() != 1) {
return emitError(unknownLoc, "OpLabel should only have result <id>");
}
// TODO(antiagainst): support basic blocks and control flow properly.
return success();
}
//===----------------------------------------------------------------------===//
// Instruction
//===----------------------------------------------------------------------===//
@ -1173,12 +1212,18 @@ Deserializer::sliceInstruction(spirv::Opcode &opcode,
if (nextOffset > binarySize)
return emitError(unknownLoc, "insufficient words for the last instruction");
opcode = static_cast<spirv::Opcode>(binary[curOffset] & 0xffff);
opcode = extractOpcode(binary[curOffset]);
operands = binary.slice(curOffset + 1, wordCount - 1);
curOffset = nextOffset;
return success();
}
Optional<spirv::Opcode> Deserializer::peekOpcode() {
if (curOffset >= binary.size())
return llvm::None;
return extractOpcode(binary[curOffset]);
}
LogicalResult Deserializer::processInstruction(spirv::Opcode opcode,
ArrayRef<uint32_t> operands,
bool deferInstructions) {
@ -1237,6 +1282,8 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode,
return processMemberDecoration(operands);
case spirv::Opcode::OpFunction:
return processFunction(operands);
case spirv::Opcode::OpLabel:
return processLabel(operands);
default:
break;
}

View file

@ -543,6 +543,8 @@ LogicalResult Serializer::processFuncOp(FuncOp op) {
}
for (auto &b : op) {
// TODO(antiagainst): support basic blocks and control flow properly.
encodeInstructionInto(functions, spirv::Opcode::OpLabel, {getNextID()});
for (auto &op : b) {
if (failed(processOperation(&op))) {
return failure();

View file

@ -111,11 +111,9 @@ protected:
return id;
}
uint32_t addFunctionEnd() {
auto id = nextID++;
addInstruction(spirv::Opcode::OpFunctionEnd, {id});
return id;
}
void addFunctionEnd() { addInstruction(spirv::Opcode::OpFunctionEnd, {}); }
void addReturn() { addInstruction(spirv::Opcode::OpReturn, {}); }
protected:
SmallVector<uint32_t, 5> binary;
@ -201,3 +199,29 @@ TEST_F(DeserializationTest, FunctionMissingParameterFailure) {
ASSERT_EQ(llvm::None, deserialize());
expectDiagnostic("expected OpFunctionParameter instruction");
}
TEST_F(DeserializationTest, FunctionMissingLabelForFirstBlockFailure) {
addHeader();
auto voidType = addVoidType();
auto fnType = addFunctionType(voidType, {});
addFunction(voidType, fnType);
// Missing OpLabel
addReturn();
addFunctionEnd();
ASSERT_EQ(llvm::None, deserialize());
expectDiagnostic("a basic block must start with OpLabel");
}
TEST_F(DeserializationTest, FunctionMalformedLabelFailure) {
addHeader();
auto voidType = addVoidType();
auto fnType = addFunctionType(voidType, {});
addFunction(voidType, fnType);
addInstruction(spirv::Opcode::OpLabel, {}); // Malformed OpLabel
addReturn();
addFunctionEnd();
ASSERT_EQ(llvm::None, deserialize());
expectDiagnostic("OpLabel should only have result <id>");
}