Skip to content

Commit

Permalink
[InstCombine] Fold comparison of integers by parts
Browse files Browse the repository at this point in the history
Let's say you represent (i32, i32) as an i64 from which the parts
are extracted with lshr/trunc. Then, if you compare two tuples by
parts you get something like A[0] == B[0] && A[1] == B[1], just
that the part extraction happens by lshr/trunc and not a narrow
load or similar.

The fold implemented here reduces such equality comparisons by
converting them into a comparison on a larger part of the integer
(which might be the whole integer). It handles both the "and of eq"
and the conjugated "or of ne" case.

I'm being conservative with one-use for now, though this could be
relaxed if profitable (the base pattern converts 11 instructions
into 5 instructions, but there's quite a few variations on how it
can play out).

Differential Revision: https://reviews.llvm.org/D101232
  • Loading branch information
nikic committed May 10, 2021
1 parent 93a9a8a commit 463ea28
Show file tree
Hide file tree
Showing 2 changed files with 179 additions and 240 deletions.
87 changes: 87 additions & 0 deletions llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1076,6 +1076,87 @@ static Value *foldUnsignedUnderflowCheck(ICmpInst *ZeroICmp,
return nullptr;
}

struct IntPart {
Value *From;
unsigned StartBit;
unsigned NumBits;
};

/// Match an extraction of bits from an integer.
static Optional<IntPart> matchIntPart(Value *V) {
Value *X;
if (!match(V, m_OneUse(m_Trunc(m_Value(X)))))
return None;

unsigned NumOriginalBits = X->getType()->getScalarSizeInBits();
unsigned NumExtractedBits = V->getType()->getScalarSizeInBits();
Value *Y;
const APInt *Shift;
// For a trunc(lshr Y, Shift) pattern, make sure we're only extracting bits
// from Y, not any shifted-in zeroes.
if (match(X, m_OneUse(m_LShr(m_Value(Y), m_APInt(Shift)))) &&
Shift->ule(NumOriginalBits - NumExtractedBits))
return {{Y, (unsigned)Shift->getZExtValue(), NumExtractedBits}};
return {{X, 0, NumExtractedBits}};
}

/// Materialize an extraction of bits from an integer in IR.
static Value *extractIntPart(const IntPart &P, IRBuilderBase &Builder) {
Value *V = P.From;
if (P.StartBit)
V = Builder.CreateLShr(V, P.StartBit);
Type *TruncTy = V->getType()->getWithNewBitWidth(P.NumBits);
if (TruncTy != V->getType())
V = Builder.CreateTrunc(V, TruncTy);
return V;
}

/// (icmp eq X0, Y0) & (icmp eq X1, Y1) -> icmp eq X01, Y01
/// (icmp ne X0, Y0) | (icmp ne X1, Y1) -> icmp ne X01, Y01
/// where X0, X1 and Y0, Y1 are adjacent parts extracted from an integer.
static Value *foldEqOfParts(ICmpInst *Cmp0, ICmpInst *Cmp1, bool IsAnd,
InstCombiner::BuilderTy &Builder) {
if (!Cmp0->hasOneUse() || !Cmp1->hasOneUse())
return nullptr;

CmpInst::Predicate Pred = IsAnd ? CmpInst::ICMP_EQ : CmpInst::ICMP_NE;
if (Cmp0->getPredicate() != Pred || Cmp1->getPredicate() != Pred)
return nullptr;

Optional<IntPart> L0 = matchIntPart(Cmp0->getOperand(0));
Optional<IntPart> R0 = matchIntPart(Cmp0->getOperand(1));
Optional<IntPart> L1 = matchIntPart(Cmp1->getOperand(0));
Optional<IntPart> R1 = matchIntPart(Cmp1->getOperand(1));
if (!L0 || !R0 || !L1 || !R1)
return nullptr;

// Make sure the LHS/RHS compare a part of the same value, possibly after
// an operand swap.
if (L0->From != L1->From || R0->From != R1->From) {
if (L0->From != R1->From || R0->From != L1->From)
return nullptr;
std::swap(L1, R1);
}

// Make sure the extracted parts are adjacent, canonicalizing to L0/R0 being
// the low part and L1/R1 being the high part.
if (L0->StartBit + L0->NumBits != L1->StartBit ||
R0->StartBit + R0->NumBits != R1->StartBit) {
if (L1->StartBit + L1->NumBits != L0->StartBit ||
R1->StartBit + R1->NumBits != R0->StartBit)
return nullptr;
std::swap(L0, L1);
std::swap(R0, R1);
}

// We can simplify to a comparison of these larger parts of the integers.
IntPart L = {L0->From, L0->StartBit, L0->NumBits + L1->NumBits};
IntPart R = {R0->From, R0->StartBit, R0->NumBits + R1->NumBits};
Value *LValue = extractIntPart(L, Builder);
Value *RValue = extractIntPart(R, Builder);
return Builder.CreateICmp(Pred, LValue, RValue);
}

/// Reduce logic-of-compares with equality to a constant by substituting a
/// common operand with the constant. Callers are expected to call this with
/// Cmp0/Cmp1 switched to handle logic op commutativity.
Expand Down Expand Up @@ -1181,6 +1262,9 @@ Value *InstCombinerImpl::foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS,
foldUnsignedUnderflowCheck(RHS, LHS, /*IsAnd=*/true, Q, Builder))
return X;

if (Value *X = foldEqOfParts(LHS, RHS, /*IsAnd=*/true, Builder))
return X;

// This only handles icmp of constants: (icmp1 A, C1) & (icmp2 B, C2).
Value *LHS0 = LHS->getOperand(0), *RHS0 = RHS->getOperand(0);

Expand Down Expand Up @@ -2411,6 +2495,9 @@ Value *InstCombinerImpl::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
foldUnsignedUnderflowCheck(RHS, LHS, /*IsAnd=*/false, Q, Builder))
return X;

if (Value *X = foldEqOfParts(LHS, RHS, /*IsAnd=*/false, Builder))
return X;

// (icmp ne A, 0) | (icmp ne B, 0) --> (icmp ne (A|B), 0)
// TODO: Remove this when foldLogOpOfMaskedICmps can handle vectors.
if (PredL == ICmpInst::ICMP_NE && match(LHS1, m_Zero()) &&
Expand Down
Loading

0 comments on commit 463ea28

Please sign in to comment.