Skip to content

Commit bf8aa2e

Browse files
authored
[Bugfix] Fix sinks related argument error (#1044)
Signed-off-by: Kyuyeun Kim <kyuyeunk@google.com>
1 parent b697484 commit bf8aa2e

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

tpu_inference/layers/jax/attention_interface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,14 +342,14 @@ def attention(
342342
q: jax.Array,
343343
k: jax.Array,
344344
v: jax.Array,
345-
sinks: jax.Array | None,
346345
attention_metadata: AttentionMetadata,
347346
mesh: Mesh,
348347
head_dim_original: int | None = None, # before padding,
349348
attention_chunk_size: int | None = None,
350349
q_scale: float | None = None,
351350
k_scale: float | None = None,
352351
v_scale: float | None = None,
352+
sinks: jax.Array | None = None,
353353
) -> Tuple[jax.Array, jax.Array]:
354354
# T: seq_len
355355
# N: num_heads

tpu_inference/layers/vllm/attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,12 +178,12 @@ def _jax_attn_func(
178178
q,
179179
k,
180180
v,
181-
sinks,
182181
attention_metadata,
183182
mesh,
184183
q_scale=q_scale,
185184
k_scale=k_scale,
186185
v_scale=v_scale,
186+
sinks=sinks,
187187
)
188188

189189
# Convert the shape back to vLLM's convention

0 commit comments

Comments
 (0)