Add BuiltIn EnumAttr to SPIR-V dialect

Generate the EnumAttr to represent BuiltIns in SPIR-V dialect. The
builtIn can be specified as a StringAttr with value being the
name of the builtin. Extend Decoration (de)serialization to handle
BuiltIns.
Also fix an error in the SPIR-V dialect generator script.

PiperOrigin-RevId: 263596624
This commit is contained in:
Mahesh Ravishankar 2019-08-15 10:52:24 -07:00 committed by A. Unique TensorFlower
parent 30e9c2fe4f
commit d71915420b
8 changed files with 204 additions and 10 deletions

View file

@ -221,6 +221,142 @@ def SPV_AddressingModelAttr :
let cppNamespace = "::mlir::spirv";
}
def SPV_BI_Position : I32EnumAttrCase<"Position", 0>;
def SPV_BI_PointSize : I32EnumAttrCase<"PointSize", 1>;
def SPV_BI_ClipDistance : I32EnumAttrCase<"ClipDistance", 3>;
def SPV_BI_CullDistance : I32EnumAttrCase<"CullDistance", 4>;
def SPV_BI_VertexId : I32EnumAttrCase<"VertexId", 5>;
def SPV_BI_InstanceId : I32EnumAttrCase<"InstanceId", 6>;
def SPV_BI_PrimitiveId : I32EnumAttrCase<"PrimitiveId", 7>;
def SPV_BI_InvocationId : I32EnumAttrCase<"InvocationId", 8>;
def SPV_BI_Layer : I32EnumAttrCase<"Layer", 9>;
def SPV_BI_ViewportIndex : I32EnumAttrCase<"ViewportIndex", 10>;
def SPV_BI_TessLevelOuter : I32EnumAttrCase<"TessLevelOuter", 11>;
def SPV_BI_TessLevelInner : I32EnumAttrCase<"TessLevelInner", 12>;
def SPV_BI_TessCoord : I32EnumAttrCase<"TessCoord", 13>;
def SPV_BI_PatchVertices : I32EnumAttrCase<"PatchVertices", 14>;
def SPV_BI_FragCoord : I32EnumAttrCase<"FragCoord", 15>;
def SPV_BI_PointCoord : I32EnumAttrCase<"PointCoord", 16>;
def SPV_BI_FrontFacing : I32EnumAttrCase<"FrontFacing", 17>;
def SPV_BI_SampleId : I32EnumAttrCase<"SampleId", 18>;
def SPV_BI_SamplePosition : I32EnumAttrCase<"SamplePosition", 19>;
def SPV_BI_SampleMask : I32EnumAttrCase<"SampleMask", 20>;
def SPV_BI_FragDepth : I32EnumAttrCase<"FragDepth", 22>;
def SPV_BI_HelperInvocation : I32EnumAttrCase<"HelperInvocation", 23>;
def SPV_BI_NumWorkgroups : I32EnumAttrCase<"NumWorkgroups", 24>;
def SPV_BI_WorkgroupSize : I32EnumAttrCase<"WorkgroupSize", 25>;
def SPV_BI_WorkgroupId : I32EnumAttrCase<"WorkgroupId", 26>;
def SPV_BI_LocalInvocationId : I32EnumAttrCase<"LocalInvocationId", 27>;
def SPV_BI_GlobalInvocationId : I32EnumAttrCase<"GlobalInvocationId", 28>;
def SPV_BI_LocalInvocationIndex : I32EnumAttrCase<"LocalInvocationIndex", 29>;
def SPV_BI_WorkDim : I32EnumAttrCase<"WorkDim", 30>;
def SPV_BI_GlobalSize : I32EnumAttrCase<"GlobalSize", 31>;
def SPV_BI_EnqueuedWorkgroupSize : I32EnumAttrCase<"EnqueuedWorkgroupSize", 32>;
def SPV_BI_GlobalOffset : I32EnumAttrCase<"GlobalOffset", 33>;
def SPV_BI_GlobalLinearId : I32EnumAttrCase<"GlobalLinearId", 34>;
def SPV_BI_SubgroupSize : I32EnumAttrCase<"SubgroupSize", 36>;
def SPV_BI_SubgroupMaxSize : I32EnumAttrCase<"SubgroupMaxSize", 37>;
def SPV_BI_NumSubgroups : I32EnumAttrCase<"NumSubgroups", 38>;
def SPV_BI_NumEnqueuedSubgroups : I32EnumAttrCase<"NumEnqueuedSubgroups", 39>;
def SPV_BI_SubgroupId : I32EnumAttrCase<"SubgroupId", 40>;
def SPV_BI_SubgroupLocalInvocationId : I32EnumAttrCase<"SubgroupLocalInvocationId", 41>;
def SPV_BI_VertexIndex : I32EnumAttrCase<"VertexIndex", 42>;
def SPV_BI_InstanceIndex : I32EnumAttrCase<"InstanceIndex", 43>;
def SPV_BI_SubgroupEqMask : I32EnumAttrCase<"SubgroupEqMask", 4416>;
def SPV_BI_SubgroupGeMask : I32EnumAttrCase<"SubgroupGeMask", 4417>;
def SPV_BI_SubgroupGtMask : I32EnumAttrCase<"SubgroupGtMask", 4418>;
def SPV_BI_SubgroupLeMask : I32EnumAttrCase<"SubgroupLeMask", 4419>;
def SPV_BI_SubgroupLtMask : I32EnumAttrCase<"SubgroupLtMask", 4420>;
def SPV_BI_BaseVertex : I32EnumAttrCase<"BaseVertex", 4424>;
def SPV_BI_BaseInstance : I32EnumAttrCase<"BaseInstance", 4425>;
def SPV_BI_DrawIndex : I32EnumAttrCase<"DrawIndex", 4426>;
def SPV_BI_DeviceIndex : I32EnumAttrCase<"DeviceIndex", 4438>;
def SPV_BI_ViewIndex : I32EnumAttrCase<"ViewIndex", 4440>;
def SPV_BI_BaryCoordNoPerspAMD : I32EnumAttrCase<"BaryCoordNoPerspAMD", 4992>;
def SPV_BI_BaryCoordNoPerspCentroidAMD : I32EnumAttrCase<"BaryCoordNoPerspCentroidAMD", 4993>;
def SPV_BI_BaryCoordNoPerspSampleAMD : I32EnumAttrCase<"BaryCoordNoPerspSampleAMD", 4994>;
def SPV_BI_BaryCoordSmoothAMD : I32EnumAttrCase<"BaryCoordSmoothAMD", 4995>;
def SPV_BI_BaryCoordSmoothCentroidAMD : I32EnumAttrCase<"BaryCoordSmoothCentroidAMD", 4996>;
def SPV_BI_BaryCoordSmoothSampleAMD : I32EnumAttrCase<"BaryCoordSmoothSampleAMD", 4997>;
def SPV_BI_BaryCoordPullModelAMD : I32EnumAttrCase<"BaryCoordPullModelAMD", 4998>;
def SPV_BI_FragStencilRefEXT : I32EnumAttrCase<"FragStencilRefEXT", 5014>;
def SPV_BI_ViewportMaskNV : I32EnumAttrCase<"ViewportMaskNV", 5253>;
def SPV_BI_SecondaryPositionNV : I32EnumAttrCase<"SecondaryPositionNV", 5257>;
def SPV_BI_SecondaryViewportMaskNV : I32EnumAttrCase<"SecondaryViewportMaskNV", 5258>;
def SPV_BI_PositionPerViewNV : I32EnumAttrCase<"PositionPerViewNV", 5261>;
def SPV_BI_ViewportMaskPerViewNV : I32EnumAttrCase<"ViewportMaskPerViewNV", 5262>;
def SPV_BI_FullyCoveredEXT : I32EnumAttrCase<"FullyCoveredEXT", 5264>;
def SPV_BI_TaskCountNV : I32EnumAttrCase<"TaskCountNV", 5274>;
def SPV_BI_PrimitiveCountNV : I32EnumAttrCase<"PrimitiveCountNV", 5275>;
def SPV_BI_PrimitiveIndicesNV : I32EnumAttrCase<"PrimitiveIndicesNV", 5276>;
def SPV_BI_ClipDistancePerViewNV : I32EnumAttrCase<"ClipDistancePerViewNV", 5277>;
def SPV_BI_CullDistancePerViewNV : I32EnumAttrCase<"CullDistancePerViewNV", 5278>;
def SPV_BI_LayerPerViewNV : I32EnumAttrCase<"LayerPerViewNV", 5279>;
def SPV_BI_MeshViewCountNV : I32EnumAttrCase<"MeshViewCountNV", 5280>;
def SPV_BI_MeshViewIndicesNV : I32EnumAttrCase<"MeshViewIndicesNV", 5281>;
def SPV_BI_BaryCoordNV : I32EnumAttrCase<"BaryCoordNV", 5286>;
def SPV_BI_BaryCoordNoPerspNV : I32EnumAttrCase<"BaryCoordNoPerspNV", 5287>;
def SPV_BI_FragSizeEXT : I32EnumAttrCase<"FragSizeEXT", 5292>;
def SPV_BI_FragInvocationCountEXT : I32EnumAttrCase<"FragInvocationCountEXT", 5293>;
def SPV_BI_LaunchIdNV : I32EnumAttrCase<"LaunchIdNV", 5319>;
def SPV_BI_LaunchSizeNV : I32EnumAttrCase<"LaunchSizeNV", 5320>;
def SPV_BI_WorldRayOriginNV : I32EnumAttrCase<"WorldRayOriginNV", 5321>;
def SPV_BI_WorldRayDirectionNV : I32EnumAttrCase<"WorldRayDirectionNV", 5322>;
def SPV_BI_ObjectRayOriginNV : I32EnumAttrCase<"ObjectRayOriginNV", 5323>;
def SPV_BI_ObjectRayDirectionNV : I32EnumAttrCase<"ObjectRayDirectionNV", 5324>;
def SPV_BI_RayTminNV : I32EnumAttrCase<"RayTminNV", 5325>;
def SPV_BI_RayTmaxNV : I32EnumAttrCase<"RayTmaxNV", 5326>;
def SPV_BI_InstanceCustomIndexNV : I32EnumAttrCase<"InstanceCustomIndexNV", 5327>;
def SPV_BI_ObjectToWorldNV : I32EnumAttrCase<"ObjectToWorldNV", 5330>;
def SPV_BI_WorldToObjectNV : I32EnumAttrCase<"WorldToObjectNV", 5331>;
def SPV_BI_HitTNV : I32EnumAttrCase<"HitTNV", 5332>;
def SPV_BI_HitKindNV : I32EnumAttrCase<"HitKindNV", 5333>;
def SPV_BI_IncomingRayFlagsNV : I32EnumAttrCase<"IncomingRayFlagsNV", 5351>;
def SPV_BI_WarpsPerSMNV : I32EnumAttrCase<"WarpsPerSMNV", 5374>;
def SPV_BI_SMCountNV : I32EnumAttrCase<"SMCountNV", 5375>;
def SPV_BI_WarpIDNV : I32EnumAttrCase<"WarpIDNV", 5376>;
def SPV_BI_SMIDNV : I32EnumAttrCase<"SMIDNV", 5377>;
def SPV_BuiltInAttr :
I32EnumAttr<"BuiltIn", "valid SPIR-V BuiltIn", [
SPV_BI_Position, SPV_BI_PointSize, SPV_BI_ClipDistance, SPV_BI_CullDistance,
SPV_BI_VertexId, SPV_BI_InstanceId, SPV_BI_PrimitiveId, SPV_BI_InvocationId,
SPV_BI_Layer, SPV_BI_ViewportIndex, SPV_BI_TessLevelOuter,
SPV_BI_TessLevelInner, SPV_BI_TessCoord, SPV_BI_PatchVertices,
SPV_BI_FragCoord, SPV_BI_PointCoord, SPV_BI_FrontFacing, SPV_BI_SampleId,
SPV_BI_SamplePosition, SPV_BI_SampleMask, SPV_BI_FragDepth,
SPV_BI_HelperInvocation, SPV_BI_NumWorkgroups, SPV_BI_WorkgroupSize,
SPV_BI_WorkgroupId, SPV_BI_LocalInvocationId, SPV_BI_GlobalInvocationId,
SPV_BI_LocalInvocationIndex, SPV_BI_WorkDim, SPV_BI_GlobalSize,
SPV_BI_EnqueuedWorkgroupSize, SPV_BI_GlobalOffset, SPV_BI_GlobalLinearId,
SPV_BI_SubgroupSize, SPV_BI_SubgroupMaxSize, SPV_BI_NumSubgroups,
SPV_BI_NumEnqueuedSubgroups, SPV_BI_SubgroupId,
SPV_BI_SubgroupLocalInvocationId, SPV_BI_VertexIndex, SPV_BI_InstanceIndex,
SPV_BI_SubgroupEqMask, SPV_BI_SubgroupGeMask, SPV_BI_SubgroupGtMask,
SPV_BI_SubgroupLeMask, SPV_BI_SubgroupLtMask, SPV_BI_BaseVertex,
SPV_BI_BaseInstance, SPV_BI_DrawIndex, SPV_BI_DeviceIndex, SPV_BI_ViewIndex,
SPV_BI_BaryCoordNoPerspAMD, SPV_BI_BaryCoordNoPerspCentroidAMD,
SPV_BI_BaryCoordNoPerspSampleAMD, SPV_BI_BaryCoordSmoothAMD,
SPV_BI_BaryCoordSmoothCentroidAMD, SPV_BI_BaryCoordSmoothSampleAMD,
SPV_BI_BaryCoordPullModelAMD, SPV_BI_FragStencilRefEXT, SPV_BI_ViewportMaskNV,
SPV_BI_SecondaryPositionNV, SPV_BI_SecondaryViewportMaskNV,
SPV_BI_PositionPerViewNV, SPV_BI_ViewportMaskPerViewNV, SPV_BI_FullyCoveredEXT,
SPV_BI_TaskCountNV, SPV_BI_PrimitiveCountNV, SPV_BI_PrimitiveIndicesNV,
SPV_BI_ClipDistancePerViewNV, SPV_BI_CullDistancePerViewNV,
SPV_BI_LayerPerViewNV, SPV_BI_MeshViewCountNV, SPV_BI_MeshViewIndicesNV,
SPV_BI_BaryCoordNV, SPV_BI_BaryCoordNoPerspNV, SPV_BI_FragSizeEXT,
SPV_BI_FragInvocationCountEXT, SPV_BI_LaunchIdNV, SPV_BI_LaunchSizeNV,
SPV_BI_WorldRayOriginNV, SPV_BI_WorldRayDirectionNV, SPV_BI_ObjectRayOriginNV,
SPV_BI_ObjectRayDirectionNV, SPV_BI_RayTminNV, SPV_BI_RayTmaxNV,
SPV_BI_InstanceCustomIndexNV, SPV_BI_ObjectToWorldNV, SPV_BI_WorldToObjectNV,
SPV_BI_HitTNV, SPV_BI_HitKindNV, SPV_BI_IncomingRayFlagsNV,
SPV_BI_WarpsPerSMNV, SPV_BI_SMCountNV, SPV_BI_WarpIDNV, SPV_BI_SMIDNV
]> {
let returnType = "::mlir::spirv::BuiltIn";
let convertFromStorage = "static_cast<::mlir::spirv::BuiltIn>($_self.getInt())";
let cppNamespace = "::mlir::spirv";
}
def SPV_D_RelaxedPrecision : I32EnumAttrCase<"RelaxedPrecision", 0>;
def SPV_D_SpecId : I32EnumAttrCase<"SpecId", 1>;
def SPV_D_Block : I32EnumAttrCase<"Block", 2>;

View file

@ -1192,11 +1192,13 @@ def SPV_VariableOp : SPV_Op<"Variable", []> {
``` {.ebnf}
variable-op ::= ssa-id `=` `spv.Variable` (`init(` ssa-use `)`)?
(`bind(` integer-literal, integer-literal `)`)?
(`built_in(` string-literal `)`)?
attribute-dict? `:` spirv-pointer-type
```
where `init` specifies initializer and `bind` specifies the descriptor set
and binding number.
where `init` specifies initializer and `bind` specifies the
descriptor set and binding number. `built_in` specifies SPIR-V
BuiltIn decoration associated with the op.
For example:
@ -1206,6 +1208,7 @@ def SPV_VariableOp : SPV_Op<"Variable", []> {
%1 = spv.Variable : !spv.ptr<f32, Function>
%2 = spv.Variable init(%0): !spv.ptr<f32, Private>
%3 = spv.Variable init(%0) bind(1, 2): !spv.ptr<f32, Uniform>
%3 = spv.Variable built_in("GlobalInvocationID") : !spv.ptr<vector<3xi32>, Uniform>
```
}];

View file

@ -898,13 +898,15 @@ static ParseResult parseVariableOp(OpAsmParser *parser, OperationState *state) {
return failure();
}
// Parse optional descriptor binding
Attribute set, binding;
auto descriptorSetName =
convertToSnakeCase(stringifyDecoration(spirv::Decoration::DescriptorSet));
auto bindingName =
convertToSnakeCase(stringifyDecoration(spirv::Decoration::Binding));
auto builtInName =
convertToSnakeCase(stringifyDecoration(spirv::Decoration::BuiltIn));
if (succeeded(parser->parseOptionalKeyword("bind"))) {
Attribute set, binding;
// Parse optional descriptor binding
auto descriptorSetName = convertToSnakeCase(
stringifyDecoration(spirv::Decoration::DescriptorSet));
auto bindingName =
convertToSnakeCase(stringifyDecoration(spirv::Decoration::Binding));
Type i32Type = parser->getBuilder().getIntegerType(32);
if (parser->parseLParen() ||
parser->parseAttribute(set, i32Type, descriptorSetName,
@ -912,8 +914,21 @@ static ParseResult parseVariableOp(OpAsmParser *parser, OperationState *state) {
parser->parseComma() ||
parser->parseAttribute(binding, i32Type, bindingName,
state->attributes) ||
parser->parseRParen())
parser->parseRParen()) {
return failure();
}
} else if (succeeded(parser->parseOptionalKeyword(builtInName.c_str()))) {
Attribute builtIn;
if (parser->parseLParen() ||
parser->parseAttribute(builtIn, Type(), builtInName,
state->attributes) ||
parser->parseRParen()) {
return failure();
}
if (!builtIn.isa<StringAttr>()) {
return parser->emitError(parser->getCurrentLocation(),
"expected string value for built_in attribute");
}
}
// Parse other attributes
@ -975,6 +990,14 @@ static void print(spirv::VariableOp varOp, OpAsmPrinter *printer) {
<< ")";
}
// Print BuiltIn attribute if present
auto builtInName =
convertToSnakeCase(stringifyDecoration(spirv::Decoration::BuiltIn));
if (auto builtin = varOp.getAttrOfType<StringAttr>(builtInName)) {
*printer << " " << builtInName << "(\"" << builtin.getValue() << "\")";
elidedAttrs.push_back(builtInName);
}
printer->printOptionalAttrDict(op->getAttrs(), elidedAttrs);
*printer << " : " << varOp.getType();
}

View file

@ -321,6 +321,15 @@ LogicalResult Deserializer::processDecoration(ArrayRef<uint32_t> words) {
opBuilder.getIdentifier(attrName),
opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2])));
break;
case spirv::Decoration::BuiltIn:
if (words.size() != 3) {
return emitError(unknownLoc, "OpDecorate with ")
<< decorationName << " needs a single integer literal";
}
decorations[words[0]].set(opBuilder.getIdentifier(attrName),
opBuilder.getStringAttr(stringifyBuiltIn(
static_cast<spirv::BuiltIn>(words[2]))));
break;
default:
return emitError(unknownLoc, "unhandled Decoration : '") << decorationName;
}

View file

@ -349,6 +349,17 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
break;
}
return emitError(loc, "expected integer attribute for ") << attrName;
case spirv::Decoration::BuiltIn:
if (auto strAttr = attr.second.dyn_cast<StringAttr>()) {
auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue());
if (enumVal) {
args.push_back(static_cast<uint32_t>(enumVal.getValue()));
break;
}
return emitError(loc, "invalid ")
<< attrName << " attribute " << strAttr.getValue();
}
return emitError(loc, "expected string attribute for ") << attrName;
default:
return emitError(loc, "unhandled decoration ") << decorationName;
}

View file

@ -2,10 +2,14 @@
// CHECK: {{%.*}} = spv.Variable bind(1, 0) : !spv.ptr<f32, Input>
// CHECK-NEXT: {{%.*}} = spv.Variable bind(0, 1) : !spv.ptr<f32, Output>
// CHECK-NEXT: {{%.*}} = spv.Variable built_in("GlobalInvocationId") : !spv.ptr<vector<3xi32>, Input>
// CHECK-NEXT: {{%.*}} = spv.Variable built_in("GlobalInvocationId") : !spv.ptr<vector<3xi32>, Input>
func @spirv_variables() -> () {
spv.module "Logical" "VulkanKHR" {
%2 = spv.Variable bind(1, 0) : !spv.ptr<f32, Input>
%3 = spv.Variable bind(0, 1): !spv.ptr<f32, Output>
%4 = spv.Variable {built_in = "GlobalInvocationId"} : !spv.ptr<vector<3xi32>, Input>
%5 = spv.Variable built_in("GlobalInvocationId") : !spv.ptr<vector<3xi32>, Input>
}
return
}

View file

@ -985,6 +985,14 @@ func @variable_init_bind() -> () {
return
}
func @variable_builtin() -> () {
// CHECK: spv.Variable built_in("GlobalInvocationID") : !spv.ptr<vector<3xi32>, Input>
%1 = spv.Variable built_in("GlobalInvocationID") : !spv.ptr<vector<3xi32>, Input>
// CHECK: spv.Variable built_in("GlobalInvocationID") : !spv.ptr<vector<3xi32>, Input>
%2 = spv.Variable {built_in = "GlobalInvocationID"} : !spv.ptr<vector<3xi32>, Input>
return
}
// -----
func @expect_ptr_result_type(%arg0: f32) -> () {

View file

@ -125,7 +125,7 @@ def uniquify(lst, equality_fn):
unique_lst = []
for elem in lst:
key = equality_fn(elem)
if equality_fn(key) not in keys:
if key not in keys:
unique_lst.append(elem)
keys.add(key)
return unique_lst