1313#include " torch_ipex/csrc/utils.h"
1414#include " dbl/Common.h"
1515#include " dbl/Conv.h"
16+ #include " dbl/Deconv.h"
1617#include " dbl/Pool.h"
1718#include " dbl/DNNLChecker.h"
1819#include " ShadeDataContext.h"
@@ -60,11 +61,11 @@ at::Tensor AtenIpexCPUDev::dil_convolution(
6061 }
6162
6263 dbl::comm::reorder_to_bf16_for_mix_prec (weight);
63- dbl::conv::prepack_conv_weights (input, dil_input,
64+ dbl::conv::prepack_conv_weights (input, dil_input,
6465 weight, stride, padding, dilation, groups);
6566 dil_weight = dbl::comm::try_gen_dil_tensor (weight);
6667
67- dil::tensor dil_output = dbl::conv::conv2d_impl (
68+ dil::tensor dil_output = dbl::conv::convolution_impl (
6869 dil_input,
6970 dil_weight,
7071 dil_bias,
@@ -172,6 +173,53 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> AtenIpexCPUDev::dil_convolution_bac
172173 return std::make_tuple (grad_input, grad_weight, grad_bias);
173174}
174175
176+ at::Tensor AtenIpexCPUDev::dil_deconvolution (
177+ const at::Tensor & input,
178+ const at::Tensor & weight,
179+ const at::Tensor & bias,
180+ at::IntArrayRef padding,
181+ at::IntArrayRef output_padding,
182+ at::IntArrayRef stride,
183+ at::IntArrayRef dilation,
184+ int64_t groups) {
185+ DEBUG (" AtenIpexCPUDev::dil_deconvolution\n " );
186+ dil::tensor dil_input;
187+ dil::tensor dil_weight;
188+ c10::optional<dil::tensor> dil_bias{c10::nullopt };
189+
190+ CHECK_DNNL_OP_PRE_COND (input);
191+ CHECK_DNNL_OP_PRE_COND (weight);
192+
193+ dbl::comm::reorder_to_bf16_for_mix_prec (input);
194+ dil_input = dbl::comm::try_gen_dil_tensor (input);
195+
196+ if (bias.defined ()) {
197+ CHECK_DNNL_OP_PRE_COND (bias);
198+ dbl::comm::reorder_to_bf16_for_mix_prec (bias);
199+ dil_bias = dbl::comm::try_gen_dil_tensor (bias);
200+ }
201+
202+ dbl::comm::reorder_to_bf16_for_mix_prec (weight);
203+
204+ // TODO
205+ // dbl::deconv::prepack_deconv_weights(input, dil_input,
206+ // weight, stride, padding, dilation, groups);
207+
208+ dil_weight = dbl::comm::try_gen_dil_tensor (weight).transpose_ (0 , 1 );
209+
210+ dil::tensor dil_output = dbl::deconv::deconvolution_impl (
211+ dil_input,
212+ dil_weight,
213+ dil_bias,
214+ padding,
215+ output_padding,
216+ stride,
217+ dilation,
218+ groups);
219+
220+ return dbl::comm::gen_aten_tensor_by (std::move (dil_output));
221+ }
222+
175223at::Tensor AtenIpexCPUDev::dil_convolution_overrideable (const at::Tensor & input, const at::Tensor & weight, const at::Tensor & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups) {
176224 DEBUG (" AtenIpexCPUDev::convolution_overrideable\n " );
177225
@@ -184,7 +232,11 @@ at::Tensor AtenIpexCPUDev::dil_convolution_overrideable(const at::Tensor & input
184232 dnnl_input_tensors.push_back (bias);
185233 }
186234 if (dbl::chk::dnnl_support_the_tensors (dnnl_input_tensors))
187- return AtenIpexCPUDev::dil_convolution (input.is_contiguous () ? input : input.contiguous (), weight.is_contiguous () ? weight : weight.contiguous (), bias.defined () && !bias.is_contiguous () ? bias.contiguous () : bias, stride, padding, dilation, groups);
235+ if (transposed) {
236+ return AtenIpexCPUDev::dil_deconvolution (input.is_contiguous () ? input : input.contiguous (), weight.is_contiguous () ? weight : weight.contiguous (), bias.defined () && !bias.is_contiguous () ? bias.contiguous () : bias, padding, output_padding, stride, dilation, groups);
237+ } else {
238+ return AtenIpexCPUDev::dil_convolution (input.is_contiguous () ? input : input.contiguous (), weight.is_contiguous () ? weight : weight.contiguous (), bias.defined () && !bias.is_contiguous () ? bias.contiguous () : bias, stride, padding, dilation, groups);
239+ }
188240 }
189241 } catch (std::exception& e) {
190242#if defined(_DEBUG)
@@ -198,43 +250,34 @@ at::Tensor AtenIpexCPUDev::dil_convolution_overrideable(const at::Tensor & input
198250 auto && _ipex_input = bridge::shallowFallbackToCPUTensor (input);
199251 auto && _ipex_weight = bridge::shallowFallbackToCPUTensor (weight);
200252 auto && _ipex_bias = bridge::shallowFallbackToCPUTensor (bias);
201- auto && _ipex_result = at::mkldnn_convolution (_ipex_input, _ipex_weight, _ipex_bias, padding, stride , dilation, groups);
253+ auto && _ipex_result = at::convolution (_ipex_input, _ipex_weight, _ipex_bias, stride, padding , dilation, transposed, output_padding , groups);
202254 static_cast <void >(_ipex_result); // Avoid warnings in case not used
203255 return bridge::shallowUpgradeToDPCPPTensor (_ipex_result);
204256}
205257
206- at::Tensor AtenIpexCPUDev::mkldnn_convolution (const at::Tensor & self, const at::Tensor & weight, const at::Tensor & bias, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups) {
207- DEBUG (" AtenIpexCPUDev::mkldnn_convolution\n " );
208- TORCH_INTERNAL_ASSERT_DEBUG_ONLY (self.defined ());
209- TORCH_INTERNAL_ASSERT_DEBUG_ONLY (weight.defined ());
210- TORCH_INTERNAL_ASSERT_DEBUG_ONLY (self.layout () == c10::kStrided );
211- TORCH_INTERNAL_ASSERT_DEBUG_ONLY (weight.layout () == c10::kStrided );
212- TORCH_INTERNAL_ASSERT_DEBUG_ONLY (!(bias.defined ()) || (bias.defined () && bias.layout () == c10::kStrided ));
213- auto && _ipex_self = bridge::shallowFallbackToCPUTensor (self);
214- auto && _ipex_weight = bridge::shallowFallbackToCPUTensor (weight);
215- auto && _ipex_bias = bridge::shallowFallbackToCPUTensor (bias);
216- auto && _ipex_result = at::mkldnn_convolution (_ipex_self.contiguous (), _ipex_weight.contiguous (), _ipex_bias.contiguous (), padding, stride, dilation, groups);
217- static_cast <void >(_ipex_result); // Avoid warnings in case not used
218- TORCH_INTERNAL_ASSERT_DEBUG_ONLY (_ipex_result.is_contiguous ());
219- TORCH_INTERNAL_ASSERT_DEBUG_ONLY (_ipex_result.layout () == c10::kStrided );
220- return bridge::shallowUpgradeToDPCPPTensor (_ipex_result);
221- }
222-
223258std::tuple<at::Tensor,at::Tensor,at::Tensor> AtenIpexCPUDev::dil_convolution_backward_overrideable (const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups, std::array<bool ,3 > output_mask) {
224259 DEBUG (" AtenIpexCPUDev::convolution_backward_overrideable\n " );
225260 // NOTE: DO NOT always call contiguous. It may break lazy-reorder. Because contiguous will call reorder instantly.
226261 if (check_auto_dnnl ()) {
227- return dil_convolution_backward (
228- input.is_contiguous () ? input : input.contiguous (),
229- grad_output.is_contiguous () ? grad_output : grad_output.contiguous (),
230- weight.is_contiguous () ? weight : weight.contiguous (),
231- padding,
232- stride,
233- dilation,
234- groups,
235- output_mask);
262+ if (transposed) {
263+ IPEX_CHECK (false , " deconvolution backward not support for dnnl path now" );
264+ } else {
265+ return AtenIpexCPUDev::dil_convolution_backward (
266+ input.is_contiguous () ? input : input.contiguous (),
267+ grad_output.is_contiguous () ? grad_output : grad_output.contiguous (),
268+ weight.is_contiguous () ? weight : weight.contiguous (),
269+ padding,
270+ stride,
271+ dilation,
272+ groups,
273+ output_mask);
274+ }
236275 } else {
237- return mkldnn_convolution_backward (input, grad_output, weight, padding, stride, dilation, groups, output_mask);
276+ if (transposed) {
277+ IPEX_CHECK (false , " deconvolution backward not support for native path now" );
278+ } else {
279+ return AtenIpexCPUDev::mkldnn_convolution_backward (input, grad_output, weight, padding, stride, dilation, groups, output_mask);
280+ }
238281 }
239282}
240283
0 commit comments