Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 159 additions & 15 deletions vllm_ascend/distributed/cpu_offload_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,29 @@
import torch
from vllm.attention import AttentionType
from vllm.attention.layer import Attention
from vllm.config import VllmConfig
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.utils import logger
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec,
MLAAttentionSpec)
MambaSpec, MLAAttentionSpec)

from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.cpu_offload_manager.metadata import (
MetadataServer, MetadataServerProc, MLAConfig)
from vllm_ascend.utils import vllm_version_is

if vllm_version_is("0.11.0"):
from vllm.attention.layer import Attention
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE

from vllm_ascend.models.layers.mla import AscendMultiHeadLatentAttention
else:
from vllm.attention.layer import MLAAttention
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE

if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
Expand Down Expand Up @@ -432,24 +442,39 @@ def _save_listener(self):


# Copied from vllm_ascend/worker/model_runner_v1.py.
def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]:
forward_ctx = vllm_config.compilation_config.static_forward_context
def get_kv_cache_spec_v0110(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]:
"""
Generates the KVCacheSpec by parsing the kv cache format from each
Attention module in the static forward context.
Returns:
KVCacheSpec: A dictionary mapping layer names to their KV cache
format. Layers that do not need KV cache are not included.
"""

block_size = vllm_config.cache_config.block_size
use_mla = vllm_config.model_config.use_mla
ascend_config = get_ascend_config()
use_sfa = ascend_config.use_sfa
use_sparse = hasattr(vllm_config.model_config.hf_config, "index_topk")
if vllm_config.cache_config.cache_dtype == "auto":
kv_cache_dtype = vllm_config.model_config.dtype
else:
kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
vllm_config.cache_config.cache_dtype]
kv_cache_spec: dict[str, KVCacheSpec] = {}
for layer_name, attn_module in forward_ctx.items():
if isinstance(attn_module, FusedMoE):
attn_layers = get_layers_from_vllm_config(vllm_config, Attention)
for layer_name, attn_module in attn_layers.items():
if isinstance(attn_module, AscendMultiHeadLatentAttention):
continue
assert isinstance(attn_module, Attention)

# TODO: Support other attention modules, e.g., cross-attention
# TODO(lucas): move the attention specs into the model layers like
# the attention backends
if attn_module.attn_type == AttentionType.DECODER:
if use_mla and not use_sfa:
if use_mla and not use_sparse:
kv_cache_spec[layer_name] = MLAAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=attn_module.dtype,
dtype=kv_cache_dtype,
cache_dtype_str=vllm_config.cache_config.cache_dtype)
else:
# TODO(cmq): This is a hack way to fix deepseek kvcache when
Expand All @@ -458,14 +483,133 @@ def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]:
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=attn_module.dtype)

dtype=kv_cache_dtype)
elif attn_module.attn_type in (AttentionType.ENCODER,
AttentionType.ENCODER_ONLY):
# encoder-only attention does not need KV cache.
continue
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
raise NotImplementedError
else:
raise ValueError(
f"Unknown attention type: {attn_module.attn_type}")

mamba_layers = get_layers_from_vllm_config(vllm_config, MambaBase)
if len(mamba_layers) > 0:
if (vllm_config.speculative_config is not None
and vllm_config.model_config.hf_config.model_type
not in ["qwen3_next"]):
raise NotImplementedError(
"Mamba with speculative decoding is not supported yet.")
if vllm_config.cache_config.enable_prefix_caching:
raise NotImplementedError(
"Prefix caching is not supported for Mamba yet.")
max_model_len = vllm_config.model_config.max_model_len

page_size_padded = (vllm_config.cache_config.mamba_page_size_padded)

# Set block_size to max_model_len, so that mamba model will always
# have only one block in the KV cache.
for layer_name, mamba_module in mamba_layers.items():
kv_cache_spec[layer_name] = MambaSpec(
shapes=mamba_module.get_state_shape(),
dtypes=mamba_module.get_state_dtype(),
block_size=max_model_len,
page_size_padded=page_size_padded,
mamba_type=mamba_module.mamba_type,
num_speculative_blocks=(
vllm_config.speculative_config.num_speculative_tokens
if vllm_config.speculative_config else 0),
)

return kv_cache_spec


def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]:
"""
Generates the KVCacheSpec by parsing the kv cache format from each
Attention module in the static forward context.
Returns:
KVCacheSpec: A dictionary mapping layer names to their KV cache
format. Layers that do not need KV cache are not included.
"""
if vllm_version_is("0.11.0"):
return get_kv_cache_spec_v0110(vllm_config)

block_size = vllm_config.cache_config.block_size
use_mla = vllm_config.model_config.use_mla
use_sparse = hasattr(vllm_config.model_config.hf_config, "index_topk")
if vllm_config.cache_config.cache_dtype == "auto":
kv_cache_dtype = vllm_config.model_config.dtype
else:
kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
vllm_config.cache_config.cache_dtype]
kv_cache_spec: dict[str, KVCacheSpec] = {}
attn_layers = get_layers_from_vllm_config(vllm_config, AttentionLayerBase)
for layer_name, attn_module in attn_layers.items():
if isinstance(attn_module, Attention):
# TODO: Support other attention modules, e.g., cross-attention
# TODO(lucas): move the attention specs into the model layers like
# the attention backends
if attn_module.attn_type == AttentionType.DECODER:
kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=kv_cache_dtype)
elif attn_module.attn_type in (AttentionType.ENCODER,
AttentionType.ENCODER_ONLY):
# encoder-only attention does not need KV cache.
continue
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
raise NotImplementedError
else:
raise ValueError(
f"Unknown attention type: {attn_module.attn_type}")

elif isinstance(attn_module, MLAAttention):
if use_mla and not use_sparse:
kv_cache_spec[layer_name] = MLAAttentionSpec(
block_size=block_size,
num_kv_heads=1,
head_size=attn_module.head_size,
dtype=kv_cache_dtype,
cache_dtype_str=vllm_config.cache_config.cache_dtype)
else:
# TODO(cmq): This is a hack way to fix deepseek kvcache when
# using DSA. Fix the spec in vLLM is a finnal way.
kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=block_size,
num_kv_heads=1,
head_size=attn_module.head_size,
dtype=kv_cache_dtype)

mamba_layers = get_layers_from_vllm_config(vllm_config, MambaBase)
if len(mamba_layers) > 0:
if (vllm_config.speculative_config is not None
and vllm_config.model_config.hf_config.model_type
not in ["qwen3_next"]):
raise NotImplementedError(
"Mamba with speculative decoding is not supported yet.")
if vllm_config.cache_config.enable_prefix_caching:
raise NotImplementedError(
"Prefix caching is not supported for Mamba yet.")
max_model_len = vllm_config.model_config.max_model_len

page_size_padded = (vllm_config.cache_config.mamba_page_size_padded)

# Set block_size to max_model_len, so that mamba model will always
# have only one block in the KV cache.
for layer_name, mamba_module in mamba_layers.items():
kv_cache_spec[layer_name] = MambaSpec(
shapes=mamba_module.get_state_shape(),
dtypes=mamba_module.get_state_dtype(),
block_size=max_model_len,
page_size_padded=page_size_padded,
mamba_type=mamba_module.mamba_type,
num_speculative_blocks=(
vllm_config.speculative_config.num_speculative_tokens
if vllm_config.speculative_config else 0),
)

return kv_cache_spec
9 changes: 5 additions & 4 deletions vllm_ascend/distributed/cpu_offload_manager/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import zmq
from vllm.config import KVTransferConfig, VllmConfig
from vllm.utils import logger
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec

from vllm_ascend.distributed.cpu_offload_manager.cpu_kv_cache_manager import \
CPUKVCacheManager
Expand Down Expand Up @@ -145,14 +145,15 @@ def init_cpu_kv_caches(
layer.page_size_bytes == any.page_size_bytes
for any in kv_cache_specs.values()
])
use_mla = isinstance(layer, MLAAttentionSpec)
# mla shares the same kv cache among different tp
if layer.use_mla:
if use_mla:
tp_rank = 0
if (pp_rank, tp_rank) in self.shared_memory:
return self.shared_memory[(pp_rank, tp_rank)]
available_memory = self.available_memory
shared_memory_dict = {}
if layer.use_mla:
if use_mla:
available_memory //= self.pipeline_parallel_size
available_memory //= len(kv_cache_specs)
num_blocks = available_memory // layer.page_size_bytes
Expand All @@ -170,7 +171,7 @@ def init_cpu_kv_caches(
shared_memory_dict[
layer_name] = MetadataServer._safe_create_shared_memory(
f"cpu_kv_cache_{pp_rank}_{tp_rank}_{layer_name}", nbytes)
if layer.use_mla:
if use_mla:
assert mla_config is not None
assert layer.head_size == mla_config.rope_dim + mla_config.nope_dim
self.shared_memory[(pp_rank,
Expand Down
Loading