@@ -137,38 +137,44 @@ def forward(
137137 k : Tensor ,
138138 positions : Tensor ,
139139 ):
140- def _rope (x : te .Tensor , positions : te .Tensor ):
140+ def _rope_fused (x : te .Tensor , positions : te .Tensor ):
141+ _ , _ , _ , d_dim = x .shape
142+ d_dim_half = d_dim // 2
141143 dtype = x .dtype
142144
143145 def compute (b : tir .Var , s : tir .Var , h : tir .Var , d : tir .Var ):
146+ d1 = d // d_dim_half
147+ d2 = d % d_dim_half
148+
144149 cos_freq , sin_freq , var_map = self .rope_fn (
145150 positions [s ], d , self .rotary_dim , self .theta , dtype
146151 )
147- cos = cos_freq * x [b , s , h , d ]
148- sin = sin_freq * tir .if_then_else (
152+ cos = x [b , s , h , d2 * 2 + d1 ] * cos_freq
153+
154+ partner_d = tir .if_then_else (
149155 d < self .rotary_dim // 2 ,
150- - x [b , s , h , d + self .rotary_dim // 2 ],
151- x [b , s , h , d - self .rotary_dim // 2 ],
156+ d + self .rotary_dim // 2 ,
157+ d - self .rotary_dim // 2 ,
158+ )
159+
160+ partner_d1 = partner_d // d_dim_half
161+ partner_d2 = partner_d % d_dim_half
162+ sin = (
163+ x [b , s , h , partner_d2 * 2 + partner_d1 ]
164+ * sin_freq
165+ * tir .if_then_else (
166+ d < self .rotary_dim // 2 , tir .const (- 1 , dtype ), tir .const (1 , dtype )
167+ )
152168 )
153169 expr = cos + sin
154- for var , value in var_map .items ():
155- expr = tir .Let (var , value , expr )
170+ for var , val in var_map .items ():
171+ expr = tir .Let (var , val , expr )
156172 return expr
157173
158174 return te .compute (x .shape , compute , name = "yarn_rope" )
159175
160- b , s , h , d = q .shape
161- q = op .reshape (
162- op .permute_dims (op .reshape (q , (b , s , h , d // 2 , 2 )), [0 , 1 , 2 , 4 , 3 ]), (b , s , h , d )
163- )
164-
165- b , s , h , d = k .shape
166- k = op .reshape (
167- op .permute_dims (op .reshape (k , (b , s , h , d // 2 , 2 )), [0 , 1 , 2 , 4 , 3 ]), (b , s , h , d )
168- )
169-
170- q_embed = op .tensor_expr_op (_rope , "rope" , [q , positions ])
171- k_embed = op .tensor_expr_op (_rope , "rope" , [k , positions ])
176+ q_embed = op .tensor_expr_op (_rope_fused , "rope" , [q , positions ])
177+ k_embed = op .tensor_expr_op (_rope_fused , "rope" , [k , positions ])
172178 return q_embed , k_embed
173179
174180
0 commit comments