Skip to content

Commit 88f2071

Browse files
authored
[KV] Change padding logic when head_dim is 64 (#1064)
Signed-off-by: Kyuyeun Kim <kyuyeunk@google.com>
1 parent 0c66fde commit 88f2071

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

tests/runner/test_kv_cache_manager.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import jax
44
import jax.numpy as jnp
55
import numpy as np
6+
import pytest
67
import torch
78
from vllm.attention import Attention
89
from vllm.attention.backends.abstract import AttentionType
@@ -200,14 +201,15 @@ def test_insert_request_with_kv_cache(self):
200201
np.testing.assert_array_equal(updated_block_content,
201202
expected_padded_slice)
202203

203-
def test_get_kv_cache_spec_with_compilation_cfg(self):
204+
@pytest.mark.parametrize("num_kv_heads", [16, 32])
205+
@pytest.mark.parametrize("head_size", [64, 100, 200])
206+
def test_get_kv_cache_spec_with_compilation_cfg(self, num_kv_heads,
207+
head_size):
204208
# tests we create kv cache spec from compilation config
205209
# create a static forward context with
206210
# 10 full attention layers +
207211
# 10 sliding window attention layers
208212
# 1 layer with shared kv cache.
209-
num_kv_heads = 16
210-
head_size = 128
211213
attn_type = AttentionType.DECODER
212214
sliding_window = 10
213215
static_forward_context = {}

tpu_inference/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,10 @@ def hbm_usage_gb(devices: Any) -> List[Tuple[float, float]]:
150150

151151
def get_padded_head_dim(head_dim: int) -> int:
152152
"""Pads head_dim up to the nearest multiple of 128 for kernel performance."""
153+
# When head_dim == 64, we use kernel specificly optimized for it which does
154+
# not require any padding.
155+
if head_dim == 64:
156+
return 64
153157
return (head_dim + 127) // 128 * 128
154158

155159

0 commit comments

Comments
 (0)