@@ -1160,6 +1160,12 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
11601160 auto elementTy = resultTy.getElementType ();
11611161 Value input = op->getOperand (0 );
11621162
1163+ // Figure out the accType if needed
1164+ bool widenAccTy = std::is_same_v<OpTy, tosa::ReduceSumOp> &&
1165+ isa<FloatType>(elementTy) &&
1166+ cast<FloatType>(elementTy).isBF16 ();
1167+ Type accTy = widenAccTy ? rewriter.getF32Type () : elementTy;
1168+
11631169 SmallVector<int64_t > reduceShape;
11641170 SmallVector<Value> dynDims;
11651171 for (unsigned i = 0 ; i < inputTy.getRank (); i++) {
@@ -1174,11 +1180,11 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
11741180 inputs.push_back (input);
11751181
11761182 // First fill the output buffer with the init value.
1177- auto emptyTensor = tensor::EmptyOp::create (rewriter, loc, reduceShape,
1178- resultTy. getElementType () , dynDims)
1179- .getResult ();
1183+ auto emptyTensor =
1184+ tensor::EmptyOp::create (rewriter, loc, reduceShape, accTy , dynDims)
1185+ .getResult ();
11801186
1181- auto fillValueAttr = createInitialValueForReduceOp (op, elementTy , rewriter);
1187+ auto fillValueAttr = createInitialValueForReduceOp (op, accTy , rewriter);
11821188 if (!fillValueAttr)
11831189 return rewriter.notifyMatchFailure (
11841190 op, " No initial value found for reduction operation" );
@@ -1231,8 +1237,14 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
12311237 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
12321238 std::array<Value, 2 > binaryArgs{
12331239 blockArgs[0 ], isNanIgnoreMode ? blockArgs[2 ] : blockArgs[1 ]};
1234- auto result = createLinalgBodyCalculationForReduceOp (
1235- op, binaryArgs, elementTy, rewriter);
1240+
1241+ // If reduction type differs then extend (applicable to reduce_sum)
1242+ if (binaryArgs[0 ].getType () != accTy)
1243+ binaryArgs[0 ] = arith::ExtFOp::create (nestedBuilder, nestedLoc, accTy,
1244+ binaryArgs[0 ]);
1245+
1246+ auto result = createLinalgBodyCalculationForReduceOp (op, binaryArgs,
1247+ accTy, rewriter);
12361248 if (result)
12371249 didEncounterError = true ;
12381250
@@ -1273,12 +1285,11 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
12731285
12741286 // Create a tensor full of NaNs.
12751287 auto nanValueAttr = rewriter.getFloatAttr (
1276- elementTy ,
1288+ accTy ,
12771289 APFloat::getNaN (cast<FloatType>(elementTy).getFloatSemantics (), false ));
12781290 auto nanValue = arith::ConstantOp::create (rewriter, loc, nanValueAttr);
12791291 auto emptyNanTensor =
1280- tensor::EmptyOp::create (rewriter, loc, reduceShape,
1281- resultTy.getElementType (), dynDims)
1292+ tensor::EmptyOp::create (rewriter, loc, reduceShape, accTy, dynDims)
12821293 .getResult ();
12831294 auto nanFilledTensor =
12841295 linalg::FillOp::create (rewriter, loc, ValueRange{nanValue},
@@ -1288,8 +1299,7 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
12881299 // Create an empty tensor, non need to fill this since it will be
12891300 // overwritten by the select.
12901301 auto finalEmptyTensor =
1291- tensor::EmptyOp::create (rewriter, loc, reduceShape,
1292- resultTy.getElementType (), dynDims)
1302+ tensor::EmptyOp::create (rewriter, loc, reduceShape, accTy, dynDims)
12931303 .getResult ();
12941304
12951305 // Do a selection between the tensors akin to:
@@ -1304,9 +1314,32 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
13041314 linalgOp = linalgSelect;
13051315 }
13061316
1317+ // Truncate back to resultTy if needed
1318+ Value reducedRes = linalgOp->getResult (0 );
1319+ if (widenAccTy) {
1320+ auto resEmptyOp =
1321+ tensor::EmptyOp::create (rewriter, loc, reduceShape, elementTy, dynDims)
1322+ .getResult ();
1323+
1324+ const unsigned reducedRank =
1325+ cast<ShapedType>(reducedRes.getType ()).getRank ();
1326+ auto identityMap = rewriter.getMultiDimIdentityMap (reducedRank);
1327+ reducedRes =
1328+ linalg::GenericOp::create (
1329+ rewriter, loc, resEmptyOp.getType (), ValueRange{reducedRes},
1330+ ValueRange{resEmptyOp},
1331+ ArrayRef<AffineMap>{identityMap, identityMap},
1332+ getNParallelLoopsAttrs (reducedRank),
1333+ [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
1334+ Value truncf = arith::TruncFOp::create (nestedBuilder, nestedLoc,
1335+ elementTy, args[0 ]);
1336+ linalg::YieldOp::create (nestedBuilder, nestedLoc, truncf);
1337+ })
1338+ .getResults ()[0 ];
1339+ }
1340+
13071341 SmallVector<ReassociationExprs, 4 > reassociationMap;
1308- uint64_t expandInputRank =
1309- cast<ShapedType>(linalgOp->getResults ()[0 ].getType ()).getRank ();
1342+ uint64_t expandInputRank = cast<ShapedType>(reducedRes.getType ()).getRank ();
13101343 reassociationMap.resize (expandInputRank);
13111344
13121345 for (uint64_t i = 0 ; i < expandInputRank; i++) {
@@ -1324,8 +1357,8 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
13241357 // since here we know which dimension to expand, and `tosa::ReshapeOp` would
13251358 // not have access to such information. This matters when handling dynamically
13261359 // sized tensors.
1327- rewriter.replaceOpWithNewOp <tensor::ExpandShapeOp>(
1328- op, resultTy, linalgOp-> getResults ()[ 0 ], reassociationMap);
1360+ rewriter.replaceOpWithNewOp <tensor::ExpandShapeOp>(op, resultTy, reducedRes,
1361+ reassociationMap);
13291362 return success ();
13301363}
13311364
0 commit comments