99from utils import _extract_step_logprobs , _random_prompt , skip_unsupported
1010
1111from vllm import LLM , SamplingParams
12+ from vllm .platforms import current_platform
13+
14+ BACKENDS : list [str ] = [
15+ "FLASH_ATTN" ,
16+ "FLASHINFER" ,
17+ ]
18+
19+ if current_platform .is_cuda () and current_platform .is_device_capability (90 ):
20+ BACKENDS .append ("FLASH_ATTN_MLA" )
21+
22+ DEFAULT_MODEL = "Qwen/Qwen3-1.7B"
23+ MLA_MODEL = "deepseek-ai/DeepSeek-V2-Lite-Chat"
24+
25+
26+ def resolve_model_name (backend : str ) -> str :
27+ """Resolve the model name for the given backend, respecting env overrides."""
28+ model = os .getenv ("VLLM_TEST_MODEL" , DEFAULT_MODEL )
29+ if backend .endswith ("MLA" ) and model == DEFAULT_MODEL :
30+ return MLA_MODEL
31+ return model
1232
1333
1434@skip_unsupported
1535@pytest .mark .timeout (1000 )
1636@pytest .mark .parametrize (
1737 "backend" ,
18- [ "FLASH_ATTN" , "FLASHINFER" , "FLASH_ATTN_MLA" , "FLASHINFER_MLA" , "TRITON_MLA" ] ,
38+ BACKENDS ,
1939)
2040def test_v1_generation_is_deterministic_across_batch_sizes_with_needle (
2141 backend , monkeypatch : pytest .MonkeyPatch
@@ -47,7 +67,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
4767 monkeypatch .setenv ("VLLM_ATTENTION_BACKEND" , backend )
4868 # Allow overrides from environment (useful for CI tuning)
4969 # "facebook/opt-125m" is too small, doesn't reliably test determinism
50- model = os . getenv ( "VLLM_TEST_MODEL" , "Qwen/Qwen3-1.7B" )
70+ model = resolve_model_name ( backend )
5171 num_trials = int (os .getenv ("VLLM_NEEDLE_TRIALS" , "5" ))
5272 max_batch_size = int (os .getenv ("VLLM_NEEDLE_BATCH_SIZE" , "128" ))
5373 min_random_prompt = int (os .getenv ("VLLM_MIN_PROMPT" , "1024" ))
@@ -150,7 +170,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
150170@skip_unsupported
151171@pytest .mark .parametrize (
152172 "backend" ,
153- [ "FLASH_ATTN" , "FLASHINFER" , "FLASH_ATTN_MLA" , "FLASHINFER_MLA" , "TRITON_MLA" ] ,
173+ BACKENDS ,
154174)
155175@pytest .mark .forked
156176def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN (
@@ -160,7 +180,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
160180
161181 seed = int (os .getenv ("VLLM_TEST_SEED" , "12345" ))
162182 random .seed (seed )
163- model_name = os . getenv ( "VLLM_TEST_MODEL" , "Qwen/Qwen3-1.7B" )
183+ model_name = resolve_model_name ( backend )
164184 tp_size = int (os .getenv ("VLLM_TEST_TP_SIZE" , "1" ))
165185
166186 # For batch invariance, disable custom all-reduce to ensure deterministic
@@ -369,15 +389,15 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
369389@skip_unsupported
370390@pytest .mark .parametrize (
371391 "backend" ,
372- [ "FLASH_ATTN" , "FLASHINFER" , "FLASH_ATTN_MLA" , "FLASHINFER_MLA" , "TRITON_MLA" ] ,
392+ BACKENDS ,
373393)
374394def test_simple_generation (backend , monkeypatch : pytest .MonkeyPatch ):
375395 """
376396 Simple test that runs the model with a basic prompt and prints the output.
377397 Useful for quick smoke testing and debugging.
378398 """
379399 monkeypatch .setenv ("VLLM_ATTENTION_BACKEND" , backend )
380- model = os . getenv ( "VLLM_TEST_MODEL" , "Qwen/Qwen3-1.7B" )
400+ model = resolve_model_name ( backend )
381401
382402 llm = LLM (
383403 model = model ,
@@ -419,7 +439,7 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
419439@skip_unsupported
420440@pytest .mark .parametrize (
421441 "backend" ,
422- [ "FLASH_ATTN" , "FLASHINFER" , "FLASH_ATTN_MLA" , "FLASHINFER_MLA" , "TRITON_MLA" ] ,
442+ BACKENDS ,
423443)
424444@pytest .mark .forked
425445def test_logprobs_without_batch_invariance_should_fail (
@@ -434,14 +454,17 @@ def test_logprobs_without_batch_invariance_should_fail(
434454 The test will PASS if we detect differences (proving batch invariance matters).
435455 The test will FAIL if everything matches (suggesting batch invariance isn't needed).
436456 """
457+ from vllm .model_executor .layers .batch_invariant import vllm_is_batch_invariant
458+
459+ vllm_is_batch_invariant .cache_clear ()
437460 monkeypatch .setenv ("VLLM_ATTENTION_BACKEND" , backend )
438461
439462 # CRITICAL: Disable batch invariance for this test
440463 monkeypatch .setenv ("VLLM_BATCH_INVARIANT" , "0" )
441464
442465 seed = int (os .getenv ("VLLM_TEST_SEED" , "12345" ))
443466 random .seed (seed )
444- model_name = os . getenv ( "VLLM_TEST_MODEL" , "Qwen/Qwen3-1.7B" )
467+ model_name = resolve_model_name ( backend )
445468 tp_size = int (os .getenv ("VLLM_TEST_TP_SIZE" , "1" ))
446469
447470 print (f"\n { '=' * 80 } " )
@@ -659,7 +682,7 @@ def test_decode_logprobs_match_prefill_logprobs(
659682
660683 seed = int (os .getenv ("VLLM_TEST_SEED" , "12345" ))
661684 random .seed (seed )
662- model_name = os . getenv ( "VLLM_TEST_MODEL" , "Qwen/Qwen3-1.7B" )
685+ model_name = resolve_model_name ( backend )
663686 tp_size = int (os .getenv ("VLLM_TEST_TP_SIZE" , "1" ))
664687
665688 from vllm .model_executor .layers .batch_invariant import (
0 commit comments