@@ -21952,74 +21952,72 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
2195221952
2195321953 SDLoc DL(N);
2195421954
21955- // The narrower of the two operands. Used as the accumulator
21956- auto NarrowOp = N->getOperand(1);
21957- auto MulOp = N ->getOperand(2);
21958- if (MulOp-> getOpcode() != ISD::MUL )
21955+ SDValue Op2 = N->getOperand(2);
21956+ if (Op2->getOpcode() != ISD::MUL ||
21957+ !ISD::isExtOpcode(Op2 ->getOperand(0)->getOpcode()) ||
21958+ !ISD::isExtOpcode(Op2->getOperand(1)-> getOpcode()) )
2195921959 return SDValue();
2196021960
21961- auto ExtA = MulOp->getOperand(0);
21962- auto ExtB = MulOp->getOperand(1);
21963-
21964- if (!ISD::isExtOpcode(ExtA->getOpcode()) ||
21965- !ISD::isExtOpcode(ExtB->getOpcode()))
21966- return SDValue();
21967- bool AIsSigned = ExtA->getOpcode() == ISD::SIGN_EXTEND;
21968- bool BIsSigned = ExtB->getOpcode() == ISD::SIGN_EXTEND;
21961+ SDValue Acc = N->getOperand(1);
21962+ SDValue Mul = N->getOperand(2);
21963+ SDValue ExtMulOpLHS = Mul->getOperand(0);
21964+ SDValue ExtMulOpRHS = Mul->getOperand(1);
2196921965
21970- auto A = ExtA ->getOperand(0);
21971- auto B = ExtB ->getOperand(0);
21972- if (A .getValueType() != B .getValueType())
21966+ SDValue MulOpLHS = ExtMulOpLHS ->getOperand(0);
21967+ SDValue MulOpRHS = ExtMulOpRHS ->getOperand(0);
21968+ if (MulOpLHS .getValueType() != MulOpRHS .getValueType())
2197321969 return SDValue();
2197421970
21975- EVT ReducedType = N->getValueType(0);
21976- EVT MulSrcType = A .getValueType();
21971+ EVT ReducedVT = N->getValueType(0);
21972+ EVT MulSrcVT = MulOpLHS .getValueType();
2197721973
2197821974 // Dot products operate on chunks of four elements so there must be four times
2197921975 // as many elements in the wide type
21980- if (!(ReducedType == MVT::nxv4i64 && MulSrcType == MVT::nxv16i8) &&
21981- !(ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8) &&
21982- !(ReducedType == MVT::nxv2i64 && MulSrcType == MVT::nxv8i16) &&
21983- !(ReducedType == MVT::v4i64 && MulSrcType == MVT::v16i8) &&
21984- !(ReducedType == MVT::v4i32 && MulSrcType == MVT::v16i8) &&
21985- !(ReducedType == MVT::v2i32 && MulSrcType == MVT::v8i8))
21976+ if (!(ReducedVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) &&
21977+ !(ReducedVT == MVT::nxv4i32 && MulSrcVT == MVT::nxv16i8) &&
21978+ !(ReducedVT == MVT::nxv2i64 && MulSrcVT == MVT::nxv8i16) &&
21979+ !(ReducedVT == MVT::v4i64 && MulSrcVT == MVT::v16i8) &&
21980+ !(ReducedVT == MVT::v4i32 && MulSrcVT == MVT::v16i8) &&
21981+ !(ReducedVT == MVT::v2i32 && MulSrcVT == MVT::v8i8))
2198621982 return SDValue();
2198721983
21984+ bool MulOpLHSIsSigned = ExtMulOpLHS->getOpcode() == ISD::SIGN_EXTEND;
21985+ bool MulOpRHSIsSigned = ExtMulOpRHS->getOpcode() == ISD::SIGN_EXTEND;
2198821986 // If the extensions are mixed, we should lower it to a usdot instead
2198921987 unsigned Opcode = 0;
21990- if (AIsSigned != BIsSigned ) {
21988+ if (MulOpLHSIsSigned != MulOpRHSIsSigned ) {
2199121989 if (!Subtarget->hasMatMulInt8())
2199221990 return SDValue();
2199321991
2199421992 bool Scalable = N->getValueType(0).isScalableVT();
2199521993 // There's no nxv2i64 version of usdot
21996- if (Scalable && ReducedType != MVT::nxv4i32 && ReducedType != MVT::nxv4i64)
21994+ if (Scalable && ReducedVT != MVT::nxv4i32 && ReducedVT != MVT::nxv4i64)
2199721995 return SDValue();
2199821996
2199921997 Opcode = AArch64ISD::USDOT;
2200021998 // USDOT expects the signed operand to be last
22001- if (!BIsSigned )
22002- std::swap(A, B );
22003- } else if (AIsSigned )
21999+ if (!MulOpRHSIsSigned )
22000+ std::swap(MulOpLHS, MulOpRHS );
22001+ } else if (MulOpLHSIsSigned )
2200422002 Opcode = AArch64ISD::SDOT;
2200522003 else
2200622004 Opcode = AArch64ISD::UDOT;
2200722005
2200822006 // Partial reduction lowering for (nx)v16i8 to (nx)v4i64 requires an i32 dot
2200922007 // product followed by a zero / sign extension
22010- if ((ReducedType == MVT::nxv4i64 && MulSrcType == MVT::nxv16i8) ||
22011- (ReducedType == MVT::v4i64 && MulSrcType == MVT::v16i8)) {
22012- EVT ReducedTypeI32 =
22013- (ReducedType .isScalableVector()) ? MVT::nxv4i32 : MVT::v4i32;
22008+ if ((ReducedVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) ||
22009+ (ReducedVT == MVT::v4i64 && MulSrcVT == MVT::v16i8)) {
22010+ EVT ReducedVTI32 =
22011+ (ReducedVT .isScalableVector()) ? MVT::nxv4i32 : MVT::v4i32;
2201422012
22015- auto DotI32 = DAG.getNode(Opcode, DL, ReducedTypeI32,
22016- DAG.getConstant(0 , DL, ReducedTypeI32), A, B);
22017- auto Extended = DAG.getSExtOrTrunc(DotI32 , DL, ReducedType );
22018- return DAG.getNode(ISD::ADD , DL, NarrowOp.getValueType(), NarrowOp,
22019- Extended);
22013+ SDValue DotI32 =
22014+ DAG.getNode(Opcode , DL, ReducedVTI32,
22015+ DAG.getConstant(0 , DL, ReducedVTI32), MulOpLHS, MulOpRHS );
22016+ SDValue Extended = DAG.getSExtOrTrunc(DotI32 , DL, ReducedVT);
22017+ return DAG.getNode(ISD::ADD, DL, ReducedVT, Acc, Extended);
2202022018 }
2202122019
22022- return DAG.getNode(Opcode, DL, ReducedType, NarrowOp, A, B );
22020+ return DAG.getNode(Opcode, DL, ReducedVT, Acc, MulOpLHS, MulOpRHS );
2202322021}
2202422022
2202522023SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
@@ -22036,32 +22034,29 @@ SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
2203622034
2203722035 SDLoc DL(N);
2203822036
22039- auto Acc = N->getOperand(1);
22040- auto ExtInput = N->getOperand(2);
22041-
22042- EVT AccVT = Acc.getValueType();
22043- EVT AccElemVT = AccVT.getVectorElementType();
22044-
22045- if (ExtInput.getValueType().getVectorElementType() != AccElemVT)
22037+ if (!ISD::isExtOpcode(N->getOperand(2).getOpcode()))
2204622038 return SDValue();
22047-
22048- unsigned ExtInputOpcode = ExtInput->getOpcode();
22049- if (!ISD::isExtOpcode(ExtInputOpcode))
22039+ SDValue Acc = N->getOperand(1);
22040+ SDValue Ext = N->getOperand(2);
22041+ EVT AccVT = Acc.getValueType();
22042+ EVT ExtVT = Ext.getValueType();
22043+ if (ExtVT.getVectorElementType() != AccVT.getVectorElementType())
2205022044 return SDValue();
2205122045
22052- auto Input = ExtInput ->getOperand(0);
22053- EVT InputVT = Input .getValueType();
22046+ SDValue ExtOp = Ext ->getOperand(0);
22047+ EVT ExtOpVT = ExtOp .getValueType();
2205422048
22055- if (!(InputVT == MVT::nxv4i32 && AccVT == MVT::nxv2i64) &&
22056- !(InputVT == MVT::nxv8i16 && AccVT == MVT::nxv4i32) &&
22057- !(InputVT == MVT::nxv16i8 && AccVT == MVT::nxv8i16))
22049+ if (!(ExtOpVT == MVT::nxv4i32 && AccVT == MVT::nxv2i64) &&
22050+ !(ExtOpVT == MVT::nxv8i16 && AccVT == MVT::nxv4i32) &&
22051+ !(ExtOpVT == MVT::nxv16i8 && AccVT == MVT::nxv8i16))
2205822052 return SDValue();
2205922053
22060- bool InputIsSigned = ExtInputOpcode == ISD::SIGN_EXTEND;
22061- auto BottomOpcode = InputIsSigned ? AArch64ISD::SADDWB : AArch64ISD::UADDWB;
22062- auto TopOpcode = InputIsSigned ? AArch64ISD::SADDWT : AArch64ISD::UADDWT;
22063- auto BottomNode = DAG.getNode(BottomOpcode, DL, AccVT, Acc, Input);
22064- return DAG.getNode(TopOpcode, DL, AccVT, BottomNode, Input);
22054+ bool ExtOpIsSigned = Ext.getOpcode() == ISD::SIGN_EXTEND;
22055+ unsigned BottomOpcode =
22056+ ExtOpIsSigned ? AArch64ISD::SADDWB : AArch64ISD::UADDWB;
22057+ unsigned TopOpcode = ExtOpIsSigned ? AArch64ISD::SADDWT : AArch64ISD::UADDWT;
22058+ SDValue BottomNode = DAG.getNode(BottomOpcode, DL, AccVT, Acc, ExtOp);
22059+ return DAG.getNode(TopOpcode, DL, AccVT, BottomNode, ExtOp);
2206522060}
2206622061
2206722062static SDValue performIntrinsicCombine(SDNode *N,
@@ -22073,9 +22068,9 @@ static SDValue performIntrinsicCombine(SDNode *N,
2207322068 default:
2207422069 break;
2207522070 case Intrinsic::experimental_vector_partial_reduce_add: {
22076- if (auto Dot = tryLowerPartialReductionToDot(N, Subtarget, DAG))
22071+ if (SDValue Dot = tryLowerPartialReductionToDot(N, Subtarget, DAG))
2207722072 return Dot;
22078- if (auto WideAdd = tryLowerPartialReductionToWideAdd(N, Subtarget, DAG))
22073+ if (SDValue WideAdd = tryLowerPartialReductionToWideAdd(N, Subtarget, DAG))
2207922074 return WideAdd;
2208022075 return DAG.getPartialReduceAdd(SDLoc(N), N->getValueType(0),
2208122076 N->getOperand(1), N->getOperand(2));
0 commit comments