@@ -27,12 +27,48 @@ def __init__(self, target: tvm.target.Target):
2727 def transform_module (self , mod : IRModule , _ctx : tvm .transform .PassContext ) -> IRModule :
2828 """Entrypoint"""
2929 mod = mod .clone ()
30- mod ["apply_logit_bias_inplace" ] = _get_apply_logit_bias_inplace (self .target )
31- mod ["apply_penalty_inplace" ] = _get_apply_penalty_inplace (self .target )
32- mod ["apply_bitmask_inplace" ] = _get_apply_bitmask_inplace (self .target )
30+ if str (self .target .kind ) == "llvm" :
31+ mod ["apply_logit_bias_inplace" ] = _get_apply_logit_bias_inplace_cpu ()
32+ mod ["apply_penalty_inplace" ] = _get_apply_penalty_inplace_cpu ()
33+ mod ["apply_bitmask_inplace" ] = _get_apply_bitmask_inplace_cpu ()
34+ else :
35+ mod ["apply_logit_bias_inplace" ] = _get_apply_logit_bias_inplace (self .target )
36+ mod ["apply_penalty_inplace" ] = _get_apply_penalty_inplace (self .target )
37+ mod ["apply_bitmask_inplace" ] = _get_apply_bitmask_inplace (self .target )
3338 return mod
3439
3540
41+ def _get_apply_logit_bias_inplace_cpu ():
42+ @T .prim_func
43+ def _apply_logit_bias_inplace (
44+ var_logits : T .handle ,
45+ var_pos2seq_id : T .handle ,
46+ var_token_ids : T .handle ,
47+ var_logit_bias : T .handle ,
48+ ) -> None :
49+ """Function that applies logit bias in place."""
50+ T .func_attr (
51+ {
52+ "global_symbol" : "apply_logit_bias_inplace" ,
53+ "tir.noalias" : True ,
54+ "tir.is_scheduled" : True ,
55+ }
56+ )
57+ batch_size = T .int32 (is_size_var = True )
58+ vocab_size = T .int32 (is_size_var = True )
59+ num_token = T .int32 (is_size_var = True )
60+ logits = T .match_buffer (var_logits , (batch_size , vocab_size ), "float32" )
61+ # seq_ids
62+ pos2seq_id = T .match_buffer (var_pos2seq_id , (num_token ,), "int32" )
63+ token_ids = T .match_buffer (var_token_ids , (num_token ,), "int32" )
64+ logit_bias = T .match_buffer (var_logit_bias , (num_token ,), "float32" )
65+
66+ for i in range (num_token ):
67+ logits [pos2seq_id [i ], token_ids [i ]] += logit_bias [i ]
68+
69+ return _apply_logit_bias_inplace
70+
71+
3672def _get_apply_logit_bias_inplace (target : tvm .target .Target ):
3773 tx = 1024 # default
3874 max_num_threads_per_block = get_max_num_threads_per_block (target )
@@ -74,6 +110,50 @@ def _apply_logit_bias_inplace(
74110 return _apply_logit_bias_inplace
75111
76112
113+ def _get_apply_penalty_inplace_cpu ():
114+ @T .prim_func
115+ def _apply_penalty_inplace ( # pylint: disable=too-many-arguments,too-many-locals
116+ var_logits : T .handle ,
117+ var_seq_ids : T .handle ,
118+ var_pos2seq_id : T .handle ,
119+ var_token_ids : T .handle ,
120+ var_token_cnt : T .handle ,
121+ var_penalties : T .handle ,
122+ ) -> None :
123+ """Function that applies penalties in place."""
124+ T .func_attr (
125+ {
126+ "global_symbol" : "apply_penalty_inplace" ,
127+ "tir.noalias" : True ,
128+ "tir.is_scheduled" : True ,
129+ }
130+ )
131+ batch_size = T .int32 (is_size_var = True )
132+ vocab_size = T .int32 (is_size_var = True )
133+ num_token = T .int32 (is_size_var = True )
134+ num_seq = T .int32 (is_size_var = True )
135+ logits = T .match_buffer (var_logits , (batch_size , vocab_size ), "float32" )
136+ seq_ids = T .match_buffer (var_seq_ids , (num_seq ,), "int32" )
137+ pos2seq_id = T .match_buffer (var_pos2seq_id , (num_token ,), "int32" )
138+ token_ids = T .match_buffer (var_token_ids , (num_token ,), "int32" )
139+ token_cnt = T .match_buffer (var_token_cnt , (num_token ,), "int32" )
140+ penalties = T .match_buffer (var_penalties , (num_seq , 3 ), "float32" )
141+
142+ for token in T .serial (num_token ):
143+ with T .block ("block" ):
144+ vp = T .axis .spatial (num_token , token )
145+ logits [seq_ids [pos2seq_id [vp ]], token_ids [vp ]] -= (
146+ penalties [pos2seq_id [vp ], 0 ] + token_cnt [vp ] * penalties [pos2seq_id [vp ], 1 ]
147+ )
148+ logits [seq_ids [pos2seq_id [vp ]], token_ids [vp ]] = T .if_then_else (
149+ logits [seq_ids [pos2seq_id [vp ]], token_ids [vp ]] < 0 ,
150+ logits [seq_ids [pos2seq_id [vp ]], token_ids [vp ]] * penalties [pos2seq_id [vp ], 2 ],
151+ logits [seq_ids [pos2seq_id [vp ]], token_ids [vp ]] / penalties [pos2seq_id [vp ], 2 ],
152+ )
153+
154+ return _apply_penalty_inplace
155+
156+
77157def _get_apply_penalty_inplace (target : tvm .target .Target ):
78158 tx = 1024 # default
79159 max_num_threads_per_block = get_max_num_threads_per_block (target )
@@ -129,6 +209,42 @@ def _apply_penalty_inplace( # pylint: disable=too-many-arguments,too-many-local
129209 return _apply_penalty_inplace
130210
131211
212+ def _get_apply_bitmask_inplace_cpu ():
213+ @T .prim_func
214+ def _apply_bitmask_inplace (
215+ var_logits : T .handle ,
216+ var_seq_ids : T .handle ,
217+ var_bitmask : T .handle ,
218+ ) -> None :
219+ """Function that applies vocabulary masking in place."""
220+ T .func_attr (
221+ {
222+ "global_symbol" : "apply_bitmask_inplace" ,
223+ "tir.noalias" : True ,
224+ "tir.is_scheduled" : True ,
225+ }
226+ )
227+ batch_size = T .int32 (is_size_var = True )
228+ vocab_size = T .int32 (is_size_var = True )
229+ num_seq = T .int32 (is_size_var = True )
230+ logits = T .match_buffer (var_logits , (batch_size , vocab_size ), "float32" )
231+ seq_ids = T .match_buffer (var_seq_ids , (num_seq ,), "int32" )
232+ bitmask = T .match_buffer (var_bitmask , (batch_size , (vocab_size + 31 ) // 32 ), "int32" )
233+
234+ for token in T .serial (num_seq * vocab_size ):
235+ with T .block ("block" ):
236+ vs = T .axis .spatial (num_seq , (token ) // vocab_size )
237+ vv = T .axis .spatial (vocab_size , (token ) % vocab_size )
238+
239+ logits [seq_ids [vs ], vv ] = T .if_then_else (
240+ (bitmask [seq_ids [vs ], vv // 32 ] >> (vv % 32 )) & 1 == 1 ,
241+ logits [seq_ids [vs ], vv ],
242+ T .min_value ("float32" ),
243+ )
244+
245+ return _apply_bitmask_inplace
246+
247+
132248def _get_apply_bitmask_inplace (target : tvm .target .Target ):
133249 tx = 1024 # default
134250 max_num_threads_per_block = get_max_num_threads_per_block (target )
0 commit comments