Skip to content

Commit cb9cc58

Browse files
yewentao256devpatelio
authored andcommitted
[Bug] Fix Batch Invariant MLA test (vllm-project#28967)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
1 parent 59a504b commit cb9cc58

File tree

2 files changed

+33
-10
lines changed

2 files changed

+33
-10
lines changed

tests/v1/determinism/test_batch_invariance.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,33 @@
99
from utils import _extract_step_logprobs, _random_prompt, skip_unsupported
1010

1111
from vllm import LLM, SamplingParams
12+
from vllm.platforms import current_platform
13+
14+
BACKENDS: list[str] = [
15+
"FLASH_ATTN",
16+
"FLASHINFER",
17+
]
18+
19+
if current_platform.is_cuda() and current_platform.is_device_capability(90):
20+
BACKENDS.append("FLASH_ATTN_MLA")
21+
22+
DEFAULT_MODEL = "Qwen/Qwen3-1.7B"
23+
MLA_MODEL = "deepseek-ai/DeepSeek-V2-Lite-Chat"
24+
25+
26+
def resolve_model_name(backend: str) -> str:
27+
"""Resolve the model name for the given backend, respecting env overrides."""
28+
model = os.getenv("VLLM_TEST_MODEL", DEFAULT_MODEL)
29+
if backend.endswith("MLA") and model == DEFAULT_MODEL:
30+
return MLA_MODEL
31+
return model
1232

1333

1434
@skip_unsupported
1535
@pytest.mark.timeout(1000)
1636
@pytest.mark.parametrize(
1737
"backend",
18-
["FLASH_ATTN", "FLASHINFER", "FLASH_ATTN_MLA", "FLASHINFER_MLA", "TRITON_MLA"],
38+
BACKENDS,
1939
)
2040
def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
2141
backend, monkeypatch: pytest.MonkeyPatch
@@ -47,7 +67,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
4767
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
4868
# Allow overrides from environment (useful for CI tuning)
4969
# "facebook/opt-125m" is too small, doesn't reliably test determinism
50-
model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
70+
model = resolve_model_name(backend)
5171
num_trials = int(os.getenv("VLLM_NEEDLE_TRIALS", "5"))
5272
max_batch_size = int(os.getenv("VLLM_NEEDLE_BATCH_SIZE", "128"))
5373
min_random_prompt = int(os.getenv("VLLM_MIN_PROMPT", "1024"))
@@ -150,7 +170,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
150170
@skip_unsupported
151171
@pytest.mark.parametrize(
152172
"backend",
153-
["FLASH_ATTN", "FLASHINFER", "FLASH_ATTN_MLA", "FLASHINFER_MLA", "TRITON_MLA"],
173+
BACKENDS,
154174
)
155175
@pytest.mark.forked
156176
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
@@ -160,7 +180,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
160180

161181
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
162182
random.seed(seed)
163-
model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
183+
model_name = resolve_model_name(backend)
164184
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
165185

166186
# For batch invariance, disable custom all-reduce to ensure deterministic
@@ -369,15 +389,15 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
369389
@skip_unsupported
370390
@pytest.mark.parametrize(
371391
"backend",
372-
["FLASH_ATTN", "FLASHINFER", "FLASH_ATTN_MLA", "FLASHINFER_MLA", "TRITON_MLA"],
392+
BACKENDS,
373393
)
374394
def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
375395
"""
376396
Simple test that runs the model with a basic prompt and prints the output.
377397
Useful for quick smoke testing and debugging.
378398
"""
379399
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
380-
model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
400+
model = resolve_model_name(backend)
381401

382402
llm = LLM(
383403
model=model,
@@ -419,7 +439,7 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
419439
@skip_unsupported
420440
@pytest.mark.parametrize(
421441
"backend",
422-
["FLASH_ATTN", "FLASHINFER", "FLASH_ATTN_MLA", "FLASHINFER_MLA", "TRITON_MLA"],
442+
BACKENDS,
423443
)
424444
@pytest.mark.forked
425445
def test_logprobs_without_batch_invariance_should_fail(
@@ -434,14 +454,17 @@ def test_logprobs_without_batch_invariance_should_fail(
434454
The test will PASS if we detect differences (proving batch invariance matters).
435455
The test will FAIL if everything matches (suggesting batch invariance isn't needed).
436456
"""
457+
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
458+
459+
vllm_is_batch_invariant.cache_clear()
437460
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
438461

439462
# CRITICAL: Disable batch invariance for this test
440463
monkeypatch.setenv("VLLM_BATCH_INVARIANT", "0")
441464

442465
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
443466
random.seed(seed)
444-
model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
467+
model_name = resolve_model_name(backend)
445468
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
446469

447470
print(f"\n{'=' * 80}")
@@ -659,7 +682,7 @@ def test_decode_logprobs_match_prefill_logprobs(
659682

660683
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
661684
random.seed(seed)
662-
model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
685+
model_name = resolve_model_name(backend)
663686
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
664687

665688
from vllm.model_executor.layers.batch_invariant import (

vllm/model_executor/layers/batch_invariant.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -803,11 +803,11 @@ def override_envs_for_invariance():
803803
"FLASH_ATTN", # best supported backend
804804
"FLASHINFER",
805805
"FLASH_ATTN_MLA",
806-
"FLASHINFER_MLA",
807806
"TRITON_MLA",
808807
# Not yet supported MLA backends
809808
# "FLASHMLA",
810809
# "FLEX_ATTENTION", # IMA issue even if we disable batch invariance
810+
# "FLASHINFER_MLA", https://github.com/vllm-project/vllm/pull/28967
811811
]
812812
if curr_attn_backend not in supported_backends:
813813
warning = (

0 commit comments

Comments
 (0)