Skip to content

Commit e05d4b1

Browse files
authored
[Bugfix] Fix attention interface unit test (#1048)
Signed-off-by: Kyuyeun Kim <kyuyeunk@google.com>
1 parent bf8aa2e commit e05d4b1

File tree

1 file changed

+49
-20
lines changed

1 file changed

+49
-20
lines changed

tests/models/jax/test_attention_interface.py

Lines changed: 49 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,6 @@
2222
NUM_HEADS = 8
2323
# Number of attention heads (Key/Value) - for Grouped-Query Attention
2424
NUM_KV_HEADS = 4
25-
# Dimension of each attention head
26-
HEAD_DIM = 64
27-
# Padded head dimension
28-
PADDED_HEAD_DIM = 64
2925
# Total number of blocks in the KV cache
3026
NUM_BLOCKS = 32
3127
# Number of tokens per block
@@ -49,7 +45,7 @@ def mesh():
4945
# ---- Test for `attention` ----
5046

5147

52-
def test_attention(monkeypatch, mesh):
48+
def _test_attention(monkeypatch, mesh, head_dim, use_sinks=False):
5349
"""
5450
Tests the main `attention` function.
5551
@@ -62,23 +58,35 @@ def test_attention(monkeypatch, mesh):
6258
# Create input tensors
6359
q_dtype = jnp.float32
6460
kv_dtype = jnp.float32
65-
q = jnp.ones((TOTAL_TOKENS, NUM_HEADS, PADDED_HEAD_DIM), dtype=q_dtype)
66-
k = jnp.ones((TOTAL_TOKENS, NUM_KV_HEADS, PADDED_HEAD_DIM), dtype=kv_dtype)
67-
v = jnp.ones((TOTAL_TOKENS, NUM_KV_HEADS, PADDED_HEAD_DIM), dtype=kv_dtype)
68-
69-
kv_cache_shape = get_kv_cache_shape_with_mesh(mesh, NUM_BLOCKS, BLOCK_SIZE,
70-
NUM_KV_HEADS, HEAD_DIM,
71-
kv_dtype)
61+
q = jnp.ones((TOTAL_TOKENS, NUM_HEADS, head_dim), dtype=q_dtype)
62+
k = jnp.ones((TOTAL_TOKENS, NUM_KV_HEADS, head_dim), dtype=kv_dtype)
63+
v = jnp.ones((TOTAL_TOKENS, NUM_KV_HEADS, head_dim), dtype=kv_dtype)
64+
sinks = jnp.ones((NUM_HEADS, ), dtype=jnp.float32) if use_sinks else None
65+
66+
kv_cache_shape = get_kv_cache_shape_with_mesh(
67+
mesh,
68+
NUM_BLOCKS,
69+
BLOCK_SIZE,
70+
NUM_KV_HEADS,
71+
head_dim,
72+
kv_dtype,
73+
)
7274
kv_cache = jnp.zeros(kv_cache_shape, dtype=kv_dtype)
7375

7476
# Mock ragged_paged_attention to return a tensor of the correct shape
75-
mock_paged_attn_kernel = MagicMock(
76-
return_value=(jnp.ones((TOTAL_TOKENS, NUM_HEADS, PADDED_HEAD_DIM)),
77-
kv_cache))
78-
monkeypatch.setattr(
79-
"tpu_inference.layers.jax.attention_interface.ragged_paged_attention",
80-
mock_paged_attn_kernel,
81-
)
77+
mock_paged_attn_kernel = MagicMock(return_value=(jnp.ones(
78+
(TOTAL_TOKENS, NUM_HEADS, head_dim)), kv_cache), )
79+
80+
if head_dim == 64:
81+
monkeypatch.setattr(
82+
"tpu_inference.layers.jax.attention_interface.ragged_paged_attention_hd64",
83+
mock_paged_attn_kernel,
84+
)
85+
else:
86+
monkeypatch.setattr(
87+
"tpu_inference.layers.jax.attention_interface.ragged_paged_attention",
88+
mock_paged_attn_kernel,
89+
)
8290

8391
# Create AttentionMetadata
8492
attention_metadata = AttentionMetadata(
@@ -98,7 +106,8 @@ def test_attention(monkeypatch, mesh):
98106
v=v,
99107
attention_metadata=attention_metadata,
100108
mesh=mesh,
101-
head_dim_original=HEAD_DIM,
109+
head_dim_original=head_dim,
110+
sinks=sinks,
102111
)
103112

104113
# 3. Assert
@@ -111,3 +120,23 @@ def test_attention(monkeypatch, mesh):
111120

112121
# Check that the output is the one from our mock
113122
assert jnp.all(output == 1.0)
123+
124+
125+
def test_attention(monkeypatch, mesh):
126+
_test_attention(monkeypatch, mesh, 128)
127+
128+
129+
def test_attention_hd64(monkeypatch, mesh):
130+
_test_attention(monkeypatch, mesh, 64)
131+
132+
133+
def test_attention_sink(monkeypatch, mesh):
134+
_test_attention(monkeypatch, mesh, 64, True)
135+
136+
137+
def test_attention_sink_no_64_raises_error(monkeypatch, mesh):
138+
with pytest.raises(
139+
NotImplementedError,
140+
match="Attention sink support is only available when head_dim==64"
141+
):
142+
_test_attention(monkeypatch, mesh, 128, True)

0 commit comments

Comments
 (0)