[spirv] Add SPV_UnaryOp and spv.FNegate

This CL also moves common parsers and printers to the
same section in SPIRVOps.cpp.

PiperOrigin-RevId: 271233546
This commit is contained in:
Lei Zhang 2019-09-25 16:34:37 -07:00 committed by A. Unique TensorFlower
parent a2bce652af
commit ae13c28f3f
7 changed files with 155 additions and 82 deletions

View file

@ -36,6 +36,13 @@ class SPV_ArithmeticOp<string mnemonic, Type type,
!listconcat(traits, !listconcat(traits,
[NoSideEffect, SameOperandsAndResultType])>; [NoSideEffect, SameOperandsAndResultType])>;
class SPV_UnaryArithmeticOp<string mnemonic, Type type,
list<OpTrait> traits = []> :
// Operand type same as result type.
SPV_UnaryOp<mnemonic, type, type,
!listconcat(traits,
[NoSideEffect, SameOperandsAndResultType])>;
// ----- // -----
def SPV_FAddOp : SPV_ArithmeticOp<"FAdd", SPV_Float, [Commutative]> { def SPV_FAddOp : SPV_ArithmeticOp<"FAdd", SPV_Float, [Commutative]> {
@ -67,7 +74,7 @@ def SPV_FAddOp : SPV_ArithmeticOp<"FAdd", SPV_Float, [Commutative]> {
// ----- // -----
def SPV_FDivOp : SPV_ArithmeticOp<"FDiv", SPV_Float> { def SPV_FDivOp : SPV_ArithmeticOp<"FDiv", SPV_Float, []> {
let summary = "Floating-point division of Operand 1 divided by Operand 2."; let summary = "Floating-point division of Operand 1 divided by Operand 2.";
let description = [{ let description = [{
@ -78,6 +85,7 @@ def SPV_FDivOp : SPV_ArithmeticOp<"FDiv", SPV_Float> {
Results are computed per component. The resulting value is undefined Results are computed per component. The resulting value is undefined
if Operand 2 is 0. if Operand 2 is 0.
### Custom assembly form ### Custom assembly form
``` {.ebnf} ``` {.ebnf}
float-scalar-vector-type ::= float-type | float-scalar-vector-type ::= float-type |
@ -97,7 +105,7 @@ def SPV_FDivOp : SPV_ArithmeticOp<"FDiv", SPV_Float> {
// ----- // -----
def SPV_FModOp : SPV_ArithmeticOp<"FMod", SPV_Float> { def SPV_FModOp : SPV_ArithmeticOp<"FMod", SPV_Float, []> {
let summary = [{ let summary = [{
The floating-point remainder whose sign matches the sign of Operand 2. The floating-point remainder whose sign matches the sign of Operand 2.
}]; }];
@ -162,7 +170,36 @@ def SPV_FMulOp : SPV_ArithmeticOp<"FMul", SPV_Float, [Commutative]> {
// ----- // -----
def SPV_FRemOp : SPV_ArithmeticOp<"FRem", SPV_Float> { def SPV_FNegateOp : SPV_UnaryArithmeticOp<"FNegate", SPV_Float, []> {
let summary = "Floating-point subtract of Operand from zero.";
let description = [{
Result Type must be a scalar or vector of floating-point type.
The type of Operand must be the same as Result Type.
Results are computed per component.
### Custom assembly form
``` {.ebnf}
float-scalar-vector-type ::= float-type |
`vector<` integer-literal `x` float-type `>`
fmul-op ::= `spv.FNegate` ssa-use `:` float-scalar-vector-type
```
For example:
```
%1 = spv.FNegate %0 : f32
%3 = spv.FNegate %2 : vector<4xf32>
```
}];
}
// -----
def SPV_FRemOp : SPV_ArithmeticOp<"FRem", SPV_Float, []> {
let summary = [{ let summary = [{
The floating-point remainder whose sign matches the sign of Operand 1. The floating-point remainder whose sign matches the sign of Operand 1.
}]; }];
@ -197,7 +234,7 @@ def SPV_FRemOp : SPV_ArithmeticOp<"FRem", SPV_Float> {
// ----- // -----
def SPV_FSubOp : SPV_ArithmeticOp<"FSub", SPV_Float> { def SPV_FSubOp : SPV_ArithmeticOp<"FSub", SPV_Float, []> {
let summary = "Floating-point subtraction of Operand 2 from Operand 1."; let summary = "Floating-point subtraction of Operand 2 from Operand 1.";
let description = [{ let description = [{
@ -299,7 +336,7 @@ def SPV_IMulOp : SPV_ArithmeticOp<"IMul", SPV_Integer, [Commutative]> {
// ----- // -----
def SPV_ISubOp : SPV_ArithmeticOp<"ISub", SPV_Integer> { def SPV_ISubOp : SPV_ArithmeticOp<"ISub", SPV_Integer, []> {
let summary = "Integer subtraction of Operand 2 from Operand 1."; let summary = "Integer subtraction of Operand 2 from Operand 1.";
let description = [{ let description = [{
@ -335,7 +372,7 @@ def SPV_ISubOp : SPV_ArithmeticOp<"ISub", SPV_Integer> {
// ----- // -----
def SPV_SDivOp : SPV_ArithmeticOp<"SDiv", SPV_Integer> { def SPV_SDivOp : SPV_ArithmeticOp<"SDiv", SPV_Integer, []> {
let summary = "Signed-integer division of Operand 1 divided by Operand 2."; let summary = "Signed-integer division of Operand 1 divided by Operand 2.";
let description = [{ let description = [{
@ -368,7 +405,7 @@ def SPV_SDivOp : SPV_ArithmeticOp<"SDiv", SPV_Integer> {
// ----- // -----
def SPV_SModOp : SPV_ArithmeticOp<"SMod", SPV_Integer> { def SPV_SModOp : SPV_ArithmeticOp<"SMod", SPV_Integer, []> {
let summary = [{ let summary = [{
Signed remainder operation for the remainder whose sign matches the sign Signed remainder operation for the remainder whose sign matches the sign
of Operand 2. of Operand 2.
@ -405,7 +442,7 @@ def SPV_SModOp : SPV_ArithmeticOp<"SMod", SPV_Integer> {
// ----- // -----
def SPV_SRemOp : SPV_ArithmeticOp<"SRem", SPV_Integer> { def SPV_SRemOp : SPV_ArithmeticOp<"SRem", SPV_Integer, []> {
let summary = [{ let summary = [{
Signed remainder operation for the remainder whose sign matches the sign Signed remainder operation for the remainder whose sign matches the sign
of Operand 1. of Operand 1.
@ -442,7 +479,7 @@ def SPV_SRemOp : SPV_ArithmeticOp<"SRem", SPV_Integer> {
// ----- // -----
def SPV_UDivOp : SPV_ArithmeticOp<"UDiv", SPV_Integer> { def SPV_UDivOp : SPV_ArithmeticOp<"UDiv", SPV_Integer, []> {
let summary = "Unsigned-integer division of Operand 1 divided by Operand 2."; let summary = "Unsigned-integer division of Operand 1 divided by Operand 2.";
let description = [{ let description = [{

View file

@ -116,6 +116,7 @@ def SPV_OC_OpAccessChain : I32EnumAttrCase<"OpAccessChain", 65>;
def SPV_OC_OpDecorate : I32EnumAttrCase<"OpDecorate", 71>; def SPV_OC_OpDecorate : I32EnumAttrCase<"OpDecorate", 71>;
def SPV_OC_OpMemberDecorate : I32EnumAttrCase<"OpMemberDecorate", 72>; def SPV_OC_OpMemberDecorate : I32EnumAttrCase<"OpMemberDecorate", 72>;
def SPV_OC_OpCompositeExtract : I32EnumAttrCase<"OpCompositeExtract", 81>; def SPV_OC_OpCompositeExtract : I32EnumAttrCase<"OpCompositeExtract", 81>;
def SPV_OC_OpFNegate : I32EnumAttrCase<"OpFNegate", 127>;
def SPV_OC_OpIAdd : I32EnumAttrCase<"OpIAdd", 128>; def SPV_OC_OpIAdd : I32EnumAttrCase<"OpIAdd", 128>;
def SPV_OC_OpFAdd : I32EnumAttrCase<"OpFAdd", 129>; def SPV_OC_OpFAdd : I32EnumAttrCase<"OpFAdd", 129>;
def SPV_OC_OpISub : I32EnumAttrCase<"OpISub", 130>; def SPV_OC_OpISub : I32EnumAttrCase<"OpISub", 130>;
@ -178,16 +179,16 @@ def SPV_OpcodeAttr :
SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction, SPV_OC_OpFunctionParameter, SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction, SPV_OC_OpFunctionParameter,
SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall, SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall, SPV_OC_OpVariable, SPV_OC_OpLoad,
SPV_OC_OpStore, SPV_OC_OpAccessChain, SPV_OC_OpDecorate, SPV_OC_OpStore, SPV_OC_OpAccessChain, SPV_OC_OpDecorate,
SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeExtract, SPV_OC_OpIAdd, SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeExtract, SPV_OC_OpFNegate,
SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul,
SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod,
SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpSelect, SPV_OC_OpIEqual, SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpSelect,
SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan, SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan,
SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan, SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual,
SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual, SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual,
SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual, SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual,
SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan,
SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan, SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan,
SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual, SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual,
SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual, SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual,
SPV_OC_OpControlBarrier, SPV_OC_OpMemoryBarrier, SPV_OC_OpLoopMerge, SPV_OC_OpControlBarrier, SPV_OC_OpMemoryBarrier, SPV_OC_OpLoopMerge,
@ -1234,6 +1235,23 @@ class SPV_Op<string mnemonic, list<OpTrait> traits = []> :
bit autogenSerialization = 1; bit autogenSerialization = 1;
} }
class SPV_UnaryOp<string mnemonic, Type resultType, Type operandType,
list<OpTrait> traits = []> :
SPV_Op<mnemonic, traits> {
let arguments = (ins
SPV_ScalarOrVectorOf<operandType>:$operand
);
let results = (outs
SPV_ScalarOrVectorOf<resultType>:$result
);
let parser = [{ return ::parseUnaryOp(parser, result); }];
let printer = [{ return ::printUnaryOp(getOperation(), p); }];
// No additional verification needed in addition to the ODS-generated ones.
let verifier = [{ return success(); }];
}
class SPV_BinaryOp<string mnemonic, Type resultType, Type operandsType, class SPV_BinaryOp<string mnemonic, Type resultType, Type operandsType,
list<OpTrait> traits = []> : list<OpTrait> traits = []> :
SPV_Op<mnemonic, traits> { SPV_Op<mnemonic, traits> {
@ -1241,9 +1259,11 @@ class SPV_BinaryOp<string mnemonic, Type resultType, Type operandsType,
SPV_ScalarOrVectorOf<operandsType>:$operand1, SPV_ScalarOrVectorOf<operandsType>:$operand1,
SPV_ScalarOrVectorOf<operandsType>:$operand2 SPV_ScalarOrVectorOf<operandsType>:$operand2
); );
let results = (outs let results = (outs
SPV_ScalarOrVectorOf<resultType>:$result SPV_ScalarOrVectorOf<resultType>:$result
); );
let parser = [{ return impl::parseBinaryOp(parser, result); }]; let parser = [{ return impl::parseBinaryOp(parser, result); }];
let printer = [{ return impl::printBinaryOp(getOperation(), p); }]; let printer = [{ return impl::printBinaryOp(getOperation(), p); }];
// No additional verification needed in addition to the ODS-generated ones. // No additional verification needed in addition to the ODS-generated ones.

View file

@ -49,9 +49,9 @@ class SPV_GLSLUnaryOp<string mnemonic, int opcode, Type resultType,
SPV_ScalarOrVectorOf<resultType>:$result SPV_ScalarOrVectorOf<resultType>:$result
); );
let parser = [{ return parseGLSLUnaryOp(parser, result); }]; let parser = [{ return parseUnaryOp(parser, result); }];
let printer = [{ return printGLSLUnaryOp(getOperation(), p); }]; let printer = [{ return printUnaryOp(getOperation(), p); }];
let verifier = [{ return success(); }]; let verifier = [{ return success(); }];
} }
@ -152,4 +152,4 @@ def SPV_GLSLFMaxOp : SPV_GLSLBinaryArithmaticOp<"FMax", 40, SPV_Float> {
// ----- // -----
#endif // SPIRV_GLSL_OPS #endif // SPIRV_GLSL_OPS

View file

@ -69,23 +69,6 @@ static LogicalResult extractValueFromConstOp(Operation *op,
return success(); return success();
} }
static ParseResult parseBinaryLogicalOp(OpAsmParser &parser,
OperationState &result) {
SmallVector<OpAsmParser::OperandType, 2> ops;
Type type;
if (parser.parseOperandList(ops, 2) || parser.parseColonType(type) ||
parser.resolveOperands(ops, type, result.operands)) {
return failure();
}
// Result must be a scalar or vector of boolean type.
Type resultType = parser.getBuilder().getIntegerType(1);
if (auto opsType = type.dyn_cast<VectorType>()) {
resultType = VectorType::get(opsType.getNumElements(), resultType);
}
result.addTypes(resultType);
return success();
}
template <typename EnumClass> template <typename EnumClass>
static ParseResult static ParseResult
parseEnumAttribute(EnumClass &value, OpAsmParser &parser, parseEnumAttribute(EnumClass &value, OpAsmParser &parser,
@ -149,19 +132,6 @@ static ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
return parser.parseRSquare(); return parser.parseRSquare();
} }
// Parses an op that has no inputs and no outputs.
static ParseResult parseNoIOOp(OpAsmParser &parser, OperationState &state) {
if (parser.parseOptionalAttributeDict(state.attributes))
return failure();
return success();
}
static void printBinaryLogicalOp(Operation *logicalOp, OpAsmPrinter &printer) {
printer << logicalOp->getName() << ' ' << *logicalOp->getOperand(0) << ", "
<< *logicalOp->getOperand(1);
printer << " : " << logicalOp->getOperand(0)->getType();
}
template <typename LoadStoreOpTy> template <typename LoadStoreOpTy>
static void static void
printMemoryAccessAttribute(LoadStoreOpTy loadStoreOp, OpAsmPrinter &printer, printMemoryAccessAttribute(LoadStoreOpTy loadStoreOp, OpAsmPrinter &printer,
@ -260,12 +230,6 @@ static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value *ptr,
return success(); return success();
} }
// Prints an op that has no inputs and no outputs.
static void printNoIOOp(Operation *op, OpAsmPrinter &printer) {
printer << op->getName();
printer.printOptionalAttrDict(op->getAttrs());
}
static ParseResult parseVariableDecorations(OpAsmParser &parser, static ParseResult parseVariableDecorations(OpAsmParser &parser,
OperationState &state) { OperationState &state) {
auto builtInName = auto builtInName =
@ -352,6 +316,62 @@ static Attribute extractCompositeElement(Attribute composite,
return {}; return {};
} }
//===----------------------------------------------------------------------===//
// Common parsers and printers
//===----------------------------------------------------------------------===//
// Parses an op that has no inputs and no outputs.
static ParseResult parseNoIOOp(OpAsmParser &parser, OperationState &state) {
if (parser.parseOptionalAttributeDict(state.attributes))
return failure();
return success();
}
// Prints an op that has no inputs and no outputs.
static void printNoIOOp(Operation *op, OpAsmPrinter &printer) {
printer << op->getName();
printer.printOptionalAttrDict(op->getAttrs());
}
static ParseResult parseUnaryOp(OpAsmParser &parser, OperationState &state) {
OpAsmParser::OperandType operandInfo;
Type type;
if (parser.parseOperand(operandInfo) || parser.parseColonType(type) ||
parser.resolveOperands(operandInfo, type, state.operands)) {
return failure();
}
state.addTypes(type);
return success();
}
static void printUnaryOp(Operation *unaryOp, OpAsmPrinter &printer) {
printer << unaryOp->getName() << ' ' << *unaryOp->getOperand(0) << " : "
<< unaryOp->getOperand(0)->getType();
}
static ParseResult parseBinaryLogicalOp(OpAsmParser &parser,
OperationState &result) {
SmallVector<OpAsmParser::OperandType, 2> ops;
Type type;
if (parser.parseOperandList(ops, 2) || parser.parseColonType(type) ||
parser.resolveOperands(ops, type, result.operands)) {
return failure();
}
// Result must be a scalar or vector of boolean type.
Type resultType = parser.getBuilder().getIntegerType(1);
if (auto opsType = type.dyn_cast<VectorType>()) {
resultType = VectorType::get(opsType.getNumElements(), resultType);
}
result.addTypes(resultType);
return success();
}
static void printBinaryLogicalOp(Operation *logicalOp, OpAsmPrinter &printer) {
printer << logicalOp->getName() << ' ' << *logicalOp->getOperand(0) << ", "
<< *logicalOp->getOperand(1);
printer << " : " << logicalOp->getOperand(0)->getType();
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// spv.AccessChainOp // spv.AccessChainOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -1057,27 +1077,6 @@ static LogicalResult verify(spirv::FunctionCallOp functionCallOp) {
return success(); return success();
} }
//===----------------------------------------------------------------------===//
// spv.GLSL.UnaryOp
//===----------------------------------------------------------------------===//
static ParseResult parseGLSLUnaryOp(OpAsmParser &parser,
OperationState &state) {
OpAsmParser::OperandType operandInfo;
Type type;
if (parser.parseOperand(operandInfo) || parser.parseColonType(type) ||
parser.resolveOperands(operandInfo, type, state.operands)) {
return failure();
}
state.addTypes(type);
return success();
}
static void printGLSLUnaryOp(Operation *unaryOp, OpAsmPrinter &printer) {
printer << unaryOp->getName() << ' ' << *unaryOp->getOperand(0) << " : "
<< unaryOp->getOperand(0)->getType();
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// spv.globalVariable // spv.globalVariable
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View file

@ -21,6 +21,11 @@ spv.module "Logical" "GLSL450" {
%0 = spv.FMod %arg0, %arg1 : vector<4xf32> %0 = spv.FMod %arg0, %arg1 : vector<4xf32>
spv.Return spv.Return
} }
func @fnegate(%arg0 : vector<4xf32>) {
// CHECK: {{%.*}} = spv.FNegate {{%.*}} : vector<4xf32>
%0 = spv.FNegate %arg0 : vector<4xf32>
spv.Return
}
func @fsub(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) { func @fsub(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) {
// CHECK: {{%.*}} = spv.FSub {{%.*}}, {{%.*}} : vector<4xf32> // CHECK: {{%.*}} = spv.FSub {{%.*}}, {{%.*}} : vector<4xf32>
%0 = spv.FSub %arg0, %arg1 : vector<4xf32> %0 = spv.FSub %arg0, %arg1 : vector<4xf32>

View file

@ -78,6 +78,18 @@ func @fmul_tensor(%arg: tensor<4xf32>) -> tensor<4xf32> {
// ----- // -----
//===----------------------------------------------------------------------===//
// spv.FNegate
//===----------------------------------------------------------------------===//
func @fnegate_scalar(%arg: f32) -> f32 {
// CHECK: spv.FNegate
%0 = spv.FNegate %arg : f32
return %0 : f32
}
// -----
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// spv.FRem // spv.FRem
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View file

@ -389,9 +389,9 @@ def get_op_definition(instruction, doc, existing_info, inst_category):
'}}\n' '}}\n'
opname = instruction['opname'][2:] opname = instruction['opname'][2:]
category_args = existing_info.get('category_args', None) category_args = existing_info.get('category_args', '')
if category_args is None: # Make sure we have ', ' to separate the category arguments from traits
category_args = ', ' category_args = category_args.rstrip(', ') + ', '
summary, description = doc.split('\n', 1) summary, description = doc.split('\n', 1)
wrapper = textwrap.TextWrapper( wrapper = textwrap.TextWrapper(