4747
4848#include < cassert>
4949#include < cstdint>
50+ #include < numeric>
5051
5152#include " mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
5253// Pull in all enum type and utility function definitions.
@@ -2412,9 +2413,38 @@ foldToElementsFromElements(ToElementsOp toElementsOp,
24122413 return success ();
24132414}
24142415
2416+ // / Folds vector.to_elements(vector.broadcast(%x)) for the scalar case only.
2417+ // /
2418+ // / Example:
2419+ // / %b = vector.broadcast %x : i32 to vector<3xf32>
2420+ // / %e:3 = vector.to_elements %b : vector<3xf32>
2421+ // / user_op %e#0, %e#1, %e#2
2422+ // / becomes:
2423+ // / user_op %x, %x, %x
2424+ // /
2425+ // / The vector source case is handled by a canonicalization pattern.
2426+ static LogicalResult
2427+ foldToElementsOfBroadcast (ToElementsOp toElementsOp,
2428+ SmallVectorImpl<OpFoldResult> &results) {
2429+ auto bcastOp = toElementsOp.getSource ().getDefiningOp <BroadcastOp>();
2430+ if (!bcastOp)
2431+ return failure ();
2432+ // Vectors are handled in the ToElementsOfBroadcast RewritePattern.
2433+ if (isa<VectorType>(bcastOp.getSource ().getType ()))
2434+ return failure ();
2435+
2436+ auto resultVecType = cast<VectorType>(toElementsOp.getSource ().getType ());
2437+
2438+ Value scalar = bcastOp.getSource ();
2439+ results.assign (resultVecType.getNumElements (), scalar);
2440+ return success ();
2441+ }
2442+
24152443LogicalResult ToElementsOp::fold (FoldAdaptor adaptor,
24162444 SmallVectorImpl<OpFoldResult> &results) {
2417- return foldToElementsFromElements (*this , results);
2445+ if (succeeded (foldToElementsFromElements (*this , results)))
2446+ return success ();
2447+ return foldToElementsOfBroadcast (*this , results);
24182448}
24192449
24202450LogicalResult
@@ -2427,6 +2457,94 @@ ToElementsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
24272457 return success ();
24282458}
24292459
2460+ // / Canonicalize `vector.to_elements(vector.broadcast(%v))` where `%v` is a
2461+ // / vector.
2462+ // / - Build `vector.to_elements %v` and remap each destination element to the
2463+ // / corresponding source element using broadcast rules (match or 1 →
2464+ // / replicate).
2465+ // /
2466+ // / Example:
2467+ // / %v = vector.broadcast %src : vector<2xf32> to vector<3x2xf32>
2468+ // / %e:6 = vector.to_elements %v : vector<3x2xf32>
2469+ // / becomes:
2470+ // / %src_elems:2 = vector.to_elements %src : vector<2xf32>
2471+ // / // uses: %src_elems#0, %src_elems#1, %src_elems#0,
2472+ // / // %src_elems#1, %src_elems#0, %src_elems#1
2473+ struct ToElementsOfBroadcast final : OpRewritePattern<ToElementsOp> {
2474+ using Base::Base;
2475+
2476+ LogicalResult matchAndRewrite (ToElementsOp toElementsOp,
2477+ PatternRewriter &rewriter) const override {
2478+ auto bcastOp = toElementsOp.getSource ().getDefiningOp <BroadcastOp>();
2479+ if (!bcastOp)
2480+ return failure ();
2481+
2482+ // Only handle broadcasts from a vector source here.
2483+ auto srcType = dyn_cast<VectorType>(bcastOp.getSource ().getType ());
2484+ if (!srcType)
2485+ return failure ();
2486+
2487+ auto dstType = cast<VectorType>(toElementsOp.getSource ().getType ());
2488+
2489+ ArrayRef<int64_t > dstShape = dstType.getShape ();
2490+ ArrayRef<int64_t > srcShape = srcType.getShape ();
2491+
2492+ int64_t dstRank = dstShape.size ();
2493+ int64_t srcRank = srcShape.size ();
2494+
2495+ // Create elements for the broadcast source vector.
2496+ auto srcElems = vector::ToElementsOp::create (
2497+ rewriter, toElementsOp.getLoc (), bcastOp.getSource ());
2498+
2499+ int64_t dstCount = std::accumulate (dstShape.begin (), dstShape.end (), 1 ,
2500+ std::multiplies<int64_t >());
2501+
2502+ SmallVector<Value> replacements;
2503+ replacements.reserve (dstCount);
2504+
2505+ // For each element of the destination, determine which element of the
2506+ // source should be used. We walk all destination positions using a single
2507+ // counter, decode it into per-dimension indices, then build the matching
2508+ // source position: use the same index where sizes match, and use 0 where
2509+ // the source size is 1 (replication). This mapping is needed so we can
2510+ // replace each result of to_elements with the corresponding element from
2511+ // the broadcast source.
2512+ // Inner-dimension stretch example:
2513+ // %v = vector.broadcast %src : vector<2x1x2xf32> to vector<2x3x2xf32>
2514+ // %e:12 = vector.to_elements %v : vector<2x3x2xf32>
2515+ // becomes:
2516+ // %src_elems:4 = vector.to_elements %src : vector<2x1x2xf32>
2517+ // // uses: %src_elems#0, %src_elems#1, %src_elems#0,
2518+ // // %src_elems#1, %src_elems#0, %src_elems#1,
2519+ // // %src_elems#2, %src_elems#3, %src_elems#2,
2520+ // // %src_elems#3, %src_elems#2, %src_elems#3
2521+
2522+ // Row-major strides for the destination shape.
2523+ SmallVector<int64_t > dstStrides = computeStrides (dstShape);
2524+ // Row-major strides for the source shape.
2525+ SmallVector<int64_t > srcStrides = computeStrides (srcShape);
2526+ SmallVector<int64_t > dstIdx (dstRank);
2527+ SmallVector<int64_t > srcIdx (srcRank);
2528+ for (int64_t lin = 0 ; lin < dstCount; ++lin) {
2529+ // Convert linear destination index to per-dimension indices.
2530+ dstIdx = delinearize (lin, dstStrides);
2531+ for (int64_t k = 0 ; k < srcRank; ++k)
2532+ srcIdx[k] = (srcShape[k] == 1 ) ? 0 : dstIdx[dstRank - srcRank + k];
2533+ // Convert per-dimension source indices back to a linear index.
2534+ int64_t srcLin = linearize (srcIdx, srcStrides);
2535+ replacements.push_back (srcElems.getResult (srcLin));
2536+ }
2537+
2538+ rewriter.replaceOp (toElementsOp, replacements);
2539+ return success ();
2540+ }
2541+ };
2542+
2543+ void ToElementsOp::getCanonicalizationPatterns (RewritePatternSet &results,
2544+ MLIRContext *context) {
2545+ results.add <ToElementsOfBroadcast>(context);
2546+ }
2547+
24302548// ===----------------------------------------------------------------------===//
24312549// FromElementsOp
24322550// ===----------------------------------------------------------------------===//
0 commit comments