Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion vllm_gaudi/v1/worker/hpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment)
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
from vllm.distributed.kv_transfer import (
ensure_kv_transfer_initialized,
get_kv_transfer_group,
has_kv_transfer_group,
)
from vllm.distributed.parallel_state import get_tp_group
from vllm.model_executor import set_random_seed
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec)
Expand Down Expand Up @@ -294,6 +299,21 @@ def profile(self, is_start: bool = True):
def execute_dummy_batch(self) -> None:
self.model_runner._dummy_run(1)

def get_kv_connector_handshake_metadata(self) -> dict | None:
"""Get KV connector metadata from this worker if available."""

if not has_kv_transfer_group():
return None

connector = get_kv_transfer_group()
# Return None for connectors that don't need to exchange handshake
# metadata across workers.
if (metadata := connector.get_handshake_metadata()) is None:
return None

tp_rank = get_tp_group().rank_in_group
return {tp_rank: metadata}


def init_worker_distributed_environment(
vllm_config: VllmConfig,
Expand Down
Loading