Skip to content

Commit e61b71a

Browse files
[TOSA] Add transposed conv support (#4360)
1 parent c180509 commit e61b71a

File tree

5 files changed

+186
-62
lines changed

5 files changed

+186
-62
lines changed

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 140 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2306,9 +2306,6 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
23062306
if (!matchPattern(op.getTransposed(), m_TorchConstantBool(&transposed)))
23072307
return rewriter.notifyMatchFailure(
23082308
op, "Unimplemented: non-constant value for transposed not supported");
2309-
if (transposed)
2310-
return rewriter.notifyMatchFailure(
2311-
op, "Unimplemented: transposed convolution not supported");
23122309

23132310
auto input = adaptor.getInput();
23142311
auto weight = adaptor.getWeight();
@@ -2340,12 +2337,19 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
23402337
auto bias = adaptor.getBias();
23412338

23422339
if (isa<Torch::NoneType>(bias.getType())) {
2343-
auto bias_result = tosa::getConvBiasForNoneType(op, rewriter, inputElemTy,
2344-
outputElemTy, weightShape);
2345-
if (failed(bias_result))
2340+
// ConvTranspose weights use IOHW; the helper expects OIHW, so swap
2341+
// dims 0/1 before we synthesize the bias.
2342+
SmallVector<int64_t, 4> biasWeightShape =
2343+
transposed ? SmallVector<int64_t, 4>{weightShape[1], weightShape[0],
2344+
weightShape[2], weightShape[3]}
2345+
: weightShape;
2346+
2347+
auto biasResult = tosa::getConvBiasForNoneType(
2348+
op, rewriter, inputElemTy, outputElemTy, biasWeightShape);
2349+
if (failed(biasResult))
23462350
return rewriter.notifyMatchFailure(
23472351
op, "Failed to create bias tensor for none type.");
2348-
bias = bias_result.value();
2352+
bias = biasResult.value();
23492353
} else {
23502354
if (!isa<RankedTensorType>(bias.getType()))
23512355
return rewriter.notifyMatchFailure(
@@ -2372,8 +2376,8 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
23722376
m_TorchListOfConstantInts(padding_2d)))
23732377
return rewriter.notifyMatchFailure(op,
23742378
"non-const padding list unsupported");
2375-
// TOSA uses 4D padding {top, bottom, left, right} while Torch defines 2D
2376-
// padding {height, width}. The Torch OFM computation uses 2*pad in each
2379+
// TOSA uses 4D padding {top, bottom, left, right} while PyTorch defines 2D
2380+
// padding {height, width}. The PyTorch OFM computation uses 2*pad in each
23772381
// spatial direction, implying the same top=bottom=height and left=right=width
23782382
// values for TOSA.
23792383
SmallVector<int64_t> padding(
@@ -2390,19 +2394,126 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
23902394
return rewriter.notifyMatchFailure(
23912395
op, "failed to get accumulator type for convolution ops");
23922396

2397+
// Weight layout reference:
2398+
// Conv : PyTorch OIHW -> TOSA OHWI
2399+
// Depthwise : PyTorch OIHW* -> TOSA HWIM
2400+
// (PyTorch depthwise uses out_ch=in_ch*depth_multiplier)
2401+
// Grouped : PyTorch O(I/G)HW -> N/A
2402+
// Transposed : PyTorch IOHW -> TOSA OHWI
23932403
// TOSA works in NHWC and takes OHWI (conv) / HWIM (depthwise conv) weights.
23942404
// Perform the necessary transformations.
23952405
SmallVector<int32_t> nchwToNhwcDims({0, 2, 3, 1});
2396-
SmallVector<int64_t> transposedInputShape(
2397-
{inputShape[0], inputShape[2], inputShape[3], inputShape[1]});
2406+
SmallVector<int32_t> nhwcToNchwDims({0, 3, 1, 2});
2407+
SmallVector<int64_t, 4> transposedInputShape;
2408+
for (int32_t dim : nchwToNhwcDims)
2409+
transposedInputShape.push_back(inputShape[dim]);
23982410
auto transposedInputType = RankedTensorType::get(
23992411
makeShapeLLVMCompatible(transposedInputShape), inputElemTy);
2400-
auto transposedInput =
2401-
tosa::TransposeOp::create(
2402-
rewriter, op->getLoc(),
2403-
getTypeConverter()->convertType(transposedInputType), input,
2404-
rewriter.getDenseI32ArrayAttr(nchwToNhwcDims))
2405-
.getResult();
2412+
auto createTransposedInput = [&]() {
2413+
return tosa::TransposeOp::create(
2414+
rewriter, op->getLoc(),
2415+
getTypeConverter()->convertType(transposedInputType), input,
2416+
rewriter.getDenseI32ArrayAttr(nchwToNhwcDims))
2417+
.getResult();
2418+
};
2419+
2420+
if (transposed) {
2421+
if (groups != 1)
2422+
return rewriter.notifyMatchFailure(
2423+
op, "Unimplemented: grouped transposed convolution not supported by "
2424+
"TOSA");
2425+
if (dilation[0] != 1 || dilation[1] != 1)
2426+
return rewriter.notifyMatchFailure(
2427+
op, "Unimplemented: dilated transposed convolution not supported by "
2428+
"TOSA");
2429+
2430+
SmallVector<int32_t> iohwToOhwi({1, 2, 3, 0});
2431+
2432+
// TOSA 'out_pad' is a 4D array {top,bottom,left,right}.
2433+
// Map from PyTorch's (padding, output_padding):
2434+
// out_pad_total(H/W) = output_padding(H/W) - 2*padding(H/W)
2435+
// Negative values are allowed and will be handled by the TOSA
2436+
// decomposition.
2437+
SmallVector<int64_t, 2> outPadding2D;
2438+
if (!matchPattern(adaptor.getOutputPadding(),
2439+
m_TorchListOfConstantInts(outPadding2D)))
2440+
return rewriter.notifyMatchFailure(
2441+
op, "non-const output_padding list unsupported for transposed conv");
2442+
2443+
int64_t outPadH = outPadding2D[0] - 2 * padding_2d[0];
2444+
int64_t outPadW = outPadding2D[1] - 2 * padding_2d[1];
2445+
int64_t outPadTop = outPadH / 2;
2446+
int64_t outPadBottom = outPadH - outPadTop;
2447+
int64_t outPadLeft = outPadW / 2;
2448+
int64_t outPadRight = outPadW - outPadLeft;
2449+
SmallVector<int64_t, 4> outPad(
2450+
{outPadTop, outPadBottom, outPadLeft, outPadRight});
2451+
2452+
Value nhwcInput = createTransposedInput();
2453+
SmallVector<int64_t, 4> ohwiWeightShape;
2454+
for (int32_t dim : iohwToOhwi)
2455+
ohwiWeightShape.push_back(weightShape[dim]);
2456+
auto ohwiWeightType = RankedTensorType::get(
2457+
makeShapeLLVMCompatible(ohwiWeightShape), weightElemTy);
2458+
Value transformedWeight =
2459+
tosa::TransposeOp::create(
2460+
rewriter, op->getLoc(),
2461+
getTypeConverter()->convertType(ohwiWeightType), weight,
2462+
rewriter.getDenseI32ArrayAttr(iohwToOhwi))
2463+
.getResult();
2464+
2465+
// Result type is NHWC (we'll transpose back).
2466+
auto outNCHW = makeShapeTorchCompatible(outputTy.getShape());
2467+
SmallVector<int64_t, 4> outNHWC;
2468+
for (int32_t dim : nchwToNhwcDims)
2469+
outNHWC.push_back(outNCHW[dim]);
2470+
auto transConvOpTy =
2471+
RankedTensorType::get(makeShapeLLVMCompatible(outNHWC), biasElemTy);
2472+
2473+
// Zero-points.
2474+
auto zps = tosa::createZPsAsConst(rewriter, input, weight);
2475+
Value inputZp = zps.first ? zps.first
2476+
: tosa::createZeroPointTensor(
2477+
rewriter, op->getLoc(), inputElemTy, 0)
2478+
.value();
2479+
Value weightZp = zps.second ? zps.second
2480+
: tosa::createZeroPointTensor(
2481+
rewriter, op->getLoc(), weightElemTy, 0)
2482+
.value();
2483+
2484+
Value convTOut = tosa::TransposeConv2DOp::create(
2485+
rewriter, op->getLoc(),
2486+
getTypeConverter()->convertType(transConvOpTy),
2487+
nhwcInput, transformedWeight, bias, inputZp, weightZp,
2488+
rewriter.getDenseI64ArrayAttr(outPad),
2489+
rewriter.getDenseI64ArrayAttr(stride), accType)
2490+
.getResult();
2491+
2492+
SmallVector<int64_t, 4> transposedOutputShape;
2493+
for (int32_t dim : nhwcToNchwDims)
2494+
transposedOutputShape.push_back(outNHWC[dim]);
2495+
auto transposedOutputType = RankedTensorType::get(
2496+
makeShapeLLVMCompatible(transposedOutputShape), biasElemTy);
2497+
Value transposedOutput =
2498+
tosa::TransposeOp::create(
2499+
rewriter, op->getLoc(),
2500+
getTypeConverter()->convertType(transposedOutputType), convTOut,
2501+
rewriter.getDenseI32ArrayAttr(nhwcToNchwDims))
2502+
.getResult();
2503+
2504+
// Quantized rescale.
2505+
Value rescaledResult = transposedOutput;
2506+
if (isa<quant::QuantizedType>(inputElemTy)) {
2507+
rescaledResult = tosa::buildRescaleOpConvOutput(
2508+
rewriter, op, transposedOutput, inputTy, weightTy, outputTy);
2509+
}
2510+
2511+
// Final cast to requested output type.
2512+
rewriter.replaceOp(
2513+
op, {tosa::tosaCastTensorToType(rewriter, rescaledResult, outputTy)
2514+
.value()});
2515+
return success();
2516+
}
24062517

24072518
SmallVector<int64_t> transformedWeightShape;
24082519
RankedTensorType transformedWeightType;
@@ -2427,6 +2538,15 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
24272538
SmallVector<int32_t> transposedDims({2, 3, 0, 1});
24282539
SmallVector<int64_t> transposedWeightShape = {
24292540
weightShape[2], weightShape[3], weightShape[0], weightShape[1]};
2541+
2542+
// reshape: HWO(I/G) -> HWIM
2543+
outputCDim = makeShapeTorchCompatible(outputTy.getShape())[1];
2544+
if (outputCDim == kUnknownSize) {
2545+
return rewriter.notifyMatchFailure(
2546+
op, "number of output channels must be statically known for "
2547+
"depthwise convolutions");
2548+
}
2549+
24302550
auto transposedWeightType = RankedTensorType::get(
24312551
makeShapeLLVMCompatible(transposedWeightShape), weightElemTy);
24322552
auto transposedWeight =
@@ -2436,13 +2556,6 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
24362556
rewriter.getDenseI32ArrayAttr(transposedDims))
24372557
.getResult();
24382558

2439-
// reshape: HWO(I/G) -> HWIM
2440-
outputCDim = makeShapeTorchCompatible(outputTy.getShape())[1];
2441-
if (outputCDim == kUnknownSize) {
2442-
return rewriter.notifyMatchFailure(
2443-
op, "number of output channels must be statically known for "
2444-
"depthwise convolutions");
2445-
}
24462559
transformedWeightShape = {
24472560
transposedWeightShape[0],
24482561
transposedWeightShape[1],
@@ -2463,6 +2576,8 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
24632576
llvm_unreachable("Unhandled convolution type");
24642577
}
24652578

2579+
Value transposedInput = createTransposedInput();
2580+
24662581
int64_t outputHDim, outputWDim;
24672582
int64_t inputHDim = inputShape[2];
24682583
int64_t inputWDim = inputShape[3];
@@ -2485,7 +2600,7 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
24852600
if (remainderHDim != 0) {
24862601
if (remainderHDim > padding[1]) {
24872602
SmallVector<int64_t> startHSlice(inputTy.getRank(), 0);
2488-
SmallVector<int64_t> sizeHSlice(transposedInputShape);
2603+
SmallVector<int64_t, 4> sizeHSlice(transposedInputShape);
24892604
// TOSA uses NHWC, so we will slice dim 1 for Height value
24902605
sizeHSlice[1] = inputHDim - (remainderHDim - padding[1]);
24912606
transposedInput = tosa::CreateOpAndInfer<tosa::SliceOp>(
@@ -2579,7 +2694,6 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
25792694
llvm_unreachable("Unhandled convolution type");
25802695
}
25812696

2582-
SmallVector<int32_t> nhwcToNchwDims({0, 3, 1, 2});
25832697
SmallVector<int64_t> transposedOutputShape(
25842698
{outputShape[0], outputShape[3], outputShape[1], outputShape[2]});
25852699
auto transposedOutputType = RankedTensorType::get(

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3582,7 +3582,6 @@
35823582
"AvgPool3dCountIncludePadFalseWithoutPadding_basic",
35833583
"Conv_Transpose1dModule_basic",
35843584
"Conv_Transpose1dStaticModule_basic",
3585-
"Conv_Transpose2dStaticModule_basic",
35863585
"Conv_Transpose3dModule_basic",
35873586
"Conv_Transpose3dStaticModule_basic",
35883587
"IndexPutWithNoneAndBroadcastModule_basic",
@@ -3707,16 +3706,11 @@
37073706
"Conv3dWithValidPaddingModule_basic",
37083707
"ConvTbcModule_basic",
37093708
"ConvTranspose2DQInt8_basic",
3710-
"Conv_Transpose2dModule_basic",
37113709
"ConvolutionBackwardModule2DPadded_basic",
3712-
"ConvolutionBackwardModule2DStatic_basic",
37133710
"ConvolutionBackwardModule2DStrided_basic",
37143711
"ConvolutionBackwardModule2D_basic",
37153712
"ConvolutionModule2DGroups_basic",
37163713
"ConvolutionModule2DTransposeNonUnitOutputPadding_basic",
3717-
"ConvolutionModule2DTransposeStridedStatic_basic",
3718-
"ConvolutionModule2DTransposeStrided_basic",
3719-
"ConvolutionModule2DTranspose_basic",
37203714
"ConvolutionModule2DGroupedTranspose_basic",
37213715
"ConvolutionModule3DGroups_basic",
37223716
"ConvolutionModule3DGroupsStrided_basic",

projects/pt1/python/torch_mlir_e2e_test/tosa_backends/linalg_on_tensors.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
# that depend on TOSA as well as TOSA-to-Standard.
3030
"tosa-to-arith",
3131
"tosa-to-scf",
32+
# Required for transposed convolution support (decomposes to conv ops).
33+
"tosa-optional-decompositions",
3234
# Named ops must be legalized prior to general tosa-to-linalg
3335
"tosa-to-linalg-named",
3436
# TOSA-to-LinAlg may generate tosa.const() ops, so we want to lower them

0 commit comments

Comments
 (0)