diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td index c9e734325879..538891e66889 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -131,6 +131,7 @@ def SPV_OC_OpULessThan : I32EnumAttrCase<"OpULessThan", 176>; def SPV_OC_OpSLessThan : I32EnumAttrCase<"OpSLessThan", 177>; def SPV_OC_OpULessThanEqual : I32EnumAttrCase<"OpULessThanEqual", 178>; def SPV_OC_OpSLessThanEqual : I32EnumAttrCase<"OpSLessThanEqual", 179>; +def SPV_OC_OpLabel : I32EnumAttrCase<"OpLabel", 248>; def SPV_OC_OpReturn : I32EnumAttrCase<"OpReturn", 253>; def SPV_OC_OpReturnValue : I32EnumAttrCase<"OpReturnValue", 254>; @@ -153,7 +154,7 @@ def SPV_OpcodeAttr : SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual, - SPV_OC_OpReturn, SPV_OC_OpReturnValue + SPV_OC_OpLabel, SPV_OC_OpReturn, SPV_OC_OpReturnValue ]> { let returnType = "::mlir::spirv::Opcode"; let convertFromStorage = "static_cast<::mlir::spirv::Opcode>($_self.getInt())"; diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp index d300725bd480..dc0d886fa888 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp @@ -1,3 +1,4 @@ +//===- Deserializer.cpp - MLIR SPIR-V Deserialization ---------------------===// // // Copyright 2019 The MLIR Authors. // @@ -43,6 +44,11 @@ static inline StringRef decodeStringLiteral(ArrayRef words, return str; } +// Extracts the opcode from the given first word of a SPIR-V instruction. +static inline spirv::Opcode extractOpcode(uint32_t word) { + return static_cast(word & 0xffff); +} + namespace { /// A SPIR-V module serializer. /// @@ -176,6 +182,13 @@ private: /// Processes a SPIR-V OpConstantNull instruction with the given `operands`. LogicalResult processConstantNull(ArrayRef operands); + //===--------------------------------------------------------------------===// + // Control flow + //===--------------------------------------------------------------------===// + + /// Processes a SPIR-V OpLabel instruction with the given `operands`. + LogicalResult processLabel(ArrayRef operands); + //===--------------------------------------------------------------------===// // Instruction //===--------------------------------------------------------------------===// @@ -195,6 +208,9 @@ private: sliceInstruction(spirv::Opcode &opcode, ArrayRef &operands, Optional expectedOpcode = llvm::None); + /// Returns the next instruction's opcode if exists. + Optional peekOpcode(); + /// Processes a SPIR-V instruction with the given `opcode` and `operands`. /// This method is the main entrance for handling SPIR-V instruction; it /// checks the instruction opcode and dispatches to the corresponding handler. @@ -581,10 +597,18 @@ LogicalResult Deserializer::processFunction(ArrayRef operands) { } } - // Create a new builder for building the body + // Create a new builder for building the body. OpBuilder funcBody(funcOp.getBody()); std::swap(funcBody, opBuilder); + // Make sure the first basic block, if exists, starts with an OpLabel + // instruction. + if (auto nextOpcode = peekOpcode()) { + if (*nextOpcode != spirv::Opcode::OpFunctionEnd && + *nextOpcode != spirv::Opcode::OpLabel) + return emitError(unknownLoc, "a basic block must start with OpLabel"); + } + spirv::Opcode opcode = spirv::Opcode::OpNop; ArrayRef instOperands; while (succeeded(sliceInstruction(opcode, instOperands, @@ -597,9 +621,12 @@ LogicalResult Deserializer::processFunction(ArrayRef operands) { if (opcode != spirv::Opcode::OpFunctionEnd) { return failure(); } + + // Process OpFunctionEnd. if (!instOperands.empty()) { return emitError(unknownLoc, "unexpected operands for OpFunctionEnd"); } + std::swap(funcBody, opBuilder); return success(); } @@ -1124,6 +1151,18 @@ LogicalResult Deserializer::processConstantNull(ArrayRef operands) { << resultType; } +//===----------------------------------------------------------------------===// +// Control flow +//===----------------------------------------------------------------------===// + +LogicalResult Deserializer::processLabel(ArrayRef operands) { + if (operands.size() != 1) { + return emitError(unknownLoc, "OpLabel should only have result "); + } + // TODO(antiagainst): support basic blocks and control flow properly. + return success(); +} + //===----------------------------------------------------------------------===// // Instruction //===----------------------------------------------------------------------===// @@ -1173,12 +1212,18 @@ Deserializer::sliceInstruction(spirv::Opcode &opcode, if (nextOffset > binarySize) return emitError(unknownLoc, "insufficient words for the last instruction"); - opcode = static_cast(binary[curOffset] & 0xffff); + opcode = extractOpcode(binary[curOffset]); operands = binary.slice(curOffset + 1, wordCount - 1); curOffset = nextOffset; return success(); } +Optional Deserializer::peekOpcode() { + if (curOffset >= binary.size()) + return llvm::None; + return extractOpcode(binary[curOffset]); +} + LogicalResult Deserializer::processInstruction(spirv::Opcode opcode, ArrayRef operands, bool deferInstructions) { @@ -1237,6 +1282,8 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode, return processMemberDecoration(operands); case spirv::Opcode::OpFunction: return processFunction(operands); + case spirv::Opcode::OpLabel: + return processLabel(operands); default: break; } diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp index 03973db0d953..43a1d08cf6ce 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp @@ -543,6 +543,8 @@ LogicalResult Serializer::processFuncOp(FuncOp op) { } for (auto &b : op) { + // TODO(antiagainst): support basic blocks and control flow properly. + encodeInstructionInto(functions, spirv::Opcode::OpLabel, {getNextID()}); for (auto &op : b) { if (failed(processOperation(&op))) { return failure(); diff --git a/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp b/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp index e4b3ee51d2c1..5262d5709590 100644 --- a/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp +++ b/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp @@ -111,11 +111,9 @@ protected: return id; } - uint32_t addFunctionEnd() { - auto id = nextID++; - addInstruction(spirv::Opcode::OpFunctionEnd, {id}); - return id; - } + void addFunctionEnd() { addInstruction(spirv::Opcode::OpFunctionEnd, {}); } + + void addReturn() { addInstruction(spirv::Opcode::OpReturn, {}); } protected: SmallVector binary; @@ -201,3 +199,29 @@ TEST_F(DeserializationTest, FunctionMissingParameterFailure) { ASSERT_EQ(llvm::None, deserialize()); expectDiagnostic("expected OpFunctionParameter instruction"); } + +TEST_F(DeserializationTest, FunctionMissingLabelForFirstBlockFailure) { + addHeader(); + auto voidType = addVoidType(); + auto fnType = addFunctionType(voidType, {}); + addFunction(voidType, fnType); + // Missing OpLabel + addReturn(); + addFunctionEnd(); + + ASSERT_EQ(llvm::None, deserialize()); + expectDiagnostic("a basic block must start with OpLabel"); +} + +TEST_F(DeserializationTest, FunctionMalformedLabelFailure) { + addHeader(); + auto voidType = addVoidType(); + auto fnType = addFunctionType(voidType, {}); + addFunction(voidType, fnType); + addInstruction(spirv::Opcode::OpLabel, {}); // Malformed OpLabel + addReturn(); + addFunctionEnd(); + + ASSERT_EQ(llvm::None, deserialize()); + expectDiagnostic("OpLabel should only have result "); +}