diff --git a/examples/lmcache/README.md b/examples/lmcache/README.md new file mode 100644 index 000000000..a0c526221 --- /dev/null +++ b/examples/lmcache/README.md @@ -0,0 +1,69 @@ +# LMCache Examples +Please Note: HPU integration for LMCache will be upstreamed. After that, the following test cases can be used. + +This folder demonstrates how to use LMCache for disaggregated prefilling and KV cache sharing. + +The test scripts are dependent on [vllm/benchmark](https://github.com/vllm-project/vllm/tree/main/benchmarks) scripts. +Please download them and set their path in disagg_example.sh. + +## 1. Disaggregated Prefill in vLLM v1 + +This example demonstrates how to run LMCache with disaggregated prefill using lm or redis on a single node. + +### Prerequisites +- At least 2 HPU cards +- Valid Hugging Face token (HF_TOKEN) for Llama 3.1 8B Instruct +- https://github.com/LMCache/LMCache/pull/1066 needed for lmcache + +### Usage + +Run +`cd disagg_prefill_lmcache_v1` +to get into `disagg_prefill_lmcache_v1` folder, and then run + +```bash +PT_HPU_GPU_MIGRATION=1 VLLM_USE_V1=1 VLLM_SKIP_WARMUP=True PT_HPU_ENABLE_LAZY_COLLECTIVES=true bash disagg_example.sh +``` + +to run disaggregated prefill and benchmark the performance. + +lmserver is default and it's configurable as well as tensor_parallel_size and model name. + +For tp>1 +```bash +PT_HPU_GPU_MIGRATION=1 VLLM_USE_V1=1 VLLM_SKIP_WARMUP=True PT_HPU_ENABLE_LAZY_COLLECTIVES=true bash disagg_example_gaudi_lm_tp2.sh +``` + +### Components + +#### Server Scripts +- `disagg_prefill_lmcache_v1/disagg_vllm_launcher.sh` - Launches individual vLLM servers for prefill/decode, and also launches the proxy server. +- `../disagg_prefill_lmcache_v1/disagg_proxy_server.py` - FastAPI proxy server that coordinates between prefiller and decoder +- `disagg_prefill_lmcache_v1/disagg_example.sh` - Main script to run the example through lm/redis remote server + +#### Configuration +- `disagg_prefill_lmcache_v1/configs/lmcache-config-lm.yaml` - Configuration for prefiller/decoder server through lm server + +#### Log Files +The main script generates several log files: +- `prefiller.log` - Logs from the prefill server +- `decoder.log` - Logs from the decode server +- `proxy.log` - Logs from the proxy server + +## 2. KV Cache Sharing + +The `kv_cache_sharing_lmcache_v1.py` example demonstrates how to share KV caches between vLLM v1 instances. + +### Usage + +```bash +PT_HPU_GPU_MIGRATION=1 VLLM_USE_V1=1 VLLM_SKIP_WARMUP=True PT_HPU_ENABLE_LAZY_COLLECTIVES=true python kv_cache_sharing_lmcache_v1.py +``` + +lmserver is default and it's configurable as well as tensor_parallel_size. + +For tp > 1 + +```bash +PT_HPU_GPU_MIGRATION=1 VLLM_USE_V1=1 VLLM_SKIP_WARMUP=True PT_HPU_ENABLE_LAZY_COLLECTIVES=true python kv_cache_sharing_lmcache_v1_tp2.py +``` diff --git a/examples/lmcache/disagg_prefill_lmcache_v1/configs/lmcache-config-lm.yaml b/examples/lmcache/disagg_prefill_lmcache_v1/configs/lmcache-config-lm.yaml new file mode 100644 index 000000000..1a16fa979 --- /dev/null +++ b/examples/lmcache/disagg_prefill_lmcache_v1/configs/lmcache-config-lm.yaml @@ -0,0 +1,6 @@ +local_cpu: False +max_local_cpu_size: 5.0 +#local_disk: +max_local_disk_size: 0 +remote_serde: naive +remote_url: "lm://localhost:8100" \ No newline at end of file diff --git a/examples/lmcache/disagg_prefill_lmcache_v1/configs/lmcache-decoder-config.yaml b/examples/lmcache/disagg_prefill_lmcache_v1/configs/lmcache-decoder-config.yaml new file mode 100644 index 000000000..9ec6a41ce --- /dev/null +++ b/examples/lmcache/disagg_prefill_lmcache_v1/configs/lmcache-decoder-config.yaml @@ -0,0 +1,13 @@ +local_cpu: False +max_local_cpu_size: 0 +#local_disk: +max_local_disk_size: 0 +remote_serde: NULL + +enable_nixl: True +nixl_role: "RECEIVER" +nixl_peer_host: "localhost" +nixl_peer_port: 55555 +nixl_buffer_size: 1073741824 # 1GB +nixl_buffer_device: "hpu" +nixl_enable_gc: True diff --git a/examples/lmcache/disagg_prefill_lmcache_v1/configs/lmcache-prefiller-config.yaml b/examples/lmcache/disagg_prefill_lmcache_v1/configs/lmcache-prefiller-config.yaml new file mode 100644 index 000000000..cdfc0e512 --- /dev/null +++ b/examples/lmcache/disagg_prefill_lmcache_v1/configs/lmcache-prefiller-config.yaml @@ -0,0 +1,13 @@ +local_cpu: False +max_local_cpu_size: 0 +#local_disk: +max_local_disk_size: 0 +remote_serde: NULL + +enable_nixl: True +nixl_role: "SENDER" +nixl_peer_host: "localhost" +nixl_peer_port: 55555 +nixl_buffer_size: 1073741824 # 1GB +nixl_buffer_device: "hpu" +nixl_enable_gc: True diff --git a/examples/lmcache/disagg_prefill_lmcache_v1/disagg_example_gaudi_lm.sh b/examples/lmcache/disagg_prefill_lmcache_v1/disagg_example_gaudi_lm.sh new file mode 100644 index 000000000..ed83ef452 --- /dev/null +++ b/examples/lmcache/disagg_prefill_lmcache_v1/disagg_example_gaudi_lm.sh @@ -0,0 +1,139 @@ +#!/bin/bash + +echo "Warning: LMCache disaggregated prefill support for vLLM v1 is experimental and subject to change." + +PIDS=() + +# Switch to the directory of the current script +cd "$(dirname "${BASH_SOURCE[0]}")" + +check_hf_token() { + if [ -z "$HF_TOKEN" ]; then + echo "HF_TOKEN is not set. Please set it to your Hugging Face token." + exit 1 + fi + if [[ "$HF_TOKEN" != hf_* ]]; then + echo "HF_TOKEN is not a valid Hugging Face token. Please set it to your Hugging Face token." + exit 1 + fi + echo "HF_TOKEN is set and valid." +} + +check_num_gpus() { + # can you check if the number of GPUs are >=2 via nvidia-smi? + num_gpus=$(hl-smi --query-gpu=name --format=csv,noheader | wc -l) + if [ "$num_gpus" -lt 2 ]; then + echo "You need at least 2 GPUs to run disaggregated prefill." + exit 1 + else + echo "Found $num_gpus GPUs." + fi +} + +ensure_python_library_installed() { + echo "Checking if $1 is installed..." + python -c "import $1" > /dev/null 2>&1 + if [ $? -ne 0 ]; then + if [ "$1" == "nixl" ]; then + echo "$1 is not installed. Please refer to https://github.com/ai-dynamo/nixl for installation." + else + echo "$1 is not installed. Please install it via pip install $1." + fi + exit 1 + else + echo "$1 is installed." + fi +} + +cleanup() { + echo "Stopping everything…" + trap - INT TERM # prevent re-entrancy + kill -- -$$ # negative PID == “this whole process-group” + wait # reap children so we don't leave zombies + exit 0 +} + +wait_for_server() { + local port=$1 + local timeout_seconds=1200 + local start_time=$(date +%s) + + echo "Waiting for server on port $port..." + + while true; do + if curl -s "localhost:${port}/v1/completions" > /dev/null; then + return 0 + fi + + local now=$(date +%s) + if (( now - start_time >= timeout_seconds )); then + echo "Timeout waiting for server" + return 1 + fi + + sleep 1 + done +} + + +main() { + #check_hf_token + check_num_gpus + ensure_python_library_installed lmcache + #ensure_python_library_installed nixl + ensure_python_library_installed pandas + ensure_python_library_installed datasets + ensure_python_library_installed vllm + + trap cleanup INT + trap cleanup USR1 + trap cleanup TERM + + echo "Launching prefiller, decoder and proxy..." + echo "Please check prefiller.log, decoder.log and proxy.log for logs." + + echo "starting lmcache " + python -m lmcache.v1.server localhost 8100 2>&1 & + echo "start prefiller " + bash disagg_vllm_launcher_gaudi_lm.sh prefiller \ + > >(tee prefiller.log) 2>&1 & + prefiller_pid=$! + PIDS+=($prefiller_pid) + echo "start decoder " + bash disagg_vllm_launcher_gaudi_lm.sh decoder \ + > >(tee decoder.log) 2>&1 & + decoder_pid=$! + PIDS+=($decoder_pid) + + python3 disagg_proxy_server.py \ + --host localhost \ + --port 1000 \ + --prefiller-host localhost \ + --prefiller-port 1100 \ + --decoder-host localhost \ + --decoder-port 1200 \ + > >(tee proxy.log) 2>&1 & + proxy_pid=$! + PIDS+=($proxy_pid) + + wait_for_server 1100 + wait_for_server 1200 + wait_for_server 1000 + + echo "All servers are up. Starting benchmark..." + + # begin benchmark + cd ../../../benchmarks/ + MODEL="meta-llama/Llama-3.1-8B-Instruct" + python benchmark_serving.py --port 1000 --seed $(date +%s) \ + --model $MODEL \ + --dataset-name random --random-input-len 8000 --random-output-len 200 \ + --num-prompts 100 --burstiness 100 --request-rate 3.6 | tee benchmark.log + + echo "Benchmarking done. Cleaning up..." + + cleanup + +} + +main diff --git a/examples/lmcache/disagg_prefill_lmcache_v1/disagg_example_gaudi_lm_tp2.sh b/examples/lmcache/disagg_prefill_lmcache_v1/disagg_example_gaudi_lm_tp2.sh new file mode 100644 index 000000000..1edec4909 --- /dev/null +++ b/examples/lmcache/disagg_prefill_lmcache_v1/disagg_example_gaudi_lm_tp2.sh @@ -0,0 +1,139 @@ +#!/bin/bash + +echo "Warning: LMCache disaggregated prefill support for vLLM v1 is experimental and subject to change." + + +PIDS=() + +# Switch to the directory of the current script +cd "$(dirname "${BASH_SOURCE[0]}")" + +check_hf_token() { + if [ -z "$HF_TOKEN" ]; then + echo "HF_TOKEN is not set. Please set it to your Hugging Face token." + exit 1 + fi + if [[ "$HF_TOKEN" != hf_* ]]; then + echo "HF_TOKEN is not a valid Hugging Face token. Please set it to your Hugging Face token." + exit 1 + fi + echo "HF_TOKEN is set and valid." +} + +check_num_gpus() { + # can you check if the number of GPUs are >=2 via nvidia-smi? + num_gpus=$(hl-smi --query-gpu=name --format=csv,noheader | wc -l) + if [ "$num_gpus" -lt 4 ]; then + echo "You need at least 4 GPUs to run disaggregated prefill TP2." + exit 1 + else + echo "Found $num_gpus GPUs." + fi +} + +ensure_python_library_installed() { + echo "Checking if $1 is installed..." + python -c "import $1" > /dev/null 2>&1 + if [ $? -ne 0 ]; then + if [ "$1" == "nixl" ]; then + echo "$1 is not installed. Please refer to https://github.com/ai-dynamo/nixl for installation." + else + echo "$1 is not installed. Please install it via pip install $1." + fi + exit 1 + else + echo "$1 is installed." + fi +} + +cleanup() { + echo "Stopping everything…" + trap - INT TERM # prevent re-entrancy + kill -- -$$ # negative PID == “this whole process-group” + wait # reap children so we don't leave zombies + exit 0 +} + +wait_for_server() { + local port=$1 + local timeout_seconds=1200 + local start_time=$(date +%s) + + echo "Waiting for server on port $port..." + + while true; do + if curl -s "localhost:${port}/v1/completions" > /dev/null; then + return 0 + fi + + local now=$(date +%s) + if (( now - start_time >= timeout_seconds )); then + echo "Timeout waiting for server" + return 1 + fi + + sleep 1 + done +} + + +main() { + #check_hf_token + check_num_gpus + ensure_python_library_installed lmcache + #ensure_python_library_installed nixl + ensure_python_library_installed pandas + ensure_python_library_installed datasets + ensure_python_library_installed vllm + + trap cleanup INT + trap cleanup USR1 + trap cleanup TERM + + echo "Launching prefiller, decoder and proxy..." + echo "Please check prefiller.log, decoder.log and proxy.log for logs." + + echo "starting lmcache " + python -m lmcache.v1.server localhost 8100 2>&1 & + echo "start prefiller " + bash disagg_vllm_launcher_gaudi_lm_tp2.sh prefiller \ + > >(tee prefiller.log) 2>&1 & + prefiller_pid=$! + PIDS+=($prefiller_pid) + echo "start decoder " + bash disagg_vllm_launcher_gaudi_lm_tp2.sh decoder \ + > >(tee decoder.log) 2>&1 & + decoder_pid=$! + PIDS+=($decoder_pid) + + python3 disagg_proxy_server.py \ + --host localhost \ + --port 1000 \ + --prefiller-host localhost \ + --prefiller-port 1100 \ + --decoder-host localhost \ + --decoder-port 1200 \ + > >(tee proxy.log) 2>&1 & + proxy_pid=$! + PIDS+=($proxy_pid) + + wait_for_server 1100 + wait_for_server 1200 + wait_for_server 1000 + + echo "All servers are up. Starting benchmark..." + + # begin benchmark + cd ../../../benchmarks/ + python benchmark_serving.py --port 1000 --seed $(date +%s) \ + --model meta-llama/Llama-3.1-8B-Instruct \ + --dataset-name random --random-input-len 8000 --random-output-len 200 \ + --num-prompts 100 --burstiness 100 --request-rate 3.6 | tee benchmark.log + + echo "Benchmarking done. Cleaning up..." + + cleanup + +} + +main diff --git a/examples/lmcache/disagg_prefill_lmcache_v1/disagg_proxy_server.py b/examples/lmcache/disagg_prefill_lmcache_v1/disagg_proxy_server.py new file mode 100644 index 000000000..8db93bc89 --- /dev/null +++ b/examples/lmcache/disagg_prefill_lmcache_v1/disagg_proxy_server.py @@ -0,0 +1,193 @@ +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import os +import time +from contextlib import asynccontextmanager + +import httpx +import numpy as np +from fastapi import FastAPI, Request +from fastapi.responses import StreamingResponse + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """ + Lifespan context manager to handle startup and shutdown events. + """ + # Startup: Initialize clients + prefiller_base_url = f'http://{global_args.prefiller_host}:{global_args.prefiller_port}/v1' + decoder_base_url = f'http://{global_args.decoder_host}:{global_args.decoder_port}/v1' + + app.state.prefill_client = httpx.AsyncClient(timeout=None, + base_url=prefiller_base_url) + app.state.decode_client = httpx.AsyncClient(timeout=None, + base_url=decoder_base_url) + + yield + + # Shutdown: Close clients + await app.state.prefill_client.aclose() + await app.state.decode_client.aclose() + + +# Update FastAPI app initialization to use lifespan +app = FastAPI(lifespan=lifespan) + + +class StatsCalculator: + + def __init__(self): + self._stats = [] + self._last_log_time = time.time() + + def add(self, value): + self._stats.append(value) + if time.time() - self._last_log_time > 5: + self._log_stats() + self._last_log_time = time.time() + + def _log_stats(self): + # Print average, median, and 99th percentile + np_arr = np.array(self._stats) + output_str = f"\nNum requests: {len(self._stats)}" + \ + "\nPrefill node TTFT stats:" + \ + f"\n - Average (ms): {np.mean(np_arr)}" + \ + f"\n - Median (ms): {np.median(np_arr)}" + \ + f"\n - 99th Percentile (ms): {np.percentile(np_arr, 99)}\n" + print("===============================", output_str, + "===============================") + + +stats_calculator = StatsCalculator() +counter = 0 + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--prefiller-host", type=str, default="localhost") + parser.add_argument("--prefiller-port", type=int, default=8100) + parser.add_argument("--decoder-host", type=str, default="localhost") + parser.add_argument("--decoder-port", type=int, default=8200) + args = parser.parse_args() + return args + + +# Initialize variables to hold the persistent clients +app.state.prefill_client = None +app.state.decode_client = None + + +async def send_request_to_service(client: httpx.AsyncClient, endpoint: str, + req_data: dict): + """ + Send a request to a service using a persistent client. + """ + req_data = req_data.copy() + req_data['max_tokens'] = 1 + if 'max_completion_tokens' in req_data: + req_data['max_completion_tokens'] = 1 + + headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} + response = await client.post(endpoint, json=req_data, headers=headers) + response.raise_for_status() + return response + + +async def stream_service_response(client: httpx.AsyncClient, endpoint: str, + req_data: dict): + """ + Asynchronously stream the response from a service using a persistent client. + """ + headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} + async with client.stream("POST", endpoint, json=req_data, + headers=headers) as response: + response.raise_for_status() + async for chunk in response.aiter_bytes(): + yield chunk + + +@app.post("/v1/completions") +async def handle_completions(request: Request): + global counter, stats_calculator + counter += 1 + + st = time.time() + try: + req_data = await request.json() + + # Send request to prefill service, ignore the response + await send_request_to_service(app.state.prefill_client, "/completions", + req_data) + + et = time.time() + stats_calculator.add(et - st) + + # Stream response from decode service + async def generate_stream(): + async for chunk in stream_service_response(app.state.decode_client, + "/completions", + req_data): + yield chunk + + return StreamingResponse(generate_stream(), + media_type="application/json") + + except Exception as e: + import sys + import traceback + exc_info = sys.exc_info() + print("Error occurred in disagg prefill proxy server" + " - completions endpoint") + print(e) + print("".join(traceback.format_exception(*exc_info))) + raise + + +@app.post("/v1/chat/completions") +async def handle_chat_completions(request: Request): + global counter, stats_calculator + counter += 1 + + st = time.time() + try: + req_data = await request.json() + + # Send request to prefill service, ignore the response + await send_request_to_service(app.state.prefill_client, + "/chat/completions", req_data) + + et = time.time() + stats_calculator.add(et - st) + + # Stream response from decode service + async def generate_stream(): + async for chunk in stream_service_response(app.state.decode_client, + "/chat/completions", + req_data): + yield chunk + + return StreamingResponse(generate_stream(), + media_type="application/json") + + except Exception as e: + import sys + import traceback + exc_info = sys.exc_info() + print("Error occurred in disagg prefill proxy server " + " - chat completions endpoint") + print(e) + print("".join(traceback.format_exception(*exc_info))) + raise + + +if __name__ == '__main__': + global global_args + global_args = parse_args() + + import uvicorn + uvicorn.run(app, host=global_args.host, port=global_args.port) diff --git a/examples/lmcache/disagg_prefill_lmcache_v1/disagg_vllm_launcher_gaudi_lm.sh b/examples/lmcache/disagg_prefill_lmcache_v1/disagg_vllm_launcher_gaudi_lm.sh new file mode 100644 index 000000000..6f6cec5ab --- /dev/null +++ b/examples/lmcache/disagg_prefill_lmcache_v1/disagg_vllm_launcher_gaudi_lm.sh @@ -0,0 +1,62 @@ +#!/bin/bash + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +if [[ $# -lt 1 ]]; then + echo "Usage: $0 [model]" + exit 1 +fi + +if [[ $# -eq 1 ]]; then + echo "Using default model: meta-llama/Llama-3.1-8B-Instruct" + MODEL="meta-llama/Llama-3.1-8B-Instruct" +else + echo "Using model: $2" + MODEL=$2 +fi + + +if [[ $1 == "prefiller" ]]; then + # Prefiller listens on port 8100 + prefill_config_file=$SCRIPT_DIR/configs/lmcache-config-lm.yaml + + #UCX_TLS=tcp \ + LMCACHE_CONFIG_FILE=$prefill_config_file \ + VLLM_ENABLE_V1_MULTIPROCESSING=1 \ + VLLM_WORKER_MULTIPROC_METHOD=spawn \ + LMCACHE_REMOTE_SERDE=naive \ + RANK=0 \ + DECODER_RANK=1 \ + vllm serve $MODEL \ + --port 1100 \ + --disable-log-requests \ + --enforce-eager \ + --kv-transfer-config \ + '{"kv_connector":"LMCacheConnectorV1","kv_role":"kv_producer","kv_connector_extra_config": {"discard_partial_chunks": false, "lmcache_rpc_port": "producer1"}}' + + +elif [[ $1 == "decoder" ]]; then + # Decoder listens on port 8200 + decode_config_file=$SCRIPT_DIR/configs/lmcache-config-lm.yaml + + #UCX_TLS=tcp \ + LMCACHE_CONFIG_FILE=$decode_config_file \ + VLLM_ENABLE_V1_MULTIPROCESSING=1 \ + VLLM_WORKER_MULTIPROC_METHOD=spawn \ + LMCACHE_CHUNK_SIZE=256 \ + RANK=1 \ + DECODER_RANK=1 \ + vllm serve $MODEL \ + --port 1200 \ + --gpu_memory_utilization 0.80 \ + --disable-log-requests \ + --enforce-eager \ + --kv-transfer-config \ + '{"kv_connector":"LMCacheConnectorV1","kv_role":"kv_consumer","kv_connector_extra_config": {"discard_partial_chunks": false, "lmcache_rpc_port": "consumer1"}}' + + +else + echo "Invalid role: $1" + echo "Should be either prefill, decode" + exit 1 +fi diff --git a/examples/lmcache/disagg_prefill_lmcache_v1/disagg_vllm_launcher_gaudi_lm_tp2.sh b/examples/lmcache/disagg_prefill_lmcache_v1/disagg_vllm_launcher_gaudi_lm_tp2.sh new file mode 100644 index 000000000..cf252d32c --- /dev/null +++ b/examples/lmcache/disagg_prefill_lmcache_v1/disagg_vllm_launcher_gaudi_lm_tp2.sh @@ -0,0 +1,61 @@ +#!/bin/bash + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +if [[ $# -lt 1 ]]; then + echo "Usage: $0 [model]" + exit 1 +fi + +if [[ $# -eq 1 ]]; then + echo "Using default model: meta-llama/Llama-3.1-8B-Instruct" + MODEL="meta-llama/Llama-3.1-8B-Instruct" +else + echo "Using model: $2" + MODEL=$2 +fi + + +if [[ $1 == "prefiller" ]]; then + # Prefiller listens on port 8100 + prefill_config_file=$SCRIPT_DIR/configs/lmcache-config-lm.yaml + + UCX_TLS=tcp \ + LMCACHE_CONFIG_FILE=$prefill_config_file \ + VLLM_ENABLE_V1_MULTIPROCESSING=1 \ + VLLM_WORKER_MULTIPROC_METHOD=spawn \ + LMCACHE_REMOTE_SERDE=naive \ + LMCACHE_CHUNK_SIZE=256 \ + vllm serve $MODEL \ + --port 1100 \ + --gpu_memory_utilization 0.8 \ + --disable-log-requests \ + --tensor_parallel_size 2 \ + --kv-transfer-config \ + '{"kv_connector":"LMCacheConnectorV1","kv_role":"kv_producer","kv_connector_extra_config": {"discard_partial_chunks": false, "lmcache_rpc_port": "producer1"}}' + + +elif [[ $1 == "decoder" ]]; then + # Decoder listens on port 8200 + decode_config_file=$SCRIPT_DIR/configs/lmcache-config-lm.yaml + + UCX_TLS=tcp \ + LMCACHE_CONFIG_FILE=$decode_config_file \ + VLLM_ENABLE_V1_MULTIPROCESSING=1 \ + VLLM_WORKER_MULTIPROC_METHOD=spawn \ + LMCACHE_REMOTE_SERDE=naive \ + LMCACHE_CHUNK_SIZE=256 \ + vllm serve $MODEL \ + --port 1200 \ + --gpu_memory_utilization 0.8 \ + --disable-log-requests \ + --tensor_parallel_size 2 \ + --kv-transfer-config \ + '{"kv_connector":"LMCacheConnectorV1","kv_role":"kv_consumer","kv_connector_extra_config": {"discard_partial_chunks": false, "lmcache_rpc_port": "consumer1"}}' + + +else + echo "Invalid role: $1" + echo "Should be either prefill, decode" + exit 1 +fi diff --git a/examples/lmcache/kv_cache_sharing_lmcache_v1.py b/examples/lmcache/kv_cache_sharing_lmcache_v1.py new file mode 100644 index 000000000..f89df0bff --- /dev/null +++ b/examples/lmcache/kv_cache_sharing_lmcache_v1.py @@ -0,0 +1,201 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +This file demonstrates the example usage of remote KV cache sharing +with LMCache. +We will launch 2 vllm instances, and launch an additional LMCache server. +KV cache is transferred in the following manner: +(1) vLLM instance 1 -> LMCache server (KV cache store). +(2) LMCache server -> vLLM instance 2 (KV cache reuse/retrieve). +Note that lmcache needs to be installed to run this example. +Learn more about LMCache in https://github.com/LMCache/LMCache. +""" + +import argparse +import os +import subprocess +import time +from multiprocessing import Event, Process + +from lmcache.integration.vllm.utils import ENGINE_NAME +from lmcache.v1.cache_engine import LMCacheEngineBuilder + +from vllm import LLM, SamplingParams +from vllm.config import KVTransferConfig + +# LMCache-related environment variables +# LMCache is set to use 256 tokens per chunk +os.environ["LMCACHE_CHUNK_SIZE"] = "256" +# Disable local CPU backend in LMCache +os.environ["LMCACHE_LOCAL_CPU"] = "False" +# Set local CPU memory buffer limit to 5.0 GB +os.environ["LMCACHE_MAX_LOCAL_CPU_SIZE"] = "5.0" +# Set the remote URL for LMCache server + +# Set the serializer/deserializer between vllm and LMCache server +# `naive` indicates using raw bytes of the tensor without any compression +os.environ["LMCACHE_REMOTE_SERDE"] = "naive" +# GAUDI-NIC + +MODEL = "mistralai/Mistral-7B-Instruct-v0.2" +# prompts = [ +# "Hello, how are you?" * 1000, +# ] +prompts = [ + "San Francisco is a", +] + + +def run_store(store_done, prompts, tp_size): + # We use GPU 0 for KV cache store process. + os.environ["RANK"] = "0" + sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10) + + ktc = KVTransferConfig( + kv_connector="LMCacheConnectorV1", kv_role="kv_producer" + ) + # Set GPU memory utilization to 0.8 for an A40 GPU with 40GB + # memory. Reduce the value if your GPU has less memory. + llm = LLM( + model=MODEL, + kv_transfer_config=ktc, + max_model_len=8000, + gpu_memory_utilization=0.8, + tensor_parallel_size=tp_size, + enforce_eager=True, + ) + + outputs = llm.generate(prompts, sampling_params) + for output in outputs: + generated_text = output.outputs[0].text + print(f"Producer Generated text: {generated_text!r}") + print("KV cache store is finished.") + store_done.set() + + # Clean up lmcache backend + LMCacheEngineBuilder.destroy(ENGINE_NAME) + + +def run_retrieve(store_done, prompts, tp_size, timeout=1): + # We use GPU 1 for KV cache retrieve process. + decoder_rank = "1" + os.environ["RANK"] = decoder_rank + + sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=20) + ktc = KVTransferConfig( + kv_connector="LMCacheConnectorV1", kv_role="kv_consumer" + ) + # Set GPU memory utilization to 0.8 for an A40 GPU with 40GB + # of memory. Reduce the value if your GPU has less memory. + llm = LLM( + model=MODEL, + kv_transfer_config=ktc, + max_model_len=8000, + gpu_memory_utilization=0.8, + tensor_parallel_size=tp_size, + enforce_eager=True, + ) + + print("Waiting for KV cache store to finish...") + store_done.wait() + time.sleep(timeout) + + outputs = llm.generate(prompts, sampling_params) + for output in outputs: + generated_text = output.outputs[0].text + print(f"Consumer Generated text: {generated_text!r}") + + # Clean up lmcache backend + LMCacheEngineBuilder.destroy(ENGINE_NAME) + + +def run_lmcache_server(port): + os.environ["LMCACHE_REMOTE_URL"] = f"lm://localhost:{port}" + server_proc = subprocess.Popen( + ["python", "-m", "lmcache.v1.server", "localhost", str(port)] + ) + return server_proc + + +def run_redis_server(port): + os.environ["LMCACHE_REMOTE_URL"] = f"redis://localhost:{port}" + redis_server_path = ( + "/usr/bin/redis-server" # Update this to the correct path + ) + + try: + # Start the Redis server + process = subprocess.Popen( + [redis_server_path, "--port", str(port)], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + print("Redis server started successfully!") + print(f"Process ID: {process.pid}") + except FileNotFoundError: + print( + "Error: Redis server executable not found. \ + Please check the path." + ) + except Exception as e: + print(f"An error occurred: {e}") + return process + + +def main(): + args = parse_args() + print(args) + + store_done = Event() + store_process = Process( + target=run_store, args=(store_done, prompts, args.tp_size) + ) + retrieve_process = Process( + target=run_retrieve, args=(store_done, prompts, args.tp_size) + ) + if args.remote_server == "lm": + remote_server_process = run_lmcache_server(args.lm_port) + elif args.remote_server == "redis": + remote_server_process = run_redis_server(args.redis_port) + else: + print("Not supported lmcache server type") + exit() + print("kvshare store start") + # Start KV cache store process + store_process.start() + + print("kvshare retrieve start") + # Start KV cache retrieve process + retrieve_process.start() + print("kvshare retrieve done") + store_process.join() + retrieve_process.join() + # Clean up the processes + retrieve_process.terminate() + remote_server_process.terminate() + remote_server_process.wait() + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--remote_server", + type=str, + default="lm", + help="remote lmcache server type. 'lm' or 'redis'", + ) + parser.add_argument( + "--lm_port", type=int, default=8100, help="lm server port" + ) + parser.add_argument( + "--redis_port", type=int, default=6379, help="redis server port" + ) + parser.add_argument( + "--tp_size", type=int, default=1, help="tensor parallel size" + ) + + return parser.parse_args() + + +if __name__ == "__main__": + main() diff --git a/examples/lmcache/kv_cache_sharing_lmcache_v1_tp2.py b/examples/lmcache/kv_cache_sharing_lmcache_v1_tp2.py new file mode 100644 index 000000000..d1d50e5c4 --- /dev/null +++ b/examples/lmcache/kv_cache_sharing_lmcache_v1_tp2.py @@ -0,0 +1,129 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +This file demonstrates the example usage of remote KV cache sharing +with LMCache. +We will launch 2 vllm instances, and launch an additional LMCache server. +KV cache is transferred in the following manner: +(1) vLLM instance 1 -> LMCache server (KV cache store). +(2) LMCache server -> vLLM instance 2 (KV cache reuse/retrieve). + +Note that lmcache needs to be installed to run this example. +Learn more about LMCache in https://github.com/LMCache/LMCache. +""" +import os +import subprocess +import time +from multiprocessing import Event, Process + +from lmcache.integration.vllm.utils import ENGINE_NAME +from lmcache.v1.cache_engine import LMCacheEngineBuilder + +from vllm import LLM, SamplingParams +from vllm.config import KVTransferConfig + +# LMCache-related environment variables +# The port to start LMCache server +port = 8100 +# LMCache is set to use 256 tokens per chunk +os.environ["LMCACHE_CHUNK_SIZE"] = "256" +# Disable local CPU backend in LMCache +os.environ["LMCACHE_LOCAL_CPU"] = "False" +# Set local CPU memory buffer limit to 5.0 GB +os.environ["LMCACHE_MAX_LOCAL_CPU_SIZE"] = "5.0" +# Set the remote URL for LMCache server +os.environ["LMCACHE_REMOTE_URL"] = f"lm://localhost:{port}" +# Set the serializer/deserializer between vllm and LMCache server +# `naive` indicates using raw bytes of the tensor without any compression +os.environ["LMCACHE_REMOTE_SERDE"] = "naive" +MODEL = "meta-llama/Llama-3.2-1B-Instruct" +#prompts = [ +# "Hello, how are you?" * 1000, +#] +prompts = [ + "San Francisco is a", +] + + +def run_store(store_done, prompts): + # We use GPU 0 for KV cache store process. + sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10) + + ktc = KVTransferConfig(kv_connector="LMCacheConnectorV1", kv_role="kv_producer") + # Set GPU memory utilization to 0.8 for an A40 GPU with 40GB + # memory. Reduce the value if your GPU has less memory. + llm = LLM(model=MODEL, + kv_transfer_config=ktc, + max_model_len=8000, + gpu_memory_utilization=0.8, + tensor_parallel_size=2, + enforce_eager=False) + + outputs = llm.generate(prompts, sampling_params) + for output in outputs: + generated_text = output.outputs[0].text + print(f"Producer Generated text: {generated_text!r}") + print("KV cache store is finished.") + store_done.set() + + # Clean up lmcache backend + LMCacheEngineBuilder.destroy(ENGINE_NAME) + + +def run_retrieve(store_done, prompts, timeout=1): + # We use GPU 1 for KV cache retrieve process. + sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=20) + # sampling_params = SamplingParams(temperature=0, max_tokens=100) + ktc = KVTransferConfig(kv_connector="LMCacheConnectorV1", kv_role="kv_consumer") + # Set GPU memory utilization to 0.8 for an A40 GPU with 40GB + # of memory. Reduce the value if your GPU has less memory. + llm = LLM(model=MODEL, + kv_transfer_config=ktc, + max_model_len=8000, + gpu_memory_utilization=0.8, + tensor_parallel_size=2, + enforce_eager=False) + + print("Waiting for KV cache store to finish...") + store_done.wait() + time.sleep(timeout) + + outputs = llm.generate(prompts, sampling_params) + for output in outputs: + generated_text = output.outputs[0].text + print(f"Consumer Generated text: {generated_text!r}") + + # Clean up lmcache backend + LMCacheEngineBuilder.destroy(ENGINE_NAME) + + +def run_lmcache_server(port): + server_proc = subprocess.Popen( + ["python", "-m", "lmcache.v1.server", "localhost", + str(port)]) + return server_proc + + +def main(): + store_done = Event() + store_process = Process(target=run_store, args=(store_done, prompts)) + retrieve_process = Process(target=run_retrieve, args=(store_done, prompts)) + + lmcache_server_process = run_lmcache_server(port) + print("libin kvshare store start") + # Start KV cache store process + store_process.start() + + print("libin kvshare retrieve start") + # Start KV cache retrieve process + retrieve_process.start() + print("libin kvshare retrieve done") + store_process.join() + retrieve_process.join() + # Clean up the processes + retrieve_process.terminate() + lmcache_server_process.terminate() + lmcache_server_process.wait() + + +if __name__ == "__main__": + main() diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index 8c16cdda8..fbc9fe0d4 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -84,8 +84,12 @@ from vllm_gaudi.extension.ops import LoraMask as LoraMask from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks +from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import NixlConnectorMetadata from vllm.v1.core.sched.output import GrammarOutput +from lmcache.integration.vllm.vllm_v1_adapter import LMCacheConnectorMetadata + + if TYPE_CHECKING: import xgrammar as xgr import xgrammar.kernels.apply_token_bitmask_inplace_torch_compile as xgr_torch_compile # noqa: E501 @@ -1446,12 +1450,15 @@ def _get_prompts_and_decodes( requests_type = {} if scheduler_output.kv_connector_metadata: - for req in scheduler_output.kv_connector_metadata.reqs_to_save: - requests_type[req] = 'prefill' - for req in scheduler_output.kv_connector_metadata.reqs_to_recv: - requests_type[req] = 'decode' - requests = scheduler_output.kv_connector_metadata.reqs_to_save | \ - scheduler_output.kv_connector_metadata.reqs_to_recv + if isinstance(scheduler_output.kv_connector_metadata, NixlConnectorMetadata): + for req in scheduler_output.kv_connector_metadata.reqs_to_save: + requests_type[req] = 'prefill' + for req in scheduler_output.kv_connector_metadata.reqs_to_recv: + requests_type[req] = 'decode' + requests = scheduler_output.kv_connector_metadata.reqs_to_save | \ + scheduler_output.kv_connector_metadata.reqs_to_recv + else: + requests = scheduler_output.kv_connector_metadata.requests else: requests = None @@ -3144,7 +3151,11 @@ def sample_tokens(self, grammar_output: "GrammarOutput | None") -> ModelRunnerOu prompt_batch_idx=idx, is_prompt=True) self.profiler.record_counter(self.event_start, counters) - if not warmup_mode: + + if not warmup_mode and \ + (isinstance(scheduler_output.kv_connector_metadata, NixlConnectorMetadata) or \ + isinstance(scheduler_output.kv_connector_metadata, LMCacheConnectorMetadata)): + logger.info(f"libin debug maybe_wait_for_kv_save") self.maybe_wait_for_kv_save() if self.is_driver_worker and self.profiler.enabled: diff --git a/vllm_gaudi/v1/worker/hpu_worker.py b/vllm_gaudi/v1/worker/hpu_worker.py index bc288826a..610b1f8af 100644 --- a/vllm_gaudi/v1/worker/hpu_worker.py +++ b/vllm_gaudi/v1/worker/hpu_worker.py @@ -74,6 +74,7 @@ def __init__( self.local_rank = local_rank self.rank = rank + self.parallel_config.rank = rank self.distributed_init_method = distributed_init_method self.is_driver_worker = is_driver_worker