Skip to content

Commit 9ada31f

Browse files
authored
* add bias data_type info into premitive cache interface (#5742)
* keep bias and scale dtype Signed-off-by: baodii <di.bao@intel.com>
1 parent b2541a4 commit 9ada31f

File tree

3 files changed

+168
-59
lines changed

3 files changed

+168
-59
lines changed

csrc/gpu/oneDNN/DnnlExt.h

Lines changed: 89 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,40 @@ enum class joint_dtypes_t {
9595

9696
enum class trans_type_t { nn = 0, nt, tn, tt };
9797

98-
enum class bias_type_t { none = 0, scalar, m, n, mn };
98+
enum class bias_shape_t : uint8_t {
99+
none = 0,
100+
scalar = 1,
101+
m = 2,
102+
n = 3,
103+
mn = 4,
104+
};
105+
106+
enum class bias_data_type_t : uint8_t {
107+
none = 0,
108+
f32 = 1,
109+
f16 = 2,
110+
bf16 = 3,
111+
// extend as needed
112+
};
113+
114+
// Packed enum
115+
enum class bias_type_t : uint16_t {};
116+
117+
// Encode function (constexpr)
118+
constexpr bias_type_t make_bias_type(
119+
bias_shape_t shape,
120+
bias_data_type_t dtype) {
121+
return static_cast<bias_type_t>((uint16_t(shape) << 8) | uint16_t(dtype));
122+
}
123+
124+
// Decode helpers (constexpr)
125+
constexpr bias_shape_t get_shape(bias_type_t type) {
126+
return static_cast<bias_shape_t>((static_cast<uint16_t>(type) >> 8) & 0xFF);
127+
}
128+
129+
constexpr bias_data_type_t get_dtype(bias_type_t type) {
130+
return static_cast<bias_data_type_t>(static_cast<uint16_t>(type) & 0xFF);
131+
}
99132

100133
template <joint_dtypes_t Ts>
101134
struct onednn_types_mapper;
@@ -172,24 +205,50 @@ struct onednn_types_mapper<joint_dtypes_t::bf16_f8_e4m3> {
172205
}
173206
};
174207

175-
// TODO: bias types maybe not right
176-
static inline dnnl::memory::dims get_bias_type(
177-
bias_type_t b_dims,
208+
static inline dnnl::memory::dims get_bias_shape_type(
209+
bias_type_t b_type,
178210
const int m,
179211
const int n) {
180-
switch (b_dims) {
181-
case bias_type_t::none:
212+
bias_shape_t b_shape = get_shape(b_type);
213+
switch (b_shape) {
214+
case bias_shape_t::none:
182215
return {0};
183-
case bias_type_t::scalar:
216+
case bias_shape_t::scalar:
184217
return {1, 1};
185-
case bias_type_t::m:
218+
case bias_shape_t::m:
186219
return {m, 1};
187-
case bias_type_t::n:
220+
case bias_shape_t::n:
188221
return {1, n};
189-
case bias_type_t::mn:
222+
case bias_shape_t::mn:
190223
return {m, n};
191224
default:
192-
throw std::runtime_error("unsupported bias type ...");
225+
throw std::runtime_error("unsupported bias shape ...");
226+
}
227+
}
228+
229+
static inline dnnl::memory::data_type get_bias_data_type(bias_type_t b_type) {
230+
bias_data_type_t b_dtype = get_dtype(b_type);
231+
switch (b_dtype) {
232+
case bias_data_type_t::none:
233+
return dnnl::memory::data_type::undef;
234+
case bias_data_type_t::f32:
235+
return dnnl::memory::data_type::f32;
236+
case bias_data_type_t::f16:
237+
return dnnl::memory::data_type::f16;
238+
case bias_data_type_t::bf16:
239+
return dnnl::memory::data_type::bf16;
240+
default:
241+
throw std::runtime_error("unsupported bias dtype ...");
242+
}
243+
}
244+
245+
static inline dnnl::memory::format_tag get_bias_format_type(
246+
bias_type_t b_type) {
247+
bias_shape_t b_shape = get_shape(b_type);
248+
if (b_shape == bias_shape_t::none) {
249+
return dnnl::memory::format_tag::undef;
250+
} else {
251+
return dnnl::memory::format_tag::ab;
193252
}
194253
}
195254

@@ -515,7 +574,7 @@ struct matmul_primitive_cache_t {
515574
const int64_t ldb,
516575
const int64_t ldc,
517576
const bias_type_t
518-
b_dims, // for shapeless bias, not put it into template parameter
577+
b_type, // for shapeless bias, not put it into template parameter
519578
const int device_id,
520579
F f_attr,
521580
const int scale_group_size,
@@ -529,7 +588,7 @@ struct matmul_primitive_cache_t {
529588
m,
530589
n,
531590
k,
532-
int(b_dims),
591+
int(b_type),
533592
scale_group_size,
534593
zp_group_size);
535594
auto iter = cached.find(pri_key);
@@ -546,19 +605,19 @@ struct matmul_primitive_cache_t {
546605
? dnnl::memory::data_type::f16
547606
: src_dt);
548607
auto dst_md = memory::desc({m, n}, dst_dt, dst_strides);
549-
auto bias_format = b_dims == bias_type_t::none
550-
? dnnl::memory::format_tag::undef
551-
: dnnl::memory::format_tag::ab;
608+
552609
auto bias_md = memory::desc(
553-
get_bias_type(b_dims, m, n), dst_dt, bias_format); // {m, n} or {1, n}
610+
get_bias_shape_type(b_type, m, n),
611+
get_bias_data_type(b_type),
612+
get_bias_format_type(b_type)); // {m, n} or {1, n}
554613

555614
primitive_attr pattr;
556615
f_attr(pattr);
557616

558617
dnnl::matmul::primitive_desc matmul_pd;
559618
at::Device curDevice = at::Device(at::kXPU, device_id);
560619
auto aengine = GpuEngineManager::Instance().get_engine(curDevice);
561-
if (b_dims == bias_type_t::none) {
620+
if (get_shape(b_type) == bias_shape_t::none) {
562621
matmul_pd = dnnl::matmul::primitive_desc(
563622
aengine, src_md, wei_md, dst_md, pattr);
564623
} else {
@@ -592,7 +651,7 @@ struct matmul_primitive_cache_t {
592651
template <joint_dtypes_t Ts, typename F>
593652
static inline primitive_ext& matmul_primitive_create_and_cache(
594653
const trans_type_t Tt,
595-
const bias_type_t b_dims,
654+
const bias_type_t b_type,
596655
const int m,
597656
const int n,
598657
const int k,
@@ -612,7 +671,7 @@ static inline primitive_ext& matmul_primitive_create_and_cache(
612671
lda,
613672
ldb,
614673
ldc,
615-
b_dims,
674+
b_type,
616675
device_id,
617676
attr,
618677
scale_group_size,
@@ -625,7 +684,7 @@ static inline primitive_ext& matmul_primitive_create_and_cache(
625684
lda,
626685
ldb,
627686
ldc,
628-
b_dims,
687+
b_type,
629688
device_id,
630689
attr,
631690
scale_group_size,
@@ -639,7 +698,7 @@ template <typename F>
639698
static inline primitive_ext& matmul_primitive_create_and_cache(
640699
const joint_dtypes_t Ts,
641700
const trans_type_t Tt,
642-
const bias_type_t b_dims,
701+
const bias_type_t b_type,
643702
const int m,
644703
const int n,
645704
const int k,
@@ -654,7 +713,7 @@ static inline primitive_ext& matmul_primitive_create_and_cache(
654713
case joint_dtypes_t::f16_int4:
655714
return matmul_primitive_create_and_cache<joint_dtypes_t::f16_int4, F>(
656715
Tt,
657-
b_dims,
716+
b_type,
658717
m,
659718
n,
660719
k,
@@ -668,7 +727,7 @@ static inline primitive_ext& matmul_primitive_create_and_cache(
668727
case joint_dtypes_t::bf16_int4:
669728
return matmul_primitive_create_and_cache<joint_dtypes_t::bf16_int4, F>(
670729
Tt,
671-
b_dims,
730+
b_type,
672731
m,
673732
n,
674733
k,
@@ -682,7 +741,7 @@ static inline primitive_ext& matmul_primitive_create_and_cache(
682741
case joint_dtypes_t::s8_int4:
683742
return matmul_primitive_create_and_cache<joint_dtypes_t::s8_int4, F>(
684743
Tt,
685-
b_dims,
744+
b_type,
686745
m,
687746
n,
688747
k,
@@ -696,7 +755,7 @@ static inline primitive_ext& matmul_primitive_create_and_cache(
696755
case joint_dtypes_t::u8_int4:
697756
return matmul_primitive_create_and_cache<joint_dtypes_t::u8_int4, F>(
698757
Tt,
699-
b_dims,
758+
b_type,
700759
m,
701760
n,
702761
k,
@@ -710,7 +769,7 @@ static inline primitive_ext& matmul_primitive_create_and_cache(
710769
case joint_dtypes_t::f16_f8_e5m2:
711770
return matmul_primitive_create_and_cache<joint_dtypes_t::f16_f8_e5m2, F>(
712771
Tt,
713-
b_dims,
772+
b_type,
714773
m,
715774
n,
716775
k,
@@ -724,7 +783,7 @@ static inline primitive_ext& matmul_primitive_create_and_cache(
724783
case joint_dtypes_t::bf16_f8_e5m2:
725784
return matmul_primitive_create_and_cache<joint_dtypes_t::bf16_f8_e5m2, F>(
726785
Tt,
727-
b_dims,
786+
b_type,
728787
m,
729788
n,
730789
k,
@@ -738,7 +797,7 @@ static inline primitive_ext& matmul_primitive_create_and_cache(
738797
case joint_dtypes_t::f16_f8_e4m3:
739798
return matmul_primitive_create_and_cache<joint_dtypes_t::f16_f8_e4m3, F>(
740799
Tt,
741-
b_dims,
800+
b_type,
742801
m,
743802
n,
744803
k,
@@ -752,7 +811,7 @@ static inline primitive_ext& matmul_primitive_create_and_cache(
752811
case joint_dtypes_t::bf16_f8_e4m3:
753812
return matmul_primitive_create_and_cache<joint_dtypes_t::bf16_f8_e4m3, F>(
754813
Tt,
755-
b_dims,
814+
b_type,
756815
m,
757816
n,
758817
k,

0 commit comments

Comments
 (0)