Skip to content

Commit ca90f50

Browse files
[Test] Add non-MoE DP test coverage (vllm-project#28235)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
1 parent da855b4 commit ca90f50

File tree

1 file changed

+19
-10
lines changed

1 file changed

+19
-10
lines changed

tests/v1/distributed/test_async_llm_dp.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,6 @@
2020

2121
DP_SIZE = int(os.getenv("DP_SIZE", 2))
2222

23-
engine_args = AsyncEngineArgs(
24-
model="ibm-research/PowerMoE-3b",
25-
enforce_eager=True,
26-
tensor_parallel_size=int(os.getenv("TP_SIZE", 1)),
27-
data_parallel_size=DP_SIZE,
28-
)
29-
3023

3124
async def generate(
3225
engine: AsyncLLM,
@@ -65,6 +58,13 @@ async def generate(
6558
return count, request_id
6659

6760

61+
@pytest.mark.parametrize(
62+
"model",
63+
[
64+
"ibm-research/PowerMoE-3b",
65+
"hmellor/tiny-random-LlamaForCausalLM",
66+
],
67+
)
6868
@pytest.mark.parametrize(
6969
"output_kind",
7070
[
@@ -76,7 +76,10 @@ async def generate(
7676
@pytest.mark.parametrize("async_scheduling", [True, False])
7777
@pytest.mark.asyncio
7878
async def test_load(
79-
output_kind: RequestOutputKind, data_parallel_backend: str, async_scheduling: bool
79+
model: str,
80+
output_kind: RequestOutputKind,
81+
data_parallel_backend: str,
82+
async_scheduling: bool,
8083
):
8184
if async_scheduling and data_parallel_backend == "ray":
8285
# TODO(NickLucche) Re-enable when async scheduling is supported
@@ -107,8 +110,14 @@ def log_engine_initialized(self):
107110
with ExitStack() as after:
108111
prompt = "This is a test of data parallel"
109112

110-
engine_args.data_parallel_backend = data_parallel_backend
111-
engine_args.async_scheduling = async_scheduling
113+
engine_args = AsyncEngineArgs(
114+
model=model,
115+
enforce_eager=True,
116+
tensor_parallel_size=int(os.getenv("TP_SIZE", 1)),
117+
data_parallel_size=DP_SIZE,
118+
data_parallel_backend=data_parallel_backend,
119+
async_scheduling=async_scheduling,
120+
)
112121
engine = AsyncLLM.from_engine_args(
113122
engine_args, stat_loggers=[SimpleStatsLogger]
114123
)

0 commit comments

Comments
 (0)