File tree Expand file tree Collapse file tree 1 file changed +9
-6
lines changed Expand file tree Collapse file tree 1 file changed +9
-6
lines changed Original file line number Diff line number Diff line change @@ -173,12 +173,15 @@ struct inner_product_forward : public dnnl::inner_product_forward {
173173 IDEEP_ENFORCE (utils::one_of (weights.get_data_type (),
174174 data_type::f32 , data_type::bf16 ),
175175 " Incorrect data type in weights" );
176-
177- // align weights data type with src
178- dst_data_type = src.get_data_type () == data_type::bf16 ? data_type::bf16
179- : data_type::f32 ;
180- src_desc = src.get_desc ().to_type (dst_data_type);
181- weights_desc = weights.get_desc ().to_type (dst_data_type);
176+ if (dst.is_empty ()) {
177+ // align weights data type with src
178+ dst_data_type = src.get_data_type () == data_type::bf16 ? data_type::bf16
179+ : data_type::f32 ;
180+ } else {
181+ dst_data_type = dst.get_data_type ();
182+ }
183+ src_desc = src.get_desc ().to_type (src.get_data_type ());
184+ weights_desc = weights.get_desc ().to_type (src.get_data_type ());
182185 if (with_bias) {
183186 IDEEP_ENFORCE (utils::one_of (bias.get_data_type (),
184187 data_type::f32 , data_type::bf16 ),
You can’t perform that action at this time.
0 commit comments