diff --git a/llvm/include/llvm/CodeGen/LiveIntervalUnion.h b/llvm/include/llvm/CodeGen/LiveIntervalUnion.h index cc0f2a45bb182..643f62fa235b1 100644 --- a/llvm/include/llvm/CodeGen/LiveIntervalUnion.h +++ b/llvm/include/llvm/CodeGen/LiveIntervalUnion.h @@ -93,6 +93,10 @@ class LiveIntervalUnion { // Remove a live virtual register's segments from this union. void extract(const LiveInterval &VirtReg, const LiveRange &Range); + // Remove all segments referencing VirtReg. This may be used if the register + // isn't used anymore. + void clear_all_segments_referencing(const LiveInterval &VirtReg); + // Remove all inserted virtual registers. void clear() { Segments.clear(); ++Tag; } diff --git a/llvm/include/llvm/CodeGen/LiveRegMatrix.h b/llvm/include/llvm/CodeGen/LiveRegMatrix.h index 0bc243271bb73..14c653244fe16 100644 --- a/llvm/include/llvm/CodeGen/LiveRegMatrix.h +++ b/llvm/include/llvm/CodeGen/LiveRegMatrix.h @@ -135,6 +135,8 @@ class LiveRegMatrix { /// the assignment and updates VirtRegMap accordingly. void unassign(const LiveInterval &VirtReg); + void unassign(Register VirtReg); + /// Returns true if the given \p PhysReg has any live intervals assigned. bool isPhysRegUsed(MCRegister PhysReg) const; @@ -168,6 +170,16 @@ class LiveRegMatrix { LiveIntervalUnion *getLiveUnions() { return &Matrix[0]; } Register getOneVReg(unsigned PhysReg) const; + + /// Verify that all LiveInterval pointers in the matrix are valid. + /// This checks that each LiveInterval referenced in LiveIntervalUnion + /// actually exists in LiveIntervals and is not a dangling pointer. + /// Returns true if the matrix is valid, false if dangling pointers are found. + /// This is primarily useful for debugging heap-use-after-free issues. + /// This method uses a lazy approach - it builds a set of valid LiveInterval + /// pointers on-demand and has zero runtime/memory overhead during normal + /// register allocation. + bool isValid() const; }; class LiveRegMatrixWrapperLegacy : public MachineFunctionPass { diff --git a/llvm/lib/CodeGen/InlineSpiller.cpp b/llvm/lib/CodeGen/InlineSpiller.cpp index c3e0964594bd5..269c17d3dfbd4 100644 --- a/llvm/lib/CodeGen/InlineSpiller.cpp +++ b/llvm/lib/CodeGen/InlineSpiller.cpp @@ -86,6 +86,7 @@ class HoistSpillHelper : private LiveRangeEdit::Delegate { const TargetInstrInfo &TII; const TargetRegisterInfo &TRI; const MachineBlockFrequencyInfo &MBFI; + LiveRegMatrix &Matrix; InsertPointAnalysis IPA; @@ -129,16 +130,17 @@ class HoistSpillHelper : private LiveRangeEdit::Delegate { public: HoistSpillHelper(const Spiller::RequiredAnalyses &Analyses, - MachineFunction &mf, VirtRegMap &vrm) + MachineFunction &mf, VirtRegMap &vrm, LiveRegMatrix &matrix) : MF(mf), LIS(Analyses.LIS), LSS(Analyses.LSS), MDT(Analyses.MDT), VRM(vrm), MRI(mf.getRegInfo()), TII(*mf.getSubtarget().getInstrInfo()), TRI(*mf.getSubtarget().getRegisterInfo()), MBFI(Analyses.MBFI), - IPA(LIS, mf.getNumBlockIDs()) {} + Matrix(matrix), IPA(LIS, mf.getNumBlockIDs()) {} void addToMergeableSpills(MachineInstr &Spill, int StackSlot, Register Original); bool rmFromMergeableSpills(MachineInstr &Spill, int StackSlot); void hoistAllSpills(); + bool LRE_CanEraseVirtReg(Register) override; void LRE_DidCloneVirtReg(Register, Register) override; }; @@ -191,7 +193,7 @@ class InlineSpiller : public Spiller { : MF(MF), LIS(Analyses.LIS), LSS(Analyses.LSS), VRM(VRM), MRI(MF.getRegInfo()), TII(*MF.getSubtarget().getInstrInfo()), TRI(*MF.getSubtarget().getRegisterInfo()), Matrix(Matrix), - HSpiller(Analyses, MF, VRM), VRAI(VRAI) {} + HSpiller(Analyses, MF, VRM, *Matrix), VRAI(VRAI) {} void spill(LiveRangeEdit &, AllocationOrder *Order = nullptr) override; ArrayRef getSpilledRegs() override { return RegsToSpill; } @@ -1750,6 +1752,20 @@ void HoistSpillHelper::hoistAllSpills() { } } +/// Called before a virtual register is erased from LiveIntervals. +/// Forcibly remove the register from LiveRegMatrix before it's deleted, +/// preventing dangling pointers. +bool HoistSpillHelper::LRE_CanEraseVirtReg(Register VirtReg) { + // If this virtual register is assigned to a physical register, + // unassign it from LiveRegMatrix before the interval is deleted. + // Use unassign_and_clear_all_refs() instead of unassign() because the + // LiveInterval may already be empty or in an inconsistent state. + if (VRM.hasPhys(VirtReg)) { + Matrix.unassign(VirtReg); + } + return true; // Allow deletion to proceed +} + /// For VirtReg clone, the \p New register should have the same physreg or /// stackslot as the \p old register. void HoistSpillHelper::LRE_DidCloneVirtReg(Register New, Register Old) { diff --git a/llvm/lib/CodeGen/LiveIntervalUnion.cpp b/llvm/lib/CodeGen/LiveIntervalUnion.cpp index eb547c5238432..f5643b9d2ca83 100644 --- a/llvm/lib/CodeGen/LiveIntervalUnion.cpp +++ b/llvm/lib/CodeGen/LiveIntervalUnion.cpp @@ -79,6 +79,19 @@ void LiveIntervalUnion::extract(const LiveInterval &VirtReg, } } +void LiveIntervalUnion::clear_all_segments_referencing( + const LiveInterval &VirtReg) { + ++Tag; + + // Remove all segments referencing VirtReg. + for (SegmentIter SegPos = Segments.begin(); SegPos.valid();) { + if (SegPos.value() == &VirtReg) + SegPos.erase(); + else + ++SegPos; + } +} + void LiveIntervalUnion::print(raw_ostream &OS, const TargetRegisterInfo *TRI) const { if (empty()) { diff --git a/llvm/lib/CodeGen/LiveRegMatrix.cpp b/llvm/lib/CodeGen/LiveRegMatrix.cpp index cfda262aac82d..a3d1d4561bef2 100644 --- a/llvm/lib/CodeGen/LiveRegMatrix.cpp +++ b/llvm/lib/CodeGen/LiveRegMatrix.cpp @@ -12,11 +12,13 @@ #include "llvm/CodeGen/LiveRegMatrix.h" #include "RegisterCoalescer.h" +#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/Statistic.h" #include "llvm/CodeGen/LiveInterval.h" #include "llvm/CodeGen/LiveIntervalUnion.h" #include "llvm/CodeGen/LiveIntervals.h" #include "llvm/CodeGen/MachineFunction.h" +#include "llvm/CodeGen/MachineRegisterInfo.h" #include "llvm/CodeGen/TargetRegisterInfo.h" #include "llvm/CodeGen/TargetSubtargetInfo.h" #include "llvm/CodeGen/VirtRegMap.h" @@ -142,6 +144,21 @@ void LiveRegMatrix::unassign(const LiveInterval &VirtReg) { LLVM_DEBUG(dbgs() << '\n'); } +void LiveRegMatrix::unassign(Register VirtReg) { + Register PhysReg = VRM->getPhys(VirtReg); + LLVM_DEBUG(dbgs() << "unassigning " << printReg(VirtReg, TRI) + << " from " << printReg(PhysReg, TRI) << ':'); + VRM->clearVirt(VirtReg); + + assert(LIS->hasInterval(VirtReg)); + const LiveInterval &LI = LIS->getInterval(VirtReg); + for (MCRegUnit Unit : TRI->regunits(PhysReg)) { + Matrix[Unit].clear_all_segments_referencing(LI); + } + ++NumUnassigned; + LLVM_DEBUG(dbgs() << '\n'); +} + bool LiveRegMatrix::isPhysRegUsed(MCRegister PhysReg) const { for (MCRegUnit Unit : TRI->regunits(PhysReg)) { if (!Matrix[Unit].empty()) @@ -290,6 +307,32 @@ Register LiveRegMatrix::getOneVReg(unsigned PhysReg) const { return MCRegister::NoRegister; } +bool LiveRegMatrix::isValid() const { + // Build set of all valid LiveInterval pointers from LiveIntervals. + DenseSet ValidIntervals; + for (unsigned RegIdx = 0, NumRegs = VRM->getRegInfo().getNumVirtRegs(); + RegIdx < NumRegs; ++RegIdx) { + Register VReg = Register::index2VirtReg(RegIdx); + // Only track assigned registers since unassigned ones won't be in Matrix + if (VRM->hasPhys(VReg) && LIS->hasInterval(VReg)) + ValidIntervals.insert(&LIS->getInterval(VReg)); + } + + // Now scan all LiveIntervalUnions in the matrix and verify each pointer + unsigned NumDanglingPointers = 0; + for (unsigned Unit = 0, NumUnits = Matrix.size(); Unit != NumUnits; ++Unit) { + for (const LiveInterval *LI : Matrix[Unit]) { + if (!ValidIntervals.contains(LI)) { + ++NumDanglingPointers; + dbgs() << "ERROR: LiveInterval pointer is not found in LiveIntervals:\n" + << " Register Unit: " << printRegUnit(Unit, TRI) << "\n" + << " LiveInterval pointer: " << LI << "\n"; + } + } + } + return NumDanglingPointers == 0; +} + AnalysisKey LiveRegMatrixAnalysis::Key; LiveRegMatrix LiveRegMatrixAnalysis::run(MachineFunction &MF, diff --git a/llvm/lib/CodeGen/RegAllocBase.cpp b/llvm/lib/CodeGen/RegAllocBase.cpp index 2400a1feea26e..f8e2daea8a340 100644 --- a/llvm/lib/CodeGen/RegAllocBase.cpp +++ b/llvm/lib/CodeGen/RegAllocBase.cpp @@ -155,6 +155,13 @@ void RegAllocBase::allocatePhysRegs() { void RegAllocBase::postOptimization() { spiller().postOptimization(); + + // Verify that LiveRegMatrix has no dangling pointers after spilling. + // This catches bugs where LiveIntervals are deleted but not removed from + // the LiveRegMatrix (e.g., LLVM bug #48911). + assert(Matrix->isValid() && + "LiveRegMatrix contains dangling pointers after postOptimization"); + for (auto *DeadInst : DeadRemats) { LIS->RemoveMachineInstrFromMaps(*DeadInst); DeadInst->eraseFromParent(); diff --git a/llvm/lib/Target/AMDGPU/SIPreAllocateWWMRegs.cpp b/llvm/lib/Target/AMDGPU/SIPreAllocateWWMRegs.cpp index ecfaa5c70e9d3..e402068b93c3f 100644 --- a/llvm/lib/Target/AMDGPU/SIPreAllocateWWMRegs.cpp +++ b/llvm/lib/Target/AMDGPU/SIPreAllocateWWMRegs.cpp @@ -153,10 +153,11 @@ void SIPreAllocateWWMRegs::rewriteRegs(MachineFunction &MF) { SIMachineFunctionInfo *MFI = MF.getInfo(); for (unsigned Reg : RegsToRewrite) { - LIS->removeInterval(Reg); - const Register PhysReg = VRM->getPhys(Reg); assert(PhysReg != 0); + + Matrix->unassign(Reg); + LIS->removeInterval(Reg); MFI->reserveWWMRegister(PhysReg); }