File tree Expand file tree Collapse file tree 2 files changed +2
-2
lines changed Expand file tree Collapse file tree 2 files changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -342,14 +342,14 @@ def attention(
342342 q : jax .Array ,
343343 k : jax .Array ,
344344 v : jax .Array ,
345- sinks : jax .Array | None ,
346345 attention_metadata : AttentionMetadata ,
347346 mesh : Mesh ,
348347 head_dim_original : int | None = None , # before padding,
349348 attention_chunk_size : int | None = None ,
350349 q_scale : float | None = None ,
351350 k_scale : float | None = None ,
352351 v_scale : float | None = None ,
352+ sinks : jax .Array | None = None ,
353353) -> Tuple [jax .Array , jax .Array ]:
354354 # T: seq_len
355355 # N: num_heads
Original file line number Diff line number Diff line change @@ -178,12 +178,12 @@ def _jax_attn_func(
178178 q ,
179179 k ,
180180 v ,
181- sinks ,
182181 attention_metadata ,
183182 mesh ,
184183 q_scale = q_scale ,
185184 k_scale = k_scale ,
186185 v_scale = v_scale ,
186+ sinks = sinks ,
187187 )
188188
189189 # Convert the shape back to vLLM's convention
You can’t perform that action at this time.
0 commit comments