Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions vllm_gaudi/v1/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
from vllm_gaudi.extension.ops import LoraMask as LoraMask
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import NixlConnectorMetadata

if TYPE_CHECKING:
import xgrammar as xgr
Expand Down Expand Up @@ -1464,12 +1465,15 @@ def _get_prompts_and_decodes(

requests_type = {}
if scheduler_output.kv_connector_metadata:
for req in scheduler_output.kv_connector_metadata.reqs_to_save:
requests_type[req] = 'prefill'
for req in scheduler_output.kv_connector_metadata.reqs_to_recv:
requests_type[req] = 'decode'
requests = scheduler_output.kv_connector_metadata.reqs_to_save | \
scheduler_output.kv_connector_metadata.reqs_to_recv
if isinstance(scheduler_output.kv_connector_metadata, NixlConnectorMetadata):
for req in scheduler_output.kv_connector_metadata.reqs_to_save:
requests_type[req] = 'prefill'
for req in scheduler_output.kv_connector_metadata.reqs_to_recv:
requests_type[req] = 'decode'
requests = scheduler_output.kv_connector_metadata.reqs_to_save | \
scheduler_output.kv_connector_metadata.reqs_to_recv
else:
requests = scheduler_output.kv_connector_metadata.requests
else:
requests = None

Expand Down Expand Up @@ -3137,7 +3141,7 @@ def execute_model(
prompt_batch_idx=idx,
is_prompt=True)
self.profiler.record_counter(self.event_start, counters)
if not warmup_mode:
if not warmup_mode and isinstance(scheduler_output.kv_connector_metadata, NixlConnectorMetadata):
self.maybe_wait_for_kv_save()
finished_sending, finished_recving = (self.get_finished_kv_transfers(scheduler_output))

Expand Down
1 change: 1 addition & 0 deletions vllm_gaudi/v1/worker/hpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(

self.local_rank = local_rank
self.rank = rank
self.parallel_config.rank = rank
self.distributed_init_method = distributed_init_method
self.is_driver_worker = is_driver_worker

Expand Down