[InstCombine] Optimize lshr+shl+and conversion pattern

if `C1` and `C3` are pow2 and `Log2(C3) >= C2`:
    ((C1 >> X) << C2) & C3 -> X == (Log2(C1)+C2-Log2(C3)) ? C3 : 0
https://alive2.llvm.org/ce/z/zvrkKF

Reviewed By: spatel

Differential Revision: https://reviews.llvm.org/D127469
This commit is contained in:
chenglin.bi 2022-06-14 11:06:10 +08:00
parent e99c07a30e
commit 286198ff04
2 changed files with 40 additions and 29 deletions

View file

@ -1917,25 +1917,40 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) {
Constant *C1, *C2;
const APInt *C3 = C;
Value *X;
if (C3->isPowerOf2() &&
match(Op0, m_OneUse(m_LShr(m_Shl(m_ImmConstant(C1), m_Value(X)),
m_ImmConstant(C2)))) &&
match(C1, m_Power2())) {
Constant *Log2C1 = ConstantExpr::getExactLogBase2(C1);
if (C3->isPowerOf2()) {
Constant *Log2C3 = ConstantInt::get(Ty, C3->countTrailingZeros());
Constant *LshrC = ConstantExpr::getAdd(C2, Log2C3);
KnownBits KnownLShrc = computeKnownBits(LshrC, 0, nullptr);
if (KnownLShrc.getMaxValue().ult(Width)) {
// iff C1,C3 is pow2 and C2 + cttz(C3) < BitWidth:
// ((C1 << X) >> C2) & C3 -> X == (cttz(C3)+C2-cttz(C1)) ? C3 : 0
Constant *CmpC = ConstantExpr::getSub(LshrC, Log2C1);
Value *Cmp = Builder.CreateICmpEQ(X, CmpC);
return SelectInst::Create(Cmp, ConstantInt::get(Ty, *C3),
ConstantInt::getNullValue(Ty));
if (match(Op0, m_OneUse(m_LShr(m_Shl(m_ImmConstant(C1), m_Value(X)),
m_ImmConstant(C2)))) &&
match(C1, m_Power2())) {
Constant *Log2C1 = ConstantExpr::getExactLogBase2(C1);
Constant *LshrC = ConstantExpr::getAdd(C2, Log2C3);
KnownBits KnownLShrc = computeKnownBits(LshrC, 0, nullptr);
if (KnownLShrc.getMaxValue().ult(Width)) {
// iff C1,C3 is pow2 and C2 + cttz(C3) < BitWidth:
// ((C1 << X) >> C2) & C3 -> X == (cttz(C3)+C2-cttz(C1)) ? C3 : 0
Constant *CmpC = ConstantExpr::getSub(LshrC, Log2C1);
Value *Cmp = Builder.CreateICmpEQ(X, CmpC);
return SelectInst::Create(Cmp, ConstantInt::get(Ty, *C3),
ConstantInt::getNullValue(Ty));
}
}
if (match(Op0, m_OneUse(m_Shl(m_LShr(m_ImmConstant(C1), m_Value(X)),
m_ImmConstant(C2)))) &&
match(C1, m_Power2())) {
Constant *Log2C1 = ConstantExpr::getExactLogBase2(C1);
Constant *Cmp =
ConstantExpr::getCompare(ICmpInst::ICMP_ULT, Log2C3, C2);
if (Cmp->isZeroValue()) {
// iff C1,C3 is pow2 and Log2(C3) >= C2:
// ((C1 >> X) << C2) & C3 -> X == (cttz(C1)+C2-cttz(C3)) ? C3 : 0
Constant *ShlC = ConstantExpr::getAdd(C2, Log2C1);
Constant *CmpC = ConstantExpr::getSub(ShlC, Log2C3);
Value *Cmp = Builder.CreateICmpEQ(X, CmpC);
return SelectInst::Create(Cmp, ConstantInt::get(Ty, *C3),
ConstantInt::getNullValue(Ty));
}
}
// TODO: Symmetrical case
// iff C1,C3 is pow2 and Log2(C3) >= C2:
// ((C1 >> X) << C2) & C3 -> X == (cttz(C1)+C2-cttz(C3)) ? C3 : 0
}
}

View file

@ -2062,9 +2062,8 @@ define i16 @lshr_shl_pow2_const_xor(i16 %x) {
define i16 @lshr_shl_pow2_const_case2(i16 %x) {
; CHECK-LABEL: @lshr_shl_pow2_const_case2(
; CHECK-NEXT: [[LSHR1:%.*]] = lshr i16 8192, [[X:%.*]]
; CHECK-NEXT: [[SHL:%.*]] = shl i16 [[LSHR1]], 4
; CHECK-NEXT: [[R:%.*]] = and i16 [[SHL]], 32
; CHECK-NEXT: [[TMP1:%.*]] = icmp eq i16 [[X:%.*]], 12
; CHECK-NEXT: [[R:%.*]] = select i1 [[TMP1]], i16 32, i16 0
; CHECK-NEXT: ret i16 [[R]]
;
%lshr1 = lshr i16 8192, %x
@ -2102,9 +2101,8 @@ define i16 @lshr_shl_pow2_const_negative_oneuse(i16 %x) {
define <3 x i16> @lshr_shl_pow2_const_case1_uniform_vec(<3 x i16> %x) {
; CHECK-LABEL: @lshr_shl_pow2_const_case1_uniform_vec(
; CHECK-NEXT: [[LSHR:%.*]] = lshr <3 x i16> <i16 8192, i16 8192, i16 8192>, [[X:%.*]]
; CHECK-NEXT: [[SHL:%.*]] = shl <3 x i16> [[LSHR]], <i16 6, i16 6, i16 6>
; CHECK-NEXT: [[R:%.*]] = and <3 x i16> [[SHL]], <i16 128, i16 128, i16 128>
; CHECK-NEXT: [[TMP1:%.*]] = icmp eq <3 x i16> [[X:%.*]], <i16 12, i16 12, i16 12>
; CHECK-NEXT: [[R:%.*]] = select <3 x i1> [[TMP1]], <3 x i16> <i16 128, i16 128, i16 128>, <3 x i16> zeroinitializer
; CHECK-NEXT: ret <3 x i16> [[R]]
;
%lshr = lshr <3 x i16> <i16 8192, i16 8192, i16 8192>, %x
@ -2141,9 +2139,8 @@ define <3 x i16> @lshr_shl_pow2_const_case1_non_uniform_vec_negative(<3 x i16> %
define <3 x i16> @lshr_shl_pow2_const_case1_undef1_vec(<3 x i16> %x) {
; CHECK-LABEL: @lshr_shl_pow2_const_case1_undef1_vec(
; CHECK-NEXT: [[LSHR:%.*]] = lshr <3 x i16> <i16 undef, i16 8192, i16 8192>, [[X:%.*]]
; CHECK-NEXT: [[SHL:%.*]] = shl <3 x i16> [[LSHR]], <i16 6, i16 6, i16 6>
; CHECK-NEXT: [[R:%.*]] = and <3 x i16> [[SHL]], <i16 128, i16 128, i16 128>
; CHECK-NEXT: [[TMP1:%.*]] = icmp eq <3 x i16> [[X:%.*]], <i16 -1, i16 12, i16 12>
; CHECK-NEXT: [[R:%.*]] = select <3 x i1> [[TMP1]], <3 x i16> <i16 128, i16 128, i16 128>, <3 x i16> zeroinitializer
; CHECK-NEXT: ret <3 x i16> [[R]]
;
%lshr = lshr <3 x i16> <i16 undef, i16 8192, i16 8192>, %x
@ -2154,9 +2151,8 @@ define <3 x i16> @lshr_shl_pow2_const_case1_undef1_vec(<3 x i16> %x) {
define <3 x i16> @lshr_shl_pow2_const_case1_undef2_vec(<3 x i16> %x) {
; CHECK-LABEL: @lshr_shl_pow2_const_case1_undef2_vec(
; CHECK-NEXT: [[LSHR:%.*]] = lshr <3 x i16> <i16 8192, i16 8192, i16 8192>, [[X:%.*]]
; CHECK-NEXT: [[SHL:%.*]] = shl <3 x i16> [[LSHR]], <i16 undef, i16 6, i16 6>
; CHECK-NEXT: [[R:%.*]] = and <3 x i16> [[SHL]], <i16 128, i16 128, i16 128>
; CHECK-NEXT: [[TMP1:%.*]] = icmp eq <3 x i16> [[X:%.*]], <i16 undef, i16 12, i16 12>
; CHECK-NEXT: [[R:%.*]] = select <3 x i1> [[TMP1]], <3 x i16> <i16 128, i16 128, i16 128>, <3 x i16> zeroinitializer
; CHECK-NEXT: ret <3 x i16> [[R]]
;
%lshr = lshr <3 x i16> <i16 8192, i16 8192, i16 8192>, %x