Skip to content
18 changes: 11 additions & 7 deletions vllm_gaudi/v1/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
from vllm_gaudi.extension.ops import LoraMask as LoraMask
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import NixlConnectorMetadata

if TYPE_CHECKING:
import xgrammar as xgr
Expand Down Expand Up @@ -1464,12 +1465,15 @@ def _get_prompts_and_decodes(

requests_type = {}
if scheduler_output.kv_connector_metadata:
for req in scheduler_output.kv_connector_metadata.reqs_to_save:
requests_type[req] = 'prefill'
for req in scheduler_output.kv_connector_metadata.reqs_to_recv:
requests_type[req] = 'decode'
requests = scheduler_output.kv_connector_metadata.reqs_to_save | \
scheduler_output.kv_connector_metadata.reqs_to_recv
if isinstance(scheduler_output.kv_connector_metadata, NixlConnectorMetadata):
for req in scheduler_output.kv_connector_metadata.reqs_to_save:
requests_type[req] = 'prefill'
for req in scheduler_output.kv_connector_metadata.reqs_to_recv:
requests_type[req] = 'decode'
requests = scheduler_output.kv_connector_metadata.reqs_to_save | \
scheduler_output.kv_connector_metadata.reqs_to_recv
else:
requests = scheduler_output.kv_connector_metadata.requests
else:
requests = None

Expand Down Expand Up @@ -3137,7 +3141,7 @@ def execute_model(
prompt_batch_idx=idx,
is_prompt=True)
self.profiler.record_counter(self.event_start, counters)
if not warmup_mode:
if not warmup_mode and isinstance(scheduler_output.kv_connector_metadata, NixlConnectorMetadata):
self.maybe_wait_for_kv_save()
finished_sending, finished_recving = (self.get_finished_kv_transfers(scheduler_output))

Expand Down
1 change: 1 addition & 0 deletions vllm_gaudi/v1/worker/hpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(

self.local_rank = local_rank
self.rank = rank
self.parallel_config.rank = rank
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is default value?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i believe default was 1 , we were having issues running tp>1 so it needed to be explicitly set

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, this is weird issue, can you send a PR to upstream to fix that?
I am ok with fix here.

Copy link
Contributor Author

@hsubramony hsubramony Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.distributed_init_method = distributed_init_method
self.is_driver_worker = is_driver_worker

Expand Down