[AArch64] Lowering and legalization of strict FP16

For strict FP16 to work correctly needs some changes in lowering and
legalization:
 * SelectionDAGLegalize::PromoteNode was missing handling for some
   strict fp opcodes.
 * Some of the custom lowering of strict fp operations needed to be
   adjusted to work with FP16.
 * Custom lowering needed to be added for round-to-int operations.

With this, and the previous patches for the rest of the strict fp
isel, we can set IsStrictFPEnabled = true.

Differential Revision: https://reviews.llvm.org/D115620
This commit is contained in:
John Brawn 2021-12-09 12:39:29 +00:00
parent d43d9e1d5c
commit 12c1022679
4 changed files with 1289 additions and 65 deletions

View file

@ -4714,6 +4714,12 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
Results.push_back(DAG.getNode(ISD::FP_ROUND, dl, OVT,
Tmp3, DAG.getIntPtrConstant(0, dl)));
break;
case ISD::STRICT_FADD:
case ISD::STRICT_FSUB:
case ISD::STRICT_FMUL:
case ISD::STRICT_FDIV:
case ISD::STRICT_FMINNUM:
case ISD::STRICT_FMAXNUM:
case ISD::STRICT_FREM:
case ISD::STRICT_FPOW:
Tmp1 = DAG.getNode(ISD::STRICT_FP_EXTEND, dl, {NVT, MVT::Other},
@ -4738,6 +4744,22 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
DAG.getNode(Node->getOpcode(), dl, NVT, Tmp1, Tmp2, Tmp3),
DAG.getIntPtrConstant(0, dl)));
break;
case ISD::STRICT_FMA:
Tmp1 = DAG.getNode(ISD::STRICT_FP_EXTEND, dl, {NVT, MVT::Other},
{Node->getOperand(0), Node->getOperand(1)});
Tmp2 = DAG.getNode(ISD::STRICT_FP_EXTEND, dl, {NVT, MVT::Other},
{Node->getOperand(0), Node->getOperand(2)});
Tmp3 = DAG.getNode(ISD::STRICT_FP_EXTEND, dl, {NVT, MVT::Other},
{Node->getOperand(0), Node->getOperand(3)});
Tmp4 = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, Tmp1.getValue(1),
Tmp2.getValue(1), Tmp3.getValue(1));
Tmp4 = DAG.getNode(Node->getOpcode(), dl, {NVT, MVT::Other},
{Tmp4, Tmp1, Tmp2, Tmp3});
Tmp4 = DAG.getNode(ISD::STRICT_FP_ROUND, dl, {OVT, MVT::Other},
{Tmp4.getValue(1), Tmp4, DAG.getIntPtrConstant(0, dl)});
Results.push_back(Tmp4);
Results.push_back(Tmp4.getValue(1));
break;
case ISD::FCOPYSIGN:
case ISD::FPOWI: {
Tmp1 = DAG.getNode(ISD::FP_EXTEND, dl, NVT, Node->getOperand(0));
@ -4754,6 +4776,16 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
Tmp3, DAG.getIntPtrConstant(isTrunc, dl)));
break;
}
case ISD::STRICT_FPOWI:
Tmp1 = DAG.getNode(ISD::STRICT_FP_EXTEND, dl, {NVT, MVT::Other},
{Node->getOperand(0), Node->getOperand(1)});
Tmp2 = DAG.getNode(Node->getOpcode(), dl, {NVT, MVT::Other},
{Tmp1.getValue(1), Tmp1, Node->getOperand(2)});
Tmp3 = DAG.getNode(ISD::STRICT_FP_ROUND, dl, {OVT, MVT::Other},
{Tmp2.getValue(1), Tmp2, DAG.getIntPtrConstant(0, dl)});
Results.push_back(Tmp3);
Results.push_back(Tmp3.getValue(1));
break;
case ISD::FFLOOR:
case ISD::FCEIL:
case ISD::FRINT:
@ -4778,12 +4810,19 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
break;
case ISD::STRICT_FFLOOR:
case ISD::STRICT_FCEIL:
case ISD::STRICT_FRINT:
case ISD::STRICT_FNEARBYINT:
case ISD::STRICT_FROUND:
case ISD::STRICT_FROUNDEVEN:
case ISD::STRICT_FTRUNC:
case ISD::STRICT_FSQRT:
case ISD::STRICT_FSIN:
case ISD::STRICT_FCOS:
case ISD::STRICT_FLOG:
case ISD::STRICT_FLOG2:
case ISD::STRICT_FLOG10:
case ISD::STRICT_FEXP:
case ISD::STRICT_FEXP2:
Tmp1 = DAG.getNode(ISD::STRICT_FP_EXTEND, dl, {NVT, MVT::Other},
{Node->getOperand(0), Node->getOperand(1)});
Tmp2 = DAG.getNode(Node->getOpcode(), dl, {NVT, MVT::Other},

View file

@ -539,64 +539,41 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
else
setOperationAction(ISD::FCOPYSIGN, MVT::f16, Promote);
setOperationAction(ISD::FREM, MVT::f16, Promote);
setOperationAction(ISD::FREM, MVT::v4f16, Expand);
setOperationAction(ISD::FREM, MVT::v8f16, Expand);
setOperationAction(ISD::FPOW, MVT::f16, Promote);
setOperationAction(ISD::FPOW, MVT::v4f16, Expand);
setOperationAction(ISD::FPOW, MVT::v8f16, Expand);
setOperationAction(ISD::FPOWI, MVT::f16, Promote);
setOperationAction(ISD::FPOWI, MVT::v4f16, Expand);
setOperationAction(ISD::FPOWI, MVT::v8f16, Expand);
setOperationAction(ISD::FCOS, MVT::f16, Promote);
setOperationAction(ISD::FCOS, MVT::v4f16, Expand);
setOperationAction(ISD::FCOS, MVT::v8f16, Expand);
setOperationAction(ISD::FSIN, MVT::f16, Promote);
setOperationAction(ISD::FSIN, MVT::v4f16, Expand);
setOperationAction(ISD::FSIN, MVT::v8f16, Expand);
setOperationAction(ISD::FSINCOS, MVT::f16, Promote);
setOperationAction(ISD::FSINCOS, MVT::v4f16, Expand);
setOperationAction(ISD::FSINCOS, MVT::v8f16, Expand);
setOperationAction(ISD::FEXP, MVT::f16, Promote);
setOperationAction(ISD::FEXP, MVT::v4f16, Expand);
setOperationAction(ISD::FEXP, MVT::v8f16, Expand);
setOperationAction(ISD::FEXP2, MVT::f16, Promote);
setOperationAction(ISD::FEXP2, MVT::v4f16, Expand);
setOperationAction(ISD::FEXP2, MVT::v8f16, Expand);
setOperationAction(ISD::FLOG, MVT::f16, Promote);
setOperationAction(ISD::FLOG, MVT::v4f16, Expand);
setOperationAction(ISD::FLOG, MVT::v8f16, Expand);
setOperationAction(ISD::FLOG2, MVT::f16, Promote);
setOperationAction(ISD::FLOG2, MVT::v4f16, Expand);
setOperationAction(ISD::FLOG2, MVT::v8f16, Expand);
setOperationAction(ISD::FLOG10, MVT::f16, Promote);
setOperationAction(ISD::FLOG10, MVT::v4f16, Expand);
setOperationAction(ISD::FLOG10, MVT::v8f16, Expand);
for (auto Op : {ISD::FREM, ISD::FPOW, ISD::FPOWI,
ISD::FCOS, ISD::FSIN, ISD::FSINCOS,
ISD::FEXP, ISD::FEXP2, ISD::FLOG,
ISD::FLOG2, ISD::FLOG10, ISD::STRICT_FREM,
ISD::STRICT_FPOW, ISD::STRICT_FPOWI, ISD::STRICT_FCOS,
ISD::STRICT_FSIN, ISD::STRICT_FEXP, ISD::STRICT_FEXP2,
ISD::STRICT_FLOG, ISD::STRICT_FLOG2, ISD::STRICT_FLOG10}) {
setOperationAction(Op, MVT::f16, Promote);
setOperationAction(Op, MVT::v4f16, Expand);
setOperationAction(Op, MVT::v8f16, Expand);
}
if (!Subtarget->hasFullFP16()) {
setOperationAction(ISD::SELECT, MVT::f16, Promote);
setOperationAction(ISD::SELECT_CC, MVT::f16, Promote);
setOperationAction(ISD::SETCC, MVT::f16, Promote);
setOperationAction(ISD::BR_CC, MVT::f16, Promote);
setOperationAction(ISD::FADD, MVT::f16, Promote);
setOperationAction(ISD::FSUB, MVT::f16, Promote);
setOperationAction(ISD::FMUL, MVT::f16, Promote);
setOperationAction(ISD::FDIV, MVT::f16, Promote);
setOperationAction(ISD::FMA, MVT::f16, Promote);
setOperationAction(ISD::FNEG, MVT::f16, Promote);
setOperationAction(ISD::FABS, MVT::f16, Promote);
setOperationAction(ISD::FCEIL, MVT::f16, Promote);
setOperationAction(ISD::FSQRT, MVT::f16, Promote);
setOperationAction(ISD::FFLOOR, MVT::f16, Promote);
setOperationAction(ISD::FNEARBYINT, MVT::f16, Promote);
setOperationAction(ISD::FRINT, MVT::f16, Promote);
setOperationAction(ISD::FROUND, MVT::f16, Promote);
setOperationAction(ISD::FROUNDEVEN, MVT::f16, Promote);
setOperationAction(ISD::FTRUNC, MVT::f16, Promote);
setOperationAction(ISD::FMINNUM, MVT::f16, Promote);
setOperationAction(ISD::FMAXNUM, MVT::f16, Promote);
setOperationAction(ISD::FMINIMUM, MVT::f16, Promote);
setOperationAction(ISD::FMAXIMUM, MVT::f16, Promote);
for (auto Op :
{ISD::SELECT, ISD::SELECT_CC, ISD::SETCC,
ISD::BR_CC, ISD::FADD, ISD::FSUB,
ISD::FMUL, ISD::FDIV, ISD::FMA,
ISD::FNEG, ISD::FABS, ISD::FCEIL,
ISD::FSQRT, ISD::FFLOOR, ISD::FNEARBYINT,
ISD::FRINT, ISD::FROUND, ISD::FROUNDEVEN,
ISD::FTRUNC, ISD::FMINNUM, ISD::FMAXNUM,
ISD::FMINIMUM, ISD::FMAXIMUM, ISD::STRICT_FADD,
ISD::STRICT_FSUB, ISD::STRICT_FMUL, ISD::STRICT_FDIV,
ISD::STRICT_FMA, ISD::STRICT_FCEIL, ISD::STRICT_FFLOOR,
ISD::STRICT_FSQRT, ISD::STRICT_FRINT, ISD::STRICT_FNEARBYINT,
ISD::STRICT_FROUND, ISD::STRICT_FTRUNC, ISD::STRICT_FROUNDEVEN,
ISD::STRICT_FMINNUM, ISD::STRICT_FMAXNUM, ISD::STRICT_FMINIMUM,
ISD::STRICT_FMAXIMUM})
setOperationAction(Op, MVT::f16, Promote);
// Round-to-integer need custom lowering for fp16, as Promote doesn't work
// because the result type is integer.
for (auto Op : {ISD::STRICT_LROUND, ISD::STRICT_LLROUND, ISD::STRICT_LRINT,
ISD::STRICT_LLRINT})
setOperationAction(Op, MVT::f16, Custom);
// promote v4f16 to v4f32 when that is known to be safe.
setOperationAction(ISD::FADD, MVT::v4f16, Promote);
@ -1402,6 +1379,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
}
PredictableSelectIsExpensive = Subtarget->predictableSelectIsExpensive();
IsStrictFPEnabled = true;
}
void AArch64TargetLowering::addTypeForNEON(MVT VT) {
@ -2592,7 +2571,18 @@ static SDValue emitStrictFPComparison(SDValue LHS, SDValue RHS, const SDLoc &dl,
bool IsSignaling) {
EVT VT = LHS.getValueType();
assert(VT != MVT::f128);
assert(VT != MVT::f16 && "Lowering of strict fp16 not yet implemented");
const bool FullFP16 =
static_cast<const AArch64Subtarget &>(DAG.getSubtarget()).hasFullFP16();
if (VT == MVT::f16 && !FullFP16) {
LHS = DAG.getNode(ISD::STRICT_FP_EXTEND, dl, {MVT::f32, MVT::Other},
{Chain, LHS});
RHS = DAG.getNode(ISD::STRICT_FP_EXTEND, dl, {MVT::f32, MVT::Other},
{LHS.getValue(1), RHS});
Chain = RHS.getValue(1);
VT = MVT::f32;
}
unsigned Opcode =
IsSignaling ? AArch64ISD::STRICT_FCMPE : AArch64ISD::STRICT_FCMP;
return DAG.getNode(Opcode, dl, {VT, MVT::Other}, {Chain, LHS, RHS});
@ -3468,8 +3458,7 @@ SDValue AArch64TargetLowering::LowerVectorFP_TO_INT(SDValue Op,
MVT::getVectorVT(MVT::getFloatingPointVT(VT.getScalarSizeInBits()),
VT.getVectorNumElements());
if (IsStrict) {
SDValue Ext = DAG.getNode(ISD::STRICT_FP_EXTEND, dl,
{ExtVT, MVT::Other},
SDValue Ext = DAG.getNode(ISD::STRICT_FP_EXTEND, dl, {ExtVT, MVT::Other},
{Op.getOperand(0), Op.getOperand(1)});
return DAG.getNode(Op.getOpcode(), dl, {VT, MVT::Other},
{Ext.getValue(1), Ext.getValue(0)});
@ -3506,8 +3495,14 @@ SDValue AArch64TargetLowering::LowerFP_TO_INT(SDValue Op,
// f16 conversions are promoted to f32 when full fp16 is not supported.
if (SrcVal.getValueType() == MVT::f16 && !Subtarget->hasFullFP16()) {
assert(!IsStrict && "Lowering of strict fp16 not yet implemented");
SDLoc dl(Op);
if (IsStrict) {
SDValue Ext =
DAG.getNode(ISD::STRICT_FP_EXTEND, dl, {MVT::f32, MVT::Other},
{Op.getOperand(0), SrcVal});
return DAG.getNode(Op.getOpcode(), dl, {Op.getValueType(), MVT::Other},
{Ext.getValue(1), Ext.getValue(0)});
}
return DAG.getNode(
Op.getOpcode(), dl, Op.getValueType(),
DAG.getNode(ISD::FP_EXTEND, dl, MVT::f32, SrcVal));
@ -3730,10 +3725,15 @@ SDValue AArch64TargetLowering::LowerINT_TO_FP(SDValue Op,
SDValue SrcVal = Op.getOperand(IsStrict ? 1 : 0);
// f16 conversions are promoted to f32 when full fp16 is not supported.
if (Op.getValueType() == MVT::f16 &&
!Subtarget->hasFullFP16()) {
assert(!IsStrict && "Lowering of strict fp16 not yet implemented");
if (Op.getValueType() == MVT::f16 && !Subtarget->hasFullFP16()) {
SDLoc dl(Op);
if (IsStrict) {
SDValue Val = DAG.getNode(Op.getOpcode(), dl, {MVT::f32, MVT::Other},
{Op.getOperand(0), SrcVal});
return DAG.getNode(
ISD::STRICT_FP_ROUND, dl, {MVT::f16, MVT::Other},
{Val.getValue(1), Val.getValue(0), DAG.getIntPtrConstant(0, dl)});
}
return DAG.getNode(
ISD::FP_ROUND, dl, MVT::f16,
DAG.getNode(Op.getOpcode(), dl, MVT::f32, SrcVal),
@ -5367,6 +5367,18 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
return LowerCTTZ(Op, DAG);
case ISD::VECTOR_SPLICE:
return LowerVECTOR_SPLICE(Op, DAG);
case ISD::STRICT_LROUND:
case ISD::STRICT_LLROUND:
case ISD::STRICT_LRINT:
case ISD::STRICT_LLRINT: {
assert(Op.getOperand(1).getValueType() == MVT::f16 &&
"Expected custom lowering of rounding operations only for f16");
SDLoc DL(Op);
SDValue Ext = DAG.getNode(ISD::STRICT_FP_EXTEND, DL, {MVT::f32, MVT::Other},
{Op.getOperand(0), Op.getOperand(1)});
return DAG.getNode(Op.getOpcode(), DL, {Op.getValueType(), MVT::Other},
{Ext.getValue(1), Ext.getValue(0)});
}
}
}

File diff suppressed because it is too large Load diff

View file

@ -1,5 +1,5 @@
; RUN: llc -mtriple=aarch64-none-eabi %s -disable-strictnode-mutation -o - | FileCheck %s
; RUN: llc -mtriple=aarch64-none-eabi -global-isel=true -global-isel-abort=2 -disable-strictnode-mutation %s -o - | FileCheck %s
; RUN: llc -mtriple=aarch64-none-eabi %s -o - | FileCheck %s
; RUN: llc -mtriple=aarch64-none-eabi -global-isel=true -global-isel-abort=2 %s -o - | FileCheck %s
; Check that constrained fp intrinsics are correctly lowered.