Skip to content

Commit 62763b5

Browse files
authored
[Torchax] Add attention sink support in torchax (#1038)
Signed-off-by: Kyuyeun Kim <kyuyeunk@google.com>
1 parent 330cb1b commit 62763b5

File tree

3 files changed

+128
-47
lines changed

3 files changed

+128
-47
lines changed

tests/layers/vllm/test_attention.py

Lines changed: 49 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -39,31 +39,42 @@
3939
MAX_BLOCKS_PER_SEQ = 8
4040

4141

42-
def create_inputs(mesh: Mesh,
43-
q_dtype: jnp.dtype = jnp.bfloat16,
44-
kv_dtype: jnp.dtype = jnp.bfloat16):
42+
def create_inputs(
43+
mesh: Mesh,
44+
q_dtype: jnp.dtype = jnp.bfloat16,
45+
kv_dtype: jnp.dtype = jnp.bfloat16,
46+
total_tokens: int = TOTAL_TOKENS,
47+
num_seqs: int = NUM_SEQS,
48+
max_num_seqs: int = MAX_NUM_SEQS,
49+
num_heads: int = NUM_HEADS,
50+
num_kv_heads: int = NUM_KV_HEADS,
51+
head_dim: int = HEAD_DIM,
52+
num_blocks: int = NUM_BLOCKS,
53+
block_size: int = BLOCK_SIZE,
54+
max_blocks_per_seq: int = MAX_BLOCKS_PER_SEQ,
55+
):
4556
key = jax.random.key(0)
46-
q = jax.random.uniform(key, (TOTAL_TOKENS, NUM_HEADS * HEAD_DIM),
57+
q = jax.random.uniform(key, (total_tokens, num_heads * head_dim),
4758
dtype=q_dtype)
48-
k = jax.random.uniform(key, (TOTAL_TOKENS, NUM_KV_HEADS * HEAD_DIM),
59+
k = jax.random.uniform(key, (total_tokens, num_kv_heads * head_dim),
4960
dtype=q_dtype)
50-
v = jax.random.uniform(key, (TOTAL_TOKENS, NUM_KV_HEADS * HEAD_DIM),
61+
v = jax.random.uniform(key, (total_tokens, num_kv_heads * head_dim),
5162
dtype=q_dtype)
5263
q = torch_view(q)
5364
k = torch_view(k)
5465
v = torch_view(v)
5566

56-
kv_cache_shape = get_kv_cache_shape_with_mesh(mesh, NUM_BLOCKS, BLOCK_SIZE,
57-
NUM_KV_HEADS, HEAD_DIM,
67+
kv_cache_shape = get_kv_cache_shape_with_mesh(mesh, num_blocks, block_size,
68+
num_kv_heads, head_dim,
5869
kv_dtype)
5970
kv_cache = jax.random.normal(key, kv_cache_shape, dtype=kv_dtype)
6071

61-
positions = jnp.ones((TOTAL_TOKENS, ), dtype=jnp.int32)
62-
block_tables = jnp.zeros((MAX_NUM_SEQS * MAX_BLOCKS_PER_SEQ),
72+
positions = jnp.ones((total_tokens, ), dtype=jnp.int32)
73+
block_tables = jnp.zeros((max_num_seqs * max_blocks_per_seq),
6374
dtype=jnp.int32).reshape(-1)
6475
seq_lens = jnp.array([5, 5, 0, 0], dtype=jnp.int32)
6576
query_start_loc = jnp.array([0, 5, 10, 10, 10], dtype=jnp.int32)
66-
request_distribution = jnp.array([0, 0, NUM_SEQS], dtype=jnp.int32)
77+
request_distribution = jnp.array([0, 0, num_seqs], dtype=jnp.int32)
6778

6879
metadata = AttentionMetadata(
6980
input_positions=positions,
@@ -276,3 +287,30 @@ def test_forward_with_output_scale_raises_error(self, mesh):
276287
torch.tensor([]),
277288
metadata,
278289
output_scale=output_scale)
290+
291+
def test_forward_with_attention_sink(self, mesh):
292+
head_dim = 64
293+
sinks = torch.rand([NUM_HEADS], dtype=torch.float32)
294+
295+
impl = PallasAttentionBackendImpl(num_heads=NUM_HEADS,
296+
head_size=head_dim,
297+
scale=0.088,
298+
num_kv_heads=NUM_KV_HEADS,
299+
alibi_slopes=None,
300+
sliding_window=None,
301+
kv_cache_dtype="auto",
302+
attn_type=AttentionType.DECODER,
303+
sinks=sinks)
304+
305+
layer = MagicMock()
306+
layer.layer_name = "0"
307+
308+
query, key, value, kv_cache, metadata = create_inputs(
309+
mesh, head_dim=head_dim)
310+
311+
with torchax.default_env(), set_vllm_model_wrapper_context(
312+
kv_caches=[kv_cache],
313+
mesh=mesh,
314+
layer_name_to_kvcache_index={'0': 0}):
315+
assert impl.sinks is not None
316+
impl.forward(layer, query, key, value, torch.tensor([]), metadata)

tpu_inference/layers/jax/attention_interface.py

Lines changed: 55 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from jax.sharding import PartitionSpec as P
1515

1616
import tpu_inference.kernels.ragged_paged_attention.v3.kernel as rpa
17+
import tpu_inference.kernels.ragged_paged_attention.v3.kernel_hd64 as rpa_hd64
1718
from tpu_inference.kernels.flash_attention.kernel import flash_attention
1819
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
1920
from tpu_inference.layers.jax.sharding import ShardingAxisName
@@ -26,6 +27,9 @@
2627
ragged_paged_attention = rpa.ragged_paged_attention
2728
get_kv_cache_shape = rpa.get_kv_cache_shape
2829

30+
ragged_paged_attention_hd64 = rpa_hd64.ragged_paged_attention_hd64
31+
get_kv_cache_shape_hd64 = rpa_hd64.get_kv_cache_shape
32+
2933

3034
def sharded_flash_attention(
3135
mesh: Mesh,
@@ -268,17 +272,27 @@ def sharded_splash_attention(
268272

269273

270274
def sharded_ragged_paged_attention(
271-
sm_scale: float,
272275
mesh: Mesh,
276+
q: jax.Array,
277+
k: jax.Array,
278+
v: jax.Array,
279+
kv_cache: jax.Array,
280+
kv_lens: jax.Array,
281+
page_indices: jax.Array,
282+
cu_q_lens: jax.Array,
283+
distribution: jax.Array,
284+
attention_sink: jax.Array | None,
285+
sm_scale: float,
273286
attention_chunk_size: int | None = None,
274287
q_scale: float | None = None,
275288
k_scale: float | None = None,
276289
v_scale: float | None = None,
277290
):
278291
"""Shards along KV heads."""
279292

280-
qkv_spec = P(ShardingAxisName.ATTN_DATA, "model", None)
281-
kv_cache_spec = P(ShardingAxisName.ATTN_DATA, None, "model")
293+
qkv_spec = P(ShardingAxisName.ATTN_DATA, ShardingAxisName.ATTN_HEAD, None)
294+
kv_cache_spec = P(ShardingAxisName.ATTN_DATA, None,
295+
ShardingAxisName.ATTN_HEAD, None, None)
282296
in_specs = (
283297
qkv_spec, # q
284298
qkv_spec, # k
@@ -291,8 +305,21 @@ def sharded_ragged_paged_attention(
291305
)
292306
out_specs = (qkv_spec, kv_cache_spec)
293307

308+
args = (q, k, v, kv_cache, kv_lens, page_indices, cu_q_lens, distribution)
309+
310+
use_hd64 = q.shape[-1] == 64
311+
func = ragged_paged_attention_hd64 if use_hd64 else ragged_paged_attention
312+
313+
if attention_sink is not None:
314+
if not use_hd64:
315+
raise NotImplementedError(
316+
"Attention sink support is only available when head_dim==64")
317+
318+
in_specs += (P(ShardingAxisName.ATTN_HEAD), )
319+
args += (attention_sink, )
320+
294321
def _ragged_paged_attention(*args):
295-
return ragged_paged_attention(
322+
return func(
296323
*args,
297324
sm_scale=sm_scale,
298325
sliding_window=attention_chunk_size,
@@ -301,21 +328,21 @@ def _ragged_paged_attention(*args):
301328
v_scale=v_scale,
302329
)
303330

304-
return jax.jit(
305-
shard_map.shard_map(
306-
_ragged_paged_attention,
307-
mesh=mesh,
308-
in_specs=in_specs,
309-
out_specs=out_specs,
310-
check_rep=False,
311-
))
331+
return shard_map.shard_map(
332+
_ragged_paged_attention,
333+
mesh=mesh,
334+
in_specs=in_specs,
335+
out_specs=out_specs,
336+
check_rep=False,
337+
)(*args)
312338

313339

314340
def attention(
315341
kv_cache: jax.Array,
316342
q: jax.Array,
317343
k: jax.Array,
318344
v: jax.Array,
345+
sinks: jax.Array | None,
319346
attention_metadata: AttentionMetadata,
320347
mesh: Mesh,
321348
head_dim_original: int | None = None, # before padding,
@@ -343,16 +370,21 @@ def attention(
343370

344371
# (T, N, H)
345372
output, kv_cache = sharded_ragged_paged_attention(
346-
head_dim_original**-0.5, mesh, attention_chunk_size, q_scale, k_scale,
347-
v_scale)(
348-
q,
349-
k,
350-
v,
351-
kv_cache,
352-
md.seq_lens,
353-
md.block_tables,
354-
md.query_start_loc,
355-
md.request_distribution,
356-
)
373+
mesh,
374+
q,
375+
k,
376+
v,
377+
kv_cache,
378+
md.seq_lens,
379+
md.block_tables,
380+
md.query_start_loc,
381+
md.request_distribution,
382+
sinks,
383+
sm_scale=head_dim_original**-0.5,
384+
attention_chunk_size=attention_chunk_size,
385+
q_scale=q_scale,
386+
k_scale=k_scale,
387+
v_scale=v_scale,
388+
)
357389

358390
return kv_cache, output

tpu_inference/layers/vllm/attention.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
from typing import Optional, Tuple
55

66
import jax
7+
import jax.numpy as jnp
78
import torch
89
from jax.sharding import Mesh
910
from torchax.interop import jax_view, torch_view
11+
from torchax.ops.mappings import t2j
1012
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
1113
AttentionLayer, AttentionType)
1214

@@ -39,18 +41,14 @@ def __init__(
3941
head_size: int,
4042
scale: float,
4143
num_kv_heads: int,
42-
alibi_slopes: Optional[list[float]],
43-
sliding_window: Optional[int],
44+
alibi_slopes: list[float] | None,
45+
sliding_window: int | None,
4446
kv_cache_dtype: str,
45-
logits_soft_cap: Optional[float] = None,
46-
attn_type: str = AttentionType.DECODER,
47-
kv_sharing_target_layer_name: Optional[int] = None,
48-
use_irope: bool = False,
47+
logits_soft_cap: float | None = None,
48+
attn_type: AttentionType = AttentionType.DECODER,
49+
kv_sharing_target_layer_name: str | None = None,
50+
sinks: torch.Tensor | None = None,
4951
) -> None:
50-
if use_irope:
51-
logger.warning_once(
52-
"Using irope in Pallas is not supported yet, it will fall back "
53-
"to global attention for long context.")
5452
self.num_heads = num_heads
5553
self.head_size = head_size
5654
self.scale = float(scale)
@@ -73,6 +71,14 @@ def __init__(
7371
"are not implemented for "
7472
"PallasAttentionBackendImpl")
7573

74+
#TODO (kyuyeunk): Shard the sinks along head axis.
75+
self.sinks = sinks
76+
if self.sinks is not None:
77+
self.sinks = t2j(self.sinks, use_dlpack=False).astype(jnp.float32)
78+
assert self.sinks.shape[0] == num_heads, (
79+
"Sinks must have the same number of heads as the number of "
80+
"heads in the layer")
81+
7682
def forward(
7783
self,
7884
layer: AttentionLayer,
@@ -115,9 +121,12 @@ def forward(
115121
k_scale = layer._k_scale_float
116122
v_scale = layer._v_scale_float
117123

124+
sinks = None if self.sinks is None else jax_view(self.sinks)
125+
118126
new_kv_cache, outputs = _jax_attn_func(kv_cache, query, key, value,
119-
attn_metadata, mesh, self.scale,
120-
self.head_size, self.num_heads,
127+
sinks, attn_metadata, mesh,
128+
self.scale, self.head_size,
129+
self.num_heads,
121130
self.num_kv_heads, q_scale,
122131
k_scale, v_scale)
123132
vllm_model_wrapper_context.kv_caches[kv_cache_index] = new_kv_cache
@@ -128,7 +137,7 @@ def forward(
128137
@functools.partial(
129138
jax.jit,
130139
static_argnums=(
131-
5, 6, 7, 8, 9, 10, 11, 12
140+
6, 7, 8, 9, 10, 11, 12, 13
132141
), # mesh, scale, head_size, num_heads, num_kv_heads, q_scale, k_scale, v_scale
133142
donate_argnums=(0, ), # donate kv_cache
134143
)
@@ -137,6 +146,7 @@ def _jax_attn_func(
137146
q: jax.Array,
138147
k: jax.Array,
139148
v: jax.Array,
149+
sinks: jax.Array | None,
140150
attention_metadata: AttentionMetadata,
141151
mesh: Mesh,
142152
scale: float,
@@ -168,6 +178,7 @@ def _jax_attn_func(
168178
q,
169179
k,
170180
v,
181+
sinks,
171182
attention_metadata,
172183
mesh,
173184
q_scale=q_scale,

0 commit comments

Comments
 (0)