diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index 4ea0fdd4e..c5b38bda6 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -9,7 +9,7 @@ import sys import time from dataclasses import dataclass, field, fields -from typing import (TYPE_CHECKING, Any, Callable, Optional, TypeAlias, Union, cast) +from typing import (TYPE_CHECKING, Any, NamedTuple, Callable, Optional, TypeAlias, Union, cast) import habana_frameworks.torch as htorch import habana_frameworks.torch.internal.bridge_config as bc @@ -84,6 +84,7 @@ from vllm_gaudi.extension.ops import LoraMask as LoraMask from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks +from vllm.v1.core.sched.output import GrammarOutput if TYPE_CHECKING: import xgrammar as xgr @@ -696,6 +697,30 @@ def get_dp_padding(num_tokens: int, dp_size: int, dp_rank: int) -> int: max_tokens_across_dp_cpu = torch.max(num_tokens_tensor).item() return max_tokens_across_dp_cpu - num_tokens +class ExecuteModelState(NamedTuple): + """Ephemeral cached state transferred between execute_model() and + sample_tokens(), after execute_model() returns None.""" + + scheduler_output: "SchedulerOutput" + #logits: torch.Tensor + #spec_decode_metadata: SpecDecodeMetadata | None + #spec_decode_common_attn_metadata: CommonAttentionMetadata | None + #hidden_states: torch.Tensor + sample_hidden_states: torch.Tensor + aux_hidden_states: list[torch.Tensor] | None + #kv_connector_output: KVConnectorOutput | None + prefill_data: None + structured_output: False + logits_device: None + batch_changed: None + prefill_sampled_token_ids: None + prefill_sampled_requests: None + decode_sampled_token_ids: None + invalid_req_indices: None + decode_data: None + warmup_mode: False + decode_sampled_requests: None + class HPUModelRunner(KVConnectorModelRunnerMixin): @@ -911,6 +936,9 @@ def __init__( dp_size=self.parallel_config.data_parallel_size, dp_rank=self.parallel_config.data_parallel_rank) + # Ephemeral state transferred between execute_model() and sample_tokens(). + self.execute_model_state: ExecuteModelState | None = None + assert not (self.unified_attn and not self.use_contiguous_pa), 'Unified attn requires contiguous_pa!' assert not (self.unified_attn and not self.use_merged_prefill), 'Unified attn requires merged_prefill!' @@ -2583,9 +2611,10 @@ def _is_quant_with_inc(self): def apply_grammar_bitmask( self, scheduler_output: "SchedulerOutput", + grammar_output: GrammarOutput, logits: torch.Tensor, ): - grammar_bitmask = scheduler_output.grammar_bitmask + grammar_bitmask = grammar_output.grammar_bitmask if grammar_bitmask is None: return @@ -2604,7 +2633,7 @@ def apply_grammar_bitmask( for req_id, batch_index in seq: logit_index = batch_index + cumulative_offset cumulative_offset += len(scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])) - if req_id in scheduler_output.structured_output_request_ids: + if req_id in grammar_output.structured_output_request_ids: struct_out_req_batch_indices[req_id] = logit_index out_indices = [] @@ -2613,7 +2642,7 @@ def apply_grammar_bitmask( sorted_bitmask = np.zeros_like(grammar_bitmask, shape=(logits.shape[0], grammar_bitmask.shape[1])) cumulative_index = 0 - for req_id in scheduler_output.structured_output_request_ids: + for req_id in grammar_output.structured_output_request_ids: logit_index = struct_out_req_batch_indices[req_id] num_spec_tokens = len(scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])) for i in range(1 + num_spec_tokens): @@ -3032,10 +3061,16 @@ def execute_model( # postprocessing. Should it be done for all requests? structured_output = False spec_decode_num_tokens = None - if scheduler_output.grammar_bitmask is not None: - logits_prompt = [] - logits_decode = [] - structured_output = True + + logits_prompt = [] + logits_decode = [] + #structured_output = True + + # Prepare prompts/decodes info + pd_info = self._get_prompts_and_decodes(scheduler_output) + num_decodes = len(pd_info.decode_req_ids) + num_prefills = len(pd_info.prompt_req_ids) + num_reqs = num_decodes + num_prefills if self.use_async_scheduling: invalid_req_indices = [] ######################### PREFILLS ######################### @@ -3110,6 +3145,10 @@ def execute_model( if self.use_aux_hidden_state_outputs: aux_hidden_states_prefills.append(aux_hidden_states) sample_hidden_states_prefills.append(sample_hidden_states) + + ''' + + # Skip separate sampling for structured output if structured_output: logits_prompt.append(logits_device) @@ -3125,6 +3164,7 @@ def execute_model( logits_requests) prefill_sampled_token_ids.append(sampler_output.sampled_token_ids.flatten()) prefill_sampled_requests.extend(logits_requests) + ''' if self.is_driver_worker and self.profiler.enabled: # Stop recording 'execute_model_generic' event self.profiler.end() @@ -3182,7 +3222,7 @@ def execute_model( warmup_mode=warmup_mode) htorch.core.mark_step() - if structured_output: + '''if structured_output: logits_decode.append(logits_device[:num_decodes]) decode_sampled_requests.extend(self.input_batch.req_ids[:num_decodes]) else: @@ -3219,7 +3259,7 @@ def execute_model( sampled_token_ids.to("hpu", non_blocking=True) decode_sampled_requests.extend(self.input_batch.req_ids[:num_decodes]) ##### Sampling End ##### - +''' if self.is_driver_worker and self.profiler.enabled: # Stop recording 'execute_model' event self.profiler.end() @@ -3247,13 +3287,463 @@ def execute_model( warmup_mode=warmup_mode) htorch.core.mark_step() + '''if structured_output: + # Scheduler places cached before prompt + logits_combined = logits_decode + logits_prompt + logits = torch.cat(logits_combined, dim=0) + # Apply structured output bitmasks if present + #if scheduler_output.structured_output_request_ids: + # self.apply_grammar_bitmask(scheduler_output, logits) + sampler_output, _sampling_metadata = self._run_sampling(batch_changed, logits, + pd_info.prompt_req_ids + pd_info.decode_req_ids, + logits.shape[0]) + # Deal with the case of incomplete prompt + for i in range(logits.shape[0] - num_decodes): + prefill_sampled_token_ids.append(sampler_output.sampled_token_ids[num_decodes + i].flatten()) + decode_sampled_token_ids.append(sampler_output.sampled_token_ids[:num_decodes].flatten()) + elif self.use_async_scheduling: + # For async scheduling: keep tokens on HPU and avoid CPU sync + # Concatenate decode and prefill tokens on HPU + if decode_sampled_token_ids or prefill_sampled_token_ids: + decode_sampled_token_ids = [tensor[:num_decodes] for tensor in decode_sampled_token_ids] + # Note: this will cause an issue with the current spec decode impl, as they are on different devices + sampled_token_ids = torch.cat(decode_sampled_token_ids + prefill_sampled_token_ids).view(-1, 1) + else: + sampled_token_ids = torch.empty((0, 1), dtype=torch.int32, device=self.device) + + # Copy some objects so they don't get modified after returning. + # This is important when using async scheduling. + req_ids_output_copy = self.input_batch.req_ids.copy() + req_id_to_index_output_copy = \ + self.input_batch.req_id_to_index.copy() + + max_req_index = max(self.input_batch.req_id_to_index.values()) + postprocessed_sampled_token_ids: list[list[int]] = [[] for _ in range(max_req_index + 1)] + if self.use_async_scheduling: + self.input_batch.prev_sampled_token_ids = sampled_token_ids.flatten() + # self.input_batch.prev_sampled_token_ids_invalid_indices + invalid_req_indices_set = set(invalid_req_indices) + self.input_batch.prev_sampled_token_ids_invalid_indices = \ + invalid_req_indices_set + self.input_batch.prev_req_id_to_index = { + req_id: i + for i, req_id in enumerate(self.input_batch.req_ids) if i not in invalid_req_indices_set + } + # For the output, postprocessed_sampled_token_ids will be filled during serialization + else: + prefill_sampled_token_ids_device = prefill_sampled_token_ids + # From this point onward, all operations are done on CPU. + # We already have tokens. Let's copy the data to + # CPU as is, and then discard padded tokens. + with self.profiler.record_event('internal', "sampler_postprocessing"): + prefill_sampled_token_ids = [tensor.cpu() for tensor in prefill_sampled_token_ids] + if spec_decode_num_tokens is not None: + decode_sampled_token_ids = [tensor.cpu() for tensor in decode_sampled_token_ids] + else: + decode_sampled_token_ids = [tensor.cpu()[:num_decodes] for tensor in decode_sampled_token_ids] + if decode_sampled_token_ids + prefill_sampled_token_ids: + sampled_token_ids_list = torch.cat(decode_sampled_token_ids + prefill_sampled_token_ids).tolist() + else: + sampled_token_ids_list = [] + sampled_token_requests = \ + decode_sampled_requests + prefill_sampled_requests + max_req_index = max(self.input_batch.req_id_to_index.values()) + # NOTE(Chendi): in post-processing, spec_decode might + # return more than 1 token during decode. + start_idx = 0 + for i, req_id in enumerate(sampled_token_requests): + num_tokens = spec_decode_num_tokens[ + i] if spec_decode_num_tokens is not None and i < num_decodes else 1 + postprocessed_sampled_token_ids[ + self.input_batch.req_id_to_index[req_id]] += sampled_token_ids_list[start_idx:start_idx + + num_tokens] + start_idx += num_tokens + + ################## RETURN ################## + # NOTE(kzawora): idk what happens if part of batch doesn't have logprobs + + ######### UPDATE REQUEST STATE WITH GENERATED TOKENS ######### + for req_id in self.input_batch.req_ids[:num_reqs]: + req_state = self.requests[req_id] + i = self.input_batch.req_id_to_index[req_id] + seq_len = (req_state.num_computed_tokens + scheduler_output.num_scheduled_tokens[req_id]) + token_ids = postprocessed_sampled_token_ids[i] + num_tokens = len(token_ids) + self.input_batch.token_ids_cpu[i, seq_len:seq_len + num_tokens] = token_ids + self.input_batch.num_tokens[i] += len(token_ids) + req_state.output_token_ids.extend(token_ids) + # NOTE(chendi): enable cache based on PR(#20291) + # Cache the sampled tokens in the model runner, so that the scheduler + # doesn't need to send them back. + # NOTE(woosuk): As an exception, when using PP, the scheduler sends + # the sampled tokens back, because there's no direct communication + # between the first-stage worker and the last-stage worker. + for req_idx, sampled_ids in enumerate(postprocessed_sampled_token_ids[:num_reqs]): + if not sampled_ids: + continue + + start_idx = self.input_batch.num_tokens_no_spec[req_idx] + end_idx = start_idx + len(sampled_ids) + # NOTE(adobrzyn): assert for full max prompt length including + # max_model_len and one token that's going to be generated + # especially needed for biggest prompt in warm-up phase + full_max_prompt = self.max_model_len + 1 + assert end_idx <= full_max_prompt, ("Sampled token IDs exceed the max model length. " + f"Total number of tokens: {end_idx} > max_model_len: " + f"{full_max_prompt}") + + self.input_batch.token_ids_cpu[req_idx, start_idx:end_idx] = sampled_ids + self.input_batch.num_tokens_no_spec[req_idx] = end_idx + self.input_batch.num_tokens[req_idx] = end_idx + req_id = self.input_batch.req_ids[req_idx] + req_state = self.requests[req_id] + req_state.output_token_ids.extend(sampled_ids) + + ################## Spec Decode ################## + # Now, we will call drafter to propose draft token ids + if self.speculative_config: + self._draft_token_ids = self.propose_draft_token_ids( + scheduler_output, postprocessed_sampled_token_ids, prefill_sampled_token_ids_device, + decode_sampled_token_ids_device, sampling_metadata, non_flattened_hidden_states, sample_hidden_states, + aux_hidden_states, non_flattened_hidden_states_prefills, sample_hidden_states_prefills, + aux_hidden_states_prefills, num_decodes, prefill_data if num_prefills > 0 else None, + decode_data if num_decodes > 0 else None) + ################## Spec Decode end ################## + ''' + ''' + # Create output. + all_req_ids = pd_info.decode_req_ids + pd_info.prompt_req_ids + # prompt_logprobs_dict: dict[ + # str, Optional[LogprobsTensors]] = self._get_prompt_logprobs_dict( + # prefill_hidden_states_device, scheduler_output) + prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {} + all_req_ids = pd_info.decode_req_ids + pd_info.prompt_req_ids + logprobs = None + + if self.use_async_scheduling: + model_runner_output = ModelRunnerOutput( + req_ids=req_ids_output_copy, # CHECK + req_id_to_index=req_id_to_index_output_copy, + sampled_token_ids=postprocessed_sampled_token_ids, + logprobs=logprobs, + prompt_logprobs_dict=prompt_logprobs_dict, # type: ignore[arg-type] + pooler_output=[], + kv_connector_output=KVConnectorOutput( + finished_sending=finished_sending, + finished_recving=finished_recving, + )) + return AsyncHPUModelRunnerOutput( + model_runner_output=model_runner_output, + sampled_token_ids=sampled_token_ids, + invalid_req_indices=invalid_req_indices, + async_output_copy_stream=self.async_output_copy_stream, + ) + model_runner_output = ModelRunnerOutput( + req_ids=all_req_ids, + req_id_to_index=self.input_batch.req_id_to_index, + sampled_token_ids=postprocessed_sampled_token_ids, + logprobs=logprobs, + prompt_logprobs_dict=prompt_logprobs_dict, # type: ignore[arg-type] + pooler_output=[], + kv_connector_output=KVConnectorOutput( + finished_sending=finished_sending, + finished_recving=finished_recving, + )) + if has_kv_transfer_group(): + get_kv_transfer_group().clear_connector_metadata() + + return model_runner_output''' + self.execute_model_state = ExecuteModelState( + scheduler_output, + #logits, + #spec_decode_metadata, + #spec_decode_common_attn_metadata, + #hidden_states, + sample_hidden_states, + aux_hidden_states, + #kv_connector_output, + prefill_data, + structured_output, + logits_device, + batch_changed, + prefill_sampled_token_ids, + prefill_sampled_requests, + decode_sampled_token_ids, + invalid_req_indices, + decode_data, + warmup_mode, + decode_sampled_requests + ) + return None + + @torch.inference_mode + def sample_tokens( + self, grammar_output: "GrammarOutput | None" + ) -> ModelRunnerOutput | AsyncModelRunnerOutput: + if self.execute_model_state is None: + # Nothing to do (PP non-final rank case), output isn't used. + return None # noqa + + # Unpack ephemeral state. + ( + scheduler_output, + #logits, + #spec_decode_metadata, + #spec_decode_common_attn_metadata, + #hidden_states, + sample_hidden_states, + aux_hidden_states, + #kv_connector_output, + prefill_data, + structured_output, + logits_device, + batch_changed, + prefill_sampled_token_ids, + prefill_sampled_requests, + decode_sampled_token_ids, + invalid_req_indices, + decode_data, + warmup_mode, + decode_sampled_requests + ) = self.execute_model_state + # Clear ephemeral state. + self.execute_model_state = None + logits_prompt = [] + logits_decode = [] + + # Apply structured output bitmasks if present. + if grammar_output is not None: + self.apply_grammar_bitmask( + scheduler_output, grammar_output, self.input_batch, logits + ) + + # Prepare prompts/decodes info + pd_info = self._get_prompts_and_decodes(scheduler_output) + num_decodes = len(pd_info.decode_req_ids) + num_prefills = len(pd_info.prompt_req_ids) + num_reqs = num_decodes + num_prefills + ######################### PREFILLS ######################### + if num_prefills > 0: + htorch.core.mark_step() + for idx, (req_id, prompt_len, token_ids, position_ids, attn_metadata, logits_indices, + logits_requests) in enumerate(zip(*shallow_tuple(prefill_data))): + + inputs_embeds = None + model_mm_kwargs = None + '''if self.supports_mm_inputs: + # Run the multimodal encoder if any. + with self.profiler.record_event('internal', 'prepare_input_encoders'): + self._execute_mm_encoder(scheduler_output, req_id) + htorch.core.mark_step() + + mm_embeds, is_mm_embed = self._gather_mm_embeddings(scheduler_output, + req_id, + total_num_scheduled_tokens=token_ids.shape[-1]) + htorch.core.mark_step() + + # TODO: Only get embeddings for valid token_ids. Ignore token_ids[] # noqa E501 + # This may require moving multimodal input preps into _prepare_inputs, # noqa E501 + # to avoid padding issues. + inputs_embeds = self.model.get_input_embeddings( + token_ids, + multimodal_embeddings=mm_embeds, + is_multimodal=is_mm_embed, + ) + + model_mm_kwargs = self._extract_mm_kwargs(scheduler_output) + model_mm_kwargs = MultiModalKwargs.as_kwargs( + model_mm_kwargs, + device=self.device, + ) + + lora_mask, lora_logits_mask = self._configure_lora(token_ids, self.requests, req_id, True) + + self.event_start = self.profiler.get_timestamp_us() + self.profiler.start("internal", "prefill") + + # NOTE(tianmu-li): Align behavior of incomplete prompt with gpu_model_runner + # If logits_indices is smaller than req_id, the last request is a chunked prompt request that + # hasn't finished in this step. We add the last token position to logits_indices to ensure + # the last token of the chunk is sampled. This sampled token will be discarded later + if logits_indices.shape[0] < len(req_id): + if structured_output or self.use_async_scheduling: + # When there are multiple requests in the batch (e.g. self.use_merged_prefill=True), + # the last token position is the sum of all prompt lengths - 1 + # This logic also holds when there is only one request in the batch + logits_indices_append = torch.tensor([torch.sum(prompt_len) - 1], + device=token_ids.device, + dtype=torch.int32) + logits_indices = torch.cat([logits_indices, logits_indices_append]) + if self.use_async_scheduling: + # Discard partial prefill logit for async scheduling + # Depends on 1 decode token/batch + prefill_start_idx = num_decodes + invalid_req_indices.append(prefill_start_idx + idx) + htorch.core.mark_step() + non_flattened_hidden_states, aux_hidden_states, \ + sample_hidden_states, logits_device = \ + self._execute_model_generic( + token_ids, position_ids, attn_metadata, logits_indices, + self.kv_caches, + lora_logits_mask, + lora_mask, + inputs_embeds=inputs_embeds, + model_mm_kwargs=model_mm_kwargs, + warmup_mode=warmup_mode,) + htorch.core.mark_step() + non_flattened_hidden_states_prefills.append(non_flattened_hidden_states) + if self.use_aux_hidden_state_outputs: + aux_hidden_states_prefills.append(aux_hidden_states) + sample_hidden_states_prefills.append(sample_hidden_states) + ''' + + + + # Skip separate sampling for structured output + if structured_output: + logits_prompt.append(logits_device) + prefill_sampled_requests.extend(logits_requests) + else: + # If there are no logits, there is nothing to sample. + # This can happen with chunked prefill when a chunk does + # not complete the prompt and no logits are generated. + if logits_device.numel() > 0: + with self.profiler.record_event('internal', "sampler"): + sampler_output, sampling_metadata = self._run_sampling(batch_changed, logits_device, req_id, + logits_device.shape[0], + logits_requests) + prefill_sampled_token_ids.append(sampler_output.sampled_token_ids.flatten()) + prefill_sampled_requests.extend(logits_requests) + '''if self.is_driver_worker and self.profiler.enabled: + # Stop recording 'execute_model_generic' event + self.profiler.end() + event_end = self.profiler.get_timestamp_us() + counters = self.profiler_counter_helper.get_counter_dict(cache_config=self.cache_config, + duration=event_end - self.event_start, + seq_len=self._seq_len(attn_metadata), + batch_size_padded=token_ids.size(0), + real_batch_size=len(req_id), + prompt_batch_idx=idx, + is_prompt=True) + self.profiler.record_counter(self.event_start, counters) + if not warmup_mode: + self.maybe_wait_for_kv_save() + ''' + finished_sending, finished_recving = (self.get_finished_kv_transfers(scheduler_output)) + ''' + if self.is_driver_worker and self.profiler.enabled: + self.profiler_counter_helper.reset_prompt_seq_stats() + + if num_pad_prefill_batch_across_dp > 0: + for idx, (req_id, prompt_len, token_ids, position_ids, attn_metadata, logits_indices, + logits_requests) in enumerate(zip(*shallow_tuple(dummy_prefill_input_data_batches_across_dp))): + htorch.core.mark_step() + _, _, _, dummy_logits_device = \ + self._execute_model_generic( + token_ids, + position_ids, + attn_metadata, + logits_indices, + self.kv_caches, + None, + None, + warmup_mode=warmup_mode) + htorch.core.mark_step() +''' + ######################### DECODES ######################### + # Decodes run as one single batch with [padded_decode_bs, 1] + if num_decodes > 0: + assert decode_data is not None + lora_mask, lora_logits_mask = self._configure_lora(decode_data.token_ids, self.requests, + pd_info.decode_req_ids, False) + self.event_start = self.profiler.get_timestamp_us() + self.profiler.start("internal", "decode") + htorch.core.mark_step() + non_flattened_hidden_states, aux_hidden_states, \ + sample_hidden_states, logits_device = \ + self._execute_model_generic( + decode_data.token_ids, + decode_data.position_ids, + decode_data.attn_metadata, + decode_data.logits_indices, + self.kv_caches, + lora_logits_mask, + lora_mask, + warmup_mode=warmup_mode) + htorch.core.mark_step() + + if structured_output: + logits_decode.append(logits_device[:num_decodes]) + decode_sampled_requests.extend(self.input_batch.req_ids[:num_decodes]) + else: + with self.profiler.record_event('internal', "sampler"): + ##### Sampling Start ##### + spec_decode_metadata = decode_data.spec_decode_metadata + sampler_output, sampling_metadata = self._run_sampling( + batch_changed, logits_device + if spec_decode_metadata is None else logits_device[spec_decode_metadata.bonus_logits_indices], + pd_info.decode_req_ids, logits_device.shape[0]) + + if spec_decode_metadata is None: + decode_sampled_token_ids.append(sampler_output.sampled_token_ids.flatten()) + else: + # Handling spec decode sampling. + sampler_output = self.rejection_sampler( + spec_decode_metadata, + None, # draft_probs + logits_device, + sampling_metadata, + ) + sampled_token_ids = sampler_output.sampled_token_ids + decode_sampled_token_ids = \ + self.rejection_sampler.parse_output( + sampled_token_ids, + self.input_batch.vocab_size, + ) + # convert decode_sampled_token_ids as list of tensor + spec_decode_num_tokens = [len(v) for v in decode_sampled_token_ids] + decode_sampled_token_ids = [ + torch.tensor(v, device="cpu").int() for v in decode_sampled_token_ids + ] + decode_sampled_token_ids_device = \ + sampled_token_ids.to("hpu", non_blocking=True) + decode_sampled_requests.extend(self.input_batch.req_ids[:num_decodes]) + ##### Sampling End ##### + + if self.is_driver_worker and self.profiler.enabled: + # Stop recording 'execute_model' event + self.profiler.end() + event_end = self.profiler.get_timestamp_us() + counters = self.profiler_counter_helper.get_counter_dict( + cache_config=self.cache_config, + duration=event_end - self.event_start, + seq_len=self._seq_len(decode_data.attn_metadata), + batch_size_padded= \ + decode_data.token_ids.size(0), # type: ignore + real_batch_size=decode_data.num_decodes, + prompt_batch_idx=None, + is_prompt=False) + self.profiler.record_counter(self.event_start, counters) + + '''elif dummy_decode_input_data_across_dp is not None: + htorch.core.mark_step() + _, _, _, dummy_logits_device = self._execute_model_generic(dummy_decode_input_data_across_dp.token_ids, + dummy_decode_input_data_across_dp.position_ids, + dummy_decode_input_data_across_dp.attn_metadata, + dummy_decode_input_data_across_dp.logits_indices, + self.kv_caches, + None, + None, + warmup_mode=warmup_mode) + htorch.core.mark_step() + ''' if structured_output: # Scheduler places cached before prompt logits_combined = logits_decode + logits_prompt logits = torch.cat(logits_combined, dim=0) # Apply structured output bitmasks if present - if scheduler_output.structured_output_request_ids: - self.apply_grammar_bitmask(scheduler_output, logits) + #if scheduler_output.structured_output_request_ids: + # self.apply_grammar_bitmask(scheduler_output, logits) sampler_output, _sampling_metadata = self._run_sampling(batch_changed, logits, pd_info.prompt_req_ids + pd_info.decode_req_ids, logits.shape[0]) @@ -3413,6 +3903,8 @@ def execute_model( return model_runner_output + + def load_model(self) -> None: import habana_frameworks.torch.core as htcore if self.model_config.quantization == 'inc' or \ @@ -4029,8 +4521,8 @@ def _execute_dummy_scenario(self, requests, scheduled_tokens): num_common_prefix_blocks=0, finished_req_ids=set(), free_encoder_mm_hashes=[], - structured_output_request_ids={}, - grammar_bitmask=None, + #structured_output_request_ids={}, + #grammar_bitmask=None, ) cleanup = SchedulerOutput( scheduled_new_reqs=[], @@ -4042,8 +4534,8 @@ def _execute_dummy_scenario(self, requests, scheduled_tokens): num_common_prefix_blocks=0, finished_req_ids=set(req.req_id for req in requests), free_encoder_mm_hashes=[], - structured_output_request_ids={}, - grammar_bitmask=None, + #structured_output_request_ids={}, + #grammar_bitmask=None, ) self.execute_model(sched_output, warmup_mode=True) self.execute_model(cleanup, warmup_mode=True) diff --git a/vllm_gaudi/v1/worker/hpu_worker.py b/vllm_gaudi/v1/worker/hpu_worker.py index 3caa79b79..69a170754 100644 --- a/vllm_gaudi/v1/worker/hpu_worker.py +++ b/vllm_gaudi/v1/worker/hpu_worker.py @@ -5,7 +5,7 @@ import os import queue from contextlib import contextmanager -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Optional import torch import torch.distributed @@ -22,7 +22,7 @@ from vllm.model_executor import set_random_seed from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec) -from vllm.v1.outputs import (DraftTokenIds, AsyncModelRunnerOutput, ModelRunnerOutput) +from vllm.v1.outputs import (DraftTokenIds, ModelRunnerOutput) from vllm.v1.worker.utils import bind_kv_cache from vllm_gaudi.utils import is_fake_hpu from vllm_gaudi.v1.worker.hpu_model_runner import HPUModelRunner, bool_helper @@ -33,7 +33,7 @@ logger = init_logger() if TYPE_CHECKING: - from vllm.v1.core.scheduler import SchedulerOutput + from vllm.v1.core.scheduler import GrammarOutput, SchedulerOutput def setup_step_profiler(steps): @@ -251,11 +251,14 @@ def compile_or_warm_up_model(self) -> None: # the model initialization and profiling. set_random_seed(self.model_config.seed) + def sample_tokens(self, grammar_output: "GrammarOutput") -> ModelRunnerOutput: + return self.model_runner.sample_tokens(grammar_output) + @torch.inference_mode() def execute_model( self, scheduler_output: "SchedulerOutput", - ) -> Union[ModelRunnerOutput, AsyncModelRunnerOutput]: + ) -> ModelRunnerOutput | None: if self.step_debug: self.step_debug(f'step={self.step}') if self.step_profiler and self.step == self.profile_steps[0]: