diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVArithmeticOps.td index ded9920ed7f6..deea880774cd 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVArithmeticOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVArithmeticOps.td @@ -36,6 +36,13 @@ class SPV_ArithmeticOp; +class SPV_UnaryArithmeticOp traits = []> : + // Operand type same as result type. + SPV_UnaryOp; + // ----- def SPV_FAddOp : SPV_ArithmeticOp<"FAdd", SPV_Float, [Commutative]> { @@ -67,7 +74,7 @@ def SPV_FAddOp : SPV_ArithmeticOp<"FAdd", SPV_Float, [Commutative]> { // ----- -def SPV_FDivOp : SPV_ArithmeticOp<"FDiv", SPV_Float> { +def SPV_FDivOp : SPV_ArithmeticOp<"FDiv", SPV_Float, []> { let summary = "Floating-point division of Operand 1 divided by Operand 2."; let description = [{ @@ -78,6 +85,7 @@ def SPV_FDivOp : SPV_ArithmeticOp<"FDiv", SPV_Float> { Results are computed per component. The resulting value is undefined if Operand 2 is 0. + ### Custom assembly form ``` {.ebnf} float-scalar-vector-type ::= float-type | @@ -97,7 +105,7 @@ def SPV_FDivOp : SPV_ArithmeticOp<"FDiv", SPV_Float> { // ----- -def SPV_FModOp : SPV_ArithmeticOp<"FMod", SPV_Float> { +def SPV_FModOp : SPV_ArithmeticOp<"FMod", SPV_Float, []> { let summary = [{ The floating-point remainder whose sign matches the sign of Operand 2. }]; @@ -162,7 +170,36 @@ def SPV_FMulOp : SPV_ArithmeticOp<"FMul", SPV_Float, [Commutative]> { // ----- -def SPV_FRemOp : SPV_ArithmeticOp<"FRem", SPV_Float> { +def SPV_FNegateOp : SPV_UnaryArithmeticOp<"FNegate", SPV_Float, []> { + let summary = "Floating-point subtract of Operand from zero."; + + let description = [{ + Result Type must be a scalar or vector of floating-point type. + + The type of Operand must be the same as Result Type. + + Results are computed per component. + + ### Custom assembly form + + ``` {.ebnf} + float-scalar-vector-type ::= float-type | + `vector<` integer-literal `x` float-type `>` + fmul-op ::= `spv.FNegate` ssa-use `:` float-scalar-vector-type + ``` + + For example: + + ``` + %1 = spv.FNegate %0 : f32 + %3 = spv.FNegate %2 : vector<4xf32> + ``` + }]; +} + +// ----- + +def SPV_FRemOp : SPV_ArithmeticOp<"FRem", SPV_Float, []> { let summary = [{ The floating-point remainder whose sign matches the sign of Operand 1. }]; @@ -197,7 +234,7 @@ def SPV_FRemOp : SPV_ArithmeticOp<"FRem", SPV_Float> { // ----- -def SPV_FSubOp : SPV_ArithmeticOp<"FSub", SPV_Float> { +def SPV_FSubOp : SPV_ArithmeticOp<"FSub", SPV_Float, []> { let summary = "Floating-point subtraction of Operand 2 from Operand 1."; let description = [{ @@ -299,7 +336,7 @@ def SPV_IMulOp : SPV_ArithmeticOp<"IMul", SPV_Integer, [Commutative]> { // ----- -def SPV_ISubOp : SPV_ArithmeticOp<"ISub", SPV_Integer> { +def SPV_ISubOp : SPV_ArithmeticOp<"ISub", SPV_Integer, []> { let summary = "Integer subtraction of Operand 2 from Operand 1."; let description = [{ @@ -335,7 +372,7 @@ def SPV_ISubOp : SPV_ArithmeticOp<"ISub", SPV_Integer> { // ----- -def SPV_SDivOp : SPV_ArithmeticOp<"SDiv", SPV_Integer> { +def SPV_SDivOp : SPV_ArithmeticOp<"SDiv", SPV_Integer, []> { let summary = "Signed-integer division of Operand 1 divided by Operand 2."; let description = [{ @@ -368,7 +405,7 @@ def SPV_SDivOp : SPV_ArithmeticOp<"SDiv", SPV_Integer> { // ----- -def SPV_SModOp : SPV_ArithmeticOp<"SMod", SPV_Integer> { +def SPV_SModOp : SPV_ArithmeticOp<"SMod", SPV_Integer, []> { let summary = [{ Signed remainder operation for the remainder whose sign matches the sign of Operand 2. @@ -405,7 +442,7 @@ def SPV_SModOp : SPV_ArithmeticOp<"SMod", SPV_Integer> { // ----- -def SPV_SRemOp : SPV_ArithmeticOp<"SRem", SPV_Integer> { +def SPV_SRemOp : SPV_ArithmeticOp<"SRem", SPV_Integer, []> { let summary = [{ Signed remainder operation for the remainder whose sign matches the sign of Operand 1. @@ -442,7 +479,7 @@ def SPV_SRemOp : SPV_ArithmeticOp<"SRem", SPV_Integer> { // ----- -def SPV_UDivOp : SPV_ArithmeticOp<"UDiv", SPV_Integer> { +def SPV_UDivOp : SPV_ArithmeticOp<"UDiv", SPV_Integer, []> { let summary = "Unsigned-integer division of Operand 1 divided by Operand 2."; let description = [{ diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td index 59b29cc353ba..081c6dfbd2f6 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -116,6 +116,7 @@ def SPV_OC_OpAccessChain : I32EnumAttrCase<"OpAccessChain", 65>; def SPV_OC_OpDecorate : I32EnumAttrCase<"OpDecorate", 71>; def SPV_OC_OpMemberDecorate : I32EnumAttrCase<"OpMemberDecorate", 72>; def SPV_OC_OpCompositeExtract : I32EnumAttrCase<"OpCompositeExtract", 81>; +def SPV_OC_OpFNegate : I32EnumAttrCase<"OpFNegate", 127>; def SPV_OC_OpIAdd : I32EnumAttrCase<"OpIAdd", 128>; def SPV_OC_OpFAdd : I32EnumAttrCase<"OpFAdd", 129>; def SPV_OC_OpISub : I32EnumAttrCase<"OpISub", 130>; @@ -178,16 +179,16 @@ def SPV_OpcodeAttr : SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction, SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall, SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpAccessChain, SPV_OC_OpDecorate, - SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeExtract, SPV_OC_OpIAdd, - SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul, - SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem, - SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpSelect, SPV_OC_OpIEqual, - SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan, - SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan, - SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual, - SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual, - SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, SPV_OC_OpFUnordLessThan, - SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan, + SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeExtract, SPV_OC_OpFNegate, + SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, + SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, + SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpSelect, + SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, + SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, + SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, + SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, + SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, + SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan, SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual, SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual, SPV_OC_OpControlBarrier, SPV_OC_OpMemoryBarrier, SPV_OC_OpLoopMerge, @@ -1234,6 +1235,23 @@ class SPV_Op traits = []> : bit autogenSerialization = 1; } +class SPV_UnaryOp traits = []> : + SPV_Op { + let arguments = (ins + SPV_ScalarOrVectorOf:$operand + ); + + let results = (outs + SPV_ScalarOrVectorOf:$result + ); + + let parser = [{ return ::parseUnaryOp(parser, result); }]; + let printer = [{ return ::printUnaryOp(getOperation(), p); }]; + // No additional verification needed in addition to the ODS-generated ones. + let verifier = [{ return success(); }]; +} + class SPV_BinaryOp traits = []> : SPV_Op { @@ -1241,9 +1259,11 @@ class SPV_BinaryOp:$operand1, SPV_ScalarOrVectorOf:$operand2 ); + let results = (outs SPV_ScalarOrVectorOf:$result ); + let parser = [{ return impl::parseBinaryOp(parser, result); }]; let printer = [{ return impl::printBinaryOp(getOperation(), p); }]; // No additional verification needed in addition to the ODS-generated ones. diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVGLSLOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVGLSLOps.td index ba5e900469d8..49e6b5051ba3 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVGLSLOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVGLSLOps.td @@ -49,9 +49,9 @@ class SPV_GLSLUnaryOp:$result ); - let parser = [{ return parseGLSLUnaryOp(parser, result); }]; + let parser = [{ return parseUnaryOp(parser, result); }]; - let printer = [{ return printGLSLUnaryOp(getOperation(), p); }]; + let printer = [{ return printUnaryOp(getOperation(), p); }]; let verifier = [{ return success(); }]; } @@ -152,4 +152,4 @@ def SPV_GLSLFMaxOp : SPV_GLSLBinaryArithmaticOp<"FMax", 40, SPV_Float> { // ----- -#endif // SPIRV_GLSL_OPS \ No newline at end of file +#endif // SPIRV_GLSL_OPS diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index 3890e0e8115f..1d637f8a30ee 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -69,23 +69,6 @@ static LogicalResult extractValueFromConstOp(Operation *op, return success(); } -static ParseResult parseBinaryLogicalOp(OpAsmParser &parser, - OperationState &result) { - SmallVector ops; - Type type; - if (parser.parseOperandList(ops, 2) || parser.parseColonType(type) || - parser.resolveOperands(ops, type, result.operands)) { - return failure(); - } - // Result must be a scalar or vector of boolean type. - Type resultType = parser.getBuilder().getIntegerType(1); - if (auto opsType = type.dyn_cast()) { - resultType = VectorType::get(opsType.getNumElements(), resultType); - } - result.addTypes(resultType); - return success(); -} - template static ParseResult parseEnumAttribute(EnumClass &value, OpAsmParser &parser, @@ -149,19 +132,6 @@ static ParseResult parseMemoryAccessAttributes(OpAsmParser &parser, return parser.parseRSquare(); } -// Parses an op that has no inputs and no outputs. -static ParseResult parseNoIOOp(OpAsmParser &parser, OperationState &state) { - if (parser.parseOptionalAttributeDict(state.attributes)) - return failure(); - return success(); -} - -static void printBinaryLogicalOp(Operation *logicalOp, OpAsmPrinter &printer) { - printer << logicalOp->getName() << ' ' << *logicalOp->getOperand(0) << ", " - << *logicalOp->getOperand(1); - printer << " : " << logicalOp->getOperand(0)->getType(); -} - template static void printMemoryAccessAttribute(LoadStoreOpTy loadStoreOp, OpAsmPrinter &printer, @@ -260,12 +230,6 @@ static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value *ptr, return success(); } -// Prints an op that has no inputs and no outputs. -static void printNoIOOp(Operation *op, OpAsmPrinter &printer) { - printer << op->getName(); - printer.printOptionalAttrDict(op->getAttrs()); -} - static ParseResult parseVariableDecorations(OpAsmParser &parser, OperationState &state) { auto builtInName = @@ -352,6 +316,62 @@ static Attribute extractCompositeElement(Attribute composite, return {}; } +//===----------------------------------------------------------------------===// +// Common parsers and printers +//===----------------------------------------------------------------------===// + +// Parses an op that has no inputs and no outputs. +static ParseResult parseNoIOOp(OpAsmParser &parser, OperationState &state) { + if (parser.parseOptionalAttributeDict(state.attributes)) + return failure(); + return success(); +} + +// Prints an op that has no inputs and no outputs. +static void printNoIOOp(Operation *op, OpAsmPrinter &printer) { + printer << op->getName(); + printer.printOptionalAttrDict(op->getAttrs()); +} + +static ParseResult parseUnaryOp(OpAsmParser &parser, OperationState &state) { + OpAsmParser::OperandType operandInfo; + Type type; + if (parser.parseOperand(operandInfo) || parser.parseColonType(type) || + parser.resolveOperands(operandInfo, type, state.operands)) { + return failure(); + } + state.addTypes(type); + return success(); +} + +static void printUnaryOp(Operation *unaryOp, OpAsmPrinter &printer) { + printer << unaryOp->getName() << ' ' << *unaryOp->getOperand(0) << " : " + << unaryOp->getOperand(0)->getType(); +} + +static ParseResult parseBinaryLogicalOp(OpAsmParser &parser, + OperationState &result) { + SmallVector ops; + Type type; + if (parser.parseOperandList(ops, 2) || parser.parseColonType(type) || + parser.resolveOperands(ops, type, result.operands)) { + return failure(); + } + // Result must be a scalar or vector of boolean type. + Type resultType = parser.getBuilder().getIntegerType(1); + if (auto opsType = type.dyn_cast()) { + resultType = VectorType::get(opsType.getNumElements(), resultType); + } + result.addTypes(resultType); + return success(); +} + +static void printBinaryLogicalOp(Operation *logicalOp, OpAsmPrinter &printer) { + printer << logicalOp->getName() << ' ' << *logicalOp->getOperand(0) << ", " + << *logicalOp->getOperand(1); + printer << " : " << logicalOp->getOperand(0)->getType(); +} + //===----------------------------------------------------------------------===// // spv.AccessChainOp //===----------------------------------------------------------------------===// @@ -1057,27 +1077,6 @@ static LogicalResult verify(spirv::FunctionCallOp functionCallOp) { return success(); } -//===----------------------------------------------------------------------===// -// spv.GLSL.UnaryOp -//===----------------------------------------------------------------------===// - -static ParseResult parseGLSLUnaryOp(OpAsmParser &parser, - OperationState &state) { - OpAsmParser::OperandType operandInfo; - Type type; - if (parser.parseOperand(operandInfo) || parser.parseColonType(type) || - parser.resolveOperands(operandInfo, type, state.operands)) { - return failure(); - } - state.addTypes(type); - return success(); -} - -static void printGLSLUnaryOp(Operation *unaryOp, OpAsmPrinter &printer) { - printer << unaryOp->getName() << ' ' << *unaryOp->getOperand(0) << " : " - << unaryOp->getOperand(0)->getType(); -} - //===----------------------------------------------------------------------===// // spv.globalVariable //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/Serialization/arithmetic-ops.mlir b/mlir/test/Dialect/SPIRV/Serialization/arithmetic-ops.mlir index 0389a169b795..ff55e8120fec 100644 --- a/mlir/test/Dialect/SPIRV/Serialization/arithmetic-ops.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/arithmetic-ops.mlir @@ -21,6 +21,11 @@ spv.module "Logical" "GLSL450" { %0 = spv.FMod %arg0, %arg1 : vector<4xf32> spv.Return } + func @fnegate(%arg0 : vector<4xf32>) { + // CHECK: {{%.*}} = spv.FNegate {{%.*}} : vector<4xf32> + %0 = spv.FNegate %arg0 : vector<4xf32> + spv.Return + } func @fsub(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) { // CHECK: {{%.*}} = spv.FSub {{%.*}}, {{%.*}} : vector<4xf32> %0 = spv.FSub %arg0, %arg1 : vector<4xf32> diff --git a/mlir/test/Dialect/SPIRV/arithmetic-ops.mlir b/mlir/test/Dialect/SPIRV/arithmetic-ops.mlir index 9369962173e3..e1b9c1b8411d 100644 --- a/mlir/test/Dialect/SPIRV/arithmetic-ops.mlir +++ b/mlir/test/Dialect/SPIRV/arithmetic-ops.mlir @@ -78,6 +78,18 @@ func @fmul_tensor(%arg: tensor<4xf32>) -> tensor<4xf32> { // ----- +//===----------------------------------------------------------------------===// +// spv.FNegate +//===----------------------------------------------------------------------===// + +func @fnegate_scalar(%arg: f32) -> f32 { + // CHECK: spv.FNegate + %0 = spv.FNegate %arg : f32 + return %0 : f32 +} + +// ----- + //===----------------------------------------------------------------------===// // spv.FRem //===----------------------------------------------------------------------===// diff --git a/mlir/utils/spirv/gen_spirv_dialect.py b/mlir/utils/spirv/gen_spirv_dialect.py index 6595931eeeda..e505b4054c50 100755 --- a/mlir/utils/spirv/gen_spirv_dialect.py +++ b/mlir/utils/spirv/gen_spirv_dialect.py @@ -389,9 +389,9 @@ def get_op_definition(instruction, doc, existing_info, inst_category): '}}\n' opname = instruction['opname'][2:] - category_args = existing_info.get('category_args', None) - if category_args is None: - category_args = ', ' + category_args = existing_info.get('category_args', '') + # Make sure we have ', ' to separate the category arguments from traits + category_args = category_args.rstrip(', ') + ', ' summary, description = doc.split('\n', 1) wrapper = textwrap.TextWrapper(