@@ -97,19 +97,6 @@ static DISubprogram *getSubprogram(DIScope *Scope) {
9797 return cast<DILocalScope>(Scope)->getSubprogram ();
9898}
9999
100- // / Erase \p V from \p BB and move \II forward to avoid invalidating
101- // / iterators.
102- static void eraseFromParentAndMove (Value *V, BasicBlock::reverse_iterator &II,
103- BasicBlock &BB) {
104- auto *Inst = cast<Instruction>(V);
105- // Still used, don't erase.
106- if (!Inst->use_empty ())
107- return ;
108- if (II != BB.rend () && Inst == &*II)
109- ++II;
110- Inst->eraseFromParent ();
111- }
112-
113100// / Return true if V is a splat of a value (which is used when multiplying a
114101// / matrix with a scalar).
115102static bool isSplat (Value *V) {
@@ -259,7 +246,7 @@ static bool isUniformShape(Value *V) {
259246// / Return the ShapeInfo for the result of \p I, it it can be determined.
260247static std::optional<ShapeInfo>
261248computeShapeInfoForInst (Instruction *I,
262- const ValueMap <Value *, ShapeInfo> &ShapeMap) {
249+ const DenseMap <Value *, ShapeInfo> &ShapeMap) {
263250 Value *M;
264251 Value *N;
265252 Value *K;
@@ -493,10 +480,16 @@ class LowerMatrixIntrinsics {
493480 // / the result value of the instruction, with the only exceptions being store
494481 // / instructions and the matrix_column_major_store intrinsics. For those, the
495482 // / shape information indicates that those instructions should be lowered
496- // / using shape information as well. A ValueMap is used so that when
497- // / sub-passes like optimizeTransposes performs RAUW the map stays
498- // / up-to-date.
499- ValueMap<Value *, ShapeInfo> ShapeMap;
483+ // / using shape information as well. Note that extra care is needed when
484+ // / erasing or RAUW'ing a value that is present in ShapeMap. If the
485+ // / replacement is also a matrix operation, use
486+ // / updateShapeAndReplaceAllUsesWith to make sure the replacement is added to
487+ // / ShapeMap. We don't use ValueMap, as there are also cases where we do not
488+ // / want to add shape information for a replacement instruction. When directly
489+ // / erasing a value with an entry in ShapeMap, use
490+ // / eraseFromParentAndRemoveFromShapeMap to make sure ShapeMap is also updated
491+ // / accordingly.
492+ DenseMap<Value *, ShapeInfo> ShapeMap;
500493
501494 // / List of instructions to remove. While lowering, we are not replacing all
502495 // / users of a lowered instruction, if shape information is available and
@@ -758,6 +751,30 @@ class LowerMatrixIntrinsics {
758751 return Operation (T0, Shape0.t (), T1, Shape1.t ());
759752 }
760753
754+ // / Erase \p Inst from both ShapeMap (if an entry exists) and erase \p Inst
755+ // / itself.
756+ void eraseFromParentAndRemoveFromShapeMap (Instruction *Inst) {
757+ auto Iter = ShapeMap.find (Inst);
758+ if (Iter != ShapeMap.end ())
759+ ShapeMap.erase (Iter);
760+ Inst->eraseFromParent ();
761+ }
762+
763+ // / Erase \p V from \p BB and move \II forward to avoid invalidating
764+ // / iterators.
765+ void eraseFromParentAndMove (Value *V, BasicBlock::reverse_iterator &II,
766+ BasicBlock &BB) {
767+ auto *Inst = cast<Instruction>(V);
768+ // Still used, don't erase.
769+ if (!Inst->use_empty ())
770+ return ;
771+ if (II != BB.rend () && Inst == &*II)
772+ ++II;
773+ eraseFromParentAndRemoveFromShapeMap (Inst);
774+ }
775+
776+ // / Add a new entry to ShapeMap for \p New with \p Old's shape info, erase the
777+ // / entry for \p Old and replace all uses of \p Old with \p New.
761778 void updateShapeAndReplaceAllUsesWith (Instruction &Old, Value *New) {
762779 // We need to remove Old from the ShapeMap otherwise RAUW will replace it
763780 // with New. We should only add New it it supportsShapeInfo so we insert
@@ -871,13 +888,13 @@ class LowerMatrixIntrinsics {
871888
872889 void liftTranspose (Instruction &I) {
873890 // Erase dead Instructions after lifting transposes from binops.
874- auto CleanupBinOp = [](Instruction &T, Value *A, Value *B) {
891+ auto CleanupBinOp = [this ](Instruction &T, Value *A, Value *B) {
875892 if (T.use_empty ())
876- T. eraseFromParent ( );
893+ eraseFromParentAndRemoveFromShapeMap (&T );
877894 if (A->use_empty ())
878- cast<Instruction>(A)-> eraseFromParent ( );
895+ eraseFromParentAndRemoveFromShapeMap ( cast<Instruction>(A));
879896 if (A != B && B->use_empty ())
880- cast<Instruction>(B)-> eraseFromParent ( );
897+ eraseFromParentAndRemoveFromShapeMap ( cast<Instruction>(B));
881898 };
882899
883900 Value *A, *B, *AT, *BT;
@@ -1484,7 +1501,7 @@ class LowerMatrixIntrinsics {
14841501 m_Value (Arg)))) {
14851502 auto *NewLoad = Builder.CreateLoad (Op->getType (), Arg);
14861503 Op->replaceAllUsesWith (NewLoad);
1487- cast<Instruction>(Op)-> eraseFromParent ( );
1504+ eraseFromParentAndRemoveFromShapeMap ( cast<Instruction>(Op));
14881505 return ;
14891506 } else if (match (Op, m_Intrinsic<Intrinsic::matrix_transpose>(
14901507 m_Value (Arg)))) {
@@ -1853,15 +1870,15 @@ class LowerMatrixIntrinsics {
18531870 // Mark eliminated instructions as fused and remove them.
18541871 FusedInsts.insert (Store);
18551872 FusedInsts.insert (MatMul);
1856- Store-> eraseFromParent ( );
1857- MatMul-> eraseFromParent ( );
1873+ eraseFromParentAndRemoveFromShapeMap (Store );
1874+ eraseFromParentAndRemoveFromShapeMap (MatMul );
18581875 if (LoadOp0->hasNUses (0 )) {
18591876 FusedInsts.insert (LoadOp0);
1860- LoadOp0-> eraseFromParent ( );
1877+ eraseFromParentAndRemoveFromShapeMap (LoadOp0 );
18611878 }
18621879 if (LoadOp1 != LoadOp0 && LoadOp1->hasNUses (0 )) {
18631880 FusedInsts.insert (LoadOp1);
1864- LoadOp1-> eraseFromParent ( );
1881+ eraseFromParentAndRemoveFromShapeMap (LoadOp1 );
18651882 }
18661883 }
18671884
0 commit comments