@@ -164,7 +164,6 @@ struct inner_product_forward : public dnnl::inner_product_forward {
164164 }
165165 } else {
166166 op_attr = attr;
167- src_desc = {src.get_dims (), data_type::f32 , format_tag::any};
168167 if (src.has_scale ()) {
169168 auto src_scale = src.get_scale ();
170169 src_scale[0 ] = 1 .f / src_scale[0 ];
@@ -178,56 +177,50 @@ struct inner_product_forward : public dnnl::inner_product_forward {
178177 // align weights data type with src
179178 dst_data_type = src.get_data_type () == data_type::bf16 ? data_type::bf16
180179 : data_type::f32 ;
181- src_desc = src.get_desc ().to_type (dst_data_type). to_format_any () ;
182- weights_desc = weights.get_desc ().to_type (dst_data_type). to_format_any () ;
180+ src_desc = src.get_desc ().to_type (dst_data_type);
181+ weights_desc = weights.get_desc ().to_type (dst_data_type);
183182 if (with_bias) {
184183 IDEEP_ENFORCE (utils::one_of (bias.get_data_type (),
185184 data_type::f32 , data_type::bf16 ),
186185 " Incorrect data type in bias" );
187- bias_desc = bias.get_desc (). to_format_any () ;
186+ bias_desc = bias.get_desc ();
188187 }
189188 }
190189
191- tensor::desc dst_desc (dst_dims, dst_data_type, format_tag::any );
190+ tensor::desc dst_desc = dst. get_desc (). to_type (dst_data_type );
192191 auto pd = with_bias
193192 ? primitive_desc ({aprop_kind, src_desc, weights_desc, bias_desc,
194193 dst_desc}, op_attr, aengine)
195194 : primitive_desc ({aprop_kind, src_desc, weights_desc, dst_desc},
196195 op_attr, aengine);
197196
198- auto expected_src = src.reorder_if_differ_in (pd.src_desc (), src_attr);
199- auto expected_weights = weights.reorder_if_differ_in (pd.weights_desc (), weights_attr);
200197 // [ Note output buffer ]
201198 // In this case, dst is an empty ideep tensor, can be re-init
202199 // If dst is not empty, ideep must write result to dst's memory and it is caller's duty to
203200 // make sure dst is big enough to hold the result
204201 if (dst.is_empty ()) {
205202 dst.init (pd.dst_desc ());
206203 }
207- auto expected_dst = dst.reorder_if_differ_in (pd.dst_desc ());
208- if (!dst_scales.empty () && utils::one_of (dst.get_data_type (), data_type::u8 , data_type::s8)) {
209- expected_dst.set_scale (dst_scales_in);
204+
205+ if (!dst_scales.empty () &&
206+ utils::one_of (dst.get_data_type (), data_type::u8 , data_type::s8)) {
207+ dst.set_scale (dst_scales_in);
210208 }
211209
212210 if (with_bias){
213- auto expected_bias = bias.reorder_if_differ_in (pd.bias_desc (), bias_attr);
214- super (pd).execute (stream::default_stream (),
215- {{DNNL_ARG_SRC, expected_src},
216- {DNNL_ARG_WEIGHTS, expected_weights},
217- {DNNL_ARG_BIAS, expected_bias},
218- {DNNL_ARG_DST, expected_dst}});
211+ super (pd).execute (stream::default_stream (), {{DNNL_ARG_SRC, src},
212+ {DNNL_ARG_WEIGHTS, weights},
213+ {DNNL_ARG_BIAS, bias},
214+ {DNNL_ARG_DST, dst}});
219215 } else {
220- super (pd).execute (stream::default_stream (),
221- {{DNNL_ARG_SRC, expected_src},
222- {DNNL_ARG_WEIGHTS, expected_weights},
223- {DNNL_ARG_DST, expected_dst}});
216+ super (pd).execute (stream::default_stream (), {{DNNL_ARG_SRC, src},
217+ {DNNL_ARG_WEIGHTS, weights},
218+ {DNNL_ARG_DST, dst}});
224219 }
225220
226- if (attr.non_negitive_output () && expected_dst .get_data_type () == data_type::s8) {
227- expected_dst .to_type (data_type::u8 );
221+ if (attr.non_negitive_output () && dst .get_data_type () == data_type::s8) {
222+ dst .to_type (data_type::u8 );
228223 }
229- // reorder back to dst's buffer if needed
230- expected_dst.reorder_to_if_differ_from (dst);
231224 }
232225};
233226
@@ -242,11 +235,6 @@ struct inner_product_backward_data : public dnnl::inner_product_backward_data {
242235 tensor& diff_src,
243236 const engine& aengine = engine::cpu_engine()) {
244237 auto weights_ = weights;
245- if (diff_dst.get_data_type () == data_type::bf16 ) {
246- weights_.init (weights.get_desc ().to_type (data_type::bf16 ));
247- weights_.reorder_from (weights);
248- }
249-
250238 // workaround: diff_src and weights from caffe2 may have different dims.
251239 // It would be better for caffe2 to do this reshape anyway.
252240 if (diff_src_dims.size () != weights.ndims ()) {
@@ -255,10 +243,9 @@ struct inner_product_backward_data : public dnnl::inner_product_backward_data {
255243 weights_.reshape (new_dims);
256244 }
257245
258- auto diff_dst_desc = diff_dst.get_desc ().to_format_any ();
259- auto weights_desc = weights_.get_desc ().to_format_any ();
260- auto diff_src_desc =
261- tensor::desc (diff_src_dims, diff_dst.get_data_type (), tag::any);
246+ auto diff_dst_desc = diff_dst.get_desc ();
247+ auto weights_desc = weights_.get_desc ();
248+ auto diff_src_desc = diff_src.get_desc ().to_type (diff_dst.get_data_type ());
262249
263250 auto forward_hints =
264251 inner_product_forward::primitive_desc (
@@ -268,8 +255,6 @@ struct inner_product_backward_data : public dnnl::inner_product_backward_data {
268255 auto pd = primitive_desc (
269256 {diff_src_desc, weights_desc, diff_dst_desc}, aengine, forward_hints);
270257
271- auto expected_diff_dst = diff_dst.reorder_if_differ_in (pd.diff_dst_desc ());
272- auto expected_weights = weights_.reorder_if_differ_in (pd.weights_desc ());
273258 // diff_src's origin content are not used, so it can be re-init directly
274259 // It's caller's duty to make sure diff_src's buffer size is same with it actually needed
275260 // Here we dose not support to write to given strided buffer since we know the grad is always contiguous
@@ -280,8 +265,8 @@ struct inner_product_backward_data : public dnnl::inner_product_backward_data {
280265 }
281266
282267 super (pd).execute (stream::default_stream (),
283- {{DNNL_ARG_DIFF_DST, expected_diff_dst },
284- {DNNL_ARG_WEIGHTS, expected_weights },
268+ {{DNNL_ARG_DIFF_DST, diff_dst },
269+ {DNNL_ARG_WEIGHTS, weights_ },
285270 {DNNL_ARG_DIFF_SRC, diff_src}});
286271 }
287272};
@@ -319,18 +304,17 @@ struct inner_product_backward_weights
319304 tensor& diff_bias,
320305 const data_type diff_weight_type,
321306 const engine& aengine = engine::cpu_engine()) {
322- auto src_desc = src.get_desc (). to_format_any () ;
323- auto diff_dst_desc = diff_dst.get_desc (). to_format_any () ;
307+ auto src_desc = src.get_desc ();
308+ auto diff_dst_desc = diff_dst.get_desc ();
324309 auto diff_weights_dims = src.get_dims ();
325310 diff_weights_dims[0 ] = diff_dst.get_dim (1 );
326311 data_type diff_dst_type = diff_dst.get_data_type ();
327312 data_type diff_weight_type_in = data_type::undef== diff_weight_type ?
328313 diff_dst_type : diff_weight_type;
329- auto diff_weights_desc =
330- tensor::desc (diff_weights_dims, diff_weight_type_in, tag::any);
331314
332- auto diff_bias_desc =
333- tensor::desc ({diff_dst.get_dim (1 )}, diff_weight_type_in, tag::any);
315+ auto diff_weights_desc =
316+ diff_weights.get_desc ().to_type (diff_weight_type_in);
317+ auto diff_bias_desc = diff_bias.get_desc ().to_type (diff_weight_type_in);
334318
335319 // for forward hint, weights_desc should have same data_type
336320 // with other input desc, expect for bias_desc
@@ -349,18 +333,13 @@ struct inner_product_backward_weights
349333 : primitive_desc ({src_desc, diff_weights_desc, diff_dst_desc},
350334 aengine, forward_hints);
351335
352- auto expected_diff_dst = diff_dst.reorder_if_differ_in (pd.diff_dst_desc ());
353- auto expected_src = src.reorder_if_differ_in (pd.src_desc ());
354336 if (diff_weights.is_empty ()){
355337 diff_weights.init (pd.diff_weights_desc ());
356338 }
357- // Here we need to write to given strided buffer, so if given buffer is different with the best format
358- // We need to firstly init a new buffer to store the output, and copy the output to a given buffer
359- tensor expected_diff_weights = diff_weights.get_desc () == pd.diff_weights_desc () ? diff_weights : tensor (pd.diff_weights_desc ());
360339
361- exec_args args {{DNNL_ARG_DIFF_DST, expected_diff_dst },
362- {DNNL_ARG_SRC, expected_src },
363- {DNNL_ARG_DIFF_WEIGHTS ,expected_diff_weights }};
340+ exec_args args{{DNNL_ARG_DIFF_DST, diff_dst },
341+ {DNNL_ARG_SRC, src },
342+ {DNNL_ARG_DIFF_WEIGHTS, diff_weights }};
364343
365344 if (with_diff_bias) {
366345 if (diff_bias.is_empty ()){
@@ -373,7 +352,6 @@ struct inner_product_backward_weights
373352 }
374353
375354 super (pd).execute (stream::default_stream (), args);
376- expected_diff_weights.reorder_to_if_differ_from (diff_weights);
377355 }
378356};
379357
0 commit comments