Skip to content

Commit 7da48cd

Browse files
rasmithRandall Smithtjtanaa
authored andcommitted
[Bugfix][CI/Test][Spec Decode] Fix illegal memory access in offline_inference/spec_decode.py (Issue 27619) (vllm-project#28432)
Signed-off-by: Randall Smith <ransmith@amd.com> Co-authored-by: Randall Smith <ransmith@amd.com> Co-authored-by: TJian <tunjian.tan@embeddedllm.com> Signed-off-by: George D. Torres <gdavtor@gmail.com>
1 parent 1cec54d commit 7da48cd

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

vllm/attention/ops/triton_reshape_and_cache_flash.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,6 @@ def triton_reshape_and_cache_flash(
9797
k_scale: torch.Tensor, # float32
9898
v_scale: torch.Tensor, # float32
9999
):
100-
num_tokens = key.shape[0]
101100
num_heads = key.shape[1]
102101
head_size = key.shape[2]
103102
block_size = key_cache.shape[1]
@@ -155,7 +154,10 @@ def triton_reshape_and_cache_flash(
155154

156155
# TODO(ngl): maybe replace with static launch grid to avoid overhead if
157156
# using cudagraphs
158-
grid = lambda meta: (int(num_tokens), triton.cdiv(n, meta["TILE_SIZE"]))
157+
grid = lambda meta: (
158+
slot_mapping.shape[0],
159+
triton.cdiv(n, meta["TILE_SIZE"]),
160+
)
159161

160162
reshape_and_cache_kernel_flash[grid](
161163
key_ptr=key,

0 commit comments

Comments
 (0)