2020
2121DP_SIZE = int (os .getenv ("DP_SIZE" , 2 ))
2222
23- engine_args = AsyncEngineArgs (
24- model = "ibm-research/PowerMoE-3b" ,
25- enforce_eager = True ,
26- tensor_parallel_size = int (os .getenv ("TP_SIZE" , 1 )),
27- data_parallel_size = DP_SIZE ,
28- )
29-
3023
3124async def generate (
3225 engine : AsyncLLM ,
@@ -65,6 +58,13 @@ async def generate(
6558 return count , request_id
6659
6760
61+ @pytest .mark .parametrize (
62+ "model" ,
63+ [
64+ "ibm-research/PowerMoE-3b" ,
65+ "hmellor/tiny-random-LlamaForCausalLM" ,
66+ ],
67+ )
6868@pytest .mark .parametrize (
6969 "output_kind" ,
7070 [
@@ -76,7 +76,10 @@ async def generate(
7676@pytest .mark .parametrize ("async_scheduling" , [True , False ])
7777@pytest .mark .asyncio
7878async def test_load (
79- output_kind : RequestOutputKind , data_parallel_backend : str , async_scheduling : bool
79+ model : str ,
80+ output_kind : RequestOutputKind ,
81+ data_parallel_backend : str ,
82+ async_scheduling : bool ,
8083):
8184 if async_scheduling and data_parallel_backend == "ray" :
8285 # TODO(NickLucche) Re-enable when async scheduling is supported
@@ -107,8 +110,14 @@ def log_engine_initialized(self):
107110 with ExitStack () as after :
108111 prompt = "This is a test of data parallel"
109112
110- engine_args .data_parallel_backend = data_parallel_backend
111- engine_args .async_scheduling = async_scheduling
113+ engine_args = AsyncEngineArgs (
114+ model = model ,
115+ enforce_eager = True ,
116+ tensor_parallel_size = int (os .getenv ("TP_SIZE" , 1 )),
117+ data_parallel_size = DP_SIZE ,
118+ data_parallel_backend = data_parallel_backend ,
119+ async_scheduling = async_scheduling ,
120+ )
112121 engine = AsyncLLM .from_engine_args (
113122 engine_args , stat_loggers = [SimpleStatsLogger ]
114123 )
0 commit comments