@@ -1146,37 +1146,25 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
11461146 Attribute oneIdxAttr = rewriter.getIndexAttr (1 );
11471147 Location loc = packOp.getLoc ();
11481148
1149- Value input = getPackOpSourceOrPaddedSource (rewriter, packOp);
1150- DenseMap<int64_t , OpFoldResult> dimAndTileMapping =
1151- packOp.getDimAndTileMapping ();
11521149 int64_t srcRank = packOp.getSourceRank ();
11531150 int64_t destRank = packOp.getDestRank ();
1154- int64_t numTiles = destRank - srcRank;
1151+ ArrayRef<int64_t > innerDimsPos = packOp.getInnerDimsPos ();
1152+ int64_t numberOfTiles = innerDimsPos.size ();
11551153
1156- // 1. Extract the inner tile sizes.
1157- // Where possible, values are replaced with constant attributes (to match the
1158- // behaviour of `getPackOpSourceOrPaddedSource`).
1159- SmallVector<OpFoldResult> tileSizes;
1160- for (auto i : llvm::seq<unsigned >(0 , srcRank)) {
1161- if (dimAndTileMapping.count (i)) {
1162- // Rather than taking the tile size as is, extact the actual constant
1163- // value Attribute where possible, e.g.:
1164- // [Value: %tile_size = arith.constant 8 : index] --> [Attribute: 8]
1165- auto [_, tileSize] =
1166- getSimplifiedOfrAndStaticSizePair (dimAndTileMapping[i], rewriter);
1167- tileSizes.push_back (tileSize);
1168- }
1169- }
1154+ // 1. Get the input that is going to be packed. If the input requires padding,
1155+ // add a padding operation and return that as the input.
1156+ Value input = getPackOpSourceOrPaddedSource (rewriter, packOp);
11701157
11711158 // 2. Transpose the input to match the inner tile order:
11721159 // %init = tensor.empty()
11731160 // %transposed_tile = linalg.transpose ins(%source_or_padded_source),
11741161 // outs(%init)
11751162 // Assumptions made:
1176- // 1. All outer dims are 1 - the corresponding transposition order doesn't
1163+ // - All outer dims are 1 - the corresponding transposition order doesn't
11771164 // matter, but requires all dim indices to be present.
1165+
1166+ // 2.1 Get the permutation for linalg.transpose
11781167 SmallVector<int64_t > srcPermForTranspose;
1179- ArrayRef<int64_t > innerDimPos (packOp.getInnerDimsPos ());
11801168 for (int64_t i = 0 ; i < srcRank; i++) {
11811169 // We assume the `k` dimensions of the inner dim position, where `k` is the
11821170 // rank of the inner tiling, correspond to the last `k` indices of the
@@ -1185,27 +1173,34 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
11851173 // rank of the source tensor. For example if we have a source tensor with
11861174 // indices [0, 1, 2, 3] and inner dim position of [3, 0], the remaining
11871175 // indices are [1, 2]. and the transpose will be [1, 2, 3, 0].
1188- if (llvm::is_contained (innerDimPos , i))
1176+ if (llvm::is_contained (innerDimsPos , i))
11891177 continue ;
11901178 srcPermForTranspose.push_back (i);
11911179 }
1192- srcPermForTranspose.append (innerDimPos.begin (), innerDimPos.end ());
1180+ srcPermForTranspose.append (innerDimsPos.begin (), innerDimsPos.end ());
1181+
1182+ // 2.2 Create the init tensor for linalg.transpose with the correct shape
1183+ SmallVector<OpFoldResult> shapeForEmptyOp (srcRank - numberOfTiles,
1184+ oneIdxAttr);
1185+ shapeForEmptyOp.append (packOp.getMixedTiles ());
1186+
1187+ // getMixedTiles() may contain Values pointing to constant ops, not the
1188+ // constant attributes. Replace them with a true OpFoldResult.
1189+ llvm::transform (shapeForEmptyOp, shapeForEmptyOp.begin (),
1190+ [&](OpFoldResult ofr) {
1191+ if (auto val = llvm::dyn_cast<Value>(ofr))
1192+ return getAsOpFoldResult (val);
1193+ return ofr;
1194+ });
11931195
11941196 LDBG () << " Pack permutation: " << packOp;
11951197 LDBG () << " perm: " << llvm::interleaved (srcPermForTranspose);
1198+ LDBG () << " Shape of empty tensor: " << llvm::interleaved (shapeForEmptyOp);
11961199
1197- // 2.1 Create tensor.empty (init value for TransposeOp)
1198- SmallVector<OpFoldResult> transShapeForEmptyOp (srcRank - numTiles,
1199- oneIdxAttr);
1200- transShapeForEmptyOp.append (tileSizes);
1201-
1202- applyPermutationToVector<OpFoldResult>(transShapeForEmptyOp,
1203- srcPermForTranspose);
1204- Value empty =
1205- tensor::EmptyOp::create (rewriter, loc, transShapeForEmptyOp,
1206- packOp.getSourceType ().getElementType ());
1200+ Value empty = tensor::EmptyOp::create (
1201+ rewriter, loc, shapeForEmptyOp, packOp.getSourceType ().getElementType ());
12071202
1208- // 2.2 Create linalg.transpose
1203+ // 2.3 Create linalg.transpose
12091204 auto transposedOp = linalg::TransposeOp::create (rewriter, loc, input, empty,
12101205 srcPermForTranspose);
12111206
@@ -1214,8 +1209,7 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
12141209 SmallVector<OpFoldResult> writeStrides (destRank, oneIdxAttr);
12151210 SmallVector<OpFoldResult> writeOffsets (destRank, zeroIdxAttr);
12161211 // Outer dims are all 1s!
1217- SmallVector<OpFoldResult> writeSizes (destRank - dimAndTileMapping.size (),
1218- oneIdxAttr);
1212+ SmallVector<OpFoldResult> writeSizes (destRank - numberOfTiles, oneIdxAttr);
12191213 SmallVector<int64_t > writeShape;
12201214
12211215 for (auto tileSize : packOp.getMixedTiles ()) {
0 commit comments