5555from vllm .platforms import current_platform
5656from vllm .sequence import IntermediateTensors
5757
58- from .interfaces import SupportsLoRA , SupportsPP
58+ from .interfaces import SupportsEagle3 , SupportsLoRA , SupportsPP
5959from .utils import (AutoWeightsLoader , is_pp_missing_parameter ,
6060 make_empty_intermediate_tensors_factory , make_layers ,
6161 maybe_prefix )
@@ -381,6 +381,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
381381 self .num_experts = getattr (self .config , "num_experts" , 0 )
382382 self ._init_layers (prefix , config , cache_config , quant_config )
383383 self .norm = RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
384+
385+ self .aux_hidden_state_layers = tuple [int , ...]()
386+
384387 self .make_empty_intermediate_tensors = (
385388 make_empty_intermediate_tensors_factory (
386389 ["hidden_states" , "residual" ], self .config .hidden_size ))
@@ -408,7 +411,8 @@ def forward(
408411 positions : torch .Tensor ,
409412 intermediate_tensors : Optional [IntermediateTensors ] = None ,
410413 inputs_embeds : Optional [torch .Tensor ] = None ,
411- ) -> Union [torch .Tensor , IntermediateTensors ]:
414+ ) -> Union [torch .Tensor , IntermediateTensors , tuple [torch .Tensor ,
415+ list [torch .Tensor ]]]:
412416 if get_pp_group ().is_first_rank :
413417 if inputs_embeds is not None :
414418 hidden_states = inputs_embeds
@@ -419,18 +423,29 @@ def forward(
419423 hidden_states = intermediate_tensors ["hidden_states" ]
420424 residual = intermediate_tensors ["residual" ]
421425
422- for layer in islice (self .layers , self .start_layer , self .end_layer ):
426+ aux_hidden_states = []
427+ for idx , layer in enumerate (
428+ islice (self .layers , self .start_layer , self .end_layer )):
429+ if idx in self .aux_hidden_state_layers :
430+ aux_hidden_states .append (
431+ hidden_states +
432+ residual if residual is not None else hidden_states )
423433 hidden_states , residual = layer (
424434 positions ,
425435 hidden_states ,
426436 residual ,
427437 )
438+
428439 if not get_pp_group ().is_last_rank :
429440 return IntermediateTensors ({
430441 "hidden_states" : hidden_states ,
431442 "residual" : residual
432443 })
444+
433445 hidden_states = self .norm (hidden_states )
446+
447+ if len (aux_hidden_states ) > 0 :
448+ return hidden_states , aux_hidden_states
434449 return hidden_states
435450
436451 def load_weights (self , weights : Iterable [tuple [str ,
@@ -502,7 +517,7 @@ def load_weights(self, weights: Iterable[tuple[str,
502517 return loaded_params
503518
504519
505- class MiniCPMForCausalLM (nn .Module , SupportsLoRA , SupportsPP ):
520+ class MiniCPMForCausalLM (nn .Module , SupportsLoRA , SupportsPP , SupportsEagle3 ):
506521 packed_modules_mapping = {
507522 "qkv_proj" : [
508523 "q_proj" ,
@@ -568,16 +583,36 @@ def _init_model(self, *, vllm_config: VllmConfig, prefix: str = ""):
568583 def get_input_embeddings (self , input_ids : torch .Tensor ) -> torch .Tensor :
569584 return self .model .get_input_embeddings (input_ids )
570585
586+ def set_aux_hidden_state_layers (self , layers : tuple [int , ...]) -> None :
587+ self .model .aux_hidden_state_layers = layers
588+
589+ def get_eagle3_aux_hidden_state_layers (self ) -> tuple [int , ...]:
590+ num_layers = len (self .model .layers )
591+ return (2 , num_layers // 2 , num_layers - 3 )
592+
571593 def forward (
572594 self ,
573595 input_ids : torch .Tensor ,
574596 positions : torch .Tensor ,
575597 intermediate_tensors : Optional [IntermediateTensors ] = None ,
576598 inputs_embeds : Optional [torch .Tensor ] = None ,
577- ) -> Union [torch .Tensor , IntermediateTensors ]:
578- hidden_states = self .model (input_ids , positions , intermediate_tensors ,
579- inputs_embeds ) / self .scale_width
580- return hidden_states
599+ ) -> Union [torch .Tensor , IntermediateTensors , tuple [torch .Tensor ,
600+ list [torch .Tensor ]]]:
601+ model_output = self .model (input_ids , positions , intermediate_tensors ,
602+ inputs_embeds )
603+
604+ if isinstance (model_output , tuple ) and len (model_output ) == 2 :
605+ # Aux hidden states are present.
606+ hidden_states , aux_hidden_states = model_output
607+ hidden_states = hidden_states / self .scale_width
608+ return hidden_states , aux_hidden_states
609+ else :
610+ # Only hidden states or IntermediateTensors
611+ if isinstance (model_output , IntermediateTensors ):
612+ return model_output
613+ else :
614+ hidden_states = model_output / self .scale_width
615+ return hidden_states
581616
582617 def compute_logits (
583618 self ,
0 commit comments