[mlir][spirv] Add pattern to lower math.copysign

This follows the logic:
https://git.musl-libc.org/cgit/musl/tree/src/math/copysignf.c

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D122910
This commit is contained in:
Lei Zhang 2022-04-01 12:02:43 -04:00
parent b8652fbcbb
commit 533ec929f6
2 changed files with 111 additions and 6 deletions

View file

@ -15,6 +15,7 @@
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "math-to-spirv-pattern"
@ -30,14 +31,74 @@ using namespace mlir;
// normal RewritePattern.
namespace {
/// Converts math.copysign to SPIR-V ops.
class CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(math::CopySignOp copySignOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto type = getTypeConverter()->convertType(copySignOp.getType());
if (!type)
return failure();
FloatType floatType;
if (auto scalarType = copySignOp.getType().dyn_cast<FloatType>()) {
floatType = scalarType;
} else if (auto vectorType = copySignOp.getType().dyn_cast<VectorType>()) {
floatType = vectorType.getElementType().cast<FloatType>();
} else {
return failure();
}
Location loc = copySignOp.getLoc();
int bitwidth = floatType.getWidth();
Type intType = rewriter.getIntegerType(bitwidth);
Value signMask = rewriter.create<spirv::ConstantOp>(
loc, intType, rewriter.getIntegerAttr(intType, (1u << (bitwidth - 1))));
Value valueMask = rewriter.create<spirv::ConstantOp>(
loc, intType,
rewriter.getIntegerAttr(intType, (1u << (bitwidth - 1)) - 1u));
if (auto vectorType = copySignOp.getType().dyn_cast<VectorType>()) {
assert(vectorType.getRank() == 1);
int count = vectorType.getNumElements();
intType = VectorType::get(count, intType);
SmallVector<Value> signSplat(count, signMask);
signMask =
rewriter.create<spirv::CompositeConstructOp>(loc, intType, signSplat);
SmallVector<Value> valueSplat(count, valueMask);
valueMask = rewriter.create<spirv::CompositeConstructOp>(loc, intType,
valueSplat);
}
Value lhsCast =
rewriter.create<spirv::BitcastOp>(loc, intType, adaptor.getLhs());
Value rhsCast =
rewriter.create<spirv::BitcastOp>(loc, intType, adaptor.getRhs());
Value value = rewriter.create<spirv::BitwiseAndOp>(
loc, intType, ValueRange{lhsCast, valueMask});
Value sign = rewriter.create<spirv::BitwiseAndOp>(
loc, intType, ValueRange{rhsCast, signMask});
Value result = rewriter.create<spirv::BitwiseOrOp>(loc, intType,
ValueRange{value, sign});
rewriter.replaceOpWithNewOp<spirv::BitcastOp>(copySignOp, type, result);
return success();
}
};
/// Converts math.expm1 to SPIR-V ops.
///
/// SPIR-V does not have a direct operations for exp(x)-1. Explicitly lower to
/// these operations.
template <typename ExpOp>
class ExpM1OpPattern final : public OpConversionPattern<math::ExpM1Op> {
public:
using OpConversionPattern<math::ExpM1Op>::OpConversionPattern;
struct ExpM1OpPattern final : public OpConversionPattern<math::ExpM1Op> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(math::ExpM1Op operation, OpAdaptor adaptor,
@ -57,9 +118,8 @@ public:
/// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to
/// these operations.
template <typename LogOp>
class Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> {
public:
using OpConversionPattern<math::Log1pOp>::OpConversionPattern;
struct Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(math::Log1pOp operation, OpAdaptor adaptor,
@ -83,6 +143,8 @@ public:
namespace mlir {
void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns) {
// Core patterns
patterns.add<CopySignPattern>(typeConverter, patterns.getContext());
// GLSL patterns
patterns

View file

@ -0,0 +1,43 @@
// RUN: mlir-opt -split-input-file -convert-math-to-spirv -verify-diagnostics %s -o - | FileCheck %s
func @copy_sign_scalar(%value: f32, %sign: f32) -> f32 {
%0 = math.copysign %value, %sign : f32
return %0: f32
}
// CHECK-LABEL: func @copy_sign_scalar
// CHECK-SAME: (%[[VALUE:.+]]: f32, %[[SIGN:.+]]: f32)
// CHECK: %[[SMASK:.+]] = spv.Constant -2147483648 : i32
// CHECK: %[[VMASK:.+]] = spv.Constant 2147483647 : i32
// CHECK: %[[VCAST:.+]] = spv.Bitcast %[[VALUE]] : f32 to i32
// CHECK: %[[SCAST:.+]] = spv.Bitcast %[[SIGN]] : f32 to i32
// CHECK: %[[VAND:.+]] = spv.BitwiseAnd %[[VCAST]], %[[VMASK]] : i32
// CHECK: %[[SAND:.+]] = spv.BitwiseAnd %[[SCAST]], %[[SMASK]] : i32
// CHECK: %[[OR:.+]] = spv.BitwiseOr %[[VAND]], %[[SAND]] : i32
// CHECK: %[[RESULT:.+]] = spv.Bitcast %[[OR]] : i32 to f32
// CHECK: return %[[RESULT]]
// -----
module attributes { spv.target_env = #spv.target_env<#spv.vce<v1.0, [Float16, Int16], []>, {}> } {
func @copy_sign_vector(%value: vector<3xf16>, %sign: vector<3xf16>) -> vector<3xf16> {
%0 = math.copysign %value, %sign : vector<3xf16>
return %0: vector<3xf16>
}
}
// CHECK-LABEL: func @copy_sign_vector
// CHECK-SAME: (%[[VALUE:.+]]: vector<3xf16>, %[[SIGN:.+]]: vector<3xf16>)
// CHECK: %[[SMASK:.+]] = spv.Constant -32768 : i16
// CHECK: %[[VMASK:.+]] = spv.Constant 32767 : i16
// CHECK: %[[SVMASK:.+]] = spv.CompositeConstruct %[[SMASK]], %[[SMASK]], %[[SMASK]] : vector<3xi16>
// CHECK: %[[VVMASK:.+]] = spv.CompositeConstruct %[[VMASK]], %[[VMASK]], %[[VMASK]] : vector<3xi16>
// CHECK: %[[VCAST:.+]] = spv.Bitcast %[[VALUE]] : vector<3xf16> to vector<3xi16>
// CHECK: %[[SCAST:.+]] = spv.Bitcast %[[SIGN]] : vector<3xf16> to vector<3xi16>
// CHECK: %[[VAND:.+]] = spv.BitwiseAnd %[[VCAST]], %[[VVMASK]] : vector<3xi16>
// CHECK: %[[SAND:.+]] = spv.BitwiseAnd %[[SCAST]], %[[SVMASK]] : vector<3xi16>
// CHECK: %[[OR:.+]] = spv.BitwiseOr %[[VAND]], %[[SAND]] : vector<3xi16>
// CHECK: %[[RESULT:.+]] = spv.Bitcast %[[OR]] : vector<3xi16> to vector<3xf16>
// CHECK: return %[[RESULT]]