Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,13 @@ FailureOr<Value> getConvBiasForNoneType(Operation *op,
Type inputElemTy, Type outputElemTy,
ArrayRef<int64_t> weightShape);

// Emit an explicit zero-valued `tosa.pad` around an NHWC tensor so that later
// avg_pool lowering can run with `pad = 0`. `padExtents` is ordered as
// {top, bottom, left, right}. Returns the padded tensor value.
Value emitExplicitZeroPadNHWC(Location loc, PatternRewriter &rewriter,
Operation *op, Value inputNHWC,
ArrayRef<int64_t> padExtents);

} // namespace tosa
} // namespace mlir

Expand Down
74 changes: 56 additions & 18 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6123,7 +6123,8 @@ static LogicalResult getOutputTypeAndPoolingParameters(
AtenOpT op, ConversionPatternRewriter &rewriter, Value &inputXchw,
SmallVectorImpl<int64_t> &dilationArray, Type &outputTy,
DenseI64ArrayAttr &kernel, DenseI64ArrayAttr &stride,
DenseI64ArrayAttr &pad) {
DenseI64ArrayAttr &pad,
SmallVectorImpl<int64_t> *explicitNHWCPad = nullptr) {

RankedTensorType inputTy = cast<RankedTensorType>(inputXchw.getType());
if (!inputTy)
Expand Down Expand Up @@ -6163,21 +6164,43 @@ static LogicalResult getOutputTypeAndPoolingParameters(

if constexpr (std::is_same<AtenOpT, AtenAvgPool1dOp>() ||
std::is_same<AtenOpT, AtenAvgPool2dOp>()) {
// Currently, we can not represent `count_include_pad` with the existing
// TOSA AvgPool2d specification. Without the below check, we produce silent
// wrong answer (SWA) when the `count_include_pad` value is `true.`
//
// Note: We need to check for `count_include_pad` only when the `padding`
// value is non-zero.
// When count_include_pad=true with non-zero padding, we will materialize an
// explicit pad after transposing to NHWC. Track the padding extents and
// zero out the TOSA op padding so the divisor matches the full kernel size.
bool countIncludePad;
if ((paddingInts[0] != 0 || paddingInts[1] != 0) &&
(!matchPattern(op.getCountIncludePad(),
m_TorchConstantBool(&countIncludePad)) ||

countIncludePad)) {
return rewriter.notifyMatchFailure(
op, "Unsupported `count_include_pad` value, for tosa AvgPool "
"`count_include_pad` value should be `False`.");
if (!explicitNHWCPad)
return rewriter.notifyMatchFailure(
op, "Unsupported `count_include_pad` value, for tosa AvgPool "
"`count_include_pad` value should be `False`.");

// Remember the spatial padding so we can emit an NHWC tosa.pad right
// after the transpose.
explicitNHWCPad->assign(
{paddingInts[0], paddingInts[0], paddingInts[1], paddingInts[1]});

auto addPad = [](int64_t dim, int64_t before, int64_t after) -> int64_t {
if (ShapedType::isDynamic(dim))
return ShapedType::kDynamic;
return dim + before + after;
};

// Update the logical input type used for shape computations to include
// the extra zeros supplied by the explicit pad.
SmallVector<int64_t> paddedShape(inputTy.getShape().begin(),
inputTy.getShape().end());
// Height stored at rank-2, width at rank-1 for NCHW shapes.
paddedShape[inputRank - 2] =
addPad(paddedShape[inputRank - 2], paddingInts[0], paddingInts[0]);
paddedShape[inputRank - 1] =
addPad(paddedShape[inputRank - 1], paddingInts[1], paddingInts[1]);
inputTy = RankedTensorType::get(paddedShape, inputTy.getElementType());

paddingInts.assign(/*Count=*/2, /*Value=*/0);
}
}

Expand Down Expand Up @@ -6321,15 +6344,23 @@ class ConvertAtenAvgPool2dOp
}

SmallVector<int64_t, 2> dilationArray{1, 1};
SmallVector<int64_t, 4> explicitNHWCPad;
if (failed(getOutputTypeAndPoolingParameters<AtenAvgPool2dOp,
tosa::AvgPool2dOp>(
op, rewriter, self, dilationArray, outputTy, kernel, stride, pad)))
op, rewriter, self, dilationArray, outputTy, kernel, stride, pad,
&explicitNHWCPad)))
return rewriter.notifyMatchFailure(
op, "invalid pooling parameters or input type");

// Transpose to xHWC
input = ConvertAtenPoolingBaseOp<AtenAvgPool2dOp, tosa::AvgPool2dOp>::
transposePoolingInputToHwc(op, rewriter, self);
Value transposed =
ConvertAtenPoolingBaseOp<AtenAvgPool2dOp, tosa::AvgPool2dOp>::
transposePoolingInputToHwc(op, rewriter, self);

if (!explicitNHWCPad.empty())
transposed = tosa::emitExplicitZeroPadNHWC(op->getLoc(), rewriter, op,
transposed, explicitNHWCPad);

input = transposed;

return success();
}
Expand Down Expand Up @@ -6372,16 +6403,23 @@ class ConvertAtenAvgPool1dOp
.getResult();

SmallVector<int64_t, 2> dilationArray{1, 1};
SmallVector<int64_t, 4> explicitNHWCPad;
if (failed(getOutputTypeAndPoolingParameters<AtenAvgPool1dOp,
tosa::AvgPool2dOp>(
op, rewriter, reshapedSelf, dilationArray, outputTy, kernel, stride,
pad)))
pad, &explicitNHWCPad)))
return rewriter.notifyMatchFailure(
op, "invalid pooling parameters or input type");

// Transpose to xHWC
input = ConvertAtenPoolingBaseOp<AtenAvgPool1dOp, tosa::AvgPool2dOp>::
transposePoolingInputToHwc(op, rewriter, reshapedSelf);
Value transposed =
ConvertAtenPoolingBaseOp<AtenAvgPool1dOp, tosa::AvgPool2dOp>::
transposePoolingInputToHwc(op, rewriter, reshapedSelf);

if (!explicitNHWCPad.empty())
transposed = tosa::emitExplicitZeroPadNHWC(op->getLoc(), rewriter, op,
transposed, explicitNHWCPad);

input = transposed;

return success();
}
Expand Down
38 changes: 38 additions & 0 deletions lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -624,5 +624,43 @@ FailureOr<Value> getConvBiasForNoneType(Operation *op,
}
}

Value emitExplicitZeroPadNHWC(Location loc, PatternRewriter &rewriter,
Operation *op, Value inputNHWC,
ArrayRef<int64_t> padExtents) {
assert(padExtents.size() == 4 && "expected [top, bottom, left, right]");

if (llvm::all_of(padExtents, [](int64_t v) { return v == 0; }))
return inputNHWC;

SmallVector<int64_t, 8> nhwcPadding = {
0, 0, padExtents[0], padExtents[1], padExtents[2], padExtents[3], 0, 0};
Value nhwcPadShape = tosa::getTosaConstShape(rewriter, loc, nhwcPadding);

auto inputTy = cast<RankedTensorType>(inputNHWC.getType());
SmallVector<int64_t, 4> resultShape(inputTy.getShape().begin(),
inputTy.getShape().end());
auto addPad = [](int64_t dim, int64_t before, int64_t after) -> int64_t {
if (ShapedType::isDynamic(dim))
return ShapedType::kDynamic;
return dim + before + after;
};
resultShape[1] = addPad(resultShape[1], padExtents[0], padExtents[1]);
resultShape[2] = addPad(resultShape[2], padExtents[2], padExtents[3]);

auto resultTy = RankedTensorType::get(resultShape, inputTy.getElementType());

Type elemTy = inputTy.getElementType();
Value padConst;
if (isa<mlir::FloatType>(elemTy)) {
padConst = *getConstTensor<float>(rewriter, op, {0.0f}, {1}, elemTy);
} else {
padConst = *getConstTensor<int32_t>(rewriter, op, {0}, {1}, elemTy);
}

return tosa::PadOp::create(rewriter, loc, resultTy, inputNHWC, nhwcPadShape,
padConst)
.getResult();
}

} // namespace tosa
} // namespace mlir
8 changes: 0 additions & 8 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3532,7 +3532,6 @@
"AtenSymConstrainRangeForSize_basic",
"AtenSymConstrainRange_basic",
"Aten_AssertScalar_basic",
"AvgPool2dSingleIntTupleParamsIncludePadModule_basic",
"ScatterAddDynamicModule_basic",
"UniformModule_basic",
"UniformStaticShapeModule_basic",
Expand Down Expand Up @@ -3655,21 +3654,14 @@
"AtenTopKModule_basic",
"AtenTopKSmallestModule_basic",
"Aten_EmbeddingBagExample_basic",
"AvgPool1dFloatModule_basic",
"AvgPool1dIntModule_basic",
"AvgPool1dStaticModule_basic",
"AvgPool2dCeilModeTrueModule_basic",
"AvgPool1dNoPadCeilPadNotIncluded_basic",
"AvgPool1dPadCeilPadNotIncluded_basic",
"AvgPool2dCeilPaddingStridedIncludePadding_basic",
"AvgPool2dCeilPaddingUnitaryStrideIncludePadding_basic",
"AvgPool2dFloorPaddingUnitaryStrideIncludePadding_basic",
"AvgPool3dDiffKernelsStridesNoPadCeilPadNotIncluded_basic",
"AvgPool3dDiffKernelsStridesPadCeilPadNotIncluded_basic",
"AvgPool2dDivisorOverrideModule_basic",
"AvgPool2dFloatModule_basic",
"AvgPool2dIntModule_basic",
"AvgPool2dStaticModule_basic",
"BernoulliFloatModule_basic",
"BernoulliPModule_basic",
"BernoulliTensorModule_basic",
Expand Down
112 changes: 79 additions & 33 deletions test/Conversion/TorchToTosa/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2307,24 +2307,6 @@ func.func @torch.aten.round(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtenso

// -----

func.func @torch.aten.avg_pool2d.count_include_pad_unsupported_value(%arg0: !torch.vtensor<[1,192,35,35],f32>) -> !torch.vtensor<[1,192,35,35],f32> {
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%int3 = torch.constant.int 3
%false= torch.constant.bool false
%count_include_pad = torch.constant.bool true
%divisor_override = torch.constant.none

%0 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
%1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// expected-error @+1 {{failed to legalize operation 'torch.aten.avg_pool2d' that was explicitly marked illegal}}
%3 = torch.aten.avg_pool2d %arg0, %0, %1, %2, %false, %count_include_pad, %divisor_override : !torch.vtensor<[1,192,35,35],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,192,35,35],f32>
return %3 : !torch.vtensor<[1,192,35,35],f32>
}

// -----

func.func @torch.aten.avg_pool2d.divisor_override_unsupported_value(%arg0: !torch.vtensor<[1,192,35,35],f32>) -> !torch.vtensor<[1,192,35,35],f32> {
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
Expand Down Expand Up @@ -2844,21 +2826,6 @@ func.func @torch.prims.collapse$basic(%arg0: !torch.vtensor<[2,3,4],f32>) -> !to

// -----

func.func @torch.aten.avg_pool1d.count_include_pad_unsupported_value(%arg0: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> {
%int1 = torch.constant.int 1
%int3 = torch.constant.int 3
%false = torch.constant.bool false
%count_include_pad = torch.constant.bool true
%0 = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list<int>
%1 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
%2 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
// expected-error @+1 {{failed to legalize operation 'torch.aten.avg_pool1d' that was explicitly marked illegal}}
%3 = torch.aten.avg_pool1d %arg0, %0, %1, %2, %false, %count_include_pad : !torch.vtensor<[1,512,10],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool -> !torch.vtensor<[1,512,10],f32>
return %3 : !torch.vtensor<[1,512,10],f32>
}

// -----

// CHECK-LABEL: func.func @torch.aten.reflection_pad1d$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,2,4],f32>) -> !torch.vtensor<[1,2,8],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,2,4],f32> -> tensor<1x2x4xf32>
Expand Down Expand Up @@ -4384,3 +4351,82 @@ func.func @torch.aten.empty.memory_format() -> !torch.vtensor<[1,0,256],f32>{
%out = torch.aten.empty.memory_format %2452, %none, %none, %cpu, %false, %none : !torch.list<int>, !torch.none, !torch.none, !torch.Device, !torch.bool, !torch.none -> !torch.vtensor<[1,0,256],f32>
return %out : !torch.vtensor<[1,0,256],f32>
}

// -----
// CHECK-LABEL: func.func @torch.aten.avg_pool2d.count_include_pad(
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1,192,35,35],f32>) -> !torch.vtensor<[1,192,35,35],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,192,35,35],f32> -> tensor<1x192x35x35xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.int 0
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1
// CHECK: %[[VAL_4:.*]] = torch.constant.int 3
// CHECK: %[[VAL_5:.*]] = torch.constant.bool false
// CHECK: %[[VAL_6:.*]] = torch.constant.bool true
// CHECK: %[[VAL_7:.*]] = torch.constant.none
// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_11:.*]] = tosa.transpose %[[VAL_1]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<1x192x35x35xf32>) -> tensor<1x35x35x192xf32>
// CHECK: %[[VAL_12:.*]] = tosa.const_shape {values = dense<[0, 0, 1, 1, 1, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
// CHECK: %[[VAL_13:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
// CHECK: %[[VAL_14:.*]] = tosa.pad %[[VAL_11]], %[[VAL_12]], %[[VAL_13]] : (tensor<1x35x35x192xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x37x37x192xf32>
// CHECK: %[[VAL_15:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
// CHECK: %[[VAL_16:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
// CHECK: %[[VAL_17:.*]] = tosa.avg_pool2d %[[VAL_14]], %[[VAL_15]], %[[VAL_16]] {acc_type = f32, kernel = array<i64: 3, 3>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x37x37x192xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x35x35x192xf32>
// CHECK: %[[VAL_18:.*]] = tosa.transpose %[[VAL_17]] {perms = array<i32: 0, 3, 1, 2>} : (tensor<1x35x35x192xf32>) -> tensor<1x192x35x35xf32>
// CHECK: %[[VAL_19:.*]] = tensor.cast %[[VAL_18]] : tensor<1x192x35x35xf32> to tensor<1x192x35x35xf32>
// CHECK: %[[VAL_20:.*]] = torch_c.from_builtin_tensor %[[VAL_19]] : tensor<1x192x35x35xf32> -> !torch.vtensor<[1,192,35,35],f32>
// CHECK: return %[[VAL_20]] : !torch.vtensor<[1,192,35,35],f32>
// CHECK: }
func.func @torch.aten.avg_pool2d.count_include_pad(%arg0: !torch.vtensor<[1,192,35,35],f32>) -> !torch.vtensor<[1,192,35,35],f32> {
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%int3 = torch.constant.int 3
%false= torch.constant.bool false
%count_include_pad = torch.constant.bool true
%divisor_override = torch.constant.none

%0 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
%1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%3 = torch.aten.avg_pool2d %arg0, %0, %1, %2, %false, %count_include_pad, %divisor_override : !torch.vtensor<[1,192,35,35],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,192,35,35],f32>
return %3 : !torch.vtensor<[1,192,35,35],f32>
}

// -----
// CHECK-LABEL: func.func @torch.aten.avg_pool1d.count_include_pad(
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,512,10],f32> -> tensor<1x512x10xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.int 1
// CHECK: %[[VAL_3:.*]] = torch.constant.int 3
// CHECK: %[[VAL_4:.*]] = torch.constant.bool false
// CHECK: %[[VAL_5:.*]] = torch.constant.bool true
// CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_3]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[VAL_9:.*]] = tosa.const_shape {values = dense<[1, 512, 10, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_1]], %[[VAL_9]] : (tensor<1x512x10xf32>, !tosa.shape<4>) -> tensor<1x512x10x1xf32>
// CHECK: %[[VAL_11:.*]] = tosa.transpose %[[VAL_10]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<1x512x10x1xf32>) -> tensor<1x10x1x512xf32>
// CHECK: %[[VAL_12:.*]] = tosa.const_shape {values = dense<[0, 0, 1, 1, 0, 0, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
// CHECK: %[[VAL_13:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
// CHECK: %[[VAL_14:.*]] = tosa.pad %[[VAL_11]], %[[VAL_12]], %[[VAL_13]] : (tensor<1x10x1x512xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x12x1x512xf32>
// CHECK: %[[VAL_15:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
// CHECK: %[[VAL_16:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
// CHECK: %[[VAL_17:.*]] = tosa.avg_pool2d %[[VAL_14]], %[[VAL_15]], %[[VAL_16]] {acc_type = f32, kernel = array<i64: 3, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x12x1x512xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x10x1x512xf32>
// CHECK: %[[VAL_18:.*]] = tosa.transpose %[[VAL_17]] {perms = array<i32: 0, 3, 1, 2>} : (tensor<1x10x1x512xf32>) -> tensor<1x512x10x1xf32>
// CHECK: %[[VAL_19:.*]] = tosa.const_shape {values = dense<[1, 512, 10]> : tensor<3xindex>} : () -> !tosa.shape<3>
// CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_18]], %[[VAL_19]] : (tensor<1x512x10x1xf32>, !tosa.shape<3>) -> tensor<1x512x10xf32>
// CHECK: %[[VAL_21:.*]] = tensor.cast %[[VAL_20]] : tensor<1x512x10xf32> to tensor<1x512x10xf32>
// CHECK: %[[VAL_22:.*]] = torch_c.from_builtin_tensor %[[VAL_21]] : tensor<1x512x10xf32> -> !torch.vtensor<[1,512,10],f32>
// CHECK: return %[[VAL_22]] : !torch.vtensor<[1,512,10],f32>
// CHECK: }
func.func @torch.aten.avg_pool1d.count_include_pad(%arg0: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> {
%int1 = torch.constant.int 1
%int3 = torch.constant.int 3
%false = torch.constant.bool false
%count_include_pad = torch.constant.bool true
%0 = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list<int>
%1 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
%2 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
%3 = torch.aten.avg_pool1d %arg0, %0, %1, %2, %false, %count_include_pad : !torch.vtensor<[1,512,10],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool -> !torch.vtensor<[1,512,10],f32>
return %3 : !torch.vtensor<[1,512,10],f32>
}
Loading