Skip to content

Commit e61d2c2

Browse files
LDLINGLINGLINGliudanliudanhmellorluccafong
authored andcommitted
Eagle3 that supports the Minicpm3 model (vllm-project#24243)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: liudan <adan@minicpm.com> Co-authored-by: liudan <liudan@qq.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Lucia Fang <116399278+luccafong@users.noreply.github.com>
1 parent f27d60b commit e61d2c2

File tree

2 files changed

+44
-9
lines changed

2 files changed

+44
-9
lines changed

vllm/config/speculative.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -540,7 +540,7 @@ def _verify_args(self) -> Self:
540540
"speculative decoding is > 1, but got "
541541
f"{self.disable_by_batch_size=}")
542542

543-
eagle3_target_supported = ["llama", "qwen", "gpt_oss"]
543+
eagle3_target_supported = ["llama", "qwen", "minicpm", "gpt_oss"]
544544
if self.method == "eagle3" and self.target_model_config and not any(
545545
supported_model in
546546
self.target_model_config.hf_text_config.model_type

vllm/model_executor/models/minicpm.py

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
from vllm.platforms import current_platform
5656
from vllm.sequence import IntermediateTensors
5757

58-
from .interfaces import SupportsLoRA, SupportsPP
58+
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
5959
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
6060
make_empty_intermediate_tensors_factory, make_layers,
6161
maybe_prefix)
@@ -381,6 +381,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
381381
self.num_experts = getattr(self.config, "num_experts", 0)
382382
self._init_layers(prefix, config, cache_config, quant_config)
383383
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
384+
385+
self.aux_hidden_state_layers = tuple[int, ...]()
386+
384387
self.make_empty_intermediate_tensors = (
385388
make_empty_intermediate_tensors_factory(
386389
["hidden_states", "residual"], self.config.hidden_size))
@@ -408,7 +411,8 @@ def forward(
408411
positions: torch.Tensor,
409412
intermediate_tensors: Optional[IntermediateTensors] = None,
410413
inputs_embeds: Optional[torch.Tensor] = None,
411-
) -> Union[torch.Tensor, IntermediateTensors]:
414+
) -> Union[torch.Tensor, IntermediateTensors, tuple[torch.Tensor,
415+
list[torch.Tensor]]]:
412416
if get_pp_group().is_first_rank:
413417
if inputs_embeds is not None:
414418
hidden_states = inputs_embeds
@@ -419,18 +423,29 @@ def forward(
419423
hidden_states = intermediate_tensors["hidden_states"]
420424
residual = intermediate_tensors["residual"]
421425

422-
for layer in islice(self.layers, self.start_layer, self.end_layer):
426+
aux_hidden_states = []
427+
for idx, layer in enumerate(
428+
islice(self.layers, self.start_layer, self.end_layer)):
429+
if idx in self.aux_hidden_state_layers:
430+
aux_hidden_states.append(
431+
hidden_states +
432+
residual if residual is not None else hidden_states)
423433
hidden_states, residual = layer(
424434
positions,
425435
hidden_states,
426436
residual,
427437
)
438+
428439
if not get_pp_group().is_last_rank:
429440
return IntermediateTensors({
430441
"hidden_states": hidden_states,
431442
"residual": residual
432443
})
444+
433445
hidden_states = self.norm(hidden_states)
446+
447+
if len(aux_hidden_states) > 0:
448+
return hidden_states, aux_hidden_states
434449
return hidden_states
435450

436451
def load_weights(self, weights: Iterable[tuple[str,
@@ -502,7 +517,7 @@ def load_weights(self, weights: Iterable[tuple[str,
502517
return loaded_params
503518

504519

505-
class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
520+
class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
506521
packed_modules_mapping = {
507522
"qkv_proj": [
508523
"q_proj",
@@ -568,16 +583,36 @@ def _init_model(self, *, vllm_config: VllmConfig, prefix: str = ""):
568583
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
569584
return self.model.get_input_embeddings(input_ids)
570585

586+
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
587+
self.model.aux_hidden_state_layers = layers
588+
589+
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
590+
num_layers = len(self.model.layers)
591+
return (2, num_layers // 2, num_layers - 3)
592+
571593
def forward(
572594
self,
573595
input_ids: torch.Tensor,
574596
positions: torch.Tensor,
575597
intermediate_tensors: Optional[IntermediateTensors] = None,
576598
inputs_embeds: Optional[torch.Tensor] = None,
577-
) -> Union[torch.Tensor, IntermediateTensors]:
578-
hidden_states = self.model(input_ids, positions, intermediate_tensors,
579-
inputs_embeds) / self.scale_width
580-
return hidden_states
599+
) -> Union[torch.Tensor, IntermediateTensors, tuple[torch.Tensor,
600+
list[torch.Tensor]]]:
601+
model_output = self.model(input_ids, positions, intermediate_tensors,
602+
inputs_embeds)
603+
604+
if isinstance(model_output, tuple) and len(model_output) == 2:
605+
# Aux hidden states are present.
606+
hidden_states, aux_hidden_states = model_output
607+
hidden_states = hidden_states / self.scale_width
608+
return hidden_states, aux_hidden_states
609+
else:
610+
# Only hidden states or IntermediateTensors
611+
if isinstance(model_output, IntermediateTensors):
612+
return model_output
613+
else:
614+
hidden_states = model_output / self.scale_width
615+
return hidden_states
581616

582617
def compute_logits(
583618
self,

0 commit comments

Comments
 (0)