Skip to content

Commit 3ad5765

Browse files
authored
[LV] Check all users of partial reductions in chain have same scale. (#162822)
Check that all partial reductions in a chain are only used by other partial reductions with the same scale factor. Otherwise we end up creating users of scaled reductions where the types of the other operands don't match. A similar issue was addressed in #158603, but misses the chained cases. Fixes #162530. PR: #162822
1 parent c940bfd commit 3ad5765

File tree

2 files changed

+50
-8
lines changed

2 files changed

+50
-8
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7933,6 +7933,26 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
79337933
(!Chain.ExtendB || ExtendIsOnlyUsedByPartialReductions(Chain.ExtendB)))
79347934
ScaledReductionMap.try_emplace(Chain.Reduction, Pair.second);
79357935
}
7936+
7937+
// Check that all partial reductions in a chain are only used by other
7938+
// partial reductions with the same scale factor. Otherwise we end up creating
7939+
// users of scaled reductions where the types of the other operands don't
7940+
// match.
7941+
for (const auto &[Chain, Scale] : PartialReductionChains) {
7942+
auto AllUsersPartialRdx = [ScaleVal = Scale, this](const User *U) {
7943+
auto *UI = cast<Instruction>(U);
7944+
if (isa<PHINode>(UI) && UI->getParent() == OrigLoop->getHeader()) {
7945+
return all_of(UI->users(), [ScaleVal, this](const User *U) {
7946+
auto *UI = cast<Instruction>(U);
7947+
return ScaledReductionMap.lookup_or(UI, 0) == ScaleVal;
7948+
});
7949+
}
7950+
return ScaledReductionMap.lookup_or(UI, 0) == ScaleVal ||
7951+
!OrigLoop->contains(UI->getParent());
7952+
};
7953+
if (!all_of(Chain.Reduction->users(), AllUsersPartialRdx))
7954+
ScaledReductionMap.erase(Chain.Reduction);
7955+
}
79367956
}
79377957

79387958
bool VPRecipeBuilder::getScaledReductions(
@@ -8116,11 +8136,8 @@ VPRecipeBase *VPRecipeBuilder::tryToCreateWidenRecipe(VPSingleDefRecipe *R,
81168136
if (isa<LoadInst>(Instr) || isa<StoreInst>(Instr))
81178137
return tryToWidenMemory(Instr, Operands, Range);
81188138

8119-
if (std::optional<unsigned> ScaleFactor = getScalingForReduction(Instr)) {
8120-
if (auto PartialRed =
8121-
tryToCreatePartialReduction(Instr, Operands, ScaleFactor.value()))
8122-
return PartialRed;
8123-
}
8139+
if (std::optional<unsigned> ScaleFactor = getScalingForReduction(Instr))
8140+
return tryToCreatePartialReduction(Instr, Operands, ScaleFactor.value());
81248141

81258142
if (!shouldWiden(Instr, Range))
81268143
return nullptr;
@@ -8154,9 +8171,9 @@ VPRecipeBuilder::tryToCreatePartialReduction(Instruction *Reduction,
81548171
isa<VPPartialReductionRecipe>(BinOpRecipe))
81558172
std::swap(BinOp, Accumulator);
81568173

8157-
if (ScaleFactor !=
8158-
vputils::getVFScaleFactor(Accumulator->getDefiningRecipe()))
8159-
return nullptr;
8174+
assert(ScaleFactor ==
8175+
vputils::getVFScaleFactor(Accumulator->getDefiningRecipe()) &&
8176+
"all accumulators in chain must have same scale factor");
81608177

81618178
unsigned ReductionOpcode = Reduction->getOpcode();
81628179
if (ReductionOpcode == Instruction::Sub) {

llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-incomplete-chains.ll

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,28 @@ loop:
7070
exit:
7171
ret i32 %red.next
7272
}
73+
74+
define i16 @test_incomplete_chain_without_mul(ptr noalias %dst, ptr %A, ptr %B) #0 {
75+
entry:
76+
br label %loop
77+
78+
loop:
79+
%iv = phi i64 [ 0, %entry ], [ %iv.next, %loop ]
80+
%red = phi i16 [ 0, %entry ], [ %red.next, %loop ]
81+
%l.a = load i8, ptr %A, align 1
82+
%a.ext = zext i8 %l.a to i16
83+
store i16 %a.ext, ptr %dst, align 2
84+
%l.b = load i8, ptr %B, align 1
85+
%b.ext = zext i8 %l.b to i16
86+
%add = add i16 %red, %b.ext
87+
%add.1 = add i16 %add, %a.ext
88+
%red.next = add i16 %add.1, %b.ext
89+
%iv.next = add i64 %iv, 1
90+
%ec = icmp ult i64 %iv, 1024
91+
br i1 %ec, label %loop, label %exit
92+
93+
exit:
94+
ret i16 %red.next
95+
}
96+
97+
attributes #0 = { "target-cpu"="grace" }

0 commit comments

Comments
 (0)