Skip to content

Commit eaacf3b

Browse files
frank-weifacebook-github-bot
authored andcommitted
add scoping for better trace
Summary: Add more scoping in the hotspot area which can greatly help us to the find the cycle heavy area Differential Revision: D86436375
1 parent 67a2da8 commit eaacf3b

File tree

3 files changed

+99
-86
lines changed

3 files changed

+99
-86
lines changed

vllm/v1/core/sched/scheduler.py

Lines changed: 61 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
from vllm.logger import init_logger
1919
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
2020
from vllm.v1.core.encoder_cache_manager import (
21-
EncoderCacheManager,
2221
compute_encoder_budget,
22+
EncoderCacheManager,
2323
)
2424
from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager
2525
from vllm.v1.core.sched.interface import SchedulerInterface
@@ -29,7 +29,7 @@
2929
NewRequestData,
3030
SchedulerOutput,
3131
)
32-
from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue
32+
from vllm.v1.core.sched.request_queue import create_request_queue, SchedulingPolicy
3333
from vllm.v1.core.sched.utils import check_stop, remove_all
3434
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs
3535
from vllm.v1.kv_cache_interface import KVCacheConfig
@@ -38,6 +38,7 @@
3838
from vllm.v1.request import Request, RequestStatus
3939
from vllm.v1.spec_decode.metrics import SpecDecodingStats
4040
from vllm.v1.structured_output import StructuredOutputManager
41+
from vllm.v1.utils import record_function_or_nullcontext
4142

4243
logger = init_logger(__name__)
4344

@@ -259,49 +260,50 @@ def schedule(self) -> SchedulerOutput:
259260
continue
260261

261262
# Schedule newly needed KV blocks for the request.
262-
while True:
263-
new_blocks = self.kv_cache_manager.allocate_slots(
264-
request,
265-
num_new_tokens,
266-
num_lookahead_tokens=self.num_lookahead_tokens,
267-
)
268-
269-
if new_blocks is not None:
270-
# The request can be scheduled.
271-
break
272-
273-
# The request cannot be scheduled.
274-
# Preempt the lowest-priority request.
275-
if self.policy == SchedulingPolicy.PRIORITY:
276-
preempted_req = max(
277-
self.running,
278-
key=lambda r: (r.priority, r.arrival_time),
263+
with record_function_or_nullcontext("SCHEDULE: ALLOCATE_SLOTS"):
264+
while True:
265+
new_blocks = self.kv_cache_manager.allocate_slots(
266+
request,
267+
num_new_tokens,
268+
num_lookahead_tokens=self.num_lookahead_tokens,
279269
)
280-
self.running.remove(preempted_req)
281-
if preempted_req in scheduled_running_reqs:
282-
scheduled_running_reqs.remove(preempted_req)
283-
token_budget += num_scheduled_tokens[preempted_req.request_id]
284-
req_to_new_blocks.pop(preempted_req.request_id)
285-
num_scheduled_tokens.pop(preempted_req.request_id)
286-
req_index -= 1
287-
else:
288-
preempted_req = self.running.pop()
289270

290-
self.kv_cache_manager.free(preempted_req)
291-
self.encoder_cache_manager.free(preempted_req)
292-
preempted_req.status = RequestStatus.PREEMPTED
293-
preempted_req.num_computed_tokens = 0
294-
preempted_req.num_preemptions += 1
295-
if self.log_stats:
296-
preempted_req.record_event(
297-
EngineCoreEventType.PREEMPTED, scheduled_timestamp
298-
)
271+
if new_blocks is not None:
272+
# The request can be scheduled.
273+
break
299274

300-
self.waiting.prepend_request(preempted_req)
301-
preempted_reqs.append(preempted_req)
302-
if preempted_req == request:
303-
# No more request to preempt. Cannot schedule this request.
304-
break
275+
# The request cannot be scheduled.
276+
# Preempt the lowest-priority request.
277+
if self.policy == SchedulingPolicy.PRIORITY:
278+
preempted_req = max(
279+
self.running,
280+
key=lambda r: (r.priority, r.arrival_time),
281+
)
282+
self.running.remove(preempted_req)
283+
if preempted_req in scheduled_running_reqs:
284+
scheduled_running_reqs.remove(preempted_req)
285+
token_budget += num_scheduled_tokens[preempted_req.request_id]
286+
req_to_new_blocks.pop(preempted_req.request_id)
287+
num_scheduled_tokens.pop(preempted_req.request_id)
288+
req_index -= 1
289+
else:
290+
preempted_req = self.running.pop()
291+
292+
self.kv_cache_manager.free(preempted_req)
293+
self.encoder_cache_manager.free(preempted_req)
294+
preempted_req.status = RequestStatus.PREEMPTED
295+
preempted_req.num_computed_tokens = 0
296+
preempted_req.num_preemptions += 1
297+
if self.log_stats:
298+
preempted_req.record_event(
299+
EngineCoreEventType.PREEMPTED, scheduled_timestamp
300+
)
301+
302+
self.waiting.prepend_request(preempted_req)
303+
preempted_reqs.append(preempted_req)
304+
if preempted_req == request:
305+
# No more request to preempt. Cannot schedule this request.
306+
break
305307

306308
if new_blocks is None:
307309
# Cannot schedule this request.
@@ -596,13 +598,14 @@ def schedule(self) -> SchedulerOutput:
596598
# Get the longest common prefix among all requests in the running queue.
597599
# This can be potentially used for cascade attention.
598600
num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups)
599-
if self.running:
600-
any_request = self.running[0]
601-
num_common_prefix_blocks = (
602-
self.kv_cache_manager.get_num_common_prefix_blocks(
603-
any_request.request_id
601+
with record_function_or_nullcontext("SCHEDULE: GET_NUM_COMMON_PREFIX_BLOCKS"):
602+
if self.running:
603+
any_request = self.running[0]
604+
num_common_prefix_blocks = (
605+
self.kv_cache_manager.get_num_common_prefix_blocks(
606+
any_request.request_id
607+
)
604608
)
605-
)
606609

607610
# Construct the scheduler output.
608611
new_reqs_data = [
@@ -611,13 +614,14 @@ def schedule(self) -> SchedulerOutput:
611614
)
612615
for req in scheduled_new_reqs
613616
]
614-
cached_reqs_data = self._make_cached_request_data(
615-
scheduled_running_reqs,
616-
scheduled_resumed_reqs,
617-
num_scheduled_tokens,
618-
scheduled_spec_decode_tokens,
619-
req_to_new_blocks,
620-
)
617+
with record_function_or_nullcontext("SCHEDULE: MAKE_CACHED_REQUEST_DATA"):
618+
cached_reqs_data = self._make_cached_request_data(
619+
scheduled_running_reqs,
620+
scheduled_resumed_reqs,
621+
num_scheduled_tokens,
622+
scheduled_spec_decode_tokens,
623+
req_to_new_blocks,
624+
)
621625

622626
# Record the request ids that were scheduled in this step.
623627
self.prev_step_scheduled_req_ids.clear()
@@ -646,8 +650,8 @@ def schedule(self) -> SchedulerOutput:
646650
if self.connector is not None:
647651
meta = self.connector.build_connector_meta(scheduler_output)
648652
scheduler_output.kv_connector_metadata = meta
649-
650-
self._update_after_schedule(scheduler_output)
653+
with record_function_or_nullcontext("SCHEDULE: UPDATE_AFTER_SCHEDULE"):
654+
self._update_after_schedule(scheduler_output)
651655
return scheduler_output
652656

653657
def _update_after_schedule(

vllm/v1/engine/core.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99
from collections import deque
1010
from collections.abc import Callable, Generator
1111
from concurrent.futures import Future
12-
from contextlib import ExitStack, contextmanager
12+
from contextlib import contextmanager, ExitStack
1313
from inspect import isclass, signature
1414
from logging import DEBUG
15-
from typing import Any, TypeVar, cast
15+
from typing import Any, cast, TypeVar
1616

1717
import msgspec
1818
import zmq
@@ -61,6 +61,7 @@
6161
from vllm.v1.request import Request, RequestStatus
6262
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
6363
from vllm.v1.structured_output import StructuredOutputManager
64+
from vllm.v1.utils import record_function_or_nullcontext
6465
from vllm.version import __version__ as VLLM_VERSION
6566

6667
logger = init_logger(__name__)
@@ -315,17 +316,21 @@ def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
315316
# or finished and not yet removed from the batch.
316317
if not self.scheduler.has_requests():
317318
return {}, False
318-
scheduler_output = self.scheduler.schedule()
319-
future = self.model_executor.execute_model(scheduler_output, non_block=True)
320-
grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
321-
with self.log_error_detail(scheduler_output):
322-
model_output = future.result()
323-
if model_output is None:
324-
model_output = self.model_executor.sample_tokens(grammar_output)
319+
with record_function_or_nullcontext("CORE STEP: SCHEDULE"):
320+
scheduler_output = self.scheduler.schedule()
325321

326-
engine_core_outputs = self.scheduler.update_from_output(
327-
scheduler_output, model_output
328-
)
322+
with record_function_or_nullcontext("CORE STEP: EXECUTE_MODEL"):
323+
future = self.model_executor.execute_model(scheduler_output, non_block=True)
324+
grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
325+
with self.log_error_detail(scheduler_output):
326+
model_output = future.result()
327+
if model_output is None:
328+
model_output = self.model_executor.sample_tokens(grammar_output)
329+
330+
with record_function_or_nullcontext("CORE STEP: UPDATE_FROM_OUTPUT"):
331+
engine_core_outputs = self.scheduler.update_from_output(
332+
scheduler_output, model_output
333+
)
329334

330335
return engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0
331336

vllm/v1/engine/llm_engine.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import torch.nn as nn
1010
from typing_extensions import TypeVar
11-
11+
from vllm.v1.utils import record_function_or_nullcontext
1212
import vllm.envs as envs
1313
from vllm.config import ParallelConfig, VllmConfig
1414
from vllm.distributed import stateless_destroy_torch_distributed_process_group
@@ -273,34 +273,38 @@ def add_request(
273273
# Add the request to EngineCore.
274274
self.engine_core.add_request(child_request)
275275

276-
def step(self) -> list[RequestOutput | PoolingRequestOutput]:
276+
def step(self) -> list[RequestOutput] | list[PoolingRequestOutput]:
277277
if self.should_execute_dummy_batch:
278278
self.should_execute_dummy_batch = False
279279
self.engine_core.execute_dummy_batch()
280280
return []
281281

282282
# 1) Get EngineCoreOutput from the EngineCore.
283-
outputs = self.engine_core.get_output()
283+
with record_function_or_nullcontext("LLM_ENGINE STEP: GET_OUTPUT"):
284+
outputs = self.engine_core.get_output()
284285

285286
# 2) Process EngineCoreOutputs.
286-
iteration_stats = IterationStats() if self.log_stats else None
287-
processed_outputs = self.output_processor.process_outputs(
288-
outputs.outputs,
289-
engine_core_timestamp=outputs.timestamp,
290-
iteration_stats=iteration_stats,
291-
)
287+
with record_function_or_nullcontext("LLM_ENGINE STEP: PROCESS_OUTPUTS"):
288+
iteration_stats = IterationStats() if self.log_stats else None
289+
processed_outputs = self.output_processor.process_outputs(
290+
outputs.outputs,
291+
engine_core_timestamp=outputs.timestamp,
292+
iteration_stats=iteration_stats,
293+
)
292294

293295
# 3) Abort any reqs that finished due to stop strings.
294-
self.engine_core.abort_requests(processed_outputs.reqs_to_abort)
296+
with record_function_or_nullcontext("LLM_ENGINE STEP: ABORT_REQUESTS"):
297+
self.engine_core.abort_requests(processed_outputs.reqs_to_abort)
295298

296299
# 4) Record stats
297-
if self.logger_manager is not None and outputs.scheduler_stats is not None:
298-
self.logger_manager.record(
299-
scheduler_stats=outputs.scheduler_stats,
300-
iteration_stats=iteration_stats,
301-
mm_cache_stats=self.processor.stat_mm_cache(),
302-
)
303-
self.do_log_stats_with_interval()
300+
with record_function_or_nullcontext("LLM_ENGINE STEP:: RECORD_STATS"):
301+
if self.logger_manager is not None and outputs.scheduler_stats is not None:
302+
self.logger_manager.record(
303+
scheduler_stats=outputs.scheduler_stats,
304+
iteration_stats=iteration_stats,
305+
mm_cache_stats=self.processor.stat_mm_cache(),
306+
)
307+
self.do_log_stats_with_interval()
304308

305309
return processed_outputs.request_outputs
306310

0 commit comments

Comments
 (0)