Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions tests/samplers/test_logprobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest

from vllm import SamplingParams
from vllm.logprobs import FlattenLogprobs
from vllm.logprobs import FlatLogprobs

MODELS = ["distilbert/distilgpt2"]
MAX_TOKENS = 5
Expand Down Expand Up @@ -44,12 +44,8 @@ def test_ranks(
decode_tokens, _, decode_logprobs, prompt_logprobs = result

# Ensure the return type of logprobs is accurate
assert isinstance(
prompt_logprobs, FlattenLogprobs if flatten_logprobs else list
)
assert isinstance(
decode_logprobs, FlattenLogprobs if flatten_logprobs else list
)
assert isinstance(prompt_logprobs, FlatLogprobs if flatten_logprobs else list)
assert isinstance(decode_logprobs, FlatLogprobs if flatten_logprobs else list)

########################
# Check prompt logprobs
Expand Down
18 changes: 9 additions & 9 deletions tests/test_logprobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest

from vllm.logprobs import (
FlattenLogprobs,
FlatLogprobs,
Logprob,
LogprobsOnePosition,
append_logprobs_for_next_position,
Expand All @@ -32,7 +32,7 @@ def test_create_logprobs_flatten(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "1")

prompt_logprobs = create_prompt_logprobs()
assert isinstance(prompt_logprobs, FlattenLogprobs)
assert isinstance(prompt_logprobs, FlatLogprobs)
assert prompt_logprobs.start_indices == [0]
assert prompt_logprobs.end_indices == [0]
assert len(prompt_logprobs.token_ids) == 0
Expand All @@ -44,7 +44,7 @@ def test_create_logprobs_flatten(monkeypatch: pytest.MonkeyPatch) -> None:
assert prompt_logprobs[0] == dict()

sample_logprobs = create_sample_logprobs()
assert isinstance(sample_logprobs, FlattenLogprobs)
assert isinstance(sample_logprobs, FlatLogprobs)
assert len(sample_logprobs.start_indices) == 0
assert len(sample_logprobs.end_indices) == 0
assert len(sample_logprobs.token_ids) == 0
Expand Down Expand Up @@ -106,7 +106,7 @@ def test_append_logprobs_for_next_position_flatten(
rank=11,
num_logprobs=-1,
)
assert isinstance(logprobs, FlattenLogprobs)
assert isinstance(logprobs, FlatLogprobs)
assert logprobs.start_indices == [0, 1]
assert logprobs.end_indices == [1, 3]
assert logprobs.token_ids == [1, 2, 3]
Expand All @@ -130,7 +130,7 @@ def test_append_logprobs_for_next_position_flatten(


def test_flatten_logprobs_append() -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

For consistency with the class rename from FlattenLogprobs to FlatLogprobs, this test function should be renamed to test_flat_logprobs_append.

Suggested change
def test_flatten_logprobs_append() -> None:
def test_flat_logprobs_append() -> None:

logprobs = FlattenLogprobs()
logprobs = FlatLogprobs()
logprobs.append(LOGPROBS_ONE_POSITION_0)
logprobs.append(LOGPROBS_ONE_POSITION_1)
assert logprobs.start_indices == [0, 1]
Expand All @@ -150,7 +150,7 @@ def test_flatten_logprobs_append() -> None:


def test_flatten_logprobs_extend() -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

To maintain consistency with the refactoring, this test function should be renamed to test_flat_logprobs_extend.

Suggested change
def test_flatten_logprobs_extend() -> None:
def test_flat_logprobs_extend() -> None:

logprobs = FlattenLogprobs()
logprobs = FlatLogprobs()
# Extend with list[LogprobsOnePosition]
logprobs.extend([LOGPROBS_ONE_POSITION_2, LOGPROBS_ONE_POSITION_0])
assert logprobs.start_indices == [0, 3]
Expand All @@ -160,9 +160,9 @@ def test_flatten_logprobs_extend() -> None:
assert logprobs.ranks == [40, 50, 60, 10]
assert logprobs.decoded_tokens == ["40", "50", "60", "10"]

other_logprobs = FlattenLogprobs()
other_logprobs = FlatLogprobs()
other_logprobs.extend([LOGPROBS_ONE_POSITION_1, LOGPROBS_ONE_POSITION_0])
# Extend with another FlattenLogprobs
# Extend with another FlatLogprobs
logprobs.extend(other_logprobs)
assert logprobs.start_indices == [0, 3, 4, 6]
assert logprobs.end_indices == [3, 4, 6, 7]
Expand All @@ -173,7 +173,7 @@ def test_flatten_logprobs_extend() -> None:


def test_flatten_logprobs_access() -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Consistent with the rename of FlattenLogprobs, this function should be renamed to test_flat_logprobs_access.

Suggested change
def test_flatten_logprobs_access() -> None:
def test_flat_logprobs_access() -> None:

logprobs = FlattenLogprobs()
logprobs = FlatLogprobs()
logprobs.extend(
[LOGPROBS_ONE_POSITION_1, LOGPROBS_ONE_POSITION_2, LOGPROBS_ONE_POSITION_0]
)
Expand Down
4 changes: 2 additions & 2 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1464,10 +1464,10 @@ def get_vllm_port() -> int | None:
"VLLM_COMPILE_CACHE_SAVE_FORMAT": env_with_choices(
"VLLM_COMPILE_CACHE_SAVE_FORMAT", "binary", ["binary", "unpacked"]
),
# Flag to enable FlattenLogprobs whose GC overhead is significantly smaller than
# Flag to enable FlatLogprobs whose GC overhead is significantly smaller than
# the original list[dict[int, Logprob]] approach.
# After enabled, PromptLogprobs and SampleLogprobs would populated as
# FlattenLogprobs.
# FlatLogprobs.
"VLLM_FLATTEN_LOGPROBS": lambda: bool(int(os.getenv("VLLM_FLATTEN_LOGPROBS", "0"))),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

While the comment and class name are updated to FlatLogprobs, the environment variable is still named VLLM_FLATTEN_LOGPROBS. This creates an inconsistency. To complete the refactoring, consider renaming the environment variable to VLLM_FLAT_LOGPROBS. This would be a breaking change for users who set this variable. If backward compatibility is a concern, you could support both variables for a transition period or add a note to the comment explaining why the old name is kept.

}

Expand Down
24 changes: 12 additions & 12 deletions vllm/logprobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class Logprob:


@dataclass
class FlattenLogprobs(MutableSequence[LogprobsOnePosition]):
class FlatLogprobs(MutableSequence[LogprobsOnePosition]):
"""
Flatten logprobs of a request into multiple primitive type lists.

Expand All @@ -39,7 +39,7 @@ class FlattenLogprobs(MutableSequence[LogprobsOnePosition]):
all positions and ranks in to multiple primitive type lists (i.e.
logprobs, token_ids, ranks per token_ids, decoded_tokens).
So regardless of the sequence length and top_logprobs setup,
FlattenLogprobs would only introduce a constant amount of objects.
FlatLogprobs would only introduce a constant amount of objects.

As each position might contains different amount of ranks,
start_indices_per_position would be used to access the logprob ranges
Expand Down Expand Up @@ -107,7 +107,7 @@ def __len__(self) -> int:
def __getitem__(self, position: int) -> LogprobsOnePosition: ...

@overload
def __getitem__(self, s: slice, /) -> "FlattenLogprobs": ...
def __getitem__(self, s: slice, /) -> "FlatLogprobs": ...

def __getitem__(self, index: int | slice):
"""Extracts logprobs of a given position or slice"""
Expand All @@ -123,7 +123,7 @@ def __getitem__(self, index: int | slice):
elif isinstance(index, slice):
min_index = self.start_indices[index][0]
max_index = self.end_indices[index][-1]
return FlattenLogprobs(
return FlatLogprobs(
# Shift updated start_indices and end_indices to
# be 0-indexed
start_indices=[i - min_index for i in self.start_indices[index]],
Expand All @@ -137,13 +137,13 @@ def __getitem__(self, index: int | slice):
raise TypeError(f"Invalid index type: {type(index)}")

def __setitem__(self, item, value) -> None:
raise TypeError("Cannot set logprobs in FlattenLogprobs")
raise TypeError("Cannot set logprobs in FlatLogprobs")

def __delitem__(self, item) -> None:
raise TypeError("Cannot delete logprobs from FlattenLogprobs")
raise TypeError("Cannot delete logprobs from FlatLogprobs")

def insert(self, item) -> None:
raise TypeError("Cannot insert logprobs to FlattenLogprobs")
raise TypeError("Cannot insert logprobs to FlatLogprobs")

def __iter__(self) -> Iterator[LogprobsOnePosition]:
"""
Expand All @@ -156,22 +156,22 @@ def __iter__(self) -> Iterator[LogprobsOnePosition]:

# {token_id -> logprob} per each sequence group. None if the corresponding
# sequence group doesn't require prompt logprob.
PromptLogprobs = FlattenLogprobs | list[LogprobsOnePosition | None]
PromptLogprobs = FlatLogprobs | list[LogprobsOnePosition | None]
# {token_id -> logprob} for each sequence group.
SampleLogprobs = FlattenLogprobs | list[LogprobsOnePosition]
SampleLogprobs = FlatLogprobs | list[LogprobsOnePosition]


def create_prompt_logprobs() -> PromptLogprobs:
"""Creates a container to store prompt logprobs for a request"""
logprobs = FlattenLogprobs() if envs.VLLM_FLATTEN_LOGPROBS else []
logprobs = FlatLogprobs() if envs.VLLM_FLATTEN_LOGPROBS else []
# NOTE: logprob of first prompt token is None.
logprobs.append(None)
return logprobs


def create_sample_logprobs() -> SampleLogprobs:
"""Creates a container to store decode logprobs for a request"""
return FlattenLogprobs() if envs.VLLM_FLATTEN_LOGPROBS else []
return FlatLogprobs() if envs.VLLM_FLATTEN_LOGPROBS else []


def append_logprobs_for_next_position(
Expand All @@ -191,7 +191,7 @@ def append_logprobs_for_next_position(
topk_ranks = range(1, num_logprobs + 1)
ranks = itertools.chain((rank,), topk_ranks)

if isinstance(request_logprobs, FlattenLogprobs):
if isinstance(request_logprobs, FlatLogprobs):
request_logprobs.append_fast(token_ids, logprobs, ranks, decoded_tokens)
else:
request_logprobs.append(
Expand Down