From 2f7bb1e25f1be3ee5c6dce72aa1eced4c823f123 Mon Sep 17 00:00:00 2001 From: Mahesh Ravishankar Date: Mon, 30 Sep 2019 10:40:07 -0700 Subject: [PATCH] Add support for Logical Ops in SPIR-V dialect Add operations corresponding to OpLogicalAnd, OpLogicalNot, OpLogicalEqual, OpLogicalNotEqual and OpLogicalOr instructions in SPIR-V dialect. This needs changes to class hierarchy in SPIR-V TableGen files to split SPIRVLogicalOp into SPIRVLogicalUnaryOp and SPIRVLogicalBinaryOp. All derived classes of SPIRVLogicalOp are updated accordingly. Update the spirv dialect generation script to 1) Allow specifying base class to use for instruction spec generation and file name to generate the specification in separately. 2) Use the existing descriptions for operations. 3) Update define_inst.sh to also invoke define_opcode.sh to also define the corresponding SPIR-V instruction opcode enum. PiperOrigin-RevId: 272014876 --- mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td | 19 +- .../mlir/Dialect/SPIRV/SPIRVLogicalOps.td | 224 ++++++++++++++++-- mlir/lib/Dialect/SPIRV/SPIRVOps.cpp | 36 ++- mlir/test/Dialect/SPIRV/ops.mlir | 107 +++++++++ mlir/utils/spirv/define_inst.sh | 12 +- mlir/utils/spirv/gen_spirv_dialect.py | 59 +++-- 6 files changed, 393 insertions(+), 64 deletions(-) diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td index 1440f75026b8..6b1d20d62cbc 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -132,6 +132,11 @@ def SPV_OC_OpSRem : I32EnumAttrCase<"OpSRem", 138>; def SPV_OC_OpSMod : I32EnumAttrCase<"OpSMod", 139>; def SPV_OC_OpFRem : I32EnumAttrCase<"OpFRem", 140>; def SPV_OC_OpFMod : I32EnumAttrCase<"OpFMod", 141>; +def SPV_OC_OpLogicalEqual : I32EnumAttrCase<"OpLogicalEqual", 164>; +def SPV_OC_OpLogicalNotEqual : I32EnumAttrCase<"OpLogicalNotEqual", 165>; +def SPV_OC_OpLogicalOr : I32EnumAttrCase<"OpLogicalOr", 166>; +def SPV_OC_OpLogicalAnd : I32EnumAttrCase<"OpLogicalAnd", 167>; +def SPV_OC_OpLogicalNot : I32EnumAttrCase<"OpLogicalNot", 168>; def SPV_OC_OpSelect : I32EnumAttrCase<"OpSelect", 169>; def SPV_OC_OpIEqual : I32EnumAttrCase<"OpIEqual", 170>; def SPV_OC_OpINotEqual : I32EnumAttrCase<"OpINotEqual", 171>; @@ -184,12 +189,14 @@ def SPV_OpcodeAttr : SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, - SPV_OC_OpSelect, SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, - SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, - SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, - SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, - SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, - SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan, + SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual, SPV_OC_OpLogicalOr, + SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect, SPV_OC_OpIEqual, + SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan, + SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan, + SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual, + SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual, + SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, SPV_OC_OpFUnordLessThan, + SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan, SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual, SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual, SPV_OC_OpControlBarrier, SPV_OC_OpMemoryBarrier, SPV_OC_OpLoopMerge, diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td index 1e9a5478e43a..b3b7df69f12d 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVLogicalOps.td @@ -29,20 +29,30 @@ include "mlir/SPIRV/SPIRVBase.td" #endif // SPIRV_BASE -class SPV_LogicalOp traits = []> : // Result type is SPV_Bool. SPV_BinaryOp { - let parser = [{ return ::parseBinaryLogicalOp(parser, result); }]; - let printer = [{ return ::printBinaryLogicalOp(getOperation(), p); }]; + let parser = [{ return ::parseLogicalBinaryOp(parser, result); }]; + let printer = [{ return ::printLogicalOp(getOperation(), p); }]; +} + +class SPV_LogicalUnaryOp traits = []> : + // Result type is SPV_Bool. + SPV_UnaryOp { + let parser = [{ return ::parseLogicalUnaryOp(parser, result); }]; + let printer = [{ return ::printLogicalOp(getOperation(), p); }]; } // ----- -def SPV_FOrdEqualOp : SPV_LogicalOp<"FOrdEqual", SPV_Float, [Commutative]> { +def SPV_FOrdEqualOp : SPV_LogicalBinaryOp<"FOrdEqual", SPV_Float, [Commutative]> { let summary = "Floating-point comparison for being ordered and equal."; let description = [{ @@ -73,7 +83,7 @@ def SPV_FOrdEqualOp : SPV_LogicalOp<"FOrdEqual", SPV_Float, [Commutative]> { // ----- -def SPV_FOrdGreaterThanOp : SPV_LogicalOp<"FOrdGreaterThan", SPV_Float, []> { +def SPV_FOrdGreaterThanOp : SPV_LogicalBinaryOp<"FOrdGreaterThan", SPV_Float, []> { let summary = [{ Floating-point comparison if operands are ordered and Operand 1 is greater than Operand 2. @@ -107,7 +117,7 @@ def SPV_FOrdGreaterThanOp : SPV_LogicalOp<"FOrdGreaterThan", SPV_Float, []> { // ----- -def SPV_FOrdGreaterThanEqualOp : SPV_LogicalOp<"FOrdGreaterThanEqual", SPV_Float, []> { +def SPV_FOrdGreaterThanEqualOp : SPV_LogicalBinaryOp<"FOrdGreaterThanEqual", SPV_Float, []> { let summary = [{ Floating-point comparison if operands are ordered and Operand 1 is greater than or equal to Operand 2. @@ -141,7 +151,7 @@ def SPV_FOrdGreaterThanEqualOp : SPV_LogicalOp<"FOrdGreaterThanEqual", SPV_Float // ----- -def SPV_FOrdLessThanOp : SPV_LogicalOp<"FOrdLessThan", SPV_Float, []> { +def SPV_FOrdLessThanOp : SPV_LogicalBinaryOp<"FOrdLessThan", SPV_Float, []> { let summary = [{ Floating-point comparison if operands are ordered and Operand 1 is less than Operand 2. @@ -175,7 +185,7 @@ def SPV_FOrdLessThanOp : SPV_LogicalOp<"FOrdLessThan", SPV_Float, []> { // ----- -def SPV_FOrdLessThanEqualOp : SPV_LogicalOp<"FOrdLessThanEqual", SPV_Float, []> { +def SPV_FOrdLessThanEqualOp : SPV_LogicalBinaryOp<"FOrdLessThanEqual", SPV_Float, []> { let summary = [{ Floating-point comparison if operands are ordered and Operand 1 is less than or equal to Operand 2. @@ -209,7 +219,7 @@ def SPV_FOrdLessThanEqualOp : SPV_LogicalOp<"FOrdLessThanEqual", SPV_Float, []> // ----- -def SPV_FOrdNotEqualOp : SPV_LogicalOp<"FOrdNotEqual", SPV_Float, [Commutative]> { +def SPV_FOrdNotEqualOp : SPV_LogicalBinaryOp<"FOrdNotEqual", SPV_Float, [Commutative]> { let summary = "Floating-point comparison for being ordered and not equal."; let description = [{ @@ -240,7 +250,7 @@ def SPV_FOrdNotEqualOp : SPV_LogicalOp<"FOrdNotEqual", SPV_Float, [Commutative]> // ----- -def SPV_FUnordEqualOp : SPV_LogicalOp<"FUnordEqual", SPV_Float, [Commutative]> { +def SPV_FUnordEqualOp : SPV_LogicalBinaryOp<"FUnordEqual", SPV_Float, [Commutative]> { let summary = "Floating-point comparison for being unordered or equal."; let description = [{ @@ -271,7 +281,7 @@ def SPV_FUnordEqualOp : SPV_LogicalOp<"FUnordEqual", SPV_Float, [Commutative]> { // ----- -def SPV_FUnordGreaterThanOp : SPV_LogicalOp<"FUnordGreaterThan", SPV_Float, []> { +def SPV_FUnordGreaterThanOp : SPV_LogicalBinaryOp<"FUnordGreaterThan", SPV_Float, []> { let summary = [{ Floating-point comparison if operands are unordered or Operand 1 is greater than Operand 2. @@ -305,7 +315,7 @@ def SPV_FUnordGreaterThanOp : SPV_LogicalOp<"FUnordGreaterThan", SPV_Float, []> // ----- -def SPV_FUnordGreaterThanEqualOp : SPV_LogicalOp<"FUnordGreaterThanEqual", SPV_Float, []> { +def SPV_FUnordGreaterThanEqualOp : SPV_LogicalBinaryOp<"FUnordGreaterThanEqual", SPV_Float, []> { let summary = [{ Floating-point comparison if operands are unordered or Operand 1 is greater than or equal to Operand 2. @@ -339,7 +349,7 @@ def SPV_FUnordGreaterThanEqualOp : SPV_LogicalOp<"FUnordGreaterThanEqual", SPV_F // ----- -def SPV_FUnordLessThanOp : SPV_LogicalOp<"FUnordLessThan", SPV_Float, []> { +def SPV_FUnordLessThanOp : SPV_LogicalBinaryOp<"FUnordLessThan", SPV_Float, []> { let summary = [{ Floating-point comparison if operands are unordered or Operand 1 is less than Operand 2. @@ -373,7 +383,7 @@ def SPV_FUnordLessThanOp : SPV_LogicalOp<"FUnordLessThan", SPV_Float, []> { // ----- -def SPV_FUnordLessThanEqualOp : SPV_LogicalOp<"FUnordLessThanEqual", SPV_Float, []> { +def SPV_FUnordLessThanEqualOp : SPV_LogicalBinaryOp<"FUnordLessThanEqual", SPV_Float, []> { let summary = [{ Floating-point comparison if operands are unordered or Operand 1 is less than or equal to Operand 2. @@ -407,7 +417,7 @@ def SPV_FUnordLessThanEqualOp : SPV_LogicalOp<"FUnordLessThanEqual", SPV_Float, // ----- -def SPV_FUnordNotEqualOp : SPV_LogicalOp<"FUnordNotEqual", SPV_Float, [Commutative]> { +def SPV_FUnordNotEqualOp : SPV_LogicalBinaryOp<"FUnordNotEqual", SPV_Float, [Commutative]> { let summary = "Floating-point comparison for being unordered or not equal."; let description = [{ @@ -438,7 +448,7 @@ def SPV_FUnordNotEqualOp : SPV_LogicalOp<"FUnordNotEqual", SPV_Float, [Commutati // ----- -def SPV_IEqualOp : SPV_LogicalOp<"IEqual", SPV_Integer, [Commutative]> { +def SPV_IEqualOp : SPV_LogicalBinaryOp<"IEqual", SPV_Integer, [Commutative]> { let summary = "Integer comparison for equality."; let description = [{ @@ -469,7 +479,7 @@ def SPV_IEqualOp : SPV_LogicalOp<"IEqual", SPV_Integer, [Commutative]> { // ----- -def SPV_INotEqualOp : SPV_LogicalOp<"INotEqual", SPV_Integer, [Commutative]> { +def SPV_INotEqualOp : SPV_LogicalBinaryOp<"INotEqual", SPV_Integer, [Commutative]> { let summary = "Integer comparison for inequality."; let description = [{ @@ -500,7 +510,168 @@ def SPV_INotEqualOp : SPV_LogicalOp<"INotEqual", SPV_Integer, [Commutative]> { // ----- -def SPV_SGreaterThanOp : SPV_LogicalOp<"SGreaterThan", SPV_Integer, []> { +def SPV_LogicalAndOp : SPV_LogicalBinaryOp<"LogicalAnd", SPV_Bool, [Commutative]> { + let summary = [{ + Result is true if both Operand 1 and Operand 2 are true. Result is false + if either Operand 1 or Operand 2 are false. + }]; + + let description = [{ + Result Type must be a scalar or vector of Boolean type. + + The type of Operand 1 must be the same as Result Type. + + The type of Operand 2 must be the same as Result Type. + + Results are computed per component. + + ### Custom assembly form + + ``` {.ebnf} + logical-and ::= `spv.LogicalAnd` ssa-use `,` ssa-use + `:` operand-type + ``` + + For example: + + ``` + %2 = spv.LogicalAnd %0, %1 : i1 + %2 = spv.LogicalAnd %0, %1 : vector<4xi1> + ``` + }]; +} + +// ----- + +def SPV_LogicalEqualOp : SPV_LogicalBinaryOp<"LogicalEqual", SPV_Bool, [Commutative]> { + let summary = [{ + Result is true if Operand 1 and Operand 2 have the same value. Result is + false if Operand 1 and Operand 2 have different values. + }]; + + let description = [{ + Result Type must be a scalar or vector of Boolean type. + + The type of Operand 1 must be the same as Result Type. + + The type of Operand 2 must be the same as Result Type. + + Results are computed per component. + + ### Custom assembly form + + ``` {.ebnf} + logical-equal ::= `spv.LogicalEqual` ssa-use `,` ssa-use + `:` operand-type + ``` + + For example: + + ``` + %2 = spv.LogicalEqual %0, %1 : i1 + %2 = spv.LogicalEqual %0, %1 : vector<4xi1> + ``` + }]; +} + +// ----- + +def SPV_LogicalNotOp : SPV_LogicalUnaryOp<"LogicalNot", SPV_Bool, []> { + let summary = [{ + Result is true if Operand is false. Result is false if Operand is true. + }]; + + let description = [{ + Result Type must be a scalar or vector of Boolean type. + + The type of Operand must be the same as Result Type. + + Results are computed per component. + + ### Custom assembly form + + ``` {.ebnf} + logical-not ::= `spv.LogicalNot` ssa-use `:` operand-type + ``` + + For example: + + ``` + %2 = spv.LogicalNot %0 : i1 + %2 = spv.LogicalNot %0 : vector<4xi1> + ``` + }]; +} + +// ----- + +def SPV_LogicalNotEqualOp : SPV_LogicalBinaryOp<"LogicalNotEqual", SPV_Bool, [Commutative]> { + let summary = [{ + Result is true if Operand 1 and Operand 2 have different values. Result + is false if Operand 1 and Operand 2 have the same value. + }]; + + let description = [{ + Result Type must be a scalar or vector of Boolean type. + + The type of Operand 1 must be the same as Result Type. + + The type of Operand 2 must be the same as Result Type. + + Results are computed per component. + + ### Custom assembly form + + ``` {.ebnf} + logical-not-equal ::= `spv.LogicalNotEqual` ssa-use `,` ssa-use + `:` operand-type + ``` + + For example: + + ``` + %2 = spv.LogicalNotEqual %0, %1 : i1 + %2 = spv.LogicalNotEqual %0, %1 : vector<4xi1> + ``` + }]; +} + +// ----- + +def SPV_LogicalOrOp : SPV_LogicalBinaryOp<"LogicalOr", SPV_Bool, [Commutative]> { + let summary = [{ + Result is true if either Operand 1 or Operand 2 is true. Result is false + if both Operand 1 and Operand 2 are false. + }]; + + let description = [{ + Result Type must be a scalar or vector of Boolean type. + + The type of Operand 1 must be the same as Result Type. + + The type of Operand 2 must be the same as Result Type. + + Results are computed per component. + + ### Custom assembly form + + ``` {.ebnf} + logical-or ::= `spv.LogicalOr` ssa-use `,` ssa-use + `:` operand-type + ``` + + For example: + + ``` + %2 = spv.LogicalOr %0, %1 : i1 + %2 = spv.LogicalOr %0, %1 : vector<4xi1> + ``` + }]; +} + +// ----- + +def SPV_SGreaterThanOp : SPV_LogicalBinaryOp<"SGreaterThan", SPV_Integer, []> { let summary = [{ Signed-integer comparison if Operand 1 is greater than Operand 2. }]; @@ -533,7 +704,7 @@ def SPV_SGreaterThanOp : SPV_LogicalOp<"SGreaterThan", SPV_Integer, []> { // ----- -def SPV_SGreaterThanEqualOp : SPV_LogicalOp<"SGreaterThanEqual", SPV_Integer, []> { +def SPV_SGreaterThanEqualOp : SPV_LogicalBinaryOp<"SGreaterThanEqual", SPV_Integer, []> { let summary = [{ Signed-integer comparison if Operand 1 is greater than or equal to Operand 2. @@ -567,7 +738,7 @@ def SPV_SGreaterThanEqualOp : SPV_LogicalOp<"SGreaterThanEqual", SPV_Integer, [] // ----- -def SPV_SLessThanOp : SPV_LogicalOp<"SLessThan", SPV_Integer, []> { +def SPV_SLessThanOp : SPV_LogicalBinaryOp<"SLessThan", SPV_Integer, []> { let summary = [{ Signed-integer comparison if Operand 1 is less than Operand 2. }]; @@ -600,7 +771,7 @@ def SPV_SLessThanOp : SPV_LogicalOp<"SLessThan", SPV_Integer, []> { // ----- -def SPV_SLessThanEqualOp : SPV_LogicalOp<"SLessThanEqual", SPV_Integer, []> { +def SPV_SLessThanEqualOp : SPV_LogicalBinaryOp<"SLessThanEqual", SPV_Integer, []> { let summary = [{ Signed-integer comparison if Operand 1 is less than or equal to Operand 2. @@ -634,7 +805,7 @@ def SPV_SLessThanEqualOp : SPV_LogicalOp<"SLessThanEqual", SPV_Integer, []> { // ----- -def SPV_SelectOp : SPV_Op<"Select", []> { +def SPV_SelectOp : SPV_Op<"Select", [NoSideEffect]> { let summary = [{ Select between two objects. Before version 1.4, results are only computed per component. @@ -691,7 +862,7 @@ def SPV_SelectOp : SPV_Op<"Select", []> { // ----- -def SPV_UGreaterThanOp : SPV_LogicalOp<"UGreaterThan", SPV_Integer, []> { +def SPV_UGreaterThanOp : SPV_LogicalBinaryOp<"UGreaterThan", SPV_Integer, []> { let summary = [{ Unsigned-integer comparison if Operand 1 is greater than Operand 2. }]; @@ -724,7 +895,7 @@ def SPV_UGreaterThanOp : SPV_LogicalOp<"UGreaterThan", SPV_Integer, []> { // ----- -def SPV_UGreaterThanEqualOp : SPV_LogicalOp<"UGreaterThanEqual", SPV_Integer, []> { +def SPV_UGreaterThanEqualOp : SPV_LogicalBinaryOp<"UGreaterThanEqual", SPV_Integer, []> { let summary = [{ Unsigned-integer comparison if Operand 1 is greater than or equal to Operand 2. @@ -758,7 +929,7 @@ def SPV_UGreaterThanEqualOp : SPV_LogicalOp<"UGreaterThanEqual", SPV_Integer, [] // ----- -def SPV_ULessThanOp : SPV_LogicalOp<"ULessThan", SPV_Integer, []> { +def SPV_ULessThanOp : SPV_LogicalBinaryOp<"ULessThan", SPV_Integer, []> { let summary = [{ Unsigned-integer comparison if Operand 1 is less than Operand 2. }]; @@ -791,7 +962,8 @@ def SPV_ULessThanOp : SPV_LogicalOp<"ULessThan", SPV_Integer, []> { // ----- -def SPV_ULessThanEqualOp : SPV_LogicalOp<"ULessThanEqual", SPV_Integer, []> { +def SPV_ULessThanEqualOp : + SPV_LogicalBinaryOp<"ULessThanEqual", SPV_Integer, []> { let summary = [{ Unsigned-integer comparison if Operand 1 is less than or equal to Operand 2. diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index 408d365c250c..f6ae3e4af2b6 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -367,7 +367,28 @@ static void printUnaryOp(Operation *unaryOp, OpAsmPrinter &printer) { << unaryOp->getOperand(0)->getType(); } -static ParseResult parseBinaryLogicalOp(OpAsmParser &parser, +/// Result of a logical op must be a scalar or vector of boolean type. +static Type getUnaryOpResultType(Builder &builder, Type operandType) { + Type resultType = builder.getIntegerType(1); + if (auto vecType = operandType.dyn_cast()) { + return VectorType::get(vecType.getNumElements(), resultType); + } + return resultType; +} + +static ParseResult parseLogicalUnaryOp(OpAsmParser &parser, + OperationState &state) { + OpAsmParser::OperandType operandInfo; + Type type; + if (parser.parseOperand(operandInfo) || parser.parseColonType(type) || + parser.resolveOperand(operandInfo, type, state.operands)) { + return failure(); + } + state.addTypes(getUnaryOpResultType(parser.getBuilder(), type)); + return success(); +} + +static ParseResult parseLogicalBinaryOp(OpAsmParser &parser, OperationState &result) { SmallVector ops; Type type; @@ -375,18 +396,13 @@ static ParseResult parseBinaryLogicalOp(OpAsmParser &parser, parser.resolveOperands(ops, type, result.operands)) { return failure(); } - // Result must be a scalar or vector of boolean type. - Type resultType = parser.getBuilder().getIntegerType(1); - if (auto opsType = type.dyn_cast()) { - resultType = VectorType::get(opsType.getNumElements(), resultType); - } - result.addTypes(resultType); + result.addTypes(getUnaryOpResultType(parser.getBuilder(), type)); return success(); } -static void printBinaryLogicalOp(Operation *logicalOp, OpAsmPrinter &printer) { - printer << logicalOp->getName() << ' ' << *logicalOp->getOperand(0) << ", " - << *logicalOp->getOperand(1); +static void printLogicalOp(Operation *logicalOp, OpAsmPrinter &printer) { + printer << logicalOp->getName() << ' '; + printer.printOperands(logicalOp->getOperands()); printer << " : " << logicalOp->getOperand(0)->getType(); } diff --git a/mlir/test/Dialect/SPIRV/ops.mlir b/mlir/test/Dialect/SPIRV/ops.mlir index 8c4b0fa2c58b..d24015a8def6 100644 --- a/mlir/test/Dialect/SPIRV/ops.mlir +++ b/mlir/test/Dialect/SPIRV/ops.mlir @@ -606,6 +606,113 @@ spv.module "Logical" "GLSL450" { // ----- +//===----------------------------------------------------------------------===// +// spv.LogicalAnd +//===----------------------------------------------------------------------===// + +func @logicalBinary(%arg0 : i1, %arg1 : i1, %arg2 : i1) +{ + // CHECK: [[TMP:%.*]] = spv.LogicalAnd {{%.*}}, {{%.*}} : i1 + %0 = spv.LogicalAnd %arg0, %arg1 : i1 + // CHECK: {{%.*}} = spv.LogicalAnd [[TMP]], {{%.*}} : i1 + %1 = spv.LogicalAnd %0, %arg2 : i1 + return +} + +func @logicalBinary2(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) +{ + // CHECK: {{%.*}} = spv.LogicalAnd {{%.*}}, {{%.*}} : vector<4xi1> + %0 = spv.LogicalAnd %arg0, %arg1 : vector<4xi1> + return +} + +// ----- + +func @logicalBinary(%arg0 : i1, %arg1 : i1) +{ + // expected-error @+2 {{expected ':'}} + %0 = spv.LogicalAnd %arg0, %arg1 + return +} + +// ----- + +func @logicalBinary(%arg0 : i1, %arg1 : i1) +{ + // expected-error @+2 {{expected non-function type}} + %0 = spv.LogicalAnd %arg0, %arg1 : + return +} + +// ----- + +func @logicalBinary(%arg0 : i1, %arg1 : i1) +{ + // expected-error @+1 {{custom op 'spv.LogicalAnd' expected 2 operands}} + %0 = spv.LogicalAnd %arg0 : i1 + return +} + +// ----- + +//===----------------------------------------------------------------------===// +// spv.LogicalNot +//===----------------------------------------------------------------------===// + +func @logicalUnary(%arg0 : i1, %arg1 : i1) +{ + // CHECK: [[TMP:%.*]] = spv.LogicalNot {{%.*}} : i1 + %0 = spv.LogicalNot %arg0 : i1 + // CHECK: {{%.*}} = spv.LogicalNot [[TMP]] : i1 + %1 = spv.LogicalNot %0 : i1 + return +} + +func @logicalUnary2(%arg0 : vector<4xi1>) +{ + // CHECK: {{%.*}} = spv.LogicalNot {{%.*}} : vector<4xi1> + %0 = spv.LogicalNot %arg0 : vector<4xi1> + return +} + +// ----- + +func @logicalUnary(%arg0 : i1) +{ + // expected-error @+2 {{expected ':'}} + %0 = spv.LogicalNot %arg0 + return +} + +// ----- + +func @logicalUnary(%arg0 : i1) +{ + // expected-error @+2 {{expected non-function type}} + %0 = spv.LogicalNot %arg0 : + return +} + +// ----- + +func @logicalUnary(%arg0 : i1) +{ + // expected-error @+1 {{expected SSA operand}} + %0 = spv.LogicalNot : i1 + return +} + +// ----- + +func @logicalUnary(%arg0 : i32) +{ + // expected-error @+1 {{'spv.LogicalNot' op operand #0 must be 1-bit integer or vector of 1-bit integer values of length 2/3/4, but got 'i32'}} + %0 = spv.LogicalNot %arg0 : i32 + return +} + +// ----- + //===----------------------------------------------------------------------===// // spv.MemoryBarrier //===----------------------------------------------------------------------===// diff --git a/mlir/utils/spirv/define_inst.sh b/mlir/utils/spirv/define_inst.sh index 55e2fa0ed9b2..328862b3d2ca 100755 --- a/mlir/utils/spirv/define_inst.sh +++ b/mlir/utils/spirv/define_inst.sh @@ -1,5 +1,4 @@ #!/bin/bash - # Copyright 2019 The MLIR Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -30,19 +29,21 @@ # ./define_inst.sh LogicalOp OpFOrdEqual set -e -inst_category=$1 +file_name=$1 +inst_category=$2 case $inst_category in Op | ArithmeticOp | LogicalOp | ControlFlowOp | StructureOp) ;; *) - echo "Usage : " $0 " ()*" + echo "Usage : " $0 " ()*" echo " must be one of " \ "(Op|ArithmeticOp|LogicalOp|ControlFlowOp|StructureOp)" exit 1; ;; esac +shift shift current_file="$(readlink -f "$0")" @@ -50,5 +51,8 @@ current_dir="$(dirname "$current_file")" python3 ${current_dir}/gen_spirv_dialect.py \ --op-td-path \ - ${current_dir}/../../include/mlir/Dialect/SPIRV/SPIRV${inst_category}s.td \ + ${current_dir}/../../include/mlir/Dialect/SPIRV/${file_name} \ --inst-category $inst_category --new-inst "$@" + +${current_dir}/define_opcodes.sh "$@" + diff --git a/mlir/utils/spirv/gen_spirv_dialect.py b/mlir/utils/spirv/gen_spirv_dialect.py index e505b4054c50..ff096df0116b 100755 --- a/mlir/utils/spirv/gen_spirv_dialect.py +++ b/mlir/utils/spirv/gen_spirv_dialect.py @@ -364,7 +364,22 @@ def map_spec_operand_to_ods_argument(operand): return '{}:${}'.format(arg_type, name) -def get_op_definition(instruction, doc, existing_info, inst_category): +def get_description(text, assembly): + """Generates the description for the given SPIR-V instruction. + + Arguments: + - text: Textual description of the operation as string. + - assembly: Custom Assembly format with example as string. + + Returns: + - A string that corresponds to the description of the Tablegen op. + """ + fmt_str = ('{text}\n\n ### Custom assembly ' 'form\n{assembly}}}];\n') + return fmt_str.format( + text=text, assembly=assembly) + + +def get_op_definition(instruction, doc, existing_info): """Generates the TableGen op definition for the given SPIR-V instruction. Arguments: @@ -379,8 +394,8 @@ def get_op_definition(instruction, doc, existing_info, inst_category): fmt_str = ('def SPV_{opname}Op : ' 'SPV_{inst_category}<"{opname}"{category_args}[{traits}]> ' '{{\n let summary = {summary};\n\n let description = ' - '[{{\n{description}\n\n ### Custom assembly ' - 'form\n{assembly}}}];\n') + '[{{\n{description}}}];\n') + inst_category = existing_info.get('inst_category', 'Op') if inst_category == 'Op': fmt_str +='\n let arguments = (ins{args});\n\n'\ ' let results = (outs{results});\n' @@ -393,7 +408,7 @@ def get_op_definition(instruction, doc, existing_info, inst_category): # Make sure we have ', ' to separate the category arguments from traits category_args = category_args.rstrip(', ') + ', ' - summary, description = doc.split('\n', 1) + summary, text = doc.split('\n', 1) wrapper = textwrap.TextWrapper( width=76, initial_indent=' ', subsequent_indent=' ') @@ -405,10 +420,10 @@ def get_op_definition(instruction, doc, existing_info, inst_category): else: summary = '[{{\n{}\n }}]'.format(wrapper.fill(summary)) - # Wrap description - description = description.split('\n') - description = [wrapper.fill(line) for line in description if line] - description = '\n\n'.join(description) + # Wrap text + text = text.split('\n') + text = [wrapper.fill(line) for line in text if line] + text = '\n\n'.join(text) operands = instruction.get('operands', []) @@ -433,8 +448,8 @@ def get_op_definition(instruction, doc, existing_info, inst_category): # Prepend and append whitespace for formatting arguments = '\n {}\n '.format(arguments) - assembly = existing_info.get('assembly', None) - if assembly is None: + description = existing_info.get('description', None) + if description is None: assembly = '\n ``` {.ebnf}\n'\ ' [TODO]\n'\ ' ```\n\n'\ @@ -442,6 +457,7 @@ def get_op_definition(instruction, doc, existing_info, inst_category): ' ```\n'\ ' [TODO]\n'\ ' ```\n ' + description = get_description(text, assembly) return fmt_str.format( opname=opname, @@ -450,7 +466,6 @@ def get_op_definition(instruction, doc, existing_info, inst_category): traits=existing_info.get('traits', ''), summary=summary, description=description, - assembly=assembly, args=arguments, results=results, extras=existing_info.get('extras', '')) @@ -493,6 +508,14 @@ def extract_td_op_info(op_def): assert len(opname) == 1, 'more than one ops in the same section!' opname = opname[0] + # Get instruction category + inst_category = [ + o[4:] for o in re.findall('SPV_\w+Op', + op_def.split(':', 1)[1]) + ] + assert len(inst_category) <= 1, 'more than one ops in the same section!' + inst_category = inst_category[0] if len(inst_category) == 1 else 'Op' + # Get category_args op_tmpl_params = op_def.split('<', 1)[1].split('>', 1)[0] opstringname, rest = get_string_between(op_tmpl_params, '"', '"') @@ -501,9 +524,9 @@ def extract_td_op_info(op_def): # Get traits traits, _ = get_string_between(rest, '[', ']') - # Get custom assembly form - assembly, rest = get_string_between(op_def, '### Custom assembly form\n', - '}];\n') + # Get description + description, rest = get_string_between(op_def, 'let description = [{\n', + '}];\n') # Get arguments args, rest = get_string_between(rest, ' let arguments = (ins', ');\n') @@ -518,9 +541,10 @@ def extract_td_op_info(op_def): return { # Prefix with 'Op' to make it consistent with SPIR-V spec 'opname': 'Op{}'.format(opname), + 'inst_category': inst_category, 'category_args': category_args, 'traits': traits, - 'assembly': assembly, + 'description': description, 'arguments': args, 'results': results, 'extras': extras @@ -567,7 +591,7 @@ def update_td_op_definitions(path, instructions, docs, filter_list, inst for inst in instructions if inst['opname'] == opname) op_defs.append( get_op_definition(instruction, docs[opname], - op_info_dict.get(opname, {}), inst_category)) + op_info_dict.get(opname, {}))) # Substitute the old op definitions op_defs = [header] + op_defs + [footer] @@ -622,8 +646,7 @@ if __name__ == '__main__': type=str, default='Op', help='SPIR-V instruction category used for choosing '\ - 'a suitable .td file and TableGen common base '\ - 'class to define this op') + 'the TableGen base class to define this op') args = cli_parser.parse_args()