3838from vllm .v1 .request import Request , RequestStatus
3939from vllm .v1 .spec_decode .metrics import SpecDecodingStats
4040from vllm .v1 .structured_output import StructuredOutputManager
41+ from vllm .v1 .utils import record_function_or_nullcontext
4142
4243logger = init_logger (__name__ )
4344
@@ -259,49 +260,50 @@ def schedule(self) -> SchedulerOutput:
259260 continue
260261
261262 # Schedule newly needed KV blocks for the request.
262- while True :
263- new_blocks = self .kv_cache_manager .allocate_slots (
264- request ,
265- num_new_tokens ,
266- num_lookahead_tokens = self .num_lookahead_tokens ,
267- )
268-
269- if new_blocks is not None :
270- # The request can be scheduled.
271- break
272-
273- # The request cannot be scheduled.
274- # Preempt the lowest-priority request.
275- if self .policy == SchedulingPolicy .PRIORITY :
276- preempted_req = max (
277- self .running ,
278- key = lambda r : (r .priority , r .arrival_time ),
263+ with record_function_or_nullcontext ("SCHEDULE: ALLOCATE_SLOTS" ):
264+ while True :
265+ new_blocks = self .kv_cache_manager .allocate_slots (
266+ request ,
267+ num_new_tokens ,
268+ num_lookahead_tokens = self .num_lookahead_tokens ,
279269 )
280- self .running .remove (preempted_req )
281- if preempted_req in scheduled_running_reqs :
282- scheduled_running_reqs .remove (preempted_req )
283- token_budget += num_scheduled_tokens [preempted_req .request_id ]
284- req_to_new_blocks .pop (preempted_req .request_id )
285- num_scheduled_tokens .pop (preempted_req .request_id )
286- req_index -= 1
287- else :
288- preempted_req = self .running .pop ()
289270
290- self .kv_cache_manager .free (preempted_req )
291- self .encoder_cache_manager .free (preempted_req )
292- preempted_req .status = RequestStatus .PREEMPTED
293- preempted_req .num_computed_tokens = 0
294- preempted_req .num_preemptions += 1
295- if self .log_stats :
296- preempted_req .record_event (
297- EngineCoreEventType .PREEMPTED , scheduled_timestamp
298- )
271+ if new_blocks is not None :
272+ # The request can be scheduled.
273+ break
299274
300- self .waiting .prepend_request (preempted_req )
301- preempted_reqs .append (preempted_req )
302- if preempted_req == request :
303- # No more request to preempt. Cannot schedule this request.
304- break
275+ # The request cannot be scheduled.
276+ # Preempt the lowest-priority request.
277+ if self .policy == SchedulingPolicy .PRIORITY :
278+ preempted_req = max (
279+ self .running ,
280+ key = lambda r : (r .priority , r .arrival_time ),
281+ )
282+ self .running .remove (preempted_req )
283+ if preempted_req in scheduled_running_reqs :
284+ scheduled_running_reqs .remove (preempted_req )
285+ token_budget += num_scheduled_tokens [preempted_req .request_id ]
286+ req_to_new_blocks .pop (preempted_req .request_id )
287+ num_scheduled_tokens .pop (preempted_req .request_id )
288+ req_index -= 1
289+ else :
290+ preempted_req = self .running .pop ()
291+
292+ self .kv_cache_manager .free (preempted_req )
293+ self .encoder_cache_manager .free (preempted_req )
294+ preempted_req .status = RequestStatus .PREEMPTED
295+ preempted_req .num_computed_tokens = 0
296+ preempted_req .num_preemptions += 1
297+ if self .log_stats :
298+ preempted_req .record_event (
299+ EngineCoreEventType .PREEMPTED , scheduled_timestamp
300+ )
301+
302+ self .waiting .prepend_request (preempted_req )
303+ preempted_reqs .append (preempted_req )
304+ if preempted_req == request :
305+ # No more request to preempt. Cannot schedule this request.
306+ break
305307
306308 if new_blocks is None :
307309 # Cannot schedule this request.
@@ -596,13 +598,14 @@ def schedule(self) -> SchedulerOutput:
596598 # Get the longest common prefix among all requests in the running queue.
597599 # This can be potentially used for cascade attention.
598600 num_common_prefix_blocks = [0 ] * len (self .kv_cache_config .kv_cache_groups )
599- if self .running :
600- any_request = self .running [0 ]
601- num_common_prefix_blocks = (
602- self .kv_cache_manager .get_num_common_prefix_blocks (
603- any_request .request_id
601+ with record_function_or_nullcontext ("SCHEDULE: GET_NUM_COMMON_PREFIX_BLOCKS" ):
602+ if self .running :
603+ any_request = self .running [0 ]
604+ num_common_prefix_blocks = (
605+ self .kv_cache_manager .get_num_common_prefix_blocks (
606+ any_request .request_id
607+ )
604608 )
605- )
606609
607610 # Construct the scheduler output.
608611 new_reqs_data = [
@@ -611,13 +614,14 @@ def schedule(self) -> SchedulerOutput:
611614 )
612615 for req in scheduled_new_reqs
613616 ]
614- cached_reqs_data = self ._make_cached_request_data (
615- scheduled_running_reqs ,
616- scheduled_resumed_reqs ,
617- num_scheduled_tokens ,
618- scheduled_spec_decode_tokens ,
619- req_to_new_blocks ,
620- )
617+ with record_function_or_nullcontext ("SCHEDULE: MAKE_CACHED_REQUEST_DATA" ):
618+ cached_reqs_data = self ._make_cached_request_data (
619+ scheduled_running_reqs ,
620+ scheduled_resumed_reqs ,
621+ num_scheduled_tokens ,
622+ scheduled_spec_decode_tokens ,
623+ req_to_new_blocks ,
624+ )
621625
622626 # Record the request ids that were scheduled in this step.
623627 self .prev_step_scheduled_req_ids .clear ()
@@ -646,8 +650,8 @@ def schedule(self) -> SchedulerOutput:
646650 if self .connector is not None :
647651 meta = self .connector .build_connector_meta (scheduler_output )
648652 scheduler_output .kv_connector_metadata = meta
649-
650- self ._update_after_schedule (scheduler_output )
653+ with record_function_or_nullcontext ( "SCHEDULE: UPDATE_AFTER_SCHEDULE" ):
654+ self ._update_after_schedule (scheduler_output )
651655 return scheduler_output
652656
653657 def _update_after_schedule (
0 commit comments