[mlir][spirv] Convert math.ctlz to spv.GLSL.FindUMsb

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D127582
This commit is contained in:
Lei Zhang 2022-06-13 13:01:53 -04:00
parent f1c84d0ff0
commit cc020a2236
6 changed files with 174 additions and 3 deletions

View file

@ -1221,4 +1221,20 @@ def SPV_GLSLFMixOp :
let hasVerifier = 0;
}
def SPV_GLSLFindUMsbOp : SPV_GLSLUnaryArithmeticOp<"FindUMsb", 75, SPV_Int32> {
let summary = "Unsigned-integer most-significant bit";
let description = [{
Results in the bit number of the most-significant 1-bit in the binary
representation of Value. If Value is 0, the result is -1.
Result Type and the type of Value must both be integer scalar or
integer vector types. Result Type and operand types must have the
same number of components with the same component width. Results are
computed per component.
This instruction is currently limited to 32-bit width components.
}];
}
#endif // MLIR_DIALECT_SPIRV_IR_GLSL_OPS

View file

@ -16,12 +16,35 @@
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "math-to-spirv-pattern"
using namespace mlir;
//===----------------------------------------------------------------------===//
// Utility functions
//===----------------------------------------------------------------------===//
/// Creates a 32-bit scalar/vector integer constant. Returns nullptr if the
/// given type is not a 32-bit scalar/vector type.
static Value getScalarOrVectorI32Constant(Type type, int value,
OpBuilder &builder, Location loc) {
if (auto vectorType = type.dyn_cast<VectorType>()) {
if (!vectorType.getElementType().isInteger(32))
return nullptr;
SmallVector<int> values(vectorType.getNumElements(), value);
return builder.create<spirv::ConstantOp>(loc, type,
builder.getI32VectorAttr(values));
}
if (type.isInteger(32))
return builder.create<spirv::ConstantOp>(loc, type,
builder.getI32IntegerAttr(value));
return nullptr;
}
//===----------------------------------------------------------------------===//
// Operation conversion
//===----------------------------------------------------------------------===//
@ -92,6 +115,42 @@ class CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
}
};
/// Converts math.ctlz to SPIR-V ops.
///
/// SPIR-V does not have a direct operations for counting leading zeros. If
/// Shader capability is supported, we can leverage GLSL FindUMsb to calculate
/// it.
class CountLeadingZerosPattern final
: public OpConversionPattern<math::CountLeadingZerosOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(math::CountLeadingZerosOp countOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto type = getTypeConverter()->convertType(countOp.getType());
if (!type)
return failure();
// We can only support 32-bit integer types for now.
unsigned bitwidth = 0;
if (type.isa<IntegerType>())
bitwidth = type.getIntOrFloatBitWidth();
if (auto vectorType = type.dyn_cast<VectorType>())
bitwidth = vectorType.getElementTypeBitWidth();
if (bitwidth != 32)
return failure();
Location loc = countOp.getLoc();
Value val31 = getScalarOrVectorI32Constant(type, 31, rewriter, loc);
Value msb =
rewriter.create<spirv::GLSLFindUMsbOp>(loc, adaptor.getOperand());
// We need to subtract from 31 given that the index is from the least
// significant bit.
rewriter.replaceOpWithNewOp<spirv::ISubOp>(countOp, val31, msb);
return success();
}
};
/// Converts math.expm1 to SPIR-V ops.
///
/// SPIR-V does not have a direct operations for exp(x)-1. Explicitly lower to
@ -148,7 +207,8 @@ void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
// GLSL patterns
patterns
.add<Log1pOpPattern<spirv::GLSLLogOp>, ExpM1OpPattern<spirv::GLSLExpOp>,
.add<CountLeadingZerosPattern, Log1pOpPattern<spirv::GLSLLogOp>,
ExpM1OpPattern<spirv::GLSLExpOp>,
spirv::ElementwiseOpPattern<math::AbsOp, spirv::GLSLFAbsOp>,
spirv::ElementwiseOpPattern<math::CeilOp, spirv::GLSLCeilOp>,
spirv::ElementwiseOpPattern<math::CosOp, spirv::GLSLCosOp>,

View file

@ -36,6 +36,17 @@ void ConvertMathToSPIRVPass::runOnOperation() {
SPIRVTypeConverter typeConverter(targetAttr);
// Use UnrealizedConversionCast as the bridge so that we don't need to pull
// in patterns for other dialects.
auto addUnrealizedCast = [](OpBuilder &builder, Type type, ValueRange inputs,
Location loc) {
auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
return Optional<Value>(cast.getResult(0));
};
typeConverter.addSourceMaterialization(addUnrealizedCast);
typeConverter.addTargetMaterialization(addUnrealizedCast);
target->addLegalOp<UnrealizedConversionCastOp>();
RewritePatternSet patterns(context);
populateMathToSPIRVPatterns(typeConverter, patterns);

View file

@ -1,6 +1,8 @@
// RUN: mlir-opt -split-input-file -convert-math-to-spirv -verify-diagnostics %s -o - | FileCheck %s
module attributes { spv.target_env = #spv.target_env<#spv.vce<v1.0, [Shader], []>, #spv.resource_limits<>> } {
module attributes {
spv.target_env = #spv.target_env<#spv.vce<v1.0, [Shader], []>, #spv.resource_limits<>>
} {
// CHECK-LABEL: @float32_unary_scalar
func.func @float32_unary_scalar(%arg0: f32) {
@ -91,4 +93,56 @@ func.func @float32_ternary_vector(%a: vector<4xf32>, %b: vector<4xf32>,
return
}
// CHECK-LABEL: @ctlz_scalar
// CHECK-SAME: (%[[VAL:.+]]: i32)
func.func @ctlz_scalar(%val: i32) -> i32 {
// CHECK: %[[V31:.+]] = spv.Constant 31 : i32
// CHECK: %[[MSB:.+]] = spv.GLSL.FindUMsb %[[VAL]] : i32
// CHECK: %[[SUB:.+]] = spv.ISub %[[V31]], %[[MSB]] : i32
// CHECK: return %[[SUB]]
%0 = math.ctlz %val : i32
return %0 : i32
}
// CHECK-LABEL: @ctlz_vector1
func.func @ctlz_vector1(%val: vector<1xi32>) -> vector<1xi32> {
// CHECK: spv.GLSL.FindUMsb
// CHECK: spv.ISub
%0 = math.ctlz %val : vector<1xi32>
return %0 : vector<1xi32>
}
// CHECK-LABEL: @ctlz_vector2
// CHECK-SAME: (%[[VAL:.+]]: vector<2xi32>)
func.func @ctlz_vector2(%val: vector<2xi32>) -> vector<2xi32> {
// CHECK-DAG: %[[V31:.+]] = spv.Constant dense<31> : vector<2xi32>
// CHECK: %[[MSB:.+]] = spv.GLSL.FindUMsb %[[VAL]] : vector<2xi32>
// CHECK: %[[SUB:.+]] = spv.ISub %[[V31]], %[[MSB]] : vector<2xi32>
// CHECK: return %[[SUB]]
%0 = math.ctlz %val : vector<2xi32>
return %0 : vector<2xi32>
}
} // end module
// -----
module attributes {
spv.target_env = #spv.target_env<#spv.vce<v1.0, [Shader, Int64, Int16], []>, #spv.resource_limits<>>
} {
// CHECK-LABEL: @ctlz_scalar
func.func @ctlz_scalar(%val: i64) -> i64 {
// CHECK: math.ctlz
%0 = math.ctlz %val : i64
return %0 : i64
}
// CHECK-LABEL: @ctlz_vector2
func.func @ctlz_vector2(%val: vector<2xi16>) -> vector<2xi16> {
// CHECK: math.ctlz
%0 = math.ctlz %val : vector<2xi16>
return %0 : vector<2xi16>
}
} // end module

View file

@ -494,10 +494,34 @@ func.func @fmix(%arg0 : f32, %arg1 : f32, %arg2 : f32) -> () {
return
}
// -----
func.func @fmix_vector(%arg0 : vector<3xf32>, %arg1 : vector<3xf32>, %arg2 : vector<3xf32>) -> () {
// CHECK: {{%.*}} = spv.GLSL.FMix {{%.*}} : vector<3xf32>, {{%.*}} : vector<3xf32>, {{%.*}} : vector<3xf32> -> vector<3xf32>
%0 = spv.GLSL.FMix %arg0 : vector<3xf32>, %arg1 : vector<3xf32>, %arg2 : vector<3xf32> -> vector<3xf32>
return
}
// -----
//===----------------------------------------------------------------------===//
// spv.GLSL.Exp
//===----------------------------------------------------------------------===//
func.func @findumsb(%arg0 : i32) -> () {
// CHECK: spv.GLSL.FindUMsb {{%.*}} : i32
%2 = spv.GLSL.FindUMsb %arg0 : i32
return
}
func.func @findumsb_vector(%arg0 : vector<3xi32>) -> () {
// CHECK: spv.GLSL.FindUMsb {{%.*}} : vector<3xi32>
%2 = spv.GLSL.FindUMsb %arg0 : vector<3xi32>
return
}
// -----
func.func @findumsb(%arg0 : i64) -> () {
// expected-error @+1 {{operand #0 must be Int32 or vector of Int32}}
%2 = spv.GLSL.FindUMsb %arg0 : i64
return
}

View file

@ -75,4 +75,10 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
%13 = spv.GLSL.Fma %arg0, %arg1, %arg2 : f32
spv.Return
}
spv.func @findumsb(%arg0 : i32) "None" {
// CHECK: spv.GLSL.FindUMsb {{%.*}} : i32
%2 = spv.GLSL.FindUMsb %arg0 : i32
spv.Return
}
}