From f5f5b8bec3ab1e45488c34d1e5fe6ea80692fb81 Mon Sep 17 00:00:00 2001 From: Wei Wei Date: Fri, 7 Nov 2025 23:35:52 -0800 Subject: [PATCH] add scoping for better trace (#28329) Summary: Add more scoping in the hotspot area which can greatly help us to the find the cycle heavy area Reviewed By: henryoier Differential Revision: D86436375 --- vllm/v1/core/sched/scheduler.py | 116 +++++++++++++++-------------- vllm/v1/engine/core.py | 99 +++++++++++++----------- vllm/v1/engine/llm_engine.py | 35 +++++---- vllm/v1/worker/gpu_model_runner.py | 66 ++++++++-------- 4 files changed, 169 insertions(+), 147 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index f558306e3b2f..d22b47c09af7 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -38,6 +38,7 @@ from vllm.v1.request import Request, RequestStatus from vllm.v1.spec_decode.metrics import SpecDecodingStats from vllm.v1.structured_output import StructuredOutputManager +from vllm.v1.utils import record_function_or_nullcontext logger = init_logger(__name__) @@ -259,49 +260,52 @@ def schedule(self) -> SchedulerOutput: continue # Schedule newly needed KV blocks for the request. - while True: - new_blocks = self.kv_cache_manager.allocate_slots( - request, - num_new_tokens, - num_lookahead_tokens=self.num_lookahead_tokens, - ) - - if new_blocks is not None: - # The request can be scheduled. - break - - # The request cannot be scheduled. - # Preempt the lowest-priority request. - if self.policy == SchedulingPolicy.PRIORITY: - preempted_req = max( - self.running, - key=lambda r: (r.priority, r.arrival_time), + with record_function_or_nullcontext("schedule: allocate_slots"): + while True: + new_blocks = self.kv_cache_manager.allocate_slots( + request, + num_new_tokens, + num_lookahead_tokens=self.num_lookahead_tokens, ) - self.running.remove(preempted_req) - if preempted_req in scheduled_running_reqs: - scheduled_running_reqs.remove(preempted_req) - token_budget += num_scheduled_tokens[preempted_req.request_id] - req_to_new_blocks.pop(preempted_req.request_id) - num_scheduled_tokens.pop(preempted_req.request_id) - req_index -= 1 - else: - preempted_req = self.running.pop() - self.kv_cache_manager.free(preempted_req) - self.encoder_cache_manager.free(preempted_req) - preempted_req.status = RequestStatus.PREEMPTED - preempted_req.num_computed_tokens = 0 - preempted_req.num_preemptions += 1 - if self.log_stats: - preempted_req.record_event( - EngineCoreEventType.PREEMPTED, scheduled_timestamp - ) + if new_blocks is not None: + # The request can be scheduled. + break - self.waiting.prepend_request(preempted_req) - preempted_reqs.append(preempted_req) - if preempted_req == request: - # No more request to preempt. Cannot schedule this request. - break + # The request cannot be scheduled. + # Preempt the lowest-priority request. + if self.policy == SchedulingPolicy.PRIORITY: + preempted_req = max( + self.running, + key=lambda r: (r.priority, r.arrival_time), + ) + self.running.remove(preempted_req) + if preempted_req in scheduled_running_reqs: + scheduled_running_reqs.remove(preempted_req) + token_budget += num_scheduled_tokens[ + preempted_req.request_id + ] + req_to_new_blocks.pop(preempted_req.request_id) + num_scheduled_tokens.pop(preempted_req.request_id) + req_index -= 1 + else: + preempted_req = self.running.pop() + + self.kv_cache_manager.free(preempted_req) + self.encoder_cache_manager.free(preempted_req) + preempted_req.status = RequestStatus.PREEMPTED + preempted_req.num_computed_tokens = 0 + preempted_req.num_preemptions += 1 + if self.log_stats: + preempted_req.record_event( + EngineCoreEventType.PREEMPTED, scheduled_timestamp + ) + + self.waiting.prepend_request(preempted_req) + preempted_reqs.append(preempted_req) + if preempted_req == request: + # No more request to preempt. Cannot schedule this request. + break if new_blocks is None: # Cannot schedule this request. @@ -596,13 +600,14 @@ def schedule(self) -> SchedulerOutput: # Get the longest common prefix among all requests in the running queue. # This can be potentially used for cascade attention. num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups) - if self.running: - any_request = self.running[0] - num_common_prefix_blocks = ( - self.kv_cache_manager.get_num_common_prefix_blocks( - any_request.request_id + with record_function_or_nullcontext("schedule: get_num_common_prefix_blocks"): + if self.running: + any_request = self.running[0] + num_common_prefix_blocks = ( + self.kv_cache_manager.get_num_common_prefix_blocks( + any_request.request_id + ) ) - ) # Construct the scheduler output. new_reqs_data = [ @@ -611,13 +616,14 @@ def schedule(self) -> SchedulerOutput: ) for req in scheduled_new_reqs ] - cached_reqs_data = self._make_cached_request_data( - scheduled_running_reqs, - scheduled_resumed_reqs, - num_scheduled_tokens, - scheduled_spec_decode_tokens, - req_to_new_blocks, - ) + with record_function_or_nullcontext("schedule: make_cached_request_data"): + cached_reqs_data = self._make_cached_request_data( + scheduled_running_reqs, + scheduled_resumed_reqs, + num_scheduled_tokens, + scheduled_spec_decode_tokens, + req_to_new_blocks, + ) # Record the request ids that were scheduled in this step. self.prev_step_scheduled_req_ids.clear() @@ -646,8 +652,8 @@ def schedule(self) -> SchedulerOutput: if self.connector is not None: meta = self.connector.build_connector_meta(scheduler_output) scheduler_output.kv_connector_metadata = meta - - self._update_after_schedule(scheduler_output) + with record_function_or_nullcontext("schedule: update_after_schedule"): + self._update_after_schedule(scheduler_output) return scheduler_output def _update_after_schedule( diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index fba018432e0a..aa540792a04e 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -61,6 +61,7 @@ from vllm.v1.request import Request, RequestStatus from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder from vllm.v1.structured_output import StructuredOutputManager +from vllm.v1.utils import record_function_or_nullcontext from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) @@ -315,17 +316,21 @@ def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]: # or finished and not yet removed from the batch. if not self.scheduler.has_requests(): return {}, False - scheduler_output = self.scheduler.schedule() - future = self.model_executor.execute_model(scheduler_output, non_block=True) - grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output) - with self.log_error_detail(scheduler_output): - model_output = future.result() - if model_output is None: - model_output = self.model_executor.sample_tokens(grammar_output) - - engine_core_outputs = self.scheduler.update_from_output( - scheduler_output, model_output - ) + with record_function_or_nullcontext("core step: schedule"): + scheduler_output = self.scheduler.schedule() + + with record_function_or_nullcontext("core step: execute_model"): + future = self.model_executor.execute_model(scheduler_output, non_block=True) + grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output) + with self.log_error_detail(scheduler_output): + model_output = future.result() + if model_output is None: + model_output = self.model_executor.sample_tokens(grammar_output) + + with record_function_or_nullcontext("core step: update_from_output"): + engine_core_outputs = self.scheduler.update_from_output( + scheduler_output, model_output + ) return engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0 @@ -363,32 +368,37 @@ def step_with_batch_queue( model_executed = False deferred_scheduler_output = None if self.scheduler.has_requests(): - scheduler_output = self.scheduler.schedule() - exec_future = self.model_executor.execute_model( - scheduler_output, non_block=True - ) + with record_function_or_nullcontext("core step_with_batch_queue: schedule"): + scheduler_output = self.scheduler.schedule() + with record_function_or_nullcontext("core step_with_batch_queue: execute_model"): + exec_future = self.model_executor.execute_model( + scheduler_output, non_block=True + ) model_executed = scheduler_output.total_num_scheduled_tokens > 0 if scheduler_output.pending_structured_output_tokens: - # We need to defer sampling until we have processed the model output - # from the prior step. - deferred_scheduler_output = scheduler_output - # Block-wait for execute to return (continues running async on the GPU). - with self.log_error_detail(scheduler_output): - exec_result = exec_future.result() - assert exec_result is None + with record_function_or_nullcontext("core step_with_batch_queue: pending_structured_output_tokens"): + # We need to defer sampling until we have processed the model output + # from the prior step. + deferred_scheduler_output = scheduler_output + # Block-wait for execute to return (continues running async on the GPU). + with self.log_error_detail(scheduler_output): + exec_result = exec_future.result() + assert exec_result is None else: - # We aren't waiting for any tokens, get any grammar output immediately. - grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output) + with record_function_or_nullcontext("core step_with_batch_queue: get_grammar_bitmask"): + # We aren't waiting for any tokens, get any grammar output immediately. + grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output) # Block-wait for execute to return (continues running async on the GPU). with self.log_error_detail(scheduler_output): exec_result = exec_future.result() if exec_result is None: - # Call sample tokens. - future = self.model_executor.sample_tokens( - grammar_output, non_block=True - ) + with record_function_or_nullcontext("core step_with_batch_queue: sample_tokens"): + # Call sample tokens. + future = self.model_executor.sample_tokens( + grammar_output, non_block=True + ) else: # No sampling required (e.g. all requests finished). future = cast(Future[ModelRunnerOutput], exec_future) @@ -408,27 +418,28 @@ def step_with_batch_queue( # only be called when the scheduler contains requests or the queue # is non-empty. return None, False - - # Block until the next result is available. - future, scheduler_output = batch_queue.pop() - with self.log_error_detail(scheduler_output): - model_output = future.result() - - engine_core_outputs = self.scheduler.update_from_output( - scheduler_output, model_output - ) + with record_function_or_nullcontext("core step_with_batch_queue: model_output"): + # Block until the next result is available. + future, scheduler_output = batch_queue.pop() + with self.log_error_detail(scheduler_output): + model_output = future.result() + with record_function_or_nullcontext("core step_with_batch_queue: update_from_output"): + engine_core_outputs = self.scheduler.update_from_output( + scheduler_output, model_output + ) # NOTE(nick): We can either handle the deferred tasks here or save # in a field and do it immediately once step_with_batch_queue is # re-called. The latter slightly favors TTFT over TPOT/throughput. if deferred_scheduler_output: - # We now have the tokens needed to compute the bitmask for the - # deferred request. Get the bitmask and call sample tokens. - grammar_output = self.scheduler.get_grammar_bitmask( - deferred_scheduler_output - ) - future = self.model_executor.sample_tokens(grammar_output, non_block=True) - batch_queue.appendleft((future, deferred_scheduler_output)) + with record_function_or_nullcontext("core step_with_batch_queue: deferred_scheduler_output"): + # We now have the tokens needed to compute the bitmask for the + # deferred request. Get the bitmask and call sample tokens. + grammar_output = self.scheduler.get_grammar_bitmask( + deferred_scheduler_output + ) + future = self.model_executor.sample_tokens(grammar_output, non_block=True) + batch_queue.appendleft((future, deferred_scheduler_output)) return engine_core_outputs, model_executed diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 995642a8356f..23da2ef4b3ae 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -36,6 +36,7 @@ from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager from vllm.v1.metrics.reader import Metric, get_metrics_snapshot from vllm.v1.metrics.stats import IterationStats +from vllm.v1.utils import record_function_or_nullcontext from vllm.v1.worker.worker_base import WorkerBase logger = init_logger(__name__) @@ -280,27 +281,31 @@ def step(self) -> list[RequestOutput | PoolingRequestOutput]: return [] # 1) Get EngineCoreOutput from the EngineCore. - outputs = self.engine_core.get_output() + with record_function_or_nullcontext("llm_genine step: get_output"): + outputs = self.engine_core.get_output() # 2) Process EngineCoreOutputs. - iteration_stats = IterationStats() if self.log_stats else None - processed_outputs = self.output_processor.process_outputs( - outputs.outputs, - engine_core_timestamp=outputs.timestamp, - iteration_stats=iteration_stats, - ) + with record_function_or_nullcontext("llm_genine step: process_outputs"): + iteration_stats = IterationStats() if self.log_stats else None + processed_outputs = self.output_processor.process_outputs( + outputs.outputs, + engine_core_timestamp=outputs.timestamp, + iteration_stats=iteration_stats, + ) # 3) Abort any reqs that finished due to stop strings. - self.engine_core.abort_requests(processed_outputs.reqs_to_abort) + with record_function_or_nullcontext("llm_genine step: abort_requests"): + self.engine_core.abort_requests(processed_outputs.reqs_to_abort) # 4) Record stats - if self.logger_manager is not None and outputs.scheduler_stats is not None: - self.logger_manager.record( - scheduler_stats=outputs.scheduler_stats, - iteration_stats=iteration_stats, - mm_cache_stats=self.processor.stat_mm_cache(), - ) - self.do_log_stats_with_interval() + with record_function_or_nullcontext("llm_genine step: record_stats"): + if self.logger_manager is not None and outputs.scheduler_stats is not None: + self.logger_manager.record( + scheduler_stats=outputs.scheduler_stats, + iteration_stats=iteration_stats, + mm_cache_stats=self.processor.stat_mm_cache(), + ) + self.do_log_stats_with_interval() return processed_outputs.request_outputs diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 2db4235c89de..c7b0c9341716 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2464,7 +2464,7 @@ def execute_model( "after execute_model() returns None." ) num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - with record_function_or_nullcontext("Preprocess"): + with record_function_or_nullcontext("gpu_model_runner: preprocess"): with self.synchronize_input_prep(): # Update persistent batch states. self._update_states(scheduler_output) @@ -2554,7 +2554,7 @@ def execute_model( batch_descriptor=batch_descriptor, ubatch_slices=ubatch_slices, ), - record_function_or_nullcontext("Forward"), + record_function_or_nullcontext("gpu_model_runner: forward"), self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output, ): model_output = self._model_forward( @@ -2565,7 +2565,7 @@ def execute_model( **model_kwargs, ) - with record_function_or_nullcontext("Postprocess"): + with record_function_or_nullcontext("gpu_model_runner: postprocess"): if self.use_aux_hidden_state_outputs: # True when EAGLE 3 is used. hidden_states, aux_hidden_states = model_output @@ -2662,12 +2662,12 @@ def sample_tokens( scheduler_output, grammar_output, self.input_batch, logits ) - with record_function_or_nullcontext("Sample"): + with record_function_or_nullcontext("gpu_model_runner: sample"): sampler_output = self._sample(logits, spec_decode_metadata) def propose_draft_token_ids(sampled_token_ids): assert spec_decode_common_attn_metadata is not None - with record_function_or_nullcontext("Draft"): + with record_function_or_nullcontext("gpu_model_runner: draft"): self._draft_token_ids = self.propose_draft_token_ids( scheduler_output, sampled_token_ids, @@ -2705,7 +2705,7 @@ def propose_draft_token_ids(sampled_token_ids): # as inputs, and does not need to wait for bookkeeping to finish. propose_draft_token_ids(sampler_output.sampled_token_ids) - with record_function_or_nullcontext("Bookkeep"): + with record_function_or_nullcontext("gpu_model_runner: bookkeep"): ( num_nans_in_logits, logprobs_lists, @@ -2732,37 +2732,37 @@ def propose_draft_token_ids(sampled_token_ids): # tokens on the CPU, so they are run after bookkeeping. propose_draft_token_ids(valid_sampled_token_ids) - with record_function_or_nullcontext("EPLB"): + with record_function_or_nullcontext("gpu_model_runner: eplb"): self.eplb_step() - - output = ModelRunnerOutput( - req_ids=req_ids_output_copy, - req_id_to_index=req_id_to_index_output_copy, - sampled_token_ids=valid_sampled_token_ids, - logprobs=logprobs_lists, - prompt_logprobs_dict=prompt_logprobs_dict, - pooler_output=[], - kv_connector_output=kv_connector_output, - num_nans_in_logits=num_nans_in_logits, - ) + with record_function_or_nullcontext("gpu_model_runner: ModelRunnerOutput"): + output = ModelRunnerOutput( + req_ids=req_ids_output_copy, + req_id_to_index=req_id_to_index_output_copy, + sampled_token_ids=valid_sampled_token_ids, + logprobs=logprobs_lists, + prompt_logprobs_dict=prompt_logprobs_dict, + pooler_output=[], + kv_connector_output=kv_connector_output, + num_nans_in_logits=num_nans_in_logits, + ) if not self.use_async_scheduling: return output - - async_output = AsyncGPUModelRunnerOutput( - model_runner_output=output, - sampled_token_ids=sampler_output.sampled_token_ids, - logprobs_tensors=sampler_output.logprobs_tensors, - invalid_req_indices=invalid_req_indices, - async_output_copy_stream=self.async_output_copy_stream, - ) - - # Save ref of sampled_token_ids CPU tensor if the batch contains - # any requests with sampling params that that require output ids. - self.input_batch.set_async_sampled_token_ids( - async_output.sampled_token_ids_cpu, - async_output.async_copy_ready_event, - ) + with record_function_or_nullcontext("gpu_model_runner: AsyncGPUModelRunnerOutput"): + async_output = AsyncGPUModelRunnerOutput( + model_runner_output=output, + sampled_token_ids=sampler_output.sampled_token_ids, + logprobs_tensors=sampler_output.logprobs_tensors, + invalid_req_indices=invalid_req_indices, + async_output_copy_stream=self.async_output_copy_stream, + ) + with record_function_or_nullcontext("gpu_model_runner: set_async_sampled_token_ids"): + # Save ref of sampled_token_ids CPU tensor if the batch contains + # any requests with sampling params that that require output ids. + self.input_batch.set_async_sampled_token_ids( + async_output.sampled_token_ids_cpu, + async_output.async_copy_ready_event, + ) return async_output