-
-
Notifications
You must be signed in to change notification settings - Fork 11.3k
[Misc] FlattenLogprobs -> FlatLogprobs #28335
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,7 +5,7 @@ | |
| import pytest | ||
|
|
||
| from vllm.logprobs import ( | ||
| FlattenLogprobs, | ||
| FlatLogprobs, | ||
| Logprob, | ||
| LogprobsOnePosition, | ||
| append_logprobs_for_next_position, | ||
|
|
@@ -32,7 +32,7 @@ def test_create_logprobs_flatten(monkeypatch: pytest.MonkeyPatch) -> None: | |
| monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "1") | ||
|
|
||
| prompt_logprobs = create_prompt_logprobs() | ||
| assert isinstance(prompt_logprobs, FlattenLogprobs) | ||
| assert isinstance(prompt_logprobs, FlatLogprobs) | ||
| assert prompt_logprobs.start_indices == [0] | ||
| assert prompt_logprobs.end_indices == [0] | ||
| assert len(prompt_logprobs.token_ids) == 0 | ||
|
|
@@ -44,7 +44,7 @@ def test_create_logprobs_flatten(monkeypatch: pytest.MonkeyPatch) -> None: | |
| assert prompt_logprobs[0] == dict() | ||
|
|
||
| sample_logprobs = create_sample_logprobs() | ||
| assert isinstance(sample_logprobs, FlattenLogprobs) | ||
| assert isinstance(sample_logprobs, FlatLogprobs) | ||
| assert len(sample_logprobs.start_indices) == 0 | ||
| assert len(sample_logprobs.end_indices) == 0 | ||
| assert len(sample_logprobs.token_ids) == 0 | ||
|
|
@@ -106,7 +106,7 @@ def test_append_logprobs_for_next_position_flatten( | |
| rank=11, | ||
| num_logprobs=-1, | ||
| ) | ||
| assert isinstance(logprobs, FlattenLogprobs) | ||
| assert isinstance(logprobs, FlatLogprobs) | ||
| assert logprobs.start_indices == [0, 1] | ||
| assert logprobs.end_indices == [1, 3] | ||
| assert logprobs.token_ids == [1, 2, 3] | ||
|
|
@@ -130,7 +130,7 @@ def test_append_logprobs_for_next_position_flatten( | |
|
|
||
|
|
||
| def test_flatten_logprobs_append() -> None: | ||
| logprobs = FlattenLogprobs() | ||
| logprobs = FlatLogprobs() | ||
| logprobs.append(LOGPROBS_ONE_POSITION_0) | ||
| logprobs.append(LOGPROBS_ONE_POSITION_1) | ||
| assert logprobs.start_indices == [0, 1] | ||
|
|
@@ -150,7 +150,7 @@ def test_flatten_logprobs_append() -> None: | |
|
|
||
|
|
||
| def test_flatten_logprobs_extend() -> None: | ||
|
||
| logprobs = FlattenLogprobs() | ||
| logprobs = FlatLogprobs() | ||
| # Extend with list[LogprobsOnePosition] | ||
| logprobs.extend([LOGPROBS_ONE_POSITION_2, LOGPROBS_ONE_POSITION_0]) | ||
| assert logprobs.start_indices == [0, 3] | ||
|
|
@@ -160,9 +160,9 @@ def test_flatten_logprobs_extend() -> None: | |
| assert logprobs.ranks == [40, 50, 60, 10] | ||
| assert logprobs.decoded_tokens == ["40", "50", "60", "10"] | ||
|
|
||
| other_logprobs = FlattenLogprobs() | ||
| other_logprobs = FlatLogprobs() | ||
| other_logprobs.extend([LOGPROBS_ONE_POSITION_1, LOGPROBS_ONE_POSITION_0]) | ||
| # Extend with another FlattenLogprobs | ||
| # Extend with another FlatLogprobs | ||
| logprobs.extend(other_logprobs) | ||
| assert logprobs.start_indices == [0, 3, 4, 6] | ||
| assert logprobs.end_indices == [3, 4, 6, 7] | ||
|
|
@@ -173,7 +173,7 @@ def test_flatten_logprobs_extend() -> None: | |
|
|
||
|
|
||
| def test_flatten_logprobs_access() -> None: | ||
|
||
| logprobs = FlattenLogprobs() | ||
| logprobs = FlatLogprobs() | ||
| logprobs.extend( | ||
| [LOGPROBS_ONE_POSITION_1, LOGPROBS_ONE_POSITION_2, LOGPROBS_ONE_POSITION_0] | ||
| ) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1464,10 +1464,10 @@ def get_vllm_port() -> int | None: | |
| "VLLM_COMPILE_CACHE_SAVE_FORMAT": env_with_choices( | ||
| "VLLM_COMPILE_CACHE_SAVE_FORMAT", "binary", ["binary", "unpacked"] | ||
| ), | ||
| # Flag to enable FlattenLogprobs whose GC overhead is significantly smaller than | ||
| # Flag to enable FlatLogprobs whose GC overhead is significantly smaller than | ||
| # the original list[dict[int, Logprob]] approach. | ||
| # After enabled, PromptLogprobs and SampleLogprobs would populated as | ||
| # FlattenLogprobs. | ||
| # FlatLogprobs. | ||
| "VLLM_FLATTEN_LOGPROBS": lambda: bool(int(os.getenv("VLLM_FLATTEN_LOGPROBS", "0"))), | ||
|
||
| } | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For consistency with the class rename from
FlattenLogprobstoFlatLogprobs, this test function should be renamed totest_flat_logprobs_append.