Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 6 additions & 10 deletions tests/samplers/test_logprobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest

from vllm import SamplingParams
from vllm.logprobs import FlattenLogprobs
from vllm.logprobs import FlatLogprobs

MODELS = ["distilbert/distilgpt2"]
MAX_TOKENS = 5
Expand All @@ -16,17 +16,17 @@
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("greedy", [True, False])
@pytest.mark.parametrize("flatten_logprobs", [True, False])
@pytest.mark.parametrize("flat_logprobs", [True, False])
def test_ranks(
vllm_runner,
model,
dtype,
greedy,
flatten_logprobs,
flat_logprobs,
example_prompts,
monkeypatch: pytest.MonkeyPatch,
):
monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "1" if flatten_logprobs else "0")
monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "1" if flat_logprobs else "0")
with vllm_runner(model, dtype=dtype, max_logprobs=MAX_LOGPROBS) as vllm_model:
tokenizer = vllm_model.llm.get_tokenizer()
example_prompt_tokens = [tokenizer.encode(prompt) for prompt in example_prompts]
Expand All @@ -44,12 +44,8 @@ def test_ranks(
decode_tokens, _, decode_logprobs, prompt_logprobs = result

# Ensure the return type of logprobs is accurate
assert isinstance(
prompt_logprobs, FlattenLogprobs if flatten_logprobs else list
)
assert isinstance(
decode_logprobs, FlattenLogprobs if flatten_logprobs else list
)
assert isinstance(prompt_logprobs, FlatLogprobs if flat_logprobs else list)
assert isinstance(decode_logprobs, FlatLogprobs if flat_logprobs else list)

########################
# Check prompt logprobs
Expand Down
40 changes: 20 additions & 20 deletions tests/test_logprobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest

from vllm.logprobs import (
FlattenLogprobs,
FlatLogprobs,
Logprob,
LogprobsOnePosition,
append_logprobs_for_next_position,
Expand All @@ -14,8 +14,8 @@
)


def test_create_logprobs_non_flatten(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "0")
def test_create_logprobs_non_flat(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "0")

prompt_logprobs = create_prompt_logprobs()
assert isinstance(prompt_logprobs, list)
Expand All @@ -28,11 +28,11 @@ def test_create_logprobs_non_flatten(monkeypatch: pytest.MonkeyPatch) -> None:
assert len(sample_logprobs) == 0


def test_create_logprobs_flatten(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "1")
def test_create_logprobs_flat(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "1")

prompt_logprobs = create_prompt_logprobs()
assert isinstance(prompt_logprobs, FlattenLogprobs)
assert isinstance(prompt_logprobs, FlatLogprobs)
assert prompt_logprobs.start_indices == [0]
assert prompt_logprobs.end_indices == [0]
assert len(prompt_logprobs.token_ids) == 0
Expand All @@ -44,7 +44,7 @@ def test_create_logprobs_flatten(monkeypatch: pytest.MonkeyPatch) -> None:
assert prompt_logprobs[0] == dict()

sample_logprobs = create_sample_logprobs()
assert isinstance(sample_logprobs, FlattenLogprobs)
assert isinstance(sample_logprobs, FlatLogprobs)
assert len(sample_logprobs.start_indices) == 0
assert len(sample_logprobs.end_indices) == 0
assert len(sample_logprobs.token_ids) == 0
Expand All @@ -54,10 +54,10 @@ def test_create_logprobs_flatten(monkeypatch: pytest.MonkeyPatch) -> None:
assert len(sample_logprobs) == 0


def test_append_logprobs_for_next_position_none_flatten(
def test_append_logprobs_for_next_position_none_flat(
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "0")
monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "0")
logprobs = create_sample_logprobs()
append_logprobs_for_next_position(
logprobs,
Expand Down Expand Up @@ -85,10 +85,10 @@ def test_append_logprobs_for_next_position_none_flatten(
]


def test_append_logprobs_for_next_position_flatten(
def test_append_logprobs_for_next_position_flat(
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "1")
monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "1")
logprobs = create_sample_logprobs()
append_logprobs_for_next_position(
logprobs,
Expand All @@ -106,7 +106,7 @@ def test_append_logprobs_for_next_position_flatten(
rank=11,
num_logprobs=-1,
)
assert isinstance(logprobs, FlattenLogprobs)
assert isinstance(logprobs, FlatLogprobs)
assert logprobs.start_indices == [0, 1]
assert logprobs.end_indices == [1, 3]
assert logprobs.token_ids == [1, 2, 3]
Expand All @@ -129,8 +129,8 @@ def test_append_logprobs_for_next_position_flatten(
}


def test_flatten_logprobs_append() -> None:
logprobs = FlattenLogprobs()
def test_flat_logprobs_append() -> None:
logprobs = FlatLogprobs()
logprobs.append(LOGPROBS_ONE_POSITION_0)
logprobs.append(LOGPROBS_ONE_POSITION_1)
assert logprobs.start_indices == [0, 1]
Expand All @@ -149,8 +149,8 @@ def test_flatten_logprobs_append() -> None:
assert logprobs.decoded_tokens == ["10", "20", "30", "40", "50", "60"]


def test_flatten_logprobs_extend() -> None:
logprobs = FlattenLogprobs()
def test_flat_logprobs_extend() -> None:
logprobs = FlatLogprobs()
# Extend with list[LogprobsOnePosition]
logprobs.extend([LOGPROBS_ONE_POSITION_2, LOGPROBS_ONE_POSITION_0])
assert logprobs.start_indices == [0, 3]
Expand All @@ -160,9 +160,9 @@ def test_flatten_logprobs_extend() -> None:
assert logprobs.ranks == [40, 50, 60, 10]
assert logprobs.decoded_tokens == ["40", "50", "60", "10"]

other_logprobs = FlattenLogprobs()
other_logprobs = FlatLogprobs()
other_logprobs.extend([LOGPROBS_ONE_POSITION_1, LOGPROBS_ONE_POSITION_0])
# Extend with another FlattenLogprobs
# Extend with another FlatLogprobs
logprobs.extend(other_logprobs)
assert logprobs.start_indices == [0, 3, 4, 6]
assert logprobs.end_indices == [3, 4, 6, 7]
Expand All @@ -172,8 +172,8 @@ def test_flatten_logprobs_extend() -> None:
assert logprobs.decoded_tokens == ["40", "50", "60", "10", "20", "30", "10"]


def test_flatten_logprobs_access() -> None:
logprobs = FlattenLogprobs()
def test_flat_logprobs_access() -> None:
logprobs = FlatLogprobs()
logprobs.extend(
[LOGPROBS_ONE_POSITION_1, LOGPROBS_ONE_POSITION_2, LOGPROBS_ONE_POSITION_0]
)
Expand Down
8 changes: 4 additions & 4 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@
VLLM_GC_DEBUG: str = ""
VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False
VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary"
VLLM_FLATTEN_LOGPROBS: bool = False
VLLM_FLAT_LOGPROBS: bool = False


def get_default_cache_root():
Expand Down Expand Up @@ -1464,11 +1464,11 @@ def get_vllm_port() -> int | None:
"VLLM_COMPILE_CACHE_SAVE_FORMAT": env_with_choices(
"VLLM_COMPILE_CACHE_SAVE_FORMAT", "binary", ["binary", "unpacked"]
),
# Flag to enable FlattenLogprobs whose GC overhead is significantly smaller than
# Flag to enable FlatLogprobs whose GC overhead is significantly smaller than
# the original list[dict[int, Logprob]] approach.
# After enabled, PromptLogprobs and SampleLogprobs would populated as
# FlattenLogprobs.
"VLLM_FLATTEN_LOGPROBS": lambda: bool(int(os.getenv("VLLM_FLATTEN_LOGPROBS", "0"))),
# FlatLogprobs.
"VLLM_FLAT_LOGPROBS": lambda: bool(int(os.getenv("VLLM_FLAT_LOGPROBS", "0"))),
}

# --8<-- [end:env-vars-definition]
Expand Down
26 changes: 13 additions & 13 deletions vllm/logprobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,16 @@ class Logprob:


@dataclass
class FlattenLogprobs(MutableSequence[LogprobsOnePosition]):
class FlatLogprobs(MutableSequence[LogprobsOnePosition]):
"""
Flatten logprobs of a request into multiple primitive type lists.
Flat logprobs of a request into multiple primitive type lists.

Compared to list[dict[int, Logprob]], this data structure reduced GC
overhead significantly. As it flattened logprob information for
all positions and ranks in to multiple primitive type lists (i.e.
logprobs, token_ids, ranks per token_ids, decoded_tokens).
So regardless of the sequence length and top_logprobs setup,
FlattenLogprobs would only introduce a constant amount of objects.
FlatLogprobs would only introduce a constant amount of objects.

As each position might contains different amount of ranks,
start_indices_per_position would be used to access the logprob ranges
Expand Down Expand Up @@ -107,7 +107,7 @@ def __len__(self) -> int:
def __getitem__(self, position: int) -> LogprobsOnePosition: ...

@overload
def __getitem__(self, s: slice, /) -> "FlattenLogprobs": ...
def __getitem__(self, s: slice, /) -> "FlatLogprobs": ...

def __getitem__(self, index: int | slice):
"""Extracts logprobs of a given position or slice"""
Expand All @@ -123,7 +123,7 @@ def __getitem__(self, index: int | slice):
elif isinstance(index, slice):
min_index = self.start_indices[index][0]
max_index = self.end_indices[index][-1]
return FlattenLogprobs(
return FlatLogprobs(
# Shift updated start_indices and end_indices to
# be 0-indexed
start_indices=[i - min_index for i in self.start_indices[index]],
Expand All @@ -137,13 +137,13 @@ def __getitem__(self, index: int | slice):
raise TypeError(f"Invalid index type: {type(index)}")

def __setitem__(self, item, value) -> None:
raise TypeError("Cannot set logprobs in FlattenLogprobs")
raise TypeError("Cannot set logprobs in FlatLogprobs")

def __delitem__(self, item) -> None:
raise TypeError("Cannot delete logprobs from FlattenLogprobs")
raise TypeError("Cannot delete logprobs from FlatLogprobs")

def insert(self, item) -> None:
raise TypeError("Cannot insert logprobs to FlattenLogprobs")
raise TypeError("Cannot insert logprobs to FlatLogprobs")

def __iter__(self) -> Iterator[LogprobsOnePosition]:
"""
Expand All @@ -156,22 +156,22 @@ def __iter__(self) -> Iterator[LogprobsOnePosition]:

# {token_id -> logprob} per each sequence group. None if the corresponding
# sequence group doesn't require prompt logprob.
PromptLogprobs = FlattenLogprobs | list[LogprobsOnePosition | None]
PromptLogprobs = FlatLogprobs | list[LogprobsOnePosition | None]
# {token_id -> logprob} for each sequence group.
SampleLogprobs = FlattenLogprobs | list[LogprobsOnePosition]
SampleLogprobs = FlatLogprobs | list[LogprobsOnePosition]


def create_prompt_logprobs() -> PromptLogprobs:
"""Creates a container to store prompt logprobs for a request"""
logprobs = FlattenLogprobs() if envs.VLLM_FLATTEN_LOGPROBS else []
logprobs = FlatLogprobs() if envs.VLLM_FLAT_LOGPROBS else []
# NOTE: logprob of first prompt token is None.
logprobs.append(None)
return logprobs


def create_sample_logprobs() -> SampleLogprobs:
"""Creates a container to store decode logprobs for a request"""
return FlattenLogprobs() if envs.VLLM_FLATTEN_LOGPROBS else []
return FlatLogprobs() if envs.VLLM_FLAT_LOGPROBS else []


def append_logprobs_for_next_position(
Expand All @@ -191,7 +191,7 @@ def append_logprobs_for_next_position(
topk_ranks = range(1, num_logprobs + 1)
ranks = itertools.chain((rank,), topk_ranks)

if isinstance(request_logprobs, FlattenLogprobs):
if isinstance(request_logprobs, FlatLogprobs):
request_logprobs.append_fast(token_ids, logprobs, ranks, decoded_tokens)
else:
request_logprobs.append(
Expand Down