diff --git a/vllm_gaudi/v1/worker/hpu_worker.py b/vllm_gaudi/v1/worker/hpu_worker.py index 3caa79b79..77f9a622a 100644 --- a/vllm_gaudi/v1/worker/hpu_worker.py +++ b/vllm_gaudi/v1/worker/hpu_worker.py @@ -18,7 +18,12 @@ import vllm.envs as envs from vllm.config import VllmConfig from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) -from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized +from vllm.distributed.kv_transfer import ( + ensure_kv_transfer_initialized, + get_kv_transfer_group, + has_kv_transfer_group, +) +from vllm.distributed.parallel_state import get_tp_group from vllm.model_executor import set_random_seed from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec) @@ -294,6 +299,21 @@ def profile(self, is_start: bool = True): def execute_dummy_batch(self) -> None: self.model_runner._dummy_run(1) + def get_kv_connector_handshake_metadata(self) -> dict | None: + """Get KV connector metadata from this worker if available.""" + + if not has_kv_transfer_group(): + return None + + connector = get_kv_transfer_group() + # Return None for connectors that don't need to exchange handshake + # metadata across workers. + if (metadata := connector.get_handshake_metadata()) is None: + return None + + tp_rank = get_tp_group().rank_in_group + return {tp_rank: metadata} + def init_worker_distributed_environment( vllm_config: VllmConfig,