@@ -30,6 +30,7 @@ void set_all_to_value(CTYPE* out_data, size_t step_len, CTYPE value) {
3030
3131template <typename CTYPE>
3232void apply_padding_to_dim (
33+ KernelRuntimeContext& ctx,
3334 size_t ndim,
3435 const CTYPE* self_data,
3536 IntArrayRef self_sizes,
@@ -57,7 +58,20 @@ void apply_padding_to_dim(
5758 size_t out_step_len = out_strides[dim];
5859 size_t in_step_len = self_strides[dim];
5960
60- for ([[maybe_unused]] const auto i : c10::irange (pad_before)) {
61+ // Do not copy padding beyond the out tensor bounds.
62+ if (pad_before > 0 ) {
63+ size_t numel = 1 ;
64+ for (ET_UNUSED const auto i : c10::irange (out_sizes.size ())) {
65+ numel *= out_sizes[i];
66+ }
67+ ET_KERNEL_CHECK_MSG (
68+ ctx,
69+ numel >= pad_before * out_step_len,
70+ InvalidArgument,
71+ /* void */ ,
72+ " Out tensor is too small for the requested padding." );
73+ }
74+ for (ET_UNUSED const auto i : c10::irange (pad_before)) {
6175 set_all_to_value (out_data, out_step_len, value);
6276 out_data += out_step_len;
6377 }
@@ -76,8 +90,9 @@ void apply_padding_to_dim(
7690 }
7791 // Otherwise, call this function recursively
7892 else {
79- for ([[maybe_unused]] const auto i : c10::irange (self_sizes[dim])) {
93+ for (ET_UNUSED const auto i : c10::irange (self_sizes[dim])) {
8094 apply_padding_to_dim (
95+ ctx,
8196 ndim,
8297 self_data,
8398 self_sizes,
@@ -95,14 +110,28 @@ void apply_padding_to_dim(
95110 }
96111 }
97112
98- for ([[maybe_unused]] const auto i : c10::irange (pad_after)) {
113+ // Do not copy padding beyond the out tensor bounds.
114+ if (pad_after > 0 ) {
115+ size_t numel = 1 ;
116+ for (ET_UNUSED const auto i : c10::irange (out_sizes.size ())) {
117+ numel *= out_sizes[i];
118+ }
119+ ET_KERNEL_CHECK_MSG (
120+ ctx,
121+ numel >= pad_after * out_step_len,
122+ InvalidArgument,
123+ /* void */ ,
124+ " Out tensor is too small for the requested padding." );
125+ }
126+ for (ET_UNUSED const auto i : c10::irange (pad_after)) {
99127 set_all_to_value (out_data, out_step_len, value);
100128 out_data += out_step_len;
101129 }
102130}
103131
104132template <typename CTYPE>
105133void constant_pad_nd_out_impl (
134+ KernelRuntimeContext& ctx,
106135 const Tensor& self,
107136 IntArrayRef pad,
108137 CTYPE value_v,
@@ -145,6 +174,7 @@ void constant_pad_nd_out_impl(
145174 IntArrayRef out_strides_ref (out_strides, ndim);
146175
147176 apply_padding_to_dim (
177+ ctx,
148178 ndim,
149179 self_data,
150180 self_sizes_ref,
@@ -192,7 +222,7 @@ Tensor& constant_pad_nd_out(
192222 utils::internal::check_overflow_scalar_cast<CTYPE>(value);
193223 ET_KERNEL_CHECK (ctx, opt_value_casted.has_value (), InvalidArgument, );
194224 auto value_casted = opt_value_casted.value ();
195- constant_pad_nd_out_impl<CTYPE>(in, pad, value_casted, out);
225+ constant_pad_nd_out_impl<CTYPE>(ctx, in, pad, value_casted, out);
196226 });
197227
198228 return out;
0 commit comments