@@ -9,6 +9,98 @@ namespace torch_ipex {
99namespace cpu {
1010
1111IPEX_DEFINE_DISPATCH (fused_experts_impl_stub);
12+ template <typename T>
13+ inline void copy_and_fill (
14+ T* __restrict__ out,
15+ const T* __restrict__ input,
16+ int size,
17+ int pad_size,
18+ T fill_value) {
19+ using Vec = at::vec::Vectorized<T>;
20+ int d = 0 ;
21+ if (size >= Vec::size ()) {
22+ #pragma GCC unroll 4
23+ for (; d < size; d += Vec::size ()) {
24+ Vec data = Vec::loadu (input + d);
25+ data.store (out + d);
26+ }
27+ }
28+ for (; d < size; ++d) {
29+ out[d] = input[d];
30+ }
31+ // using scalar padding as pad_size is less than vec size here
32+ for (; d < pad_size; ++d) {
33+ out[d] = fill_value;
34+ }
35+ }
36+
37+ at::Tensor fused_experts_with_shared (
38+ const at::Tensor& hidden_states,
39+ const at::Tensor& w1,
40+ const at::Tensor& w2,
41+ const at::Tensor& topk_weights,
42+ const at::Tensor& topk_ids,
43+ bool inplace,
44+ bool is_vnni,
45+ bool is_distributed,
46+ bool is_woq,
47+ int64_t woq_weight_dtype,
48+ int64_t woq_group_size,
49+ int64_t woq_lowp_mode,
50+ const std::optional<at::Tensor>& w1_scale,
51+ const std::optional<at::Tensor>& w1_zp,
52+ const std::optional<at::Tensor>& w1_compensation,
53+ const std::optional<at::Tensor>& w2_scale,
54+ const std::optional<at::Tensor>& w2_zp,
55+ const std::optional<at::Tensor>& w2_compensation) {
56+ RECORD_FUNCTION (
57+ " ipex::fused_experts_with_shared" , c10::ArrayRef<c10::IValue>({}));
58+ int32_t num_tokens = topk_weights.size (0 );
59+ int32_t num_topk_experts = topk_weights.size (1 );
60+ int32_t num_topk_experts_pad = num_topk_experts + 1 ;
61+ int32_t num_experts = w1.size (0 );
62+ auto pad_weight =
63+ at::empty ({num_tokens, num_topk_experts_pad}, topk_weights.options ());
64+ auto pad_ids =
65+ at::empty ({num_tokens, num_topk_experts_pad}, topk_ids.options ());
66+ // padding 1 shared expert to routed expert
67+ // topk_id is num_experts - 1, and topk weights is 1.0
68+ for (int id = 0 ; id < num_tokens; id++) {
69+ copy_and_fill<int32_t >(
70+ pad_ids.data_ptr <int32_t >() + id * num_topk_experts_pad,
71+ topk_ids.data_ptr <int32_t >() + id * num_topk_experts,
72+ num_topk_experts,
73+ num_topk_experts_pad,
74+ num_experts - 1 );
75+ copy_and_fill<float >(
76+ pad_weight.data_ptr <float >() + id * num_topk_experts_pad,
77+ topk_weights.data_ptr <float >() + id * num_topk_experts,
78+ num_topk_experts,
79+ num_topk_experts_pad,
80+ 1.0 );
81+ }
82+ return fused_experts_impl_stub (
83+ kCPU ,
84+ hidden_states,
85+ w1,
86+ w2,
87+ pad_weight,
88+ pad_ids,
89+ inplace,
90+ is_vnni,
91+ is_distributed,
92+ is_woq,
93+ woq_weight_dtype,
94+ woq_group_size,
95+ woq_lowp_mode,
96+ w1_scale,
97+ w1_zp,
98+ w1_compensation,
99+ w2_scale,
100+ w2_zp,
101+ w2_compensation);
102+ }
103+
12104at::Tensor fused_experts (
13105 const at::Tensor& hidden_states,
14106 const at::Tensor& w1,
@@ -334,6 +426,15 @@ TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
334426 Tensor? w1_scale, Tensor? w1_zp, Tensor? w1_compensation, Tensor? w2_scale, Tensor? w2_zp, Tensor? w2_compensation) -> Tensor" );
335427 m.impl (
336428 " fused_experts" , c10::DispatchKey::CPU, torch_ipex::cpu::fused_experts);
429+ m.def (
430+ " fused_experts_with_shared(Tensor hidden_states, Tensor w1, Tensor w2, Tensor topk_weights, \
431+ Tensor topk_ids, bool inplace, bool is_vnni, \
432+ bool is_distributed, bool is_woq, int woq_weight_dtype, int woq_group_size, int woq_lowp_mode, \
433+ Tensor? w1_scale, Tensor? w1_zp, Tensor? w1_compensation, Tensor? w2_scale, Tensor? w2_zp, Tensor? w2_compensation) -> Tensor" );
434+ m.impl (
435+ " fused_experts_with_shared" ,
436+ c10::DispatchKey::CPU,
437+ torch_ipex::cpu::fused_experts_with_shared);
337438 m.def (
338439 " grouped_topk(Tensor hidden_states, Tensor gating_output, \
339440 int topk, bool renormalize, int num_expert_group, int topk_group, Tensor e_score_correction_bias, Tensor routed_scaling_factor) -> (Tensor, Tensor)" );
0 commit comments