diff --git a/tests/samplers/test_logprobs.py b/tests/samplers/test_logprobs.py index 87f5d40ac1da..c9d227599cde 100644 --- a/tests/samplers/test_logprobs.py +++ b/tests/samplers/test_logprobs.py @@ -4,7 +4,7 @@ import pytest from vllm import SamplingParams -from vllm.logprobs import FlattenLogprobs +from vllm.logprobs import FlatLogprobs MODELS = ["distilbert/distilgpt2"] MAX_TOKENS = 5 @@ -16,17 +16,17 @@ @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("greedy", [True, False]) -@pytest.mark.parametrize("flatten_logprobs", [True, False]) +@pytest.mark.parametrize("flat_logprobs", [True, False]) def test_ranks( vllm_runner, model, dtype, greedy, - flatten_logprobs, + flat_logprobs, example_prompts, monkeypatch: pytest.MonkeyPatch, ): - monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "1" if flatten_logprobs else "0") + monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "1" if flat_logprobs else "0") with vllm_runner(model, dtype=dtype, max_logprobs=MAX_LOGPROBS) as vllm_model: tokenizer = vllm_model.llm.get_tokenizer() example_prompt_tokens = [tokenizer.encode(prompt) for prompt in example_prompts] @@ -44,12 +44,8 @@ def test_ranks( decode_tokens, _, decode_logprobs, prompt_logprobs = result # Ensure the return type of logprobs is accurate - assert isinstance( - prompt_logprobs, FlattenLogprobs if flatten_logprobs else list - ) - assert isinstance( - decode_logprobs, FlattenLogprobs if flatten_logprobs else list - ) + assert isinstance(prompt_logprobs, FlatLogprobs if flat_logprobs else list) + assert isinstance(decode_logprobs, FlatLogprobs if flat_logprobs else list) ######################## # Check prompt logprobs diff --git a/tests/test_logprobs.py b/tests/test_logprobs.py index 1799d3638178..d26a460d2bca 100644 --- a/tests/test_logprobs.py +++ b/tests/test_logprobs.py @@ -5,7 +5,7 @@ import pytest from vllm.logprobs import ( - FlattenLogprobs, + FlatLogprobs, Logprob, LogprobsOnePosition, append_logprobs_for_next_position, @@ -14,8 +14,8 @@ ) -def test_create_logprobs_non_flatten(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "0") +def test_create_logprobs_non_flat(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "0") prompt_logprobs = create_prompt_logprobs() assert isinstance(prompt_logprobs, list) @@ -28,11 +28,11 @@ def test_create_logprobs_non_flatten(monkeypatch: pytest.MonkeyPatch) -> None: assert len(sample_logprobs) == 0 -def test_create_logprobs_flatten(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "1") +def test_create_logprobs_flat(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "1") prompt_logprobs = create_prompt_logprobs() - assert isinstance(prompt_logprobs, FlattenLogprobs) + assert isinstance(prompt_logprobs, FlatLogprobs) assert prompt_logprobs.start_indices == [0] assert prompt_logprobs.end_indices == [0] assert len(prompt_logprobs.token_ids) == 0 @@ -44,7 +44,7 @@ def test_create_logprobs_flatten(monkeypatch: pytest.MonkeyPatch) -> None: assert prompt_logprobs[0] == dict() sample_logprobs = create_sample_logprobs() - assert isinstance(sample_logprobs, FlattenLogprobs) + assert isinstance(sample_logprobs, FlatLogprobs) assert len(sample_logprobs.start_indices) == 0 assert len(sample_logprobs.end_indices) == 0 assert len(sample_logprobs.token_ids) == 0 @@ -54,10 +54,10 @@ def test_create_logprobs_flatten(monkeypatch: pytest.MonkeyPatch) -> None: assert len(sample_logprobs) == 0 -def test_append_logprobs_for_next_position_none_flatten( +def test_append_logprobs_for_next_position_none_flat( monkeypatch: pytest.MonkeyPatch, ) -> None: - monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "0") + monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "0") logprobs = create_sample_logprobs() append_logprobs_for_next_position( logprobs, @@ -85,10 +85,10 @@ def test_append_logprobs_for_next_position_none_flatten( ] -def test_append_logprobs_for_next_position_flatten( +def test_append_logprobs_for_next_position_flat( monkeypatch: pytest.MonkeyPatch, ) -> None: - monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "1") + monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "1") logprobs = create_sample_logprobs() append_logprobs_for_next_position( logprobs, @@ -106,7 +106,7 @@ def test_append_logprobs_for_next_position_flatten( rank=11, num_logprobs=-1, ) - assert isinstance(logprobs, FlattenLogprobs) + assert isinstance(logprobs, FlatLogprobs) assert logprobs.start_indices == [0, 1] assert logprobs.end_indices == [1, 3] assert logprobs.token_ids == [1, 2, 3] @@ -129,8 +129,8 @@ def test_append_logprobs_for_next_position_flatten( } -def test_flatten_logprobs_append() -> None: - logprobs = FlattenLogprobs() +def test_flat_logprobs_append() -> None: + logprobs = FlatLogprobs() logprobs.append(LOGPROBS_ONE_POSITION_0) logprobs.append(LOGPROBS_ONE_POSITION_1) assert logprobs.start_indices == [0, 1] @@ -149,8 +149,8 @@ def test_flatten_logprobs_append() -> None: assert logprobs.decoded_tokens == ["10", "20", "30", "40", "50", "60"] -def test_flatten_logprobs_extend() -> None: - logprobs = FlattenLogprobs() +def test_flat_logprobs_extend() -> None: + logprobs = FlatLogprobs() # Extend with list[LogprobsOnePosition] logprobs.extend([LOGPROBS_ONE_POSITION_2, LOGPROBS_ONE_POSITION_0]) assert logprobs.start_indices == [0, 3] @@ -160,9 +160,9 @@ def test_flatten_logprobs_extend() -> None: assert logprobs.ranks == [40, 50, 60, 10] assert logprobs.decoded_tokens == ["40", "50", "60", "10"] - other_logprobs = FlattenLogprobs() + other_logprobs = FlatLogprobs() other_logprobs.extend([LOGPROBS_ONE_POSITION_1, LOGPROBS_ONE_POSITION_0]) - # Extend with another FlattenLogprobs + # Extend with another FlatLogprobs logprobs.extend(other_logprobs) assert logprobs.start_indices == [0, 3, 4, 6] assert logprobs.end_indices == [3, 4, 6, 7] @@ -172,8 +172,8 @@ def test_flatten_logprobs_extend() -> None: assert logprobs.decoded_tokens == ["40", "50", "60", "10", "20", "30", "10"] -def test_flatten_logprobs_access() -> None: - logprobs = FlattenLogprobs() +def test_flat_logprobs_access() -> None: + logprobs = FlatLogprobs() logprobs.extend( [LOGPROBS_ONE_POSITION_1, LOGPROBS_ONE_POSITION_2, LOGPROBS_ONE_POSITION_0] ) diff --git a/vllm/envs.py b/vllm/envs.py index 30c62e90e9fb..911a59a57790 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -222,7 +222,7 @@ VLLM_GC_DEBUG: str = "" VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary" - VLLM_FLATTEN_LOGPROBS: bool = False + VLLM_FLAT_LOGPROBS: bool = False def get_default_cache_root(): @@ -1476,11 +1476,11 @@ def get_vllm_port() -> int | None: "VLLM_COMPILE_CACHE_SAVE_FORMAT": env_with_choices( "VLLM_COMPILE_CACHE_SAVE_FORMAT", "binary", ["binary", "unpacked"] ), - # Flag to enable FlattenLogprobs whose GC overhead is significantly smaller than + # Flag to enable FlatLogprobs whose GC overhead is significantly smaller than # the original list[dict[int, Logprob]] approach. # After enabled, PromptLogprobs and SampleLogprobs would populated as - # FlattenLogprobs. - "VLLM_FLATTEN_LOGPROBS": lambda: bool(int(os.getenv("VLLM_FLATTEN_LOGPROBS", "0"))), + # FlatLogprobs. + "VLLM_FLAT_LOGPROBS": lambda: bool(int(os.getenv("VLLM_FLAT_LOGPROBS", "0"))), } # --8<-- [end:env-vars-definition] diff --git a/vllm/logprobs.py b/vllm/logprobs.py index bf66e5f75c79..a34398db2c96 100644 --- a/vllm/logprobs.py +++ b/vllm/logprobs.py @@ -30,16 +30,16 @@ class Logprob: @dataclass -class FlattenLogprobs(MutableSequence[LogprobsOnePosition]): +class FlatLogprobs(MutableSequence[LogprobsOnePosition]): """ - Flatten logprobs of a request into multiple primitive type lists. + Flat logprobs of a request into multiple primitive type lists. Compared to list[dict[int, Logprob]], this data structure reduced GC overhead significantly. As it flattened logprob information for all positions and ranks in to multiple primitive type lists (i.e. logprobs, token_ids, ranks per token_ids, decoded_tokens). So regardless of the sequence length and top_logprobs setup, - FlattenLogprobs would only introduce a constant amount of objects. + FlatLogprobs would only introduce a constant amount of objects. As each position might contains different amount of ranks, start_indices_per_position would be used to access the logprob ranges @@ -107,7 +107,7 @@ def __len__(self) -> int: def __getitem__(self, position: int) -> LogprobsOnePosition: ... @overload - def __getitem__(self, s: slice, /) -> "FlattenLogprobs": ... + def __getitem__(self, s: slice, /) -> "FlatLogprobs": ... def __getitem__(self, index: int | slice): """Extracts logprobs of a given position or slice""" @@ -123,7 +123,7 @@ def __getitem__(self, index: int | slice): elif isinstance(index, slice): min_index = self.start_indices[index][0] max_index = self.end_indices[index][-1] - return FlattenLogprobs( + return FlatLogprobs( # Shift updated start_indices and end_indices to # be 0-indexed start_indices=[i - min_index for i in self.start_indices[index]], @@ -137,13 +137,13 @@ def __getitem__(self, index: int | slice): raise TypeError(f"Invalid index type: {type(index)}") def __setitem__(self, item, value) -> None: - raise TypeError("Cannot set logprobs in FlattenLogprobs") + raise TypeError("Cannot set logprobs in FlatLogprobs") def __delitem__(self, item) -> None: - raise TypeError("Cannot delete logprobs from FlattenLogprobs") + raise TypeError("Cannot delete logprobs from FlatLogprobs") def insert(self, item) -> None: - raise TypeError("Cannot insert logprobs to FlattenLogprobs") + raise TypeError("Cannot insert logprobs to FlatLogprobs") def __iter__(self) -> Iterator[LogprobsOnePosition]: """ @@ -156,14 +156,14 @@ def __iter__(self) -> Iterator[LogprobsOnePosition]: # {token_id -> logprob} per each sequence group. None if the corresponding # sequence group doesn't require prompt logprob. -PromptLogprobs = FlattenLogprobs | list[LogprobsOnePosition | None] +PromptLogprobs = FlatLogprobs | list[LogprobsOnePosition | None] # {token_id -> logprob} for each sequence group. -SampleLogprobs = FlattenLogprobs | list[LogprobsOnePosition] +SampleLogprobs = FlatLogprobs | list[LogprobsOnePosition] def create_prompt_logprobs() -> PromptLogprobs: """Creates a container to store prompt logprobs for a request""" - logprobs = FlattenLogprobs() if envs.VLLM_FLATTEN_LOGPROBS else [] + logprobs = FlatLogprobs() if envs.VLLM_FLAT_LOGPROBS else [] # NOTE: logprob of first prompt token is None. logprobs.append(None) return logprobs @@ -171,7 +171,7 @@ def create_prompt_logprobs() -> PromptLogprobs: def create_sample_logprobs() -> SampleLogprobs: """Creates a container to store decode logprobs for a request""" - return FlattenLogprobs() if envs.VLLM_FLATTEN_LOGPROBS else [] + return FlatLogprobs() if envs.VLLM_FLAT_LOGPROBS else [] def append_logprobs_for_next_position( @@ -191,7 +191,7 @@ def append_logprobs_for_next_position( topk_ranks = range(1, num_logprobs + 1) ranks = itertools.chain((rank,), topk_ranks) - if isinstance(request_logprobs, FlattenLogprobs): + if isinstance(request_logprobs, FlatLogprobs): request_logprobs.append_fast(token_ids, logprobs, ranks, decoded_tokens) else: request_logprobs.append(