@@ -2306,9 +2306,6 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
23062306 if (!matchPattern (op.getTransposed (), m_TorchConstantBool (&transposed)))
23072307 return rewriter.notifyMatchFailure (
23082308 op, " Unimplemented: non-constant value for transposed not supported" );
2309- if (transposed)
2310- return rewriter.notifyMatchFailure (
2311- op, " Unimplemented: transposed convolution not supported" );
23122309
23132310 auto input = adaptor.getInput ();
23142311 auto weight = adaptor.getWeight ();
@@ -2340,12 +2337,19 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
23402337 auto bias = adaptor.getBias ();
23412338
23422339 if (isa<Torch::NoneType>(bias.getType ())) {
2343- auto bias_result = tosa::getConvBiasForNoneType (op, rewriter, inputElemTy,
2344- outputElemTy, weightShape);
2345- if (failed (bias_result))
2340+ // ConvTranspose weights use IOHW; the helper expects OIHW, so swap
2341+ // dims 0/1 before we synthesize the bias.
2342+ SmallVector<int64_t , 4 > biasWeightShape =
2343+ transposed ? SmallVector<int64_t , 4 >{weightShape[1 ], weightShape[0 ],
2344+ weightShape[2 ], weightShape[3 ]}
2345+ : weightShape;
2346+
2347+ auto biasResult = tosa::getConvBiasForNoneType (
2348+ op, rewriter, inputElemTy, outputElemTy, biasWeightShape);
2349+ if (failed (biasResult))
23462350 return rewriter.notifyMatchFailure (
23472351 op, " Failed to create bias tensor for none type." );
2348- bias = bias_result .value ();
2352+ bias = biasResult .value ();
23492353 } else {
23502354 if (!isa<RankedTensorType>(bias.getType ()))
23512355 return rewriter.notifyMatchFailure (
@@ -2372,8 +2376,8 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
23722376 m_TorchListOfConstantInts (padding_2d)))
23732377 return rewriter.notifyMatchFailure (op,
23742378 " non-const padding list unsupported" );
2375- // TOSA uses 4D padding {top, bottom, left, right} while Torch defines 2D
2376- // padding {height, width}. The Torch OFM computation uses 2*pad in each
2379+ // TOSA uses 4D padding {top, bottom, left, right} while PyTorch defines 2D
2380+ // padding {height, width}. The PyTorch OFM computation uses 2*pad in each
23772381 // spatial direction, implying the same top=bottom=height and left=right=width
23782382 // values for TOSA.
23792383 SmallVector<int64_t > padding (
@@ -2390,19 +2394,126 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
23902394 return rewriter.notifyMatchFailure (
23912395 op, " failed to get accumulator type for convolution ops" );
23922396
2397+ // Weight layout reference:
2398+ // Conv : PyTorch OIHW -> TOSA OHWI
2399+ // Depthwise : PyTorch OIHW* -> TOSA HWIM
2400+ // (PyTorch depthwise uses out_ch=in_ch*depth_multiplier)
2401+ // Grouped : PyTorch O(I/G)HW -> N/A
2402+ // Transposed : PyTorch IOHW -> TOSA OHWI
23932403 // TOSA works in NHWC and takes OHWI (conv) / HWIM (depthwise conv) weights.
23942404 // Perform the necessary transformations.
23952405 SmallVector<int32_t > nchwToNhwcDims ({0 , 2 , 3 , 1 });
2396- SmallVector<int64_t > transposedInputShape (
2397- {inputShape[0 ], inputShape[2 ], inputShape[3 ], inputShape[1 ]});
2406+ SmallVector<int32_t > nhwcToNchwDims ({0 , 3 , 1 , 2 });
2407+ SmallVector<int64_t , 4 > transposedInputShape;
2408+ for (int32_t dim : nchwToNhwcDims)
2409+ transposedInputShape.push_back (inputShape[dim]);
23982410 auto transposedInputType = RankedTensorType::get (
23992411 makeShapeLLVMCompatible (transposedInputShape), inputElemTy);
2400- auto transposedInput =
2401- tosa::TransposeOp::create (
2402- rewriter, op->getLoc (),
2403- getTypeConverter ()->convertType (transposedInputType), input,
2404- rewriter.getDenseI32ArrayAttr (nchwToNhwcDims))
2405- .getResult ();
2412+ auto createTransposedInput = [&]() {
2413+ return tosa::TransposeOp::create (
2414+ rewriter, op->getLoc (),
2415+ getTypeConverter ()->convertType (transposedInputType), input,
2416+ rewriter.getDenseI32ArrayAttr (nchwToNhwcDims))
2417+ .getResult ();
2418+ };
2419+
2420+ if (transposed) {
2421+ if (groups != 1 )
2422+ return rewriter.notifyMatchFailure (
2423+ op, " Unimplemented: grouped transposed convolution not supported by "
2424+ " TOSA" );
2425+ if (dilation[0 ] != 1 || dilation[1 ] != 1 )
2426+ return rewriter.notifyMatchFailure (
2427+ op, " Unimplemented: dilated transposed convolution not supported by "
2428+ " TOSA" );
2429+
2430+ SmallVector<int32_t > iohwToOhwi ({1 , 2 , 3 , 0 });
2431+
2432+ // TOSA 'out_pad' is a 4D array {top,bottom,left,right}.
2433+ // Map from PyTorch's (padding, output_padding):
2434+ // out_pad_total(H/W) = output_padding(H/W) - 2*padding(H/W)
2435+ // Negative values are allowed and will be handled by the TOSA
2436+ // decomposition.
2437+ SmallVector<int64_t , 2 > outPadding2D;
2438+ if (!matchPattern (adaptor.getOutputPadding (),
2439+ m_TorchListOfConstantInts (outPadding2D)))
2440+ return rewriter.notifyMatchFailure (
2441+ op, " non-const output_padding list unsupported for transposed conv" );
2442+
2443+ int64_t outPadH = outPadding2D[0 ] - 2 * padding_2d[0 ];
2444+ int64_t outPadW = outPadding2D[1 ] - 2 * padding_2d[1 ];
2445+ int64_t outPadTop = outPadH / 2 ;
2446+ int64_t outPadBottom = outPadH - outPadTop;
2447+ int64_t outPadLeft = outPadW / 2 ;
2448+ int64_t outPadRight = outPadW - outPadLeft;
2449+ SmallVector<int64_t , 4 > outPad (
2450+ {outPadTop, outPadBottom, outPadLeft, outPadRight});
2451+
2452+ Value nhwcInput = createTransposedInput ();
2453+ SmallVector<int64_t , 4 > ohwiWeightShape;
2454+ for (int32_t dim : iohwToOhwi)
2455+ ohwiWeightShape.push_back (weightShape[dim]);
2456+ auto ohwiWeightType = RankedTensorType::get (
2457+ makeShapeLLVMCompatible (ohwiWeightShape), weightElemTy);
2458+ Value transformedWeight =
2459+ tosa::TransposeOp::create (
2460+ rewriter, op->getLoc (),
2461+ getTypeConverter ()->convertType (ohwiWeightType), weight,
2462+ rewriter.getDenseI32ArrayAttr (iohwToOhwi))
2463+ .getResult ();
2464+
2465+ // Result type is NHWC (we'll transpose back).
2466+ auto outNCHW = makeShapeTorchCompatible (outputTy.getShape ());
2467+ SmallVector<int64_t , 4 > outNHWC;
2468+ for (int32_t dim : nchwToNhwcDims)
2469+ outNHWC.push_back (outNCHW[dim]);
2470+ auto transConvOpTy =
2471+ RankedTensorType::get (makeShapeLLVMCompatible (outNHWC), biasElemTy);
2472+
2473+ // Zero-points.
2474+ auto zps = tosa::createZPsAsConst (rewriter, input, weight);
2475+ Value inputZp = zps.first ? zps.first
2476+ : tosa::createZeroPointTensor (
2477+ rewriter, op->getLoc (), inputElemTy, 0 )
2478+ .value ();
2479+ Value weightZp = zps.second ? zps.second
2480+ : tosa::createZeroPointTensor (
2481+ rewriter, op->getLoc (), weightElemTy, 0 )
2482+ .value ();
2483+
2484+ Value convTOut = tosa::TransposeConv2DOp::create (
2485+ rewriter, op->getLoc (),
2486+ getTypeConverter ()->convertType (transConvOpTy),
2487+ nhwcInput, transformedWeight, bias, inputZp, weightZp,
2488+ rewriter.getDenseI64ArrayAttr (outPad),
2489+ rewriter.getDenseI64ArrayAttr (stride), accType)
2490+ .getResult ();
2491+
2492+ SmallVector<int64_t , 4 > transposedOutputShape;
2493+ for (int32_t dim : nhwcToNchwDims)
2494+ transposedOutputShape.push_back (outNHWC[dim]);
2495+ auto transposedOutputType = RankedTensorType::get (
2496+ makeShapeLLVMCompatible (transposedOutputShape), biasElemTy);
2497+ Value transposedOutput =
2498+ tosa::TransposeOp::create (
2499+ rewriter, op->getLoc (),
2500+ getTypeConverter ()->convertType (transposedOutputType), convTOut,
2501+ rewriter.getDenseI32ArrayAttr (nhwcToNchwDims))
2502+ .getResult ();
2503+
2504+ // Quantized rescale.
2505+ Value rescaledResult = transposedOutput;
2506+ if (isa<quant::QuantizedType>(inputElemTy)) {
2507+ rescaledResult = tosa::buildRescaleOpConvOutput (
2508+ rewriter, op, transposedOutput, inputTy, weightTy, outputTy);
2509+ }
2510+
2511+ // Final cast to requested output type.
2512+ rewriter.replaceOp (
2513+ op, {tosa::tosaCastTensorToType (rewriter, rescaledResult, outputTy)
2514+ .value ()});
2515+ return success ();
2516+ }
24062517
24072518 SmallVector<int64_t > transformedWeightShape;
24082519 RankedTensorType transformedWeightType;
@@ -2427,6 +2538,15 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
24272538 SmallVector<int32_t > transposedDims ({2 , 3 , 0 , 1 });
24282539 SmallVector<int64_t > transposedWeightShape = {
24292540 weightShape[2 ], weightShape[3 ], weightShape[0 ], weightShape[1 ]};
2541+
2542+ // reshape: HWO(I/G) -> HWIM
2543+ outputCDim = makeShapeTorchCompatible (outputTy.getShape ())[1 ];
2544+ if (outputCDim == kUnknownSize ) {
2545+ return rewriter.notifyMatchFailure (
2546+ op, " number of output channels must be statically known for "
2547+ " depthwise convolutions" );
2548+ }
2549+
24302550 auto transposedWeightType = RankedTensorType::get (
24312551 makeShapeLLVMCompatible (transposedWeightShape), weightElemTy);
24322552 auto transposedWeight =
@@ -2436,13 +2556,6 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
24362556 rewriter.getDenseI32ArrayAttr (transposedDims))
24372557 .getResult ();
24382558
2439- // reshape: HWO(I/G) -> HWIM
2440- outputCDim = makeShapeTorchCompatible (outputTy.getShape ())[1 ];
2441- if (outputCDim == kUnknownSize ) {
2442- return rewriter.notifyMatchFailure (
2443- op, " number of output channels must be statically known for "
2444- " depthwise convolutions" );
2445- }
24462559 transformedWeightShape = {
24472560 transposedWeightShape[0 ],
24482561 transposedWeightShape[1 ],
@@ -2463,6 +2576,8 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
24632576 llvm_unreachable (" Unhandled convolution type" );
24642577 }
24652578
2579+ Value transposedInput = createTransposedInput ();
2580+
24662581 int64_t outputHDim, outputWDim;
24672582 int64_t inputHDim = inputShape[2 ];
24682583 int64_t inputWDim = inputShape[3 ];
@@ -2485,7 +2600,7 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
24852600 if (remainderHDim != 0 ) {
24862601 if (remainderHDim > padding[1 ]) {
24872602 SmallVector<int64_t > startHSlice (inputTy.getRank (), 0 );
2488- SmallVector<int64_t > sizeHSlice (transposedInputShape);
2603+ SmallVector<int64_t , 4 > sizeHSlice (transposedInputShape);
24892604 // TOSA uses NHWC, so we will slice dim 1 for Height value
24902605 sizeHSlice[1 ] = inputHDim - (remainderHDim - padding[1 ]);
24912606 transposedInput = tosa::CreateOpAndInfer<tosa::SliceOp>(
@@ -2579,7 +2694,6 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
25792694 llvm_unreachable (" Unhandled convolution type" );
25802695 }
25812696
2582- SmallVector<int32_t > nhwcToNchwDims ({0 , 3 , 1 , 2 });
25832697 SmallVector<int64_t > transposedOutputShape (
25842698 {outputShape[0 ], outputShape[3 ], outputShape[1 ], outputShape[2 ]});
25852699 auto transposedOutputType = RankedTensorType::get (
0 commit comments