@@ -95,7 +95,40 @@ enum class joint_dtypes_t {
9595
9696enum class trans_type_t { nn = 0 , nt, tn, tt };
9797
98- enum class bias_type_t { none = 0 , scalar, m, n, mn };
98+ enum class bias_shape_t : uint8_t {
99+ none = 0 ,
100+ scalar = 1 ,
101+ m = 2 ,
102+ n = 3 ,
103+ mn = 4 ,
104+ };
105+
106+ enum class bias_data_type_t : uint8_t {
107+ none = 0 ,
108+ f32 = 1 ,
109+ f16 = 2 ,
110+ bf16 = 3 ,
111+ // extend as needed
112+ };
113+
114+ // Packed enum
115+ enum class bias_type_t : uint16_t {};
116+
117+ // Encode function (constexpr)
118+ constexpr bias_type_t make_bias_type (
119+ bias_shape_t shape,
120+ bias_data_type_t dtype) {
121+ return static_cast <bias_type_t >((uint16_t (shape) << 8 ) | uint16_t (dtype));
122+ }
123+
124+ // Decode helpers (constexpr)
125+ constexpr bias_shape_t get_shape (bias_type_t type) {
126+ return static_cast <bias_shape_t >((static_cast <uint16_t >(type) >> 8 ) & 0xFF );
127+ }
128+
129+ constexpr bias_data_type_t get_dtype (bias_type_t type) {
130+ return static_cast <bias_data_type_t >(static_cast <uint16_t >(type) & 0xFF );
131+ }
99132
100133template <joint_dtypes_t Ts>
101134struct onednn_types_mapper ;
@@ -172,24 +205,50 @@ struct onednn_types_mapper<joint_dtypes_t::bf16_f8_e4m3> {
172205 }
173206};
174207
175- // TODO: bias types maybe not right
176- static inline dnnl::memory::dims get_bias_type (
177- bias_type_t b_dims,
208+ static inline dnnl::memory::dims get_bias_shape_type (
209+ bias_type_t b_type,
178210 const int m,
179211 const int n) {
180- switch (b_dims) {
181- case bias_type_t ::none:
212+ bias_shape_t b_shape = get_shape (b_type);
213+ switch (b_shape) {
214+ case bias_shape_t ::none:
182215 return {0 };
183- case bias_type_t ::scalar:
216+ case bias_shape_t ::scalar:
184217 return {1 , 1 };
185- case bias_type_t ::m:
218+ case bias_shape_t ::m:
186219 return {m, 1 };
187- case bias_type_t ::n:
220+ case bias_shape_t ::n:
188221 return {1 , n};
189- case bias_type_t ::mn:
222+ case bias_shape_t ::mn:
190223 return {m, n};
191224 default :
192- throw std::runtime_error (" unsupported bias type ..." );
225+ throw std::runtime_error (" unsupported bias shape ..." );
226+ }
227+ }
228+
229+ static inline dnnl::memory::data_type get_bias_data_type (bias_type_t b_type) {
230+ bias_data_type_t b_dtype = get_dtype (b_type);
231+ switch (b_dtype) {
232+ case bias_data_type_t ::none:
233+ return dnnl::memory::data_type::undef;
234+ case bias_data_type_t ::f32 :
235+ return dnnl::memory::data_type::f32 ;
236+ case bias_data_type_t ::f16 :
237+ return dnnl::memory::data_type::f16 ;
238+ case bias_data_type_t ::bf16 :
239+ return dnnl::memory::data_type::bf16 ;
240+ default :
241+ throw std::runtime_error (" unsupported bias dtype ..." );
242+ }
243+ }
244+
245+ static inline dnnl::memory::format_tag get_bias_format_type (
246+ bias_type_t b_type) {
247+ bias_shape_t b_shape = get_shape (b_type);
248+ if (b_shape == bias_shape_t ::none) {
249+ return dnnl::memory::format_tag::undef;
250+ } else {
251+ return dnnl::memory::format_tag::ab;
193252 }
194253}
195254
@@ -515,7 +574,7 @@ struct matmul_primitive_cache_t {
515574 const int64_t ldb,
516575 const int64_t ldc,
517576 const bias_type_t
518- b_dims , // for shapeless bias, not put it into template parameter
577+ b_type , // for shapeless bias, not put it into template parameter
519578 const int device_id,
520579 F f_attr,
521580 const int scale_group_size,
@@ -529,7 +588,7 @@ struct matmul_primitive_cache_t {
529588 m,
530589 n,
531590 k,
532- int (b_dims ),
591+ int (b_type ),
533592 scale_group_size,
534593 zp_group_size);
535594 auto iter = cached.find (pri_key);
@@ -546,19 +605,19 @@ struct matmul_primitive_cache_t {
546605 ? dnnl::memory::data_type::f16
547606 : src_dt);
548607 auto dst_md = memory::desc ({m, n}, dst_dt, dst_strides);
549- auto bias_format = b_dims == bias_type_t ::none
550- ? dnnl::memory::format_tag::undef
551- : dnnl::memory::format_tag::ab;
608+
552609 auto bias_md = memory::desc (
553- get_bias_type (b_dims, m, n), dst_dt, bias_format); // {m, n} or {1, n}
610+ get_bias_shape_type (b_type, m, n),
611+ get_bias_data_type (b_type),
612+ get_bias_format_type (b_type)); // {m, n} or {1, n}
554613
555614 primitive_attr pattr;
556615 f_attr (pattr);
557616
558617 dnnl::matmul::primitive_desc matmul_pd;
559618 at::Device curDevice = at::Device (at::kXPU , device_id);
560619 auto aengine = GpuEngineManager::Instance ().get_engine (curDevice);
561- if (b_dims == bias_type_t ::none) {
620+ if (get_shape (b_type) == bias_shape_t ::none) {
562621 matmul_pd = dnnl::matmul::primitive_desc (
563622 aengine, src_md, wei_md, dst_md, pattr);
564623 } else {
@@ -592,7 +651,7 @@ struct matmul_primitive_cache_t {
592651template <joint_dtypes_t Ts, typename F>
593652static inline primitive_ext& matmul_primitive_create_and_cache (
594653 const trans_type_t Tt,
595- const bias_type_t b_dims ,
654+ const bias_type_t b_type ,
596655 const int m,
597656 const int n,
598657 const int k,
@@ -612,7 +671,7 @@ static inline primitive_ext& matmul_primitive_create_and_cache(
612671 lda,
613672 ldb,
614673 ldc,
615- b_dims ,
674+ b_type ,
616675 device_id,
617676 attr,
618677 scale_group_size,
@@ -625,7 +684,7 @@ static inline primitive_ext& matmul_primitive_create_and_cache(
625684 lda,
626685 ldb,
627686 ldc,
628- b_dims ,
687+ b_type ,
629688 device_id,
630689 attr,
631690 scale_group_size,
@@ -639,7 +698,7 @@ template <typename F>
639698static inline primitive_ext& matmul_primitive_create_and_cache (
640699 const joint_dtypes_t Ts,
641700 const trans_type_t Tt,
642- const bias_type_t b_dims ,
701+ const bias_type_t b_type ,
643702 const int m,
644703 const int n,
645704 const int k,
@@ -654,7 +713,7 @@ static inline primitive_ext& matmul_primitive_create_and_cache(
654713 case joint_dtypes_t ::f16_int4:
655714 return matmul_primitive_create_and_cache<joint_dtypes_t ::f16_int4, F>(
656715 Tt,
657- b_dims ,
716+ b_type ,
658717 m,
659718 n,
660719 k,
@@ -668,7 +727,7 @@ static inline primitive_ext& matmul_primitive_create_and_cache(
668727 case joint_dtypes_t ::bf16_int4:
669728 return matmul_primitive_create_and_cache<joint_dtypes_t ::bf16_int4, F>(
670729 Tt,
671- b_dims ,
730+ b_type ,
672731 m,
673732 n,
674733 k,
@@ -682,7 +741,7 @@ static inline primitive_ext& matmul_primitive_create_and_cache(
682741 case joint_dtypes_t ::s8_int4:
683742 return matmul_primitive_create_and_cache<joint_dtypes_t ::s8_int4, F>(
684743 Tt,
685- b_dims ,
744+ b_type ,
686745 m,
687746 n,
688747 k,
@@ -696,7 +755,7 @@ static inline primitive_ext& matmul_primitive_create_and_cache(
696755 case joint_dtypes_t ::u8_int4:
697756 return matmul_primitive_create_and_cache<joint_dtypes_t ::u8_int4, F>(
698757 Tt,
699- b_dims ,
758+ b_type ,
700759 m,
701760 n,
702761 k,
@@ -710,7 +769,7 @@ static inline primitive_ext& matmul_primitive_create_and_cache(
710769 case joint_dtypes_t ::f16_f8_e5m2:
711770 return matmul_primitive_create_and_cache<joint_dtypes_t ::f16_f8_e5m2, F>(
712771 Tt,
713- b_dims ,
772+ b_type ,
714773 m,
715774 n,
716775 k,
@@ -724,7 +783,7 @@ static inline primitive_ext& matmul_primitive_create_and_cache(
724783 case joint_dtypes_t ::bf16_f8_e5m2:
725784 return matmul_primitive_create_and_cache<joint_dtypes_t ::bf16_f8_e5m2, F>(
726785 Tt,
727- b_dims ,
786+ b_type ,
728787 m,
729788 n,
730789 k,
@@ -738,7 +797,7 @@ static inline primitive_ext& matmul_primitive_create_and_cache(
738797 case joint_dtypes_t ::f16_f8_e4m3:
739798 return matmul_primitive_create_and_cache<joint_dtypes_t ::f16_f8_e4m3, F>(
740799 Tt,
741- b_dims ,
800+ b_type ,
742801 m,
743802 n,
744803 k,
@@ -752,7 +811,7 @@ static inline primitive_ext& matmul_primitive_create_and_cache(
752811 case joint_dtypes_t ::bf16_f8_e4m3:
753812 return matmul_primitive_create_and_cache<joint_dtypes_t ::bf16_f8_e4m3, F>(
754813 Tt,
755- b_dims ,
814+ b_type ,
756815 m,
757816 n,
758817 k,
0 commit comments