Skip to content

Commit 4a8d6bd

Browse files
authored
Fix cu_num_generated_tokens slicing logic in LogprobsLists.slice() method (#28214)
Signed-off-by: Bradley <bradley.b.pitt@gmail.com>
1 parent 636efd1 commit 4a8d6bd

File tree

2 files changed

+111
-3
lines changed

2 files changed

+111
-3
lines changed

tests/v1/test_outputs.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from unittest import TestCase
4+
5+
from vllm.v1.outputs import LogprobsLists
6+
7+
8+
class TestLogprobsLists(TestCase):
9+
def setUp(self):
10+
self.logprobsLists = LogprobsLists(
11+
logprob_token_ids=[
12+
[1, 2], # Request 0 token 0
13+
[3, 4], # Request 0 token 1
14+
[5, 6], # Request 1 token 0
15+
[7, 8], # Request 1 token 1
16+
[9, 10], # Request 1 token 2
17+
[11, 12], # Request 2 token 0
18+
[13, 14], # Request 2 token 1
19+
[15, 16], # Request 2 token 2
20+
[17, 18], # Request 2 token 3
21+
],
22+
logprobs=[
23+
[0.1, 0.2],
24+
[0.3, 0.4],
25+
[0.5, 0.6],
26+
[0.7, 0.8],
27+
[0.9, 1.0],
28+
[1.1, 1.2],
29+
[1.3, 1.4],
30+
[1.5, 1.6],
31+
[1.7, 1.8],
32+
],
33+
sampled_token_ranks=[1, 3, 5, 7, 9, 11, 13, 15, 17],
34+
cu_num_generated_tokens=[0, 2, 5, 9],
35+
)
36+
37+
def test_slice_without_cu_num_generated_tokens(self):
38+
"""Test slicing without cu_num_generated_tokens"""
39+
logprobsLists = LogprobsLists(
40+
logprob_token_ids=[[1], [2], [3]],
41+
logprobs=[[0.1], [0.2], [0.3]],
42+
sampled_token_ranks=[1, 2, 3],
43+
cu_num_generated_tokens=None,
44+
)
45+
46+
sliced = logprobsLists.slice(1, 3)
47+
assert sliced.logprob_token_ids == [[2], [3]]
48+
assert sliced.logprobs == [[0.2], [0.3]]
49+
assert sliced.sampled_token_ranks == [2, 3]
50+
assert sliced.cu_num_generated_tokens is None
51+
52+
def test_slice_from_start(self):
53+
"""Test slicing from the start position"""
54+
sliced = self.logprobsLists.slice(0, 2)
55+
assert len(sliced.logprob_token_ids) == 5
56+
assert sliced.logprob_token_ids == [
57+
[1, 2],
58+
[3, 4],
59+
[5, 6],
60+
[7, 8],
61+
[9, 10],
62+
]
63+
assert sliced.cu_num_generated_tokens == [0, 2, 5]
64+
65+
def test_slice_from_middle(self):
66+
"""Test slicing from the middle position"""
67+
sliced = self.logprobsLists.slice(1, 3)
68+
assert len(sliced.logprob_token_ids) == 7
69+
assert sliced.logprob_token_ids == [
70+
[5, 6],
71+
[7, 8],
72+
[9, 10],
73+
[11, 12],
74+
[13, 14],
75+
[15, 16],
76+
[17, 18],
77+
]
78+
assert sliced.cu_num_generated_tokens == [0, 3, 7]
79+
80+
def test_slice_single_request(self):
81+
"""Test slicing a single request"""
82+
sliced = self.logprobsLists.slice(1, 2)
83+
assert len(sliced.logprob_token_ids) == 3
84+
assert sliced.logprob_token_ids == [[5, 6], [7, 8], [9, 10]]
85+
assert sliced.cu_num_generated_tokens == [0, 3]
86+
87+
def test_slice_last_request(self):
88+
"""Test slicing the last request"""
89+
sliced = self.logprobsLists.slice(2, 3)
90+
assert len(sliced.logprob_token_ids) == 4
91+
assert sliced.logprob_token_ids == [[11, 12], [13, 14], [15, 16], [17, 18]]
92+
assert sliced.cu_num_generated_tokens == [0, 4]
93+
94+
def test_slice_all_requests(self):
95+
"""Test slicing all requests (full slice)"""
96+
sliced = self.logprobsLists.slice(0, 3)
97+
assert len(sliced.logprob_token_ids) == 9 # All tokens
98+
assert sliced.logprob_token_ids == self.logprobsLists.logprob_token_ids
99+
assert (
100+
sliced.cu_num_generated_tokens == self.logprobsLists.cu_num_generated_tokens
101+
)

vllm/v1/outputs.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,23 @@ def slice(self, start_req_idx: int, end_req_idx: int):
3030
if self.cu_num_generated_tokens:
3131
start = self.cu_num_generated_tokens[start_req_idx]
3232
end = self.cu_num_generated_tokens[end_req_idx]
33+
# Recompute cumulative array starting from 0
34+
cu_num_offset = self.cu_num_generated_tokens[start_req_idx]
35+
sliced_cu_num_generated_tokens = [
36+
cu_num - cu_num_offset
37+
for cu_num in self.cu_num_generated_tokens[
38+
start_req_idx : end_req_idx + 1
39+
]
40+
]
3341
else:
3442
start = start_req_idx
3543
end = end_req_idx
44+
sliced_cu_num_generated_tokens = None
3645
return LogprobsLists(
3746
self.logprob_token_ids[start:end],
3847
self.logprobs[start:end],
3948
self.sampled_token_ranks[start:end],
40-
self.cu_num_generated_tokens[start_req_idx:end_req_idx]
41-
if self.cu_num_generated_tokens
42-
else None,
49+
sliced_cu_num_generated_tokens,
4350
)
4451

4552

0 commit comments

Comments
 (0)