[mlir][spirv] Add pattern to lower math.copysign
This follows the logic: https://git.musl-libc.org/cgit/musl/tree/src/math/copysignf.c Reviewed By: ThomasRaoux Differential Revision: https://reviews.llvm.org/D122910
This commit is contained in:
parent
b8652fbcbb
commit
533ec929f6
|
@ -15,6 +15,7 @@
|
|||
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
|
||||
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
|
||||
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
||||
#define DEBUG_TYPE "math-to-spirv-pattern"
|
||||
|
@ -30,14 +31,74 @@ using namespace mlir;
|
|||
// normal RewritePattern.
|
||||
|
||||
namespace {
|
||||
/// Converts math.copysign to SPIR-V ops.
|
||||
class CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(math::CopySignOp copySignOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto type = getTypeConverter()->convertType(copySignOp.getType());
|
||||
if (!type)
|
||||
return failure();
|
||||
|
||||
FloatType floatType;
|
||||
if (auto scalarType = copySignOp.getType().dyn_cast<FloatType>()) {
|
||||
floatType = scalarType;
|
||||
} else if (auto vectorType = copySignOp.getType().dyn_cast<VectorType>()) {
|
||||
floatType = vectorType.getElementType().cast<FloatType>();
|
||||
} else {
|
||||
return failure();
|
||||
}
|
||||
|
||||
Location loc = copySignOp.getLoc();
|
||||
int bitwidth = floatType.getWidth();
|
||||
Type intType = rewriter.getIntegerType(bitwidth);
|
||||
|
||||
Value signMask = rewriter.create<spirv::ConstantOp>(
|
||||
loc, intType, rewriter.getIntegerAttr(intType, (1u << (bitwidth - 1))));
|
||||
Value valueMask = rewriter.create<spirv::ConstantOp>(
|
||||
loc, intType,
|
||||
rewriter.getIntegerAttr(intType, (1u << (bitwidth - 1)) - 1u));
|
||||
|
||||
if (auto vectorType = copySignOp.getType().dyn_cast<VectorType>()) {
|
||||
assert(vectorType.getRank() == 1);
|
||||
int count = vectorType.getNumElements();
|
||||
intType = VectorType::get(count, intType);
|
||||
|
||||
SmallVector<Value> signSplat(count, signMask);
|
||||
signMask =
|
||||
rewriter.create<spirv::CompositeConstructOp>(loc, intType, signSplat);
|
||||
|
||||
SmallVector<Value> valueSplat(count, valueMask);
|
||||
valueMask = rewriter.create<spirv::CompositeConstructOp>(loc, intType,
|
||||
valueSplat);
|
||||
}
|
||||
|
||||
Value lhsCast =
|
||||
rewriter.create<spirv::BitcastOp>(loc, intType, adaptor.getLhs());
|
||||
Value rhsCast =
|
||||
rewriter.create<spirv::BitcastOp>(loc, intType, adaptor.getRhs());
|
||||
|
||||
Value value = rewriter.create<spirv::BitwiseAndOp>(
|
||||
loc, intType, ValueRange{lhsCast, valueMask});
|
||||
Value sign = rewriter.create<spirv::BitwiseAndOp>(
|
||||
loc, intType, ValueRange{rhsCast, signMask});
|
||||
|
||||
Value result = rewriter.create<spirv::BitwiseOrOp>(loc, intType,
|
||||
ValueRange{value, sign});
|
||||
rewriter.replaceOpWithNewOp<spirv::BitcastOp>(copySignOp, type, result);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Converts math.expm1 to SPIR-V ops.
|
||||
///
|
||||
/// SPIR-V does not have a direct operations for exp(x)-1. Explicitly lower to
|
||||
/// these operations.
|
||||
template <typename ExpOp>
|
||||
class ExpM1OpPattern final : public OpConversionPattern<math::ExpM1Op> {
|
||||
public:
|
||||
using OpConversionPattern<math::ExpM1Op>::OpConversionPattern;
|
||||
struct ExpM1OpPattern final : public OpConversionPattern<math::ExpM1Op> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(math::ExpM1Op operation, OpAdaptor adaptor,
|
||||
|
@ -57,9 +118,8 @@ public:
|
|||
/// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to
|
||||
/// these operations.
|
||||
template <typename LogOp>
|
||||
class Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> {
|
||||
public:
|
||||
using OpConversionPattern<math::Log1pOp>::OpConversionPattern;
|
||||
struct Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(math::Log1pOp operation, OpAdaptor adaptor,
|
||||
|
@ -83,6 +143,8 @@ public:
|
|||
namespace mlir {
|
||||
void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns) {
|
||||
// Core patterns
|
||||
patterns.add<CopySignPattern>(typeConverter, patterns.getContext());
|
||||
|
||||
// GLSL patterns
|
||||
patterns
|
||||
|
|
43
mlir/test/Conversion/MathToSPIRV/math-to-core-spirv.mlir
Normal file
43
mlir/test/Conversion/MathToSPIRV/math-to-core-spirv.mlir
Normal file
|
@ -0,0 +1,43 @@
|
|||
// RUN: mlir-opt -split-input-file -convert-math-to-spirv -verify-diagnostics %s -o - | FileCheck %s
|
||||
|
||||
func @copy_sign_scalar(%value: f32, %sign: f32) -> f32 {
|
||||
%0 = math.copysign %value, %sign : f32
|
||||
return %0: f32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @copy_sign_scalar
|
||||
// CHECK-SAME: (%[[VALUE:.+]]: f32, %[[SIGN:.+]]: f32)
|
||||
// CHECK: %[[SMASK:.+]] = spv.Constant -2147483648 : i32
|
||||
// CHECK: %[[VMASK:.+]] = spv.Constant 2147483647 : i32
|
||||
// CHECK: %[[VCAST:.+]] = spv.Bitcast %[[VALUE]] : f32 to i32
|
||||
// CHECK: %[[SCAST:.+]] = spv.Bitcast %[[SIGN]] : f32 to i32
|
||||
// CHECK: %[[VAND:.+]] = spv.BitwiseAnd %[[VCAST]], %[[VMASK]] : i32
|
||||
// CHECK: %[[SAND:.+]] = spv.BitwiseAnd %[[SCAST]], %[[SMASK]] : i32
|
||||
// CHECK: %[[OR:.+]] = spv.BitwiseOr %[[VAND]], %[[SAND]] : i32
|
||||
// CHECK: %[[RESULT:.+]] = spv.Bitcast %[[OR]] : i32 to f32
|
||||
// CHECK: return %[[RESULT]]
|
||||
|
||||
// -----
|
||||
|
||||
module attributes { spv.target_env = #spv.target_env<#spv.vce<v1.0, [Float16, Int16], []>, {}> } {
|
||||
|
||||
func @copy_sign_vector(%value: vector<3xf16>, %sign: vector<3xf16>) -> vector<3xf16> {
|
||||
%0 = math.copysign %value, %sign : vector<3xf16>
|
||||
return %0: vector<3xf16>
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @copy_sign_vector
|
||||
// CHECK-SAME: (%[[VALUE:.+]]: vector<3xf16>, %[[SIGN:.+]]: vector<3xf16>)
|
||||
// CHECK: %[[SMASK:.+]] = spv.Constant -32768 : i16
|
||||
// CHECK: %[[VMASK:.+]] = spv.Constant 32767 : i16
|
||||
// CHECK: %[[SVMASK:.+]] = spv.CompositeConstruct %[[SMASK]], %[[SMASK]], %[[SMASK]] : vector<3xi16>
|
||||
// CHECK: %[[VVMASK:.+]] = spv.CompositeConstruct %[[VMASK]], %[[VMASK]], %[[VMASK]] : vector<3xi16>
|
||||
// CHECK: %[[VCAST:.+]] = spv.Bitcast %[[VALUE]] : vector<3xf16> to vector<3xi16>
|
||||
// CHECK: %[[SCAST:.+]] = spv.Bitcast %[[SIGN]] : vector<3xf16> to vector<3xi16>
|
||||
// CHECK: %[[VAND:.+]] = spv.BitwiseAnd %[[VCAST]], %[[VVMASK]] : vector<3xi16>
|
||||
// CHECK: %[[SAND:.+]] = spv.BitwiseAnd %[[SCAST]], %[[SVMASK]] : vector<3xi16>
|
||||
// CHECK: %[[OR:.+]] = spv.BitwiseOr %[[VAND]], %[[SAND]] : vector<3xi16>
|
||||
// CHECK: %[[RESULT:.+]] = spv.Bitcast %[[OR]] : vector<3xi16> to vector<3xf16>
|
||||
// CHECK: return %[[RESULT]]
|
Loading…
Reference in a new issue