From d423d4a3387a90a5b6634f087bdef838566008d3 Mon Sep 17 00:00:00 2001 From: Denis Khalikov Date: Wed, 30 Oct 2019 14:41:26 -0700 Subject: [PATCH] [spirv] Add cast operations This CL added op definitions for a few cast operations: * OpConvertFToU * OpConvertFToS * OpConvertSToF * OpConvertUToF * OpUConvert * OpSConvert * OpFConvert Also moved the definition of spv.Bitcast to the new file. Closes tensorflow/mlir#208 and tensorflow/mlir#174 COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/208 from denis0x0D:sandbox/cast_ops 79bc9b37398aafddee6cf6beb301807988fe67f9 PiperOrigin-RevId: 277587891 --- mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td | 11 +- .../mlir/Dialect/SPIRV/SPIRVCastOps.td | 336 ++++++++++++++++++ mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td | 57 +-- mlir/lib/Dialect/SPIRV/SPIRVOps.cpp | 56 +-- .../Dialect/SPIRV/Serialization/cast-ops.mlir | 48 ++- mlir/test/Dialect/SPIRV/ops.mlir | 150 +++++++- mlir/utils/spirv/define_inst.sh | 4 +- 7 files changed, 564 insertions(+), 98 deletions(-) create mode 100644 mlir/include/mlir/Dialect/SPIRV/SPIRVCastOps.td diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td index a6446e0d8661..6dbf28f877c6 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -117,6 +117,13 @@ def SPV_OC_OpAccessChain : I32EnumAttrCase<"OpAccessChain", 65>; def SPV_OC_OpDecorate : I32EnumAttrCase<"OpDecorate", 71>; def SPV_OC_OpMemberDecorate : I32EnumAttrCase<"OpMemberDecorate", 72>; def SPV_OC_OpCompositeExtract : I32EnumAttrCase<"OpCompositeExtract", 81>; +def SPV_OC_OpConvertFToU : I32EnumAttrCase<"OpConvertFToU", 109>; +def SPV_OC_OpConvertFToS : I32EnumAttrCase<"OpConvertFToS", 110>; +def SPV_OC_OpConvertSToF : I32EnumAttrCase<"OpConvertSToF", 111>; +def SPV_OC_OpConvertUToF : I32EnumAttrCase<"OpConvertUToF", 112>; +def SPV_OC_OpUConvert : I32EnumAttrCase<"OpUConvert", 113>; +def SPV_OC_OpSConvert : I32EnumAttrCase<"OpSConvert", 114>; +def SPV_OC_OpFConvert : I32EnumAttrCase<"OpFConvert", 115>; def SPV_OC_OpBitcast : I32EnumAttrCase<"OpBitcast", 124>; def SPV_OC_OpFNegate : I32EnumAttrCase<"OpFNegate", 127>; def SPV_OC_OpIAdd : I32EnumAttrCase<"OpIAdd", 128>; @@ -192,7 +199,9 @@ def SPV_OpcodeAttr : SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction, SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall, SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpAccessChain, SPV_OC_OpDecorate, - SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeExtract, SPV_OC_OpBitcast, + SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeExtract, SPV_OC_OpConvertFToU, + SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF, SPV_OC_OpConvertUToF, + SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert, SPV_OC_OpBitcast, SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVCastOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVCastOps.td new file mode 100644 index 000000000000..1e90707008d2 --- /dev/null +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVCastOps.td @@ -0,0 +1,336 @@ +//===-- SPIRVCastOps.td - MLIR SPIR-V Cast Ops -------*- tablegen -*-------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file contains cast ops for the SPIR-V dialect. It corresponds +// to "3.32.11. Convertion Instructions" of the SPIR-V specification. +// +//===----------------------------------------------------------------------===// + +#ifdef SPIRV_CAST_OPS +#else +#define SPIRV_CAST_OPS + +#ifdef SPIRV_BASE +#else +include "mlir/SPIRV/SPIRVBase.td" +#endif // SPIRV_BASE + +class SPV_CastOp traits = []> : + SPV_Op { + let arguments = (ins + SPV_ScalarOrVectorOf:$operand + ); + + let results = (outs + SPV_ScalarOrVectorOf:$result + ); + + let parser = [{ return mlir::impl::parseCastOp(parser, result); }]; + let printer = [{ mlir::impl::printCastOp(this->getOperation(), p); }]; + let verifier = [{ return verifyCastOp(this->getOperation()); }]; +} + +// ----- + +def SPV_BitcastOp : SPV_Op<"Bitcast", [NoSideEffect]> { + let summary = "Bit pattern-preserving type conversion."; + + let description = [{ + Result Type must be an OpTypePointer, or a scalar or vector of + numerical-type. + + Operand must have a type of OpTypePointer, or a scalar or vector of + numerical-type. It must be a different type than Result Type. + + If either Result Type or Operand is a pointer, the other must be a + pointer (diverges from the SPIR-V spec). + + If Result Type has a different number of components than Operand, the + total number of bits in Result Type must equal the total number of bits + in Operand. Let L be the type, either Result Type or Operand’s type, + that has the larger number of components. Let S be the other type, with + the smaller number of components. The number of components in L must be + an integer multiple of the number of components in S. The first + component (that is, the only or lowest-numbered component) of S maps to + the first components of L, and so on, up to the last component of S + mapping to the last components of L. Within this mapping, any single + component of S (mapping to multiple components of L) maps its lower- + ordered bits to the lower-numbered components of L. + + ### Custom assembly form + + ``` {.ebnf} + bitcast-op ::= ssa-id `=` `spv.Bitcast` ssa-use + `:` operand-type `to` result-type + ``` + + For example: + + ``` + %1 = spv.Bitcast %0 : f32 to i32 + %1 = spv.Bitcast %0 : vector<2xf32> to i64 + %1 = spv.Bitcast %0 : !spv.ptr to !spv.ptr + ``` + }]; + + let arguments = (ins + SPV_ScalarOrVectorOrPtr:$operand + ); + + let results = (outs + SPV_ScalarOrVectorOrPtr:$result + ); + + let parser = [{ return mlir::impl::parseCastOp(parser, result); }]; + let printer = [{ mlir::impl::printCastOp(this->getOperation(), p); }]; +} + +// ----- + +def SPV_ConvertFToSOp : SPV_CastOp<"ConvertFToS", SPV_Integer, SPV_Float, []> { + let summary = [{ + Convert value numerically from floating point to signed integer, with + round toward 0.0. + }]; + + let description = [{ + Result Type must be a scalar or vector of integer type. + + Float Value must be a scalar or vector of floating-point type. It must + have the same number of components as Result Type. + + Results are computed per component. + + ### Custom assembly form + + ``` {.ebnf} + convert-f-to-s-op ::= ssa-id `=` `spv.ConvertFToSOp` ssa-use + `:` operand-type `to` result-type + ``` + + For example: + + ``` + %1 = spv.ConvertFToS %0 : f32 to i32 + %3 = spv.ConvertFToS %2 : vector<3xf32> to vector<3xi32> + ``` + }]; +} + +// ----- + +def SPV_ConvertFToUOp : SPV_CastOp<"ConvertFToU", SPV_Integer, SPV_Float, []> { + let summary = [{ + Convert value numerically from floating point to unsigned integer, with + round toward 0.0. + }]; + + let description = [{ + Result Type must be a scalar or vector of integer type, whose Signedness + operand is 0. + + Float Value must be a scalar or vector of floating-point type. It must + have the same number of components as Result Type. + + Results are computed per component. + + ### Custom assembly form + + ``` {.ebnf} + convert-f-to-u-op ::= ssa-id `=` `spv.ConvertFToUOp` ssa-use + `:` operand-type `to` result-type + ``` + + For example: + + ``` + %1 = spv.ConvertFToU %0 : f32 to i32 + %3 = spv.ConvertFToU %2 : vector<3xf32> to vector<3xi32> + ``` + }]; +} + +// ----- + +def SPV_ConvertSToFOp : SPV_CastOp<"ConvertSToF", SPV_Float, SPV_Integer, []> { + let summary = [{ + Convert value numerically from signed integer to floating point. + }]; + + let description = [{ + Result Type must be a scalar or vector of floating-point type. + + Signed Value must be a scalar or vector of integer type. It must have + the same number of components as Result Type. + + Results are computed per component. + + ### Custom assembly form + + ``` {.ebnf} + convert-s-to-f-op ::= ssa-id `=` `spv.ConvertSToFOp` ssa-use + `:` operand-type `to` result-type + ``` + + For example: + + ``` + %1 = spv.ConvertSToF %0 : i32 to f32 + %3 = spv.ConvertSToF %2 : vector<3xi32> to vector<3xf32> + ``` + }]; +} + +// ----- + +def SPV_ConvertUToFOp : SPV_CastOp<"ConvertUToF", SPV_Float, SPV_Integer, []> { + let summary = [{ + Convert value numerically from unsigned integer to floating point. + }]; + + let description = [{ + Result Type must be a scalar or vector of floating-point type. + + Unsigned Value must be a scalar or vector of integer type. It must have + the same number of components as Result Type. + + Results are computed per component. + + ### Custom assembly form + + ``` {.ebnf} + convert-u-to-f-op ::= ssa-id `=` `spv.ConvertUToFOp` ssa-use + `:` operand-type `to` result-type + ``` + + For example: + + ``` + %1 = spv.ConvertUToF %0 : i32 to f32 + %3 = spv.ConvertUToF %2 : vector<3xi32> to vector<3xf32> + ``` + }]; +} + +// ----- + +def SPV_FConvertOp : SPV_CastOp<"FConvert", SPV_Float, SPV_Float, []> { + let summary = [{ + Convert value numerically from one floating-point width to another + width. + }]; + + let description = [{ + Result Type must be a scalar or vector of floating-point type. + + Float Value must be a scalar or vector of floating-point type. It must + have the same number of components as Result Type. The component width + cannot equal the component width in Result Type. + + Results are computed per component. + + ### Custom assembly form + + ``` {.ebnf} + f-convert-op ::= ssa-id `=` `spv.FConvertOp` ssa-use + `:` operand-type `to` result-type + ``` + + For example: + + ``` + %1 = spv.FConvertOp %0 : f32 to f64 + %3 = spv.FConvertOp %2 : vector<3xf32> to vector<3xf64> + ``` + }]; + + let verifier = [{ return verifyCastOp(this->getOperation(), false); }]; +} + +// ----- + +def SPV_SConvertOp : SPV_CastOp<"SConvert", SPV_Integer, SPV_Integer, []> { + let summary = [{ + Convert signed width. This is either a truncate or a sign extend. + }]; + + let description = [{ + Result Type must be a scalar or vector of integer type. + + Signed Value must be a scalar or vector of integer type. It must have + the same number of components as Result Type. The component width + cannot equal the component width in Result Type. + + Results are computed per component. + + ### Custom assembly form + + ``` {.ebnf} + s-convert-op ::= ssa-id `=` `spv.SConvertOp` ssa-use + `:` operand-type `to` result-type + ``` + + For example: + + ``` + %1 = spv.SConvertOp %0 : i32 to i64 + %3 = spv.SConvertOp %2 : vector<3xi32> to vector<3xi64> + ``` + }]; + + let verifier = [{ return verifyCastOp(this->getOperation(), false); }]; +} + +// ----- + +def SPV_UConvertOp : SPV_CastOp<"UConvert", SPV_Integer, SPV_Integer, []> { + let summary = [{ + Convert unsigned width. This is either a truncate or a zero extend. + }]; + + let description = [{ + Result Type must be a scalar or vector of integer type, whose Signedness + operand is 0. + + Unsigned Value must be a scalar or vector of integer type. It must have + the same number of components as Result Type. The component width + cannot equal the component width in Result Type. + + Results are computed per component. + + ### Custom assembly form + + ``` {.ebnf} + u-convert-op ::= ssa-id `=` `spv.UConvertOp` ssa-use + `:` operand-type `to` result-type + ``` + + For example: + + ``` + %1 = spv.UConvertOp %0 : i32 to i64 + %3 = spv.UConvertOp %2 : vector<3xi32> to vector<3xi64> + ``` + }]; + + let verifier = [{ return verifyCastOp(this->getOperation(), false); }]; +} + +#endif // SPIRV_CAST_OPS diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td index 1ffb352b5616..4e310c67107e 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td @@ -45,6 +45,11 @@ include "mlir/Dialect/SPIRV/SPIRVArithmeticOps.td" include "mlir/Dialect/SPIRV/SPIRVBitOps.td" #endif // SPIRV_BIT_OPS +#ifdef SPIRV_CAST_OPS +#else +include "mlir/Dialect/SPIRV/SPIRVCastOps.td" +#endif // SPIRV_CAST_OPS + #ifdef SPIRV_CONTROLFLOW_OPS #else include "mlir/Dialect/SPIRV/SPIRVControlFlowOps.td" @@ -133,58 +138,6 @@ def SPV_AccessChainOp : SPV_Op<"AccessChain", [NoSideEffect]> { // ----- -def SPV_BitcastOp : SPV_Op<"Bitcast", [NoSideEffect]> { - let summary = "Bit pattern-preserving type conversion."; - - let description = [{ - Result Type must be an OpTypePointer, or a scalar or vector of - numerical-type. - - Operand must have a type of OpTypePointer, or a scalar or vector of - numerical-type. It must be a different type than Result Type. - - If either Result Type or Operand is a pointer, the other must be a - pointer (diverges from the SPIR-V spec). - - If Result Type has a different number of components than Operand, the - total number of bits in Result Type must equal the total number of bits - in Operand. Let L be the type, either Result Type or Operand’s type, - that has the larger number of components. Let S be the other type, with - the smaller number of components. The number of components in L must be - an integer multiple of the number of components in S. The first - component (that is, the only or lowest-numbered component) of S maps to - the first components of L, and so on, up to the last component of S - mapping to the last components of L. Within this mapping, any single - component of S (mapping to multiple components of L) maps its lower- - ordered bits to the lower-numbered components of L. - - ### Custom assembly form - - ``` {.ebnf} - bitcast-op ::= ssa-id `=` `spv.Bitcast` ssa-use - `from` operand-type `to` result-type - ``` - - For example: - - ``` - %1 = spv.Bitcast %0 from f32 to i32 - %1 = spv.Bitcast %0 from vector<2xf32> to i64 - %1 = spv.Bitcast %0 from !spv.ptr to !spv.ptr - ``` - }]; - - let arguments = (ins - SPV_ScalarOrVectorOrPtr:$operand - ); - - let results = (outs - SPV_ScalarOrVectorOrPtr:$result - ); -} - -// ----- - def SPV_CompositeExtractOp : SPV_Op<"CompositeExtract", [NoSideEffect]> { let summary = "Extract a part of a composite object."; diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index c47d4b870278..2c6dfd6a9961 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -154,6 +154,40 @@ printMemoryAccessAttribute(LoadStoreOpTy loadStoreOp, OpAsmPrinter &printer, elidedAttrs.push_back(spirv::attributeName()); } +static LogicalResult verifyCastOp(Operation *op, + bool requireSameBitWidth = true) { + Type operandType = op->getOperand(0)->getType(); + Type resultType = op->getResult(0)->getType(); + + // ODS checks that result type and operand type have the same shape. + if (auto vectorType = operandType.dyn_cast()) { + operandType = vectorType.getElementType(); + resultType = resultType.cast().getElementType(); + } + + auto operandTypeBitWidth = operandType.getIntOrFloatBitWidth(); + auto resultTypeBitWidth = resultType.getIntOrFloatBitWidth(); + auto isSameBitWidth = operandTypeBitWidth == resultTypeBitWidth; + + if (requireSameBitWidth) { + if (!isSameBitWidth) { + return op->emitOpError( + "expected the same bit widths for operand type and result " + "type, but provided ") + << operandType << " and " << resultType; + } + return success(); + } + + if (isSameBitWidth) { + return op->emitOpError( + "expected the different bit widths for operand type and result " + "type, but provided ") + << operandType << " and " << resultType; + } + return success(); +} + template static LogicalResult verifyMemoryAccessAttribute(LoadStoreOpTy loadStoreOp) { // ODS checks for attributes values. Just need to verify that if the @@ -636,28 +670,6 @@ static LogicalResult verify(spirv::AddressOfOp addressOfOp) { // spv.BitcastOp //===----------------------------------------------------------------------===// -static ParseResult parseBitcastOp(OpAsmParser &parser, OperationState &state) { - OpAsmParser::OperandType operandInfo; - Type operandType, resultType; - if (parser.parseOperand(operandInfo) || parser.parseKeyword("from") || - parser.parseType(operandType) || parser.parseKeyword("to") || - parser.parseType(resultType)) { - return failure(); - } - if (parser.resolveOperands(operandInfo, operandType, state.operands)) { - return failure(); - } - state.addTypes(resultType); - return success(); -} - -static void print(spirv::BitcastOp bitcastOp, OpAsmPrinter &printer) { - printer << spirv::BitcastOp::getOperationName() << ' '; - printer.printOperand(bitcastOp.operand()); - printer << " from " << bitcastOp.operand()->getType() << " to " - << bitcastOp.result()->getType(); -} - static LogicalResult verify(spirv::BitcastOp bitcastOp) { // TODO: The SPIR-V spec validation rules are different for different // versions. diff --git a/mlir/test/Dialect/SPIRV/Serialization/cast-ops.mlir b/mlir/test/Dialect/SPIRV/Serialization/cast-ops.mlir index 5e488f7dc775..f3f8ddf65a1f 100644 --- a/mlir/test/Dialect/SPIRV/Serialization/cast-ops.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/cast-ops.mlir @@ -1,9 +1,49 @@ -// RUN: mlir-translate -test-spirv-roundtrip %s | FileCheck %s +// RUN: mlir-translate -test-spirv-roundtrip -split-input-file %s | FileCheck %s spv.module "Logical" "GLSL450" { - func @fmul(%arg0 : f32) { - // CHECK: {{%.*}} = spv.Bitcast {{%.*}} from f32 to i32 - %0 = spv.Bitcast %arg0 from f32 to i32 + func @bit_cast(%arg0 : f32) { + // CHECK: {{%.*}} = spv.Bitcast {{%.*}} : f32 to i32 + %0 = spv.Bitcast %arg0 : f32 to i32 spv.Return } } + +// ----- + +spv.module "Logical" "GLSL450" { + func @convert_f_to_s(%arg0 : f32) -> i32 { + // CHECK: {{%.*}} = spv.ConvertFToS {{%.*}} : f32 to i32 + %0 = spv.ConvertFToS %arg0 : f32 to i32 + spv.ReturnValue %0 : i32 + } + func @convert_f_to_u(%arg0 : f32) -> i32 { + // CHECK: {{%.*}} = spv.ConvertFToU {{%.*}} : f32 to i32 + %0 = spv.ConvertFToU %arg0 : f32 to i32 + spv.ReturnValue %0 : i32 + } + func @convert_s_to_f(%arg0 : i32) -> f32 { + // CHECK: {{%.*}} = spv.ConvertSToF {{%.*}} : i32 to f32 + %0 = spv.ConvertSToF %arg0 : i32 to f32 + spv.ReturnValue %0 : f32 + } + func @convert_u_to_f(%arg0 : i32) -> f32 { + // CHECK: {{%.*}} = spv.ConvertUToF {{%.*}} : i32 to f32 + %0 = spv.ConvertUToF %arg0 : i32 to f32 + spv.ReturnValue %0 : f32 + } + func @f_convert(%arg0 : f32) -> f64 { + // CHECK: {{%.*}} = spv.FConvert {{%.*}} : f32 to f64 + %0 = spv.FConvert %arg0 : f32 to f64 + spv.ReturnValue %0 : f64 + } + func @s_convert(%arg0 : i32) -> i64 { + // CHECK: {{%.*}} = spv.SConvert {{%.*}} : i32 to i64 + %0 = spv.SConvert %arg0 : i32 to i64 + spv.ReturnValue %0 : i64 + } + func @u_convert(%arg0 : i32) -> i64 { + // CHECK: {{%.*}} = spv.UConvert {{%.*}} : i32 to i64 + %0 = spv.UConvert %arg0 : i32 to i64 + spv.ReturnValue %0 : i64 + } +} diff --git a/mlir/test/Dialect/SPIRV/ops.mlir b/mlir/test/Dialect/SPIRV/ops.mlir index 639e10faed97..5d763acfd716 100644 --- a/mlir/test/Dialect/SPIRV/ops.mlir +++ b/mlir/test/Dialect/SPIRV/ops.mlir @@ -125,38 +125,38 @@ func @access_chain_invalid_accessing_type(%index0 : i32) -> () { //===----------------------------------------------------------------------===// func @cast1(%arg0 : f32) { - // CHECK: {{%.*}} = spv.Bitcast {{%.*}} from f32 to i32 - %0 = spv.Bitcast %arg0 from f32 to i32 + // CHECK: {{%.*}} = spv.Bitcast {{%.*}} : f32 to i32 + %0 = spv.Bitcast %arg0 : f32 to i32 return } func @cast2(%arg0 : vector<2xf32>) { - // CHECK: {{%.*}} = spv.Bitcast {{%.*}} from vector<2xf32> to vector<2xi32> - %0 = spv.Bitcast %arg0 from vector<2xf32> to vector<2xi32> + // CHECK: {{%.*}} = spv.Bitcast {{%.*}} : vector<2xf32> to vector<2xi32> + %0 = spv.Bitcast %arg0 : vector<2xf32> to vector<2xi32> return } func @cast3(%arg0 : vector<2xf32>) { - // CHECK: {{%.*}} = spv.Bitcast {{%.*}} from vector<2xf32> to i64 - %0 = spv.Bitcast %arg0 from vector<2xf32> to i64 + // CHECK: {{%.*}} = spv.Bitcast {{%.*}} : vector<2xf32> to i64 + %0 = spv.Bitcast %arg0 : vector<2xf32> to i64 return } func @cast4(%arg0 : !spv.ptr) { - // CHECK: {{%.*}} = spv.Bitcast {{%.*}} from !spv.ptr to !spv.ptr - %0 = spv.Bitcast %arg0 from !spv.ptr to !spv.ptr + // CHECK: {{%.*}} = spv.Bitcast {{%.*}} : !spv.ptr to !spv.ptr + %0 = spv.Bitcast %arg0 : !spv.ptr to !spv.ptr return } func @cast5(%arg0 : !spv.ptr) { - // CHECK: {{%.*}} = spv.Bitcast {{%.*}} from !spv.ptr to !spv.ptr, Function> - %0 = spv.Bitcast %arg0 from !spv.ptr to !spv.ptr, Function> + // CHECK: {{%.*}} = spv.Bitcast {{%.*}} : !spv.ptr to !spv.ptr, Function> + %0 = spv.Bitcast %arg0 : !spv.ptr to !spv.ptr, Function> return } func @cast6(%arg0 : vector<4xf32>) { - // CHECK: {{%.*}} = spv.Bitcast {{%.*}} from vector<4xf32> to vector<2xi64> - %0 = spv.Bitcast %arg0 from vector<4xf32> to vector<2xi64> + // CHECK: {{%.*}} = spv.Bitcast {{%.*}} : vector<4xf32> to vector<2xi64> + %0 = spv.Bitcast %arg0 : vector<4xf32> to vector<2xi64> return } @@ -164,7 +164,7 @@ func @cast6(%arg0 : vector<4xf32>) { func @cast1(%arg0 : f32) { // expected-error @+1 {{result type must be different from operand type}} - %0 = spv.Bitcast %arg0 from f32 to f32 + %0 = spv.Bitcast %arg0 : f32 to f32 return } @@ -172,7 +172,7 @@ func @cast1(%arg0 : f32) { func @cast1(%arg0 : f32) { // expected-error @+1 {{mismatch in result type bitwidth 64 and operand type bitwidth 32}} - %0 = spv.Bitcast %arg0 from f32 to i64 + %0 = spv.Bitcast %arg0 : f32 to i64 return } @@ -180,7 +180,7 @@ func @cast1(%arg0 : f32) { func @cast1(%arg0 : vector<2xf32>) { // expected-error @+1 {{mismatch in result type bitwidth 96 and operand type bitwidth 64}} - %0 = spv.Bitcast %arg0 from vector<2xf32> to vector<3xf32> + %0 = spv.Bitcast %arg0 : vector<2xf32> to vector<3xf32> return } @@ -188,7 +188,7 @@ func @cast1(%arg0 : vector<2xf32>) { func @cast3(%arg0 : !spv.ptr) { // expected-error @+1 {{unhandled bit cast conversion from pointer type to non-pointer type}} - %0 = spv.Bitcast %arg0 from !spv.ptr to i64 + %0 = spv.Bitcast %arg0 : !spv.ptr to i64 return } @@ -196,7 +196,7 @@ func @cast3(%arg0 : !spv.ptr) { func @cast3(%arg0 : i64) { // expected-error @+1 {{unhandled bit cast conversion from non-pointer type to pointer type}} - %0 = spv.Bitcast %arg0 from i64 to !spv.ptr + %0 = spv.Bitcast %arg0 : i64 to !spv.ptr return } @@ -366,6 +366,122 @@ func @control_barrier_1() -> () { // ----- +//===----------------------------------------------------------------------===// +// spv.ConvertFToS +//===----------------------------------------------------------------------===// + +func @convert_f_to_s_scalar(%arg0 : f32) -> i32 { + // CHECK: {{%.*}} = spv.ConvertFToS {{%.*}} : f32 to i32 + %0 = spv.ConvertFToS %arg0 : f32 to i32 + spv.ReturnValue %0 : i32 +} + +// ----- + +//===----------------------------------------------------------------------===// +// spv.ConvertFToU +//===----------------------------------------------------------------------===// + +func @convert_f_to_u_scalar(%arg0 : f32) -> i32 { + // CHECK: {{%.*}} = spv.ConvertFToU {{%.*}} : f32 to i32 + %0 = spv.ConvertFToU %arg0 : f32 to i32 + spv.ReturnValue %0 : i32 +} + +// ----- + +func @convert_f_to_u_vector(%arg0 : vector<3xf32>) -> vector<3xi32> { + // CHECK: {{%.*}} = spv.ConvertFToU {{%.*}} : vector<3xf32> to vector<3xi32> + %0 = spv.ConvertFToU %arg0 : vector<3xf32> to vector<3xi32> + spv.ReturnValue %0 : vector<3xi32> +} + +// ----- + +func @convert_f_to_u_scalar_invalid(%arg0 : f16) -> i32 { + // expected-error @+1 {{expected the same bit widths for operand type and result type, but provided 'f16' and 'i32'}} + %0 = spv.ConvertFToU %arg0 : f16 to i32 + spv.ReturnValue %0 : i32 +} + +// ----- + +//===----------------------------------------------------------------------===// +// spv.ConvertSToF +//===----------------------------------------------------------------------===// + +func @convert_s_to_f_scalar(%arg0 : i32) -> f32 { + // CHECK: {{%.*}} = spv.ConvertSToF {{%.*}} : i32 to f32 + %0 = spv.ConvertSToF %arg0 : i32 to f32 + spv.ReturnValue %0 : f32 +} + +// ----- + +//===----------------------------------------------------------------------===// +// spv.ConvertUToF +//===----------------------------------------------------------------------===// + +func @convert_u_to_f_scalar(%arg0 : i32) -> f32 { + // CHECK: {{%.*}} = spv.ConvertUToF {{%.*}} : i32 to f32 + %0 = spv.ConvertUToF %arg0 : i32 to f32 + spv.ReturnValue %0 : f32 +} + +// ----- + +//===----------------------------------------------------------------------===// +// spv.FConvert +//===----------------------------------------------------------------------===// + +func @f_convert_scalar(%arg0 : f32) -> f64 { + // CHECK: {{%.*}} = spv.FConvert {{%.*}} : f32 to f64 + %0 = spv.FConvert %arg0 : f32 to f64 + spv.ReturnValue %0 : f64 +} + +// ----- + +func @f_convert_vector(%arg0 : vector<3xf32>) -> vector<3xf64> { + // CHECK: {{%.*}} = spv.FConvert {{%.*}} : vector<3xf32> to vector<3xf64> + %0 = spv.FConvert %arg0 : vector<3xf32> to vector<3xf64> + spv.ReturnValue %0 : vector<3xf64> +} + +// ----- + +func @f_convert_vector(%arg0 : f32) -> f32 { + // expected-error @+1 {{expected the different bit widths for operand type and result type, but provided 'f32' and 'f32'}} + %0 = spv.FConvert %arg0 : f32 to f32 + spv.ReturnValue %0 : f32 +} + +// ----- + +//===----------------------------------------------------------------------===// +// spv.SConvert +//===----------------------------------------------------------------------===// + +func @s_convert_scalar(%arg0 : i32) -> i64 { + // CHECK: {{%.*}} = spv.SConvert {{%.*}} : i32 to i64 + %0 = spv.SConvert %arg0 : i32 to i64 + spv.ReturnValue %0 : i64 +} + +// ----- + +//===----------------------------------------------------------------------===// +// spv.UConvert +//===----------------------------------------------------------------------===// + +func @u_convert_scalar(%arg0 : i32) -> i64 { + // CHECK: {{%.*}} = spv.UConvert {{%.*}} : i32 to i64 + %0 = spv.UConvert %arg0 : i32 to i64 + spv.ReturnValue %0 : i64 +} + +// ----- + //===----------------------------------------------------------------------===// // spv.ExecutionMode //===----------------------------------------------------------------------===// diff --git a/mlir/utils/spirv/define_inst.sh b/mlir/utils/spirv/define_inst.sh index c45341ad4837..3508c4f9b4f8 100755 --- a/mlir/utils/spirv/define_inst.sh +++ b/mlir/utils/spirv/define_inst.sh @@ -35,13 +35,13 @@ file_name=$1 inst_category=$2 case $inst_category in - Op | ArithmeticOp | LogicalOp | ControlFlowOp | StructureOp) + Op | ArithmeticOp | LogicalOp | CastOp | ControlFlowOp | StructureOp) ;; *) echo "Usage : " $0 " ()*" echo " is the file name of MLIR SPIR-V op definitions spec" echo " must be one of " \ - "(Op|ArithmeticOp|LogicalOp|ControlFlowOp|StructureOp)" + "(Op|ArithmeticOp|LogicalOp|CastOp|ControlFlowOp|StructureOp)" exit 1; ;; esac