Skip to content

Commit 506f858

Browse files
committed
[AMDGPU] Use fmac_f64 in "if (cond) a -= c"
1 parent a2fcef0 commit 506f858

File tree

3 files changed

+1225
-9
lines changed

3 files changed

+1225
-9
lines changed

llvm/lib/Target/AMDGPU/GCNPreRAOptimizations.cpp

Lines changed: 198 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#include "MCTargetDesc/AMDGPUMCTargetDesc.h"
3737
#include "SIRegisterInfo.h"
3838
#include "llvm/CodeGen/LiveIntervals.h"
39+
#include "llvm/CodeGen/MachineLoopInfo.h"
3940
#include "llvm/CodeGen/MachineFunctionPass.h"
4041
#include "llvm/InitializePasses.h"
4142

@@ -45,17 +46,33 @@ using namespace llvm;
4546

4647
namespace {
4748

49+
static bool isImmConstant(const MachineOperand &Op, int64_t Val) {
50+
return Op.isImm() && Op.getImm() == Val;
51+
}
52+
4853
class GCNPreRAOptimizationsImpl {
4954
private:
5055
const SIInstrInfo *TII;
5156
const SIRegisterInfo *TRI;
5257
MachineRegisterInfo *MRI;
5358
LiveIntervals *LIS;
59+
MachineLoopInfo *MLI;
5460

5561
bool processReg(Register Reg);
5662

63+
bool isSingleUseVReg(Register Reg) const {
64+
return Reg.isVirtual() && MRI->hasOneUse(Reg);
65+
}
66+
67+
bool isConstMove(MachineInstr &MI, int64_t C) const {
68+
return TII->isFoldableCopy(MI) && isImmConstant(MI.getOperand(1), C);
69+
}
70+
71+
bool optimizeConditionalFMAPattern(MachineInstr &FMAInstr);
72+
5773
public:
58-
GCNPreRAOptimizationsImpl(LiveIntervals *LS) : LIS(LS) {}
74+
GCNPreRAOptimizationsImpl(LiveIntervals *LS, MachineLoopInfo *MLI)
75+
: LIS(LS), MLI(MLI) {}
5976
bool run(MachineFunction &MF);
6077
};
6178

@@ -75,6 +92,7 @@ class GCNPreRAOptimizationsLegacy : public MachineFunctionPass {
7592

7693
void getAnalysisUsage(AnalysisUsage &AU) const override {
7794
AU.addRequired<LiveIntervalsWrapperPass>();
95+
AU.addRequired<MachineLoopInfoWrapperPass>();
7896
AU.setPreservesAll();
7997
MachineFunctionPass::getAnalysisUsage(AU);
8098
}
@@ -84,6 +102,7 @@ class GCNPreRAOptimizationsLegacy : public MachineFunctionPass {
84102
INITIALIZE_PASS_BEGIN(GCNPreRAOptimizationsLegacy, DEBUG_TYPE,
85103
"AMDGPU Pre-RA optimizations", false, false)
86104
INITIALIZE_PASS_DEPENDENCY(LiveIntervalsWrapperPass)
105+
INITIALIZE_PASS_DEPENDENCY(MachineLoopInfoWrapperPass)
87106
INITIALIZE_PASS_END(GCNPreRAOptimizationsLegacy, DEBUG_TYPE,
88107
"Pre-RA optimizations", false, false)
89108

@@ -229,14 +248,17 @@ bool GCNPreRAOptimizationsLegacy::runOnMachineFunction(MachineFunction &MF) {
229248
if (skipFunction(MF.getFunction()))
230249
return false;
231250
LiveIntervals *LIS = &getAnalysis<LiveIntervalsWrapperPass>().getLIS();
232-
return GCNPreRAOptimizationsImpl(LIS).run(MF);
251+
MachineLoopInfo *MLI =
252+
&getAnalysis<MachineLoopInfoWrapperPass>().getLI();
253+
return GCNPreRAOptimizationsImpl(LIS, MLI).run(MF);
233254
}
234255

235256
PreservedAnalyses
236257
GCNPreRAOptimizationsPass::run(MachineFunction &MF,
237258
MachineFunctionAnalysisManager &MFAM) {
238259
LiveIntervals *LIS = &MFAM.getResult<LiveIntervalsAnalysis>(MF);
239-
GCNPreRAOptimizationsImpl(LIS).run(MF);
260+
MachineLoopInfo *MLI = &MFAM.getResult<MachineLoopAnalysis>(MF);
261+
GCNPreRAOptimizationsImpl(LIS, MLI).run(MF);
240262
return PreservedAnalyses::all();
241263
}
242264

@@ -260,6 +282,13 @@ bool GCNPreRAOptimizationsImpl::run(MachineFunction &MF) {
260282
Changed |= processReg(Reg);
261283
}
262284

285+
for (MachineBasicBlock &MBB : MF) {
286+
for (MachineInstr &MI : make_early_inc_range(MBB)) {
287+
if (MI.getOpcode() == AMDGPU::V_FMAC_F64_e32)
288+
Changed |= optimizeConditionalFMAPattern(MI);
289+
}
290+
}
291+
263292
if (!ST.useRealTrue16Insts())
264293
return Changed;
265294

@@ -295,3 +324,169 @@ bool GCNPreRAOptimizationsImpl::run(MachineFunction &MF) {
295324

296325
return Changed;
297326
}
327+
328+
/// Conditional FMA to Conditional Subtraction:
329+
///
330+
/// Detects a pattern where an FMA is used to conditionally subtract a value:
331+
/// FMA(dst, cond ? -1.0 : 0.0, value, accum) -> accum - (cond ? value : 0)
332+
///
333+
/// Pattern detected:
334+
/// v_mov_b32_e32 vNegOneHi, 0xbff00000 ; -1.0 high bits (single use)
335+
/// v_mov_b32_e32 vMul.lo, 0 ; (single use)
336+
/// v_cndmask_b32_e64 vMul.hi, 0, vNegOneHi, vCondReg ; (single use)
337+
/// v_fmac_f64_e32 vDst[0:1], vMul[0:1], vValue[0:1] ; vDst is tied to vAccum
338+
///
339+
/// Transformed to (3 instructions instead of 4, lower register pressure):
340+
/// v_cndmask_b32_e64 vCondValue.lo, 0, vValue.lo, vCondReg
341+
/// v_cndmask_b32_e64 vCondValue.hi, 0, vValue.hi, vCondReg
342+
/// v_add_f64_e64 vDst[0:1], vAccum[0:1], -vCondValue[0:1]
343+
///
344+
/// Benefits: Reduces instruction count from 4 to 3, and register pressure by
345+
/// eliminating the need for -1.0 constant and zero/conditional intermediate
346+
/// values.
347+
bool GCNPreRAOptimizationsImpl::optimizeConditionalFMAPattern(
348+
MachineInstr &FMAInstr) {
349+
assert(FMAInstr.getOpcode() == AMDGPU::V_FMAC_F64_e32);
350+
351+
MachineOperand *MulOp = TII->getNamedOperand(FMAInstr, AMDGPU::OpName::src0);
352+
assert(MulOp);
353+
if (!MulOp->isReg() || !isSingleUseVReg(MulOp->getReg()))
354+
return false;
355+
356+
// Find subregister definitions for the 64-bit multiplicand register
357+
MachineInstr *MulLoDefMI = nullptr;
358+
MachineInstr *MulHiDefMI = nullptr;
359+
360+
for (auto &DefMI : MRI->def_instructions(MulOp->getReg())) {
361+
if (DefMI.getOperand(0).getSubReg() == AMDGPU::sub0) {
362+
MulLoDefMI = &DefMI;
363+
} else if (DefMI.getOperand(0).getSubReg() == AMDGPU::sub1) {
364+
MulHiDefMI = &DefMI;
365+
}
366+
}
367+
368+
// Check sub0 is zero constant (representing low 32 bits of 0.0 or -1.0)
369+
if (!MulLoDefMI || !isConstMove(*MulLoDefMI, 0))
370+
return false;
371+
372+
// Check sub1 is a conditional mask: condition ? 0xbff00000 : 0
373+
if (!MulHiDefMI || MulHiDefMI->getOpcode() != AMDGPU::V_CNDMASK_B32_e64)
374+
return false;
375+
376+
MachineInstr *CndMaskMI = MulHiDefMI;
377+
MachineOperand *CndMaskFalseOp =
378+
TII->getNamedOperand(*CndMaskMI, AMDGPU::OpName::src0);
379+
assert(CndMaskFalseOp);
380+
if (!isImmConstant(*CndMaskFalseOp, 0))
381+
return false;
382+
383+
MachineOperand *CndMaskTrueOp =
384+
TII->getNamedOperand(*CndMaskMI, AMDGPU::OpName::src1);
385+
assert(CndMaskTrueOp);
386+
if (!isSingleUseVReg(CndMaskTrueOp->getReg()))
387+
return false;
388+
389+
// Check that the true operand is -1.0's high 32 bits (0xbff00000)
390+
MachineOperand *NegOneHiDef = MRI->getOneDef(CndMaskTrueOp->getReg());
391+
if (!NegOneHiDef ||
392+
!isConstMove(*NegOneHiDef->getParent(), -1074790400 /* 0xbff00000 */))
393+
return false;
394+
395+
MachineInstr *NegOneHiMovMI = NegOneHiDef->getParent();
396+
397+
MachineInstr *OldMI[] = {&FMAInstr, MulLoDefMI, MulHiDefMI, NegOneHiMovMI};
398+
399+
// Don't transform if FMAInstr is in a loop: it only makes sense if both
400+
// cndmasks in the target pattern could be hoisted out of the loop, let's not
401+
// overcomplicate this. Exception: all the instructions are in the same loop.
402+
if (MachineLoop *L = MLI->getLoopFor(FMAInstr.getParent())) {
403+
for (MachineInstr *MI : drop_begin(OldMI)) {
404+
if (MLI->getLoopFor(MI->getParent()) != L)
405+
return false;
406+
}
407+
}
408+
409+
// Perform the transformation
410+
// Extract operands from FMA: vDst = vAccum + vMul * vValue
411+
auto *DstOpnd = TII->getNamedOperand(FMAInstr, AMDGPU::OpName::vdst);
412+
auto *ValueOpnd = TII->getNamedOperand(FMAInstr, AMDGPU::OpName::src1);
413+
auto *AccumOpnd = TII->getNamedOperand(FMAInstr, AMDGPU::OpName::src2);
414+
auto *CondOpnd = TII->getNamedOperand(*CndMaskMI, AMDGPU::OpName::src2);
415+
assert(DstOpnd && ValueOpnd && AccumOpnd && CondOpnd);
416+
417+
Register DstReg = DstOpnd->getReg();
418+
Register ValueReg = ValueOpnd->getReg();
419+
Register AccumReg = AccumOpnd->getReg();
420+
Register CondReg = CondOpnd->getReg();
421+
422+
// Create a new 64-bit register for the conditional value
423+
Register CondValueReg =
424+
MRI->createVirtualRegister(MRI->getRegClass(ValueReg));
425+
426+
MachineBasicBlock::iterator InsertPt = FMAInstr.getIterator();
427+
DebugLoc DL = FMAInstr.getDebugLoc();
428+
429+
// Build: vCondValue.lo = condition ? vValue.lo : 0
430+
MachineBasicBlock *MBB = FMAInstr.getParent();
431+
MachineInstr *SelLo =
432+
BuildMI(*MBB, InsertPt, DL, TII->get(AMDGPU::V_CNDMASK_B32_e64))
433+
.addReg(CondValueReg, RegState::DefineNoRead, AMDGPU::sub0)
434+
.addImm(0) // src0_modifiers
435+
.addImm(0) // src0 (false value = 0)
436+
.addImm(0) // src1_modifiers
437+
.addReg(ValueReg, 0, AMDGPU::sub0) // src1 (true value = vValue.lo)
438+
.addReg(CondReg) // condition
439+
.getInstr();
440+
441+
// Build: vCondValue.hi = condition ? vValue.hi : 0
442+
MachineInstr *SelHi =
443+
BuildMI(*MBB, InsertPt, DL, TII->get(AMDGPU::V_CNDMASK_B32_e64))
444+
.addReg(CondValueReg, RegState::Define, AMDGPU::sub1)
445+
.addImm(0) // src0_modifiers
446+
.addImm(0) // src0 (false value = 0)
447+
.addImm(0) // src1_modifiers
448+
.addReg(ValueReg, 0, AMDGPU::sub1) // src1 (true value = vValue.hi)
449+
.addReg(CondReg) // condition
450+
.getInstr();
451+
452+
// Build: vDst = vAccum - vCondValue (negation via src1_modifiers bit)
453+
MachineInstr *Sub =
454+
BuildMI(*MBB, InsertPt, DL, TII->get(AMDGPU::V_ADD_F64_e64))
455+
.addReg(DstReg, RegState::Define)
456+
.addImm(0) // src0_modifiers
457+
.addReg(AccumReg) // src0 (accumulator)
458+
.addImm(1) // src1_modifiers (negation bit)
459+
.addReg(CondValueReg) // src1 (negated conditional value)
460+
.addImm(0) // clamp
461+
.addImm(0) // omod
462+
.getInstr();
463+
464+
// Delete the old instructions
465+
for (MachineInstr *MI : OldMI) {
466+
LIS->RemoveMachineInstrFromMaps(*MI);
467+
MI->eraseFromParent();
468+
}
469+
470+
LIS->InsertMachineInstrInMaps(*SelLo);
471+
LIS->InsertMachineInstrInMaps(*SelHi);
472+
LIS->InsertMachineInstrInMaps(*Sub);
473+
474+
// Removed registers.
475+
LIS->removeInterval(MulOp->getReg());
476+
LIS->removeInterval(CndMaskTrueOp->getReg());
477+
478+
// Reused registers.
479+
LIS->removeInterval(CondReg);
480+
LIS->createAndComputeVirtRegInterval(CondReg);
481+
482+
LIS->removeInterval(DstReg);
483+
LIS->createAndComputeVirtRegInterval(DstReg);
484+
485+
LIS->removeInterval(ValueReg);
486+
LIS->createAndComputeVirtRegInterval(ValueReg);
487+
488+
// New register.
489+
LIS->createAndComputeVirtRegInterval(CondValueReg);
490+
491+
return true;
492+
}

0 commit comments

Comments
 (0)