2222NUM_HEADS = 8
2323# Number of attention heads (Key/Value) - for Grouped-Query Attention
2424NUM_KV_HEADS = 4
25- # Dimension of each attention head
26- HEAD_DIM = 64
27- # Padded head dimension
28- PADDED_HEAD_DIM = 64
2925# Total number of blocks in the KV cache
3026NUM_BLOCKS = 32
3127# Number of tokens per block
@@ -49,7 +45,7 @@ def mesh():
4945# ---- Test for `attention` ----
5046
5147
52- def test_attention (monkeypatch , mesh ):
48+ def _test_attention (monkeypatch , mesh , head_dim , use_sinks = False ):
5349 """
5450 Tests the main `attention` function.
5551
@@ -62,23 +58,35 @@ def test_attention(monkeypatch, mesh):
6258 # Create input tensors
6359 q_dtype = jnp .float32
6460 kv_dtype = jnp .float32
65- q = jnp .ones ((TOTAL_TOKENS , NUM_HEADS , PADDED_HEAD_DIM ), dtype = q_dtype )
66- k = jnp .ones ((TOTAL_TOKENS , NUM_KV_HEADS , PADDED_HEAD_DIM ), dtype = kv_dtype )
67- v = jnp .ones ((TOTAL_TOKENS , NUM_KV_HEADS , PADDED_HEAD_DIM ), dtype = kv_dtype )
68-
69- kv_cache_shape = get_kv_cache_shape_with_mesh (mesh , NUM_BLOCKS , BLOCK_SIZE ,
70- NUM_KV_HEADS , HEAD_DIM ,
71- kv_dtype )
61+ q = jnp .ones ((TOTAL_TOKENS , NUM_HEADS , head_dim ), dtype = q_dtype )
62+ k = jnp .ones ((TOTAL_TOKENS , NUM_KV_HEADS , head_dim ), dtype = kv_dtype )
63+ v = jnp .ones ((TOTAL_TOKENS , NUM_KV_HEADS , head_dim ), dtype = kv_dtype )
64+ sinks = jnp .ones ((NUM_HEADS , ), dtype = jnp .float32 ) if use_sinks else None
65+
66+ kv_cache_shape = get_kv_cache_shape_with_mesh (
67+ mesh ,
68+ NUM_BLOCKS ,
69+ BLOCK_SIZE ,
70+ NUM_KV_HEADS ,
71+ head_dim ,
72+ kv_dtype ,
73+ )
7274 kv_cache = jnp .zeros (kv_cache_shape , dtype = kv_dtype )
7375
7476 # Mock ragged_paged_attention to return a tensor of the correct shape
75- mock_paged_attn_kernel = MagicMock (
76- return_value = (jnp .ones ((TOTAL_TOKENS , NUM_HEADS , PADDED_HEAD_DIM )),
77- kv_cache ))
78- monkeypatch .setattr (
79- "tpu_inference.layers.jax.attention_interface.ragged_paged_attention" ,
80- mock_paged_attn_kernel ,
81- )
77+ mock_paged_attn_kernel = MagicMock (return_value = (jnp .ones (
78+ (TOTAL_TOKENS , NUM_HEADS , head_dim )), kv_cache ), )
79+
80+ if head_dim == 64 :
81+ monkeypatch .setattr (
82+ "tpu_inference.layers.jax.attention_interface.ragged_paged_attention_hd64" ,
83+ mock_paged_attn_kernel ,
84+ )
85+ else :
86+ monkeypatch .setattr (
87+ "tpu_inference.layers.jax.attention_interface.ragged_paged_attention" ,
88+ mock_paged_attn_kernel ,
89+ )
8290
8391 # Create AttentionMetadata
8492 attention_metadata = AttentionMetadata (
@@ -98,7 +106,8 @@ def test_attention(monkeypatch, mesh):
98106 v = v ,
99107 attention_metadata = attention_metadata ,
100108 mesh = mesh ,
101- head_dim_original = HEAD_DIM ,
109+ head_dim_original = head_dim ,
110+ sinks = sinks ,
102111 )
103112
104113 # 3. Assert
@@ -111,3 +120,23 @@ def test_attention(monkeypatch, mesh):
111120
112121 # Check that the output is the one from our mock
113122 assert jnp .all (output == 1.0 )
123+
124+
125+ def test_attention (monkeypatch , mesh ):
126+ _test_attention (monkeypatch , mesh , 128 )
127+
128+
129+ def test_attention_hd64 (monkeypatch , mesh ):
130+ _test_attention (monkeypatch , mesh , 64 )
131+
132+
133+ def test_attention_sink (monkeypatch , mesh ):
134+ _test_attention (monkeypatch , mesh , 64 , True )
135+
136+
137+ def test_attention_sink_no_64_raises_error (monkeypatch , mesh ):
138+ with pytest .raises (
139+ NotImplementedError ,
140+ match = "Attention sink support is only available when head_dim==64"
141+ ):
142+ _test_attention (monkeypatch , mesh , 128 , True )
0 commit comments