Skip to content

Commit da786e3

Browse files
authored
[Core] Rework handling of async scheduling config (#28250)
Signed-off-by: Nick Hill <nhill@redhat.com>
1 parent 1890321 commit da786e3

File tree

6 files changed

+121
-71
lines changed

6 files changed

+121
-71
lines changed

tests/v1/engine/test_engine_core.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def test_engine_core():
6666
assert len(engine_core.scheduler.waiting) == 1
6767
assert len(engine_core.scheduler.running) == 0
6868

69-
_ = engine_core.step()
69+
_ = engine_core.step_fn()
7070
assert len(engine_core.scheduler.waiting) == 0
7171
assert len(engine_core.scheduler.running) == 1
7272

@@ -75,7 +75,7 @@ def test_engine_core():
7575
assert len(engine_core.scheduler.waiting) == 1
7676
assert len(engine_core.scheduler.running) == 1
7777

78-
_ = engine_core.step()
78+
_ = engine_core.step_fn()
7979
assert len(engine_core.scheduler.waiting) == 0
8080
assert len(engine_core.scheduler.running) == 2
8181

@@ -85,12 +85,12 @@ def test_engine_core():
8585
assert len(engine_core.scheduler.waiting) == 2
8686
assert len(engine_core.scheduler.running) == 2
8787

88-
_ = engine_core.step()
88+
_ = engine_core.step_fn()
8989
assert len(engine_core.scheduler.waiting) == 0
9090
assert len(engine_core.scheduler.running) == 4
9191

9292
# Loop through until they are all done.
93-
while (outs := engine_core.step()[0].get(0)) and outs.outputs:
93+
while (outs := engine_core.step_fn()[0].get(0)) and outs.outputs:
9494
pass
9595

9696
assert len(engine_core.scheduler.waiting) == 0
@@ -107,7 +107,7 @@ def test_engine_core():
107107
assert engine_core.scheduler.has_unfinished_requests()
108108
assert not engine_core.scheduler.has_finished_requests()
109109

110-
_ = engine_core.step()
110+
_ = engine_core.step_fn()
111111
assert len(engine_core.scheduler.waiting) == 0
112112
assert len(engine_core.scheduler.running) == 1
113113
assert engine_core.scheduler.has_unfinished_requests()
@@ -119,7 +119,7 @@ def test_engine_core():
119119
assert not engine_core.scheduler.has_unfinished_requests()
120120
assert engine_core.scheduler.has_finished_requests()
121121

122-
_ = engine_core.step()
122+
_ = engine_core.step_fn()
123123
assert not engine_core.scheduler.has_unfinished_requests()
124124
assert not engine_core.scheduler.has_finished_requests()
125125

@@ -133,15 +133,15 @@ def test_engine_core():
133133
assert len(engine_core.scheduler.waiting) == 2
134134
assert len(engine_core.scheduler.running) == 0
135135

136-
_ = engine_core.step()
136+
_ = engine_core.step_fn()
137137
assert len(engine_core.scheduler.waiting) == 0
138138
assert len(engine_core.scheduler.running) == 2
139139

140140
engine_core.add_request(*engine_core.preprocess_add_request(req2))
141141
assert len(engine_core.scheduler.waiting) == 1
142142
assert len(engine_core.scheduler.running) == 2
143143

144-
_ = engine_core.step()
144+
_ = engine_core.step_fn()
145145
assert len(engine_core.scheduler.waiting) == 0
146146
assert len(engine_core.scheduler.running) == 3
147147

@@ -150,7 +150,7 @@ def test_engine_core():
150150
assert len(engine_core.scheduler.waiting) == 0
151151
assert len(engine_core.scheduler.running) == 2
152152

153-
_ = engine_core.step()
153+
_ = engine_core.step_fn()
154154
assert len(engine_core.scheduler.waiting) == 0
155155
assert len(engine_core.scheduler.running) == 2
156156

@@ -165,12 +165,12 @@ def test_engine_core():
165165
req0.request_id = req1.request_id = "test"
166166
engine_core.add_request(*engine_core.preprocess_add_request(req0))
167167

168-
while (outs := engine_core.step()[0].get(0)) and outs.outputs:
169-
pass
168+
while engine_core.scheduler.has_requests():
169+
engine_core.step_fn()
170170

171171
engine_core.add_request(*engine_core.preprocess_add_request(req1))
172-
while (outs := engine_core.step()[0].get(0)) and outs.outputs:
173-
pass
172+
while engine_core.scheduler.has_requests():
173+
engine_core.step_fn()
174174

175175
assert len(engine_core.scheduler.waiting) == 0
176176
assert len(engine_core.scheduler.running) == 0
@@ -208,8 +208,8 @@ def _check_engine_state():
208208
assert len(engine_core.scheduler.waiting) == 1
209209
assert len(engine_core.scheduler.running) == 0
210210
# Loop through until they are all done.
211-
while (outs := engine_core.step()[0].get(0)) and outs.outputs:
212-
pass
211+
while engine_core.scheduler.has_requests():
212+
engine_core.step_fn()
213213
assert len(engine_core.scheduler.waiting) == 0
214214
assert len(engine_core.scheduler.running) == 0
215215

@@ -297,6 +297,8 @@ def shutdown(self):
297297
max_num_batched_tokens=10,
298298
# Reduce startup time.
299299
enforce_eager=True,
300+
# Test concurrent batch behaviour independently of async scheduling.
301+
async_scheduling=False,
300302
)
301303
vllm_config = engine_args.create_engine_config()
302304
with set_default_torch_num_threads(1):

vllm/config/scheduler.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import hashlib
55
from collections.abc import Callable
66
from dataclasses import InitVar
7-
from typing import Any, Literal
7+
from typing import TYPE_CHECKING, Any, Literal, cast
88

99
from pydantic import Field, field_validator, model_validator
1010
from pydantic.dataclasses import dataclass
@@ -17,6 +17,10 @@
1717
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
1818
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS,
1919
)
20+
from vllm.utils.import_utils import resolve_obj_by_qualname
21+
22+
if TYPE_CHECKING:
23+
from vllm.v1.core.sched.interface import SchedulerInterface
2024

2125
logger = init_logger(__name__)
2226

@@ -120,7 +124,7 @@ class SchedulerConfig:
120124

121125
# scheduler class or path. "vllm.v1.core.sched.scheduler.Scheduler"
122126
# (default) or "mod.custom_class".
123-
scheduler_cls: str | type[object] = "vllm.v1.core.sched.scheduler.Scheduler"
127+
scheduler_cls: str | type[object] = Field(default=None)
124128
"""The scheduler class to use. "vllm.v1.core.sched.scheduler.Scheduler" is
125129
the default scheduler. Can be a class directly or the path to a class of
126130
form "mod.custom_class"."""
@@ -132,12 +136,34 @@ class SchedulerConfig:
132136
"""
133137

134138
async_scheduling: bool = False
135-
"""EXPERIMENTAL: If set to True, perform async scheduling. This may help
136-
reduce the CPU overheads, leading to better latency and throughput. However,
137-
async scheduling is currently not supported with some features such as
138-
structured outputs, speculative decoding, and pipeline parallelism.
139+
"""If set to True, perform async scheduling. This helps to avoid gaps in
140+
GPU utilization, leading to better latency and throughput.
141+
Async scheduling is currently not supported with some features such as
142+
speculative decoding and pipeline parallelism.
139143
"""
140144

145+
def get_scheduler_cls(self) -> type["SchedulerInterface"]:
146+
if self.scheduler_cls is None:
147+
if self.async_scheduling:
148+
from vllm.v1.core.sched.async_scheduler import AsyncScheduler
149+
150+
return AsyncScheduler
151+
from vllm.v1.core.sched.scheduler import Scheduler
152+
153+
return Scheduler
154+
155+
# This warning can be removed once the Scheduler interface is
156+
# finalized and we can maintain support for scheduler classes that
157+
# implement it
158+
logger.warning_once(
159+
"Using custom scheduler class %s. This scheduler interface is "
160+
"not public and compatibility may not be maintained.",
161+
self.scheduler_cls,
162+
)
163+
if not isinstance(self.scheduler_cls, str):
164+
return cast(type["SchedulerInterface"], self.scheduler_cls)
165+
return resolve_obj_by_qualname(self.scheduler_cls)
166+
141167
def compute_hash(self) -> str:
142168
"""
143169
WARNING: Whenever a new field is added to this config,
@@ -161,6 +187,8 @@ def compute_hash(self) -> str:
161187
"max_num_seqs",
162188
"max_model_len",
163189
"enable_chunked_prefill",
190+
"scheduler_cls",
191+
"async_scheduling",
164192
mode="wrap",
165193
)
166194
@classmethod
@@ -242,9 +270,6 @@ def __post_init__(self, is_encoder_decoder: bool) -> None:
242270
self.long_prefill_token_threshold,
243271
)
244272

245-
if self.async_scheduling:
246-
self.scheduler_cls = "vllm.v1.core.sched.async_scheduler.AsyncScheduler"
247-
248273
@model_validator(mode="after")
249274
def _verify_args(self) -> Self:
250275
if (

vllm/config/vllm.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,53 @@ def __post_init__(self):
353353
self.model_config, self.load_config
354354
)
355355

356+
executor_backend = self.parallel_config.distributed_executor_backend
357+
executor_supports_async_sched = executor_backend in (
358+
"mp",
359+
"uni",
360+
"external_launcher",
361+
)
362+
363+
if self.scheduler_config.async_scheduling:
364+
# Async scheduling explicitly enabled, hard fail any incompatibilities.
365+
if self.parallel_config.pipeline_parallel_size > 1:
366+
raise ValueError(
367+
"Async scheduling is not yet compatible with "
368+
"pipeline_parallel_size > 1."
369+
)
370+
if self.speculative_config is not None:
371+
raise ValueError(
372+
"Async scheduling is not yet compatible with speculative decoding."
373+
)
374+
if not executor_supports_async_sched:
375+
raise ValueError(
376+
"Currently, async scheduling only supports `mp`, `uni`, or "
377+
"`external_launcher` distributed executor backend, but you chose "
378+
f"`{executor_backend}`."
379+
)
380+
elif self.scheduler_config.async_scheduling is None:
381+
# Enable async scheduling unless there is an incompatible option.
382+
# NOTE: we won't reach here until async scheduling is enabled by default.
383+
if (
384+
self.parallel_config.pipeline_parallel_size > 1
385+
or self.speculative_config is not None
386+
):
387+
logger.warning(
388+
"Async scheduling is not yet supported with speculative decoding "
389+
" or pipeline_parallel_size > 1 and will be disabled."
390+
)
391+
self.scheduler_config.async_scheduling = False
392+
elif not executor_supports_async_sched:
393+
logger.warning(
394+
"Async scheduling will be disabled because it is not supported "
395+
"with the `%s` distributed executor backend (only `mp`, `uni`, and "
396+
"`external_launcher` are supported).",
397+
executor_backend,
398+
)
399+
self.scheduler_config.async_scheduling = False
400+
else:
401+
self.scheduler_config.async_scheduling = True
402+
356403
from vllm.platforms import current_platform
357404

358405
if (
@@ -467,7 +514,7 @@ def __post_init__(self):
467514
self.speculative_config is not None
468515
and self.speculative_config.use_eagle()
469516
):
470-
raise NotImplementedError(
517+
raise ValueError(
471518
"Fast prefill optimization for KV sharing is not "
472519
"compatible with EAGLE as EAGLE requires correct logits "
473520
"for all tokens while fast prefill gives incorrect logits "
@@ -491,7 +538,7 @@ def __post_init__(self):
491538
)
492539
if not getattr(self.model_config.hf_config, "is_causal", True):
493540
disable_chunked_prefill_reasons.append(
494-
"Only models using causal attention supports chunked "
541+
"Only models using causal attention support chunked "
495542
"prefill and prefix caching; disabling both."
496543
)
497544
elif self.model_config.is_encoder_decoder:

vllm/engine/arg_utils.py

Lines changed: 2 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -513,7 +513,7 @@ class EngineArgs:
513513
ObservabilityConfig.collect_detailed_traces
514514
)
515515
scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
516-
scheduler_cls: str | type[object] = SchedulerConfig.scheduler_cls
516+
scheduler_cls: str | type[object] | None = SchedulerConfig.scheduler_cls
517517

518518
pooler_config: PoolerConfig | None = ModelConfig.pooler_config
519519
override_pooler_config: dict | PoolerConfig | None = (
@@ -552,7 +552,7 @@ class EngineArgs:
552552
)
553553
"""Custom logitproc types"""
554554

555-
async_scheduling: bool = SchedulerConfig.async_scheduling
555+
async_scheduling: bool | None = SchedulerConfig.async_scheduling
556556

557557
kv_sharing_fast_prefill: bool = CacheConfig.kv_sharing_fast_prefill
558558

@@ -1479,20 +1479,6 @@ def create_engine_config(
14791479
else ParallelConfig.data_parallel_rpc_port
14801480
)
14811481

1482-
if self.async_scheduling:
1483-
if self.pipeline_parallel_size > 1:
1484-
raise ValueError(
1485-
"Async scheduling is not supported with pipeline-parallel-size > 1."
1486-
)
1487-
1488-
# Currently, async scheduling does not support speculative decoding.
1489-
# TODO(woosuk): Support it.
1490-
if self.speculative_config is not None:
1491-
raise ValueError(
1492-
"Currently, speculative decoding is not supported with "
1493-
"async scheduling."
1494-
)
1495-
14961482
# Forward the deprecated CLI args to the EPLB config.
14971483
if self.num_redundant_experts is not None:
14981484
self.eplb_config.num_redundant_experts = self.num_redundant_experts
@@ -1536,16 +1522,6 @@ def create_engine_config(
15361522
_api_process_rank=self._api_process_rank,
15371523
)
15381524

1539-
if self.async_scheduling and (
1540-
parallel_config.distributed_executor_backend
1541-
not in ("mp", "uni", "external_launcher")
1542-
):
1543-
raise ValueError(
1544-
"Currently, async scheduling only supports `mp`, `uni` or "
1545-
"`external_launcher` distributed executor backend, but you choose "
1546-
f"`{parallel_config.distributed_executor_backend}`."
1547-
)
1548-
15491525
speculative_config = self.create_speculative_config(
15501526
target_model_config=model_config,
15511527
target_parallel_config=parallel_config,

vllm/v1/core/sched/interface.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,34 @@
44
from collections.abc import Iterable
55
from typing import TYPE_CHECKING, Optional
66

7+
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
8+
79
if TYPE_CHECKING:
10+
from vllm.config import VllmConfig
811
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
912
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
1013
from vllm.v1.engine import EngineCoreOutputs
14+
from vllm.v1.kv_cache_interface import KVCacheConfig
1115
from vllm.v1.metrics.stats import SchedulerStats
1216
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
1317
from vllm.v1.request import Request, RequestStatus
18+
from vllm.v1.structured_output import StructuredOutputManager
1419

1520

1621
class SchedulerInterface(ABC):
22+
@abstractmethod
23+
def __init__(
24+
self,
25+
vllm_config: "VllmConfig",
26+
kv_cache_config: "KVCacheConfig",
27+
structured_output_manager: "StructuredOutputManager",
28+
block_size: int,
29+
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
30+
include_finished_set: bool = False,
31+
log_stats: bool = False,
32+
) -> None:
33+
raise NotImplementedError
34+
1735
@abstractmethod
1836
def schedule(self) -> "SchedulerOutput":
1937
"""Schedule the requests to process in this scheduling step.

0 commit comments

Comments
 (0)