Miscellaneous fixes to SPIR-V Deserializer (details below).

1) Process and ignore the following debug instructions: OpSource,
OpSourceContinued, OpSourceExtension, OpString, OpModuleProcessed.
2) While processing OpTypeInt instruction, ignore the signedness
specification. Currently MLIR doesnt make a distinction between signed
and unsigned integer types.
3) Process and ignore BufferBlock decoration (similar to Buffer
decoration). StructType needs to be enhanced to track this attribute
since its needed for proper validation checks.
4) Report better error for unhandled instruction during
deserialization.

PiperOrigin-RevId: 271057060
This commit is contained in:
Mahesh Ravishankar 2019-09-24 22:50:16 -07:00 committed by A. Unique TensorFlower
parent 03db422359
commit 3a4bee0fe1
3 changed files with 45 additions and 18 deletions

View file

@ -73,8 +73,12 @@ class SPV_OpCode<string name, int val> {
// Begin opcode section. Generated from SPIR-V spec; DO NOT MODIFY!
def SPV_OC_OpNop : I32EnumAttrCase<"OpNop", 0>;
def SPV_OC_OpSourceContinued : I32EnumAttrCase<"OpSourceContinued", 2>;
def SPV_OC_OpSource : I32EnumAttrCase<"OpSource", 3>;
def SPV_OC_OpSourceExtension : I32EnumAttrCase<"OpSourceExtension", 4>;
def SPV_OC_OpName : I32EnumAttrCase<"OpName", 5>;
def SPV_OC_OpMemberName : I32EnumAttrCase<"OpMemberName", 6>;
def SPV_OC_OpString : I32EnumAttrCase<"OpString", 7>;
def SPV_OC_OpExtension : I32EnumAttrCase<"OpExtension", 10>;
def SPV_OC_OpExtInstImport : I32EnumAttrCase<"OpExtInstImport", 11>;
def SPV_OC_OpExtInst : I32EnumAttrCase<"OpExtInst", 12>;
@ -157,18 +161,20 @@ def SPV_OC_OpBranch : I32EnumAttrCase<"OpBranch", 249>;
def SPV_OC_OpBranchConditional : I32EnumAttrCase<"OpBranchConditional", 250>;
def SPV_OC_OpReturn : I32EnumAttrCase<"OpReturn", 253>;
def SPV_OC_OpReturnValue : I32EnumAttrCase<"OpReturnValue", 254>;
def SPV_OC_OpModuleProcessed : I32EnumAttrCase<"OpModuleProcessed", 330>;
def SPV_OpcodeAttr :
I32EnumAttr<"Opcode", "valid SPIR-V instructions", [
SPV_OC_OpNop, SPV_OC_OpName, SPV_OC_OpMemberName, SPV_OC_OpExtension,
SPV_OC_OpExtInstImport, SPV_OC_OpExtInst, SPV_OC_OpMemoryModel,
SPV_OC_OpEntryPoint, SPV_OC_OpExecutionMode, SPV_OC_OpCapability,
SPV_OC_OpTypeVoid, SPV_OC_OpTypeBool, SPV_OC_OpTypeInt, SPV_OC_OpTypeFloat,
SPV_OC_OpTypeVector, SPV_OC_OpTypeArray, SPV_OC_OpTypeRuntimeArray,
SPV_OC_OpTypeStruct, SPV_OC_OpTypePointer, SPV_OC_OpTypeFunction,
SPV_OC_OpConstantTrue, SPV_OC_OpConstantFalse, SPV_OC_OpConstant,
SPV_OC_OpConstantComposite, SPV_OC_OpConstantNull, SPV_OC_OpSpecConstantTrue,
SPV_OC_OpSpecConstantFalse, SPV_OC_OpSpecConstant,
SPV_OC_OpNop, SPV_OC_OpSourceContinued, SPV_OC_OpSource,
SPV_OC_OpSourceExtension, SPV_OC_OpName, SPV_OC_OpMemberName, SPV_OC_OpString,
SPV_OC_OpExtension, SPV_OC_OpExtInstImport, SPV_OC_OpExtInst,
SPV_OC_OpMemoryModel, SPV_OC_OpEntryPoint, SPV_OC_OpExecutionMode,
SPV_OC_OpCapability, SPV_OC_OpTypeVoid, SPV_OC_OpTypeBool, SPV_OC_OpTypeInt,
SPV_OC_OpTypeFloat, SPV_OC_OpTypeVector, SPV_OC_OpTypeArray,
SPV_OC_OpTypeRuntimeArray, SPV_OC_OpTypeStruct, SPV_OC_OpTypePointer,
SPV_OC_OpTypeFunction, SPV_OC_OpConstantTrue, SPV_OC_OpConstantFalse,
SPV_OC_OpConstant, SPV_OC_OpConstantComposite, SPV_OC_OpConstantNull,
SPV_OC_OpSpecConstantTrue, SPV_OC_OpSpecConstantFalse, SPV_OC_OpSpecConstant,
SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction, SPV_OC_OpFunctionParameter,
SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall, SPV_OC_OpVariable, SPV_OC_OpLoad,
SPV_OC_OpStore, SPV_OC_OpAccessChain, SPV_OC_OpDecorate,
@ -186,7 +192,7 @@ def SPV_OpcodeAttr :
SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual,
SPV_OC_OpControlBarrier, SPV_OC_OpMemoryBarrier, SPV_OC_OpLoopMerge,
SPV_OC_OpLabel, SPV_OC_OpBranch, SPV_OC_OpBranchConditional, SPV_OC_OpReturn,
SPV_OC_OpReturnValue
SPV_OC_OpReturnValue, SPV_OC_OpModuleProcessed
]> {
let returnType = "::mlir::spirv::Opcode";
let convertFromStorage = "static_cast<::mlir::spirv::Opcode>($_self.getInt())";

View file

@ -624,11 +624,17 @@ LogicalResult Deserializer::processDecoration(ArrayRef<uint32_t> words) {
typeDecorations[words[0]] = static_cast<uint32_t>(words[2]);
break;
case spirv::Decoration::Block:
case spirv::Decoration::BufferBlock:
if (words.size() != 2) {
return emitError(unknownLoc, "OpDecoration with ")
<< decorationName << "needs a single target <id>";
}
// Block decoration does not affect spv.struct type.
// Block decoration does not affect spv.struct type, but is still stored for
// verification.
// TODO: Update StructType to contain this information since
// it is needed for many validation rules.
decorations[words[0]].set(opBuilder.getIdentifier(attrName),
opBuilder.getUnitAttr());
break;
default:
return emitError(unknownLoc, "unhandled Decoration : '") << decorationName;
@ -985,9 +991,8 @@ LogicalResult Deserializer::processType(spirv::Opcode opcode,
return emitError(
unknownLoc, "OpTypeInt must have bitwidth and signedness parameters");
}
if (operands[2] == 0) {
return emitError(unknownLoc, "unhandled unsigned OpTypeInt");
}
// TODO: Ignoring the signedness right now. Need to handle this effectively
// in the MLIR representation.
typeMap[operands[0]] = opBuilder.getIntegerType(operands[1]);
break;
case spirv::Opcode::OpTypeFloat: {
@ -1787,6 +1792,14 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode,
break;
case spirv::Opcode::OpName:
return processName(operands);
case spirv::Opcode::OpModuleProcessed:
case spirv::Opcode::OpString:
case spirv::Opcode::OpSource:
case spirv::Opcode::OpSourceContinued:
case spirv::Opcode::OpSourceExtension:
// TODO: This is debug information embedded in the binary which should be
// translated into the spv.module.
return success();
case spirv::Opcode::OpTypeVoid:
case spirv::Opcode::OpTypeBool:
case spirv::Opcode::OpTypeInt:

View file

@ -512,10 +512,18 @@ static void finalizeDispatchDeserializationFn(StringRef opcode,
os << " default:\n";
os << " ;\n";
os << " }\n";
os << formatv(
" return emitError(unknownLoc, \"unhandled deserialization of \") << "
"spirv::stringifyOpcode({0});\n",
opcode);
StringRef opcodeVar("opcodeString");
os << formatv(" auto {0} = spirv::stringifyOpcode({1});\n", opcodeVar,
opcode);
os << formatv(" if (!{0}.empty()) {{\n", opcodeVar);
os << formatv(" return emitError(unknownLoc, \"unhandled deserialization "
"of \") << {0};\n",
opcodeVar);
os << " } else {\n";
os << formatv(" return emitError(unknownLoc, \"unhandled opcode \") << "
"static_cast<uint32_t>({0});\n",
opcode);
os << " }\n";
os << "}\n";
}