From 506ae6a713dc955eabd7eab63f8115ec1c0418ef Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sat, 8 Nov 2025 15:51:49 +0800 Subject: [PATCH 01/22] support megatron MTP --- .../Megatron-SWIFT/Command-line-parameters.md | 5 + .../Megatron-SWIFT/Command-line-parameters.md | 5 + swift/megatron/argument/megatron_args.py | 4 + swift/megatron/model/gpt_model.py | 187 ++++++++++++++---- 4 files changed, 168 insertions(+), 33 deletions(-) diff --git a/docs/source/Megatron-SWIFT/Command-line-parameters.md b/docs/source/Megatron-SWIFT/Command-line-parameters.md index 9a1c90658c..fea09478e7 100644 --- a/docs/source/Megatron-SWIFT/Command-line-parameters.md +++ b/docs/source/Megatron-SWIFT/Command-line-parameters.md @@ -214,6 +214,11 @@ - qk_head_dim: QK 投影中 head 的维度。 `q_head_dim = qk_head_dim + qk_pos_emb_head_dim`。默认为None,自动从config.json读取。 - qk_pos_emb_head_dim: QK 投影中位置嵌入的维度。默认为None,自动从config.json读取。 +**MTP参数** +- mtp_num_layers: Number of Multi-Token Prediction (MTP) layers. MTP extends the prediction scope at each position to multiple future tokens. This MTP implementation uses D sequential modules to sequentially predict D additional tokens. Default is None. +- mtp_loss_scaling_factor: Scaling factor of Multi-Token Prediction (MTP) loss. We compute the average of MTP losses across all depths, then multiply it by this scaling factor to obtain the overall MTP loss, which serves as an additional training objective. Default is 0.1. + + **Tuner参数**: - train_type: 可选为'lora'和'full'。默认为'full'。 - 🔥freeze_llm: 该参数只对多模态模型生效,可用于全参数训练和LoRA训练,但会产生不同的效果。若是全参数训练,将freeze_llm设置为True会将LLM部分权重进行冻结;若是LoRA训练且`target_modules`设置为'all-linear',将freeze_llm设置为True将会取消在LLM部分添加LoRA模块。该参数默认为False。 diff --git a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md index 7062d13d23..ba7067adf8 100644 --- a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md +++ b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md @@ -227,6 +227,11 @@ For guidance on selecting parallelization strategies, please refer to the [Train - qk_head_dim: Dimension of the head in the QK projection. `q_head_dim = qk_head_dim + qk_pos_emb_head_dim`. Default is None and will be automatically read from config.json. - qk_pos_emb_head_dim: Dimension of the position embedding in the QK projection. Default is None and will be automatically read from config.json. + +**MTP Parameters** +- mtp_num_layers: 多token预测(MTP)层的数量。MTP将每个位置的预测范围扩展到多个未来token。此MTP实现使用D个顺序模块依次预测D个额外的token。默认为None。 +- mtp_loss_scaling_factor: 多token预测(MTP)损失的缩放因子。我们计算所有深度上MTP损失的平均值,然后乘以该缩放因子得到总体MTP损失,它将作为一个额外的训练目标。默认为0.1。 + **Tuner Parameters**: - train_type: Options are `'lora'` and `'full'`. Default is `'full'`. diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py index d2b3c96217..f6611b02b9 100644 --- a/swift/megatron/argument/megatron_args.py +++ b/swift/megatron/argument/megatron_args.py @@ -313,6 +313,10 @@ class MegatronArguments(ExtraMegatronArguments): qk_head_dim: Optional[int] = None qk_pos_emb_head_dim: Optional[int] = None + # mtp + mtp_num_layers: Optional[int] = None + mtp_loss_scaling_factor: float = 0.1 + # fp8 fp8_format: Literal['e4m3', 'hybrid'] = None fp8_recipe: Literal['tensorwise', 'delayed', 'mxfp8', 'blockwise'] = 'delayed' diff --git a/swift/megatron/model/gpt_model.py b/swift/megatron/model/gpt_model.py index b3aa2b4f8d..ea07cdbba4 100644 --- a/swift/megatron/model/gpt_model.py +++ b/swift/megatron/model/gpt_model.py @@ -4,16 +4,17 @@ from typing import Any, Dict, Literal, Optional, Tuple import torch -from megatron.core import InferenceParams from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk from megatron.core.dist_checkpointing.mapping import ShardedStateDict from megatron.core.extensions.transformer_engine import TELinear +from megatron.core.inference.contexts import BaseInferenceContext from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding from megatron.core.models.gpt import GPTModel as McoreGPTModel from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.utils import WrappedTensor, deprecate_inference_params from megatron.training import get_args from swift.utils import get_logger @@ -145,30 +146,20 @@ def apply_rotary_pos_emb(*args, **kwargs): finally: attention.apply_rotary_pos_emb = origin_apply_rotary_pos_emb - # Code borrowed from NVIDIA/Megatron-LM - def forward( + def _preprocess( self, input_ids: torch.Tensor, position_ids: torch.Tensor, - attention_mask: torch.Tensor = None, decoder_input: torch.Tensor = None, - labels: torch.Tensor = None, - inference_params: InferenceParams = None, + inference_context: BaseInferenceContext = None, packed_seq_params: PackedSeqParams = None, - extra_block_kwargs: dict = None, - runtime_gather_output: Optional[bool] = None, - **kwargs, - ) -> torch.Tensor: - """Forward function of the GPT Model This function passes the input tensors - through the embedding layer, and then the decoeder and finally into the post - processing layer (optional). - - It either returns the Loss values if labels are given or the final hidden units + ): + """Preprocesses inputs for the transformer decoder. - Args: - runtime_gather_output (bool): Gather output at runtime. Default None means - `parallel_output` arg in the constructor will be used. + Applies embeddings to input tokens, or uses `decoder_input` from a previous + pipeline stage. Also sets up rotary positional embeddings. """ + # If decoder_input is provided (not None), then input_ids and position_ids are ignored. # Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input. @@ -185,20 +176,23 @@ def forward( if decoder_input is not None and self.training and torch.is_grad_enabled() and not decoder_input.requires_grad: # fix LoRA incompatibility with gradient checkpointing decoder_input = decoder_input.requires_grad_(True) + # Rotary positional embeddings (embedding is None for PP intermediate devices) rotary_pos_emb = None rotary_pos_cos = None rotary_pos_sin = None if self.position_embedding_type in {'rope', 'mrope'}: - if not self.training and self.config.flash_decode and inference_params: + if not self.training and self.config.flash_decode and inference_context: + assert (inference_context.is_static_batching() + ), 'GPTModel currently only supports static inference batching.' # Flash decoding uses precomputed cos and sin for RoPE rotary_pos_cos, rotary_pos_sin = self.rotary_pos_emb_cache.setdefault( - inference_params.max_sequence_length, - self.rotary_pos_emb.get_cos_sin(inference_params.max_sequence_length), + inference_context.max_sequence_length, + self.rotary_pos_emb.get_cos_sin(inference_context.max_sequence_length), ) else: - rotary_seq_len = RotaryEmbedding.get_rotary_seq_len(self, inference_params, self.decoder, decoder_input, - self.config, packed_seq_params) + rotary_seq_len = RotaryEmbedding.get_rotary_seq_len(self, inference_context, self.decoder, + decoder_input, self.config, packed_seq_params) if self.hf_rope_scaling is not None: attention_scaling = dynamic_rope_update(self, self.rotary_pos_emb.inv_freq, rotary_seq_len) if attention_scaling is not None: @@ -214,20 +208,63 @@ def forward( rotary_seq_len, packed_seq=packed_seq_params is not None and packed_seq_params.qkv_format == 'thd', ) + if ((self.config.enable_cuda_graph or self.config.flash_decode) and rotary_pos_cos is not None - and inference_params): + and inference_context and inference_context.is_static_batching() and not self.training): + current_batch_size = input_ids.shape[0] sequence_len_offset = torch.tensor( - [inference_params.sequence_len_offset] * inference_params.current_batch_size, + [inference_context.sequence_len_offset] * current_batch_size, dtype=torch.int32, device=rotary_pos_cos.device, # Co-locate this with the rotary tensors ) else: sequence_len_offset = None - # Run decoder. - with self._patch_apply_rotary_pos_emb(): - hidden_states = self.decoder( - hidden_states=decoder_input, + # Wrap decoder_input to allow the decoder (TransformerBlock) to delete the + # reference held by this caller function, enabling early garbage collection for + # inference. Skip wrapping if decoder_input is logged after decoder completion. + if (inference_context is not None and not self.training and not has_config_logger_enabled(self.config)): + decoder_input = WrappedTensor(decoder_input) + + return decoder_input, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset + + def _postprocess( + self, + hidden_states, + input_ids, + position_ids, + labels, + rotary_pos_emb, + rotary_pos_cos, + rotary_pos_sin, + mtp_in_postprocess=None, + loss_mask=None, + decoder_input=None, + attention_mask=None, + inference_params=None, + packed_seq_params=None, + sequence_len_offset=None, + runtime_gather_output=None, + extra_block_kwargs=None, + inference_context=None, + ): + """Postprocesses decoder hidden states to generate logits or compute loss. + + Applies Multi-Token Prediction if enabled, generates output logits through + the output layer, and computes language model loss when labels are provided. + """ + # logits and loss + output_weight = None + if self.share_embeddings_and_output_weights: + output_weight = self.shared_embedding_or_output_weight() + + if mtp_in_postprocess: + hidden_states = self.mtp( + input_ids=input_ids, + position_ids=position_ids, + labels=labels, + loss_mask=loss_mask, + hidden_states=hidden_states, attention_mask=attention_mask, inference_params=inference_params, rotary_pos_emb=rotary_pos_emb, @@ -235,17 +272,28 @@ def forward( rotary_pos_sin=rotary_pos_sin, packed_seq_params=packed_seq_params, sequence_len_offset=sequence_len_offset, + embedding=self.embedding, + output_layer=self.output_layer, + output_weight=output_weight, + runtime_gather_output=runtime_gather_output, + compute_language_model_loss=self.compute_language_model_loss, **(extra_block_kwargs or {}), - **kwargs, ) if not self.post_process: return hidden_states + if (not self.training and inference_context is not None + and inference_context.materialize_only_last_token_logits): + if inference_context.is_static_batching(): + hidden_states = hidden_states[-1:, :, :] + else: + # Reshape [B, 1, H] to [1, B, H] → extract each sample’s true last‐token hidden + # state ([B, H]) → unsqueeze back to [1, B, H] + # (so that the output layer, which expects S×B×H, receives only the final token) + hidden_states = inference_context.last_token_logits(hidden_states.squeeze(1).unsqueeze(0)).unsqueeze(1) + # logits and loss - output_weight = None - if self.share_embeddings_and_output_weights: - output_weight = self.shared_embedding_or_output_weight() args = get_args() if args.task_type == 'causal_lm': logits, _ = self.output_layer( @@ -271,5 +319,78 @@ def forward( return loss + # Code borrowed from NVIDIA/Megatron-LM + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + attention_mask: torch.Tensor, + decoder_input: torch.Tensor = None, + labels: torch.Tensor = None, + inference_context: BaseInferenceContext = None, + packed_seq_params: PackedSeqParams = None, + extra_block_kwargs: dict = None, + runtime_gather_output: Optional[bool] = None, + *, + inference_params: Optional[BaseInferenceContext] = None, + loss_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """Forward function of the GPT Model This function passes the input tensors + through the embedding layer, and then the decoeder and finally into the post + processing layer (optional). + + It either returns the Loss values if labels are given or the final hidden units + + Args: + runtime_gather_output (bool): Gather output at runtime. Default None means + `parallel_output` arg in the constructor will be used. + """ + + inference_context = deprecate_inference_params(inference_context, inference_params) + + decoder_input, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset = ( + self._preprocess( + input_ids=input_ids, + position_ids=position_ids, + decoder_input=decoder_input, + inference_context=inference_context, + packed_seq_params=packed_seq_params, + )) + # Run decoder. + with self._patch_apply_rotary_pos_emb(): + hidden_states = self.decoder( + hidden_states=decoder_input, + attention_mask=attention_mask, + inference_params=inference_params, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + **(extra_block_kwargs or {}), + **kwargs, + ) + + return self._postprocess( + hidden_states=hidden_states, + input_ids=input_ids, + position_ids=position_ids, + labels=labels, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + mtp_in_postprocess=self.mtp_process, + loss_mask=loss_mask, + decoder_input=decoder_input, + attention_mask=attention_mask, + inference_params=inference_params, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + runtime_gather_output=runtime_gather_output, + extra_block_kwargs=extra_block_kwargs, + inference_context=inference_context, + ) + def get_input_tensor(self): return self.decoder.input_tensor From 6738b66561a0ee177c384c0187438404005e57f6 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 10 Nov 2025 21:59:15 +0800 Subject: [PATCH 02/22] update --- swift/megatron/model/gpt_model.py | 79 +------------------------------ 1 file changed, 1 insertion(+), 78 deletions(-) diff --git a/swift/megatron/model/gpt_model.py b/swift/megatron/model/gpt_model.py index 9bc438413d..0aaa563277 100644 --- a/swift/megatron/model/gpt_model.py +++ b/swift/megatron/model/gpt_model.py @@ -288,12 +288,8 @@ def forward( rotary_pos_sin=rotary_pos_sin, packed_seq_params=packed_seq_params, sequence_len_offset=sequence_len_offset, - embedding=self.embedding, - output_layer=self.output_layer, - output_weight=output_weight, - runtime_gather_output=runtime_gather_output, - compute_language_model_loss=self.compute_language_model_loss, **(extra_block_kwargs or {}), + **kwargs, ) args = get_args() @@ -317,78 +313,5 @@ def forward( inference_context=inference_context, ) - # Code borrowed from NVIDIA/Megatron-LM - def forward( - self, - input_ids: torch.Tensor, - position_ids: torch.Tensor, - attention_mask: torch.Tensor, - decoder_input: torch.Tensor = None, - labels: torch.Tensor = None, - inference_context: BaseInferenceContext = None, - packed_seq_params: PackedSeqParams = None, - extra_block_kwargs: dict = None, - runtime_gather_output: Optional[bool] = None, - *, - inference_params: Optional[BaseInferenceContext] = None, - loss_mask: Optional[torch.Tensor] = None, - **kwargs, - ) -> torch.Tensor: - """Forward function of the GPT Model This function passes the input tensors - through the embedding layer, and then the decoeder and finally into the post - processing layer (optional). - - It either returns the Loss values if labels are given or the final hidden units - - Args: - runtime_gather_output (bool): Gather output at runtime. Default None means - `parallel_output` arg in the constructor will be used. - """ - - inference_context = deprecate_inference_params(inference_context, inference_params) - - decoder_input, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset = ( - self._preprocess( - input_ids=input_ids, - position_ids=position_ids, - decoder_input=decoder_input, - inference_context=inference_context, - packed_seq_params=packed_seq_params, - )) - # Run decoder. - with self._patch_apply_rotary_pos_emb(): - hidden_states = self.decoder( - hidden_states=decoder_input, - attention_mask=attention_mask, - inference_params=inference_params, - rotary_pos_emb=rotary_pos_emb, - rotary_pos_cos=rotary_pos_cos, - rotary_pos_sin=rotary_pos_sin, - packed_seq_params=packed_seq_params, - sequence_len_offset=sequence_len_offset, - **(extra_block_kwargs or {}), - **kwargs, - ) - - return self._postprocess( - hidden_states=hidden_states, - input_ids=input_ids, - position_ids=position_ids, - labels=labels, - rotary_pos_emb=rotary_pos_emb, - rotary_pos_cos=rotary_pos_cos, - rotary_pos_sin=rotary_pos_sin, - mtp_in_postprocess=self.mtp_process, - loss_mask=loss_mask, - decoder_input=decoder_input, - attention_mask=attention_mask, - inference_params=inference_params, - packed_seq_params=packed_seq_params, - sequence_len_offset=sequence_len_offset, - runtime_gather_output=runtime_gather_output, - extra_block_kwargs=extra_block_kwargs, - inference_context=inference_context, - ) - def get_input_tensor(self): return self.decoder.input_tensor From b460d2847f4ef7464eeaf0fb6189c58a9c755a4c Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 18 Nov 2025 16:42:25 +0800 Subject: [PATCH 03/22] update --- .../Megatron-SWIFT/Command-line-parameters.md | 5 +-- .../Megatron-SWIFT/Command-line-parameters.md | 4 +- swift/megatron/model/gpt_bridge.py | 44 ++++++++++++++++++- swift/megatron/trainers/base.py | 4 +- 4 files changed, 50 insertions(+), 7 deletions(-) diff --git a/docs/source/Megatron-SWIFT/Command-line-parameters.md b/docs/source/Megatron-SWIFT/Command-line-parameters.md index 25ddd68fbc..28e6c0eeca 100644 --- a/docs/source/Megatron-SWIFT/Command-line-parameters.md +++ b/docs/source/Megatron-SWIFT/Command-line-parameters.md @@ -219,9 +219,8 @@ - qk_pos_emb_head_dim: QK 投影中位置嵌入的维度。默认为None,自动从config.json读取。 **MTP参数** -- mtp_num_layers: Number of Multi-Token Prediction (MTP) layers. MTP extends the prediction scope at each position to multiple future tokens. This MTP implementation uses D sequential modules to sequentially predict D additional tokens. Default is None. -- mtp_loss_scaling_factor: Scaling factor of Multi-Token Prediction (MTP) loss. We compute the average of MTP losses across all depths, then multiply it by this scaling factor to obtain the overall MTP loss, which serves as an additional training objective. Default is 0.1. - +- mtp_num_layers: 多token预测(MTP)层的数量。MTP将每个位置的预测范围扩展到多个未来token。此MTP实现使用D个顺序模块依次预测D个额外的token。默认为None。 +- mtp_loss_scaling_factor: 多token预测(MTP)损失的缩放因子。我们计算所有深度上MTP损失的平均值,然后乘以该缩放因子得到总体MTP损失,它将作为一个额外的训练目标。默认为0.1。 **Tuner参数**: - train_type: 可选为'lora'和'full'。默认为'full'。 diff --git a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md index 84bc68050b..2a8674b0ee 100644 --- a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md +++ b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md @@ -233,8 +233,8 @@ For guidance on selecting parallelization strategies, please refer to the [Train **MTP Parameters** -- mtp_num_layers: 多token预测(MTP)层的数量。MTP将每个位置的预测范围扩展到多个未来token。此MTP实现使用D个顺序模块依次预测D个额外的token。默认为None。 -- mtp_loss_scaling_factor: 多token预测(MTP)损失的缩放因子。我们计算所有深度上MTP损失的平均值,然后乘以该缩放因子得到总体MTP损失,它将作为一个额外的训练目标。默认为0.1。 +- mtp_num_layers: Number of Multi-Token Prediction (MTP) layers. MTP extends the prediction scope at each position to multiple future tokens. This MTP implementation uses D sequential modules to sequentially predict D additional tokens. Default is None. +- mtp_loss_scaling_factor: Scaling factor of Multi-Token Prediction (MTP) loss. We compute the average of MTP losses across all depths, then multiply it by this scaling factor to obtain the overall MTP loss, which serves as an additional training objective. Default is 0.1. **Tuner Parameters**: diff --git a/swift/megatron/model/gpt_bridge.py b/swift/megatron/model/gpt_bridge.py index 40aff04568..48c5f8e2ab 100644 --- a/swift/megatron/model/gpt_bridge.py +++ b/swift/megatron/model/gpt_bridge.py @@ -26,6 +26,7 @@ # Some ideas for LoRA conversion are referenced from: https://github.com/modelscope/ms-swift/pull/6225 class GPTBridge: hf_layers_prefix = 'model.layers' + hf_mtp_prefix = 'model.layers' hf_embed_key = 'model.embed_tokens.weight' hf_final_layernorm_key = 'model.norm.weight' hf_lm_head_key = 'lm_head.weight' @@ -79,7 +80,9 @@ def _get_tp_split_dim(self, mg_key: Optional[str]) -> Optional[int]: # mla 'linear_q_proj', 'linear_q_up_proj', - 'linear_kv_up_proj' + 'linear_kv_up_proj', + # mtp + 'eh_proj', } if self.args.task_type == 'causal_lm': dim0_keys.add('output_layer') @@ -1018,6 +1021,23 @@ def _convert(self, mg_models, hf_state_dict, hf_prefix: str, to_mcore: bool, tqd else: yield from list(self._add_prefix(res, hf_prefix).items()) hf_state_dict = {} + + if not to_mcore or is_pp_last_stage and self.args.mtp_num_layers: + layer_idx = 0 + while layer_idx < self.args.mtp_num_layers: + mtp_layer = mg_model.mtp.layers[layer_idx] if hasattr(mg_model, 'mtp') else None + if self.hf_mtp_prefix == self.hf_layers_prefix: + hf_layer_idx = layer_idx + self.args.num_layers + else: + hf_layer_idx = layer_idx + res = self._convert_mtp_layer(mtp_layer, hf_state_dict, f'{self.hf_mtp_prefix}.', hf_layer_idx, + to_mcore) + layer_idx += 1 + if to_mcore: + yield + else: + yield from list(self._add_prefix(res, hf_prefix).items()) + hf_state_dict = {} if not to_mcore or is_pp_last_stage: hf_state_dict.update(self._convert_post_process(mg_model, hf_state_dict, '', to_mcore)) if to_mcore: @@ -1026,6 +1046,28 @@ def _convert(self, mg_models, hf_state_dict, hf_prefix: str, to_mcore: bool, tqd hf_state_dict = self._convert_hf_state_dict(hf_state_dict, to_mcore) yield from list(self._add_prefix(hf_state_dict, hf_prefix).items()) + def _convert_mtp_layer(self, mg_layer, hf_state_dict, hf_prefix: str, layer_idx: int, to_mcore: bool): + hf_prefix = f'{hf_prefix}{layer_idx}.' + if to_mcore: + hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) + else: + hf_state_dict = {} + if not to_mcore: + # TODO: 'embed_tokens.weight', 'shared_head.head.weight' + pass + for key in ['enorm.weight', 'hnorm.weight', 'eh_proj.weight']: + self._set_state_dict(mg_layer, key, hf_state_dict, key, to_mcore) + if layer_idx >= len(self.hf_layers): + layer_idx = -1 + hf_state_dict.update(self._set_layer_attn(mg_layer.transformer_layer, hf_state_dict, layer_idx, to_mcore)) + hf_state_dict.update(self._set_layer_mlp(mg_layer.transformer_layer, hf_state_dict, layer_idx, to_mcore)) + self._set_state_dict(mg_layer, 'final_layernorm.weight', hf_state_dict, 'shared_head.norm.weight', to_mcore) + if to_mcore: + hf_state_dict = {} + else: + hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix) + return hf_state_dict + def load_weights(self, mg_model, hf_model_dir: str, is_peft_format: bool = False, adapter_name: str = 'default'): self._is_peft_format = is_peft_format self._adapter_name = adapter_name diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index 999745acbf..4eebeb2eeb 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -785,6 +785,7 @@ def training_log(self, loss_dict, total_loss_dict, learning_rate, decoupled_lear track_names.append('load_balancing_loss') if args.moe_z_loss_coeff is not None: track_names.append('z_loss') + track_moe_kwargs = {'mtp_num_layers': args.mtp_num_layers} if self.mcore_013 else {} track_moe_metrics( loss_scale=moe_loss_scale, iteration=iteration, @@ -795,7 +796,8 @@ def training_log(self, loss_dict, total_loss_dict, learning_rate, decoupled_lear force_initialize=True, track_names=track_names, num_layers=args.num_layers, - moe_layer_freq=args.moe_layer_freq) + moe_layer_freq=args.moe_layer_freq, + **track_moe_kwargs) if args.mtp_num_layers is not None: mtp_loss_scale = 1 / get_num_microbatches() MTPLossLoggingHelper.track_mtp_metrics(mtp_loss_scale, iteration, writer, wandb_writer, total_loss_dict) From cfa8057adcec52bb825685594b4a8f8d756966fe Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 19 Nov 2025 10:56:27 +0800 Subject: [PATCH 04/22] update --- swift/megatron/model/gpt_bridge.py | 30 +++++++++++++++--------------- swift/megatron/model/gpt_model.py | 1 + swift/megatron/trainers/utils.py | 2 +- 3 files changed, 17 insertions(+), 16 deletions(-) diff --git a/swift/megatron/model/gpt_bridge.py b/swift/megatron/model/gpt_bridge.py index 48c5f8e2ab..b789268c91 100644 --- a/swift/megatron/model/gpt_bridge.py +++ b/swift/megatron/model/gpt_bridge.py @@ -1023,14 +1023,10 @@ def _convert(self, mg_models, hf_state_dict, hf_prefix: str, to_mcore: bool, tqd hf_state_dict = {} if not to_mcore or is_pp_last_stage and self.args.mtp_num_layers: + lm_model = getattr(mg_model, 'language_model') if self.args.is_multimodal else mg_model layer_idx = 0 while layer_idx < self.args.mtp_num_layers: - mtp_layer = mg_model.mtp.layers[layer_idx] if hasattr(mg_model, 'mtp') else None - if self.hf_mtp_prefix == self.hf_layers_prefix: - hf_layer_idx = layer_idx + self.args.num_layers - else: - hf_layer_idx = layer_idx - res = self._convert_mtp_layer(mtp_layer, hf_state_dict, f'{self.hf_mtp_prefix}.', hf_layer_idx, + res = self._convert_mtp_layer(lm_model, hf_state_dict, f'{self.hf_mtp_prefix}.', layer_idx, to_mcore) layer_idx += 1 if to_mcore: @@ -1046,21 +1042,25 @@ def _convert(self, mg_models, hf_state_dict, hf_prefix: str, to_mcore: bool, tqd hf_state_dict = self._convert_hf_state_dict(hf_state_dict, to_mcore) yield from list(self._add_prefix(hf_state_dict, hf_prefix).items()) - def _convert_mtp_layer(self, mg_layer, hf_state_dict, hf_prefix: str, layer_idx: int, to_mcore: bool): - hf_prefix = f'{hf_prefix}{layer_idx}.' + def _convert_mtp_layer(self, lm_model, hf_state_dict, hf_prefix: str, layer_idx: int, to_mcore: bool): + mtp_layer = lm_model.mtp.layers[layer_idx] if hasattr(lm_model, 'mtp') else None + if self.hf_mtp_prefix == self.hf_layers_prefix: + hf_layer_idx = layer_idx + self.args.num_layers + else: + hf_layer_idx = layer_idx + hf_prefix = f'{hf_prefix}{hf_layer_idx}.' if to_mcore: hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) else: hf_state_dict = {} - if not to_mcore: - # TODO: 'embed_tokens.weight', 'shared_head.head.weight' - pass + self._set_state_dict(lm_model, 'embedding.word_embeddings.weight', hf_state_dict, 'embed_tokens.weight', to_mcore) + self._set_state_dict(lm_model, 'output_layer.weight', hf_state_dict, 'shared_head.head.weight', to_mcore) for key in ['enorm.weight', 'hnorm.weight', 'eh_proj.weight']: self._set_state_dict(mg_layer, key, hf_state_dict, key, to_mcore) - if layer_idx >= len(self.hf_layers): - layer_idx = -1 - hf_state_dict.update(self._set_layer_attn(mg_layer.transformer_layer, hf_state_dict, layer_idx, to_mcore)) - hf_state_dict.update(self._set_layer_mlp(mg_layer.transformer_layer, hf_state_dict, layer_idx, to_mcore)) + if hf_layer_idx >= len(self.hf_layers): + hf_layer_idx = -1 + hf_state_dict.update(self._set_layer_attn(mg_layer.transformer_layer, hf_state_dict, hf_layer_idx, to_mcore)) + hf_state_dict.update(self._set_layer_mlp(mg_layer.transformer_layer, hf_state_dict, hf_layer_idx, to_mcore)) self._set_state_dict(mg_layer, 'final_layernorm.weight', hf_state_dict, 'shared_head.norm.weight', to_mcore) if to_mcore: hf_state_dict = {} diff --git a/swift/megatron/model/gpt_model.py b/swift/megatron/model/gpt_model.py index b529a73337..19d14a7b47 100644 --- a/swift/megatron/model/gpt_model.py +++ b/swift/megatron/model/gpt_model.py @@ -305,6 +305,7 @@ def forward( args = get_args() labels = labels if args.task_type == 'causal_lm' else None if mcore_013: + # MTP: https://github.com/NVIDIA/Megatron-LM/issues/1661 return self._postprocess( hidden_states=hidden_states, input_ids=input_ids, diff --git a/swift/megatron/trainers/utils.py b/swift/megatron/trainers/utils.py index 594561cdd8..88586edb1e 100644 --- a/swift/megatron/trainers/utils.py +++ b/swift/megatron/trainers/utils.py @@ -57,7 +57,7 @@ def get_batch_on_this_tp_rank(data, vp_stage=None): else: is_pp_first_stage = mpu.is_pipeline_first_stage() is_pp_last_stage = mpu.is_pipeline_last_stage() - if not is_pp_first_stage: + if not args.mtp_num_layers and not is_pp_first_stage: batch['input_ids'] = None if not is_pp_last_stage: batch['labels'] = None From 2bfeed50bfb3495916ca3f3b8f6075c1ec8d6c9f Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 19 Nov 2025 11:05:14 +0800 Subject: [PATCH 05/22] update --- swift/megatron/model/gpt_bridge.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/swift/megatron/model/gpt_bridge.py b/swift/megatron/model/gpt_bridge.py index 63a59430e7..63bd030acb 100644 --- a/swift/megatron/model/gpt_bridge.py +++ b/swift/megatron/model/gpt_bridge.py @@ -1042,8 +1042,7 @@ def _convert(self, mg_models, hf_state_dict, hf_prefix: str, to_mcore: bool, tqd lm_model = getattr(mg_model, 'language_model') if self.args.is_multimodal else mg_model layer_idx = 0 while layer_idx < self.args.mtp_num_layers: - res = self._convert_mtp_layer(lm_model, hf_state_dict, f'{self.hf_mtp_prefix}.', layer_idx, - to_mcore) + res = self._convert_mtp_layer(lm_model, hf_state_dict, f'{self.hf_mtp_prefix}.', layer_idx, to_mcore) layer_idx += 1 if to_mcore: yield @@ -1069,15 +1068,16 @@ def _convert_mtp_layer(self, lm_model, hf_state_dict, hf_prefix: str, layer_idx: hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) else: hf_state_dict = {} - self._set_state_dict(lm_model, 'embedding.word_embeddings.weight', hf_state_dict, 'embed_tokens.weight', to_mcore) + self._set_state_dict(lm_model, 'embedding.word_embeddings.weight', hf_state_dict, 'embed_tokens.weight', + to_mcore) self._set_state_dict(lm_model, 'output_layer.weight', hf_state_dict, 'shared_head.head.weight', to_mcore) for key in ['enorm.weight', 'hnorm.weight', 'eh_proj.weight']: - self._set_state_dict(mg_layer, key, hf_state_dict, key, to_mcore) + self._set_state_dict(mtp_layer, key, hf_state_dict, key, to_mcore) if hf_layer_idx >= len(self.hf_layers): hf_layer_idx = -1 - hf_state_dict.update(self._set_layer_attn(mg_layer.transformer_layer, hf_state_dict, hf_layer_idx, to_mcore)) - hf_state_dict.update(self._set_layer_mlp(mg_layer.transformer_layer, hf_state_dict, hf_layer_idx, to_mcore)) - self._set_state_dict(mg_layer, 'final_layernorm.weight', hf_state_dict, 'shared_head.norm.weight', to_mcore) + hf_state_dict.update(self._set_layer_attn(mtp_layer.transformer_layer, hf_state_dict, hf_layer_idx, to_mcore)) + hf_state_dict.update(self._set_layer_mlp(mtp_layer.transformer_layer, hf_state_dict, hf_layer_idx, to_mcore)) + self._set_state_dict(mtp_layer, 'final_layernorm.weight', hf_state_dict, 'shared_head.norm.weight', to_mcore) if to_mcore: hf_state_dict = {} else: From 4d60be4b8aea5e29e61b5d78fa3f410ca6afe1be Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 19 Nov 2025 15:15:56 +0800 Subject: [PATCH 06/22] update --- swift/llm/argument/infer_args.py | 10 ++ swift/llm/infer/infer_engine/infer_engine.py | 3 +- swift/llm/infer/infer_engine/sglang_engine.py | 9 ++ swift/megatron/init.py | 97 +++++++++++++++++++ 4 files changed, 118 insertions(+), 1 deletion(-) diff --git a/swift/llm/argument/infer_args.py b/swift/llm/argument/infer_args.py index cbf746d4dc..e0cce69ffb 100644 --- a/swift/llm/argument/infer_args.py +++ b/swift/llm/argument/infer_args.py @@ -61,6 +61,12 @@ class SglangArguments: sglang_kv_cache_dtype: str = 'auto' sglang_enable_dp_attention: bool = False sglang_disable_custom_all_reduce: bool = True + # speculative decoding + # e.g. EAGLE, EAGLE3, NEXTN + sglang_speculative_algorithm: Optional[str] = None + sglang_speculative_num_steps: Optional[int] = None + sglang_speculative_eagle_topk: Optional[int] = None + sglang_speculative_num_draft_tokens: Optional[int] = None def get_sglang_engine_kwargs(self): kwargs = { @@ -76,6 +82,10 @@ def get_sglang_engine_kwargs(self): 'kv_cache_dtype': self.sglang_kv_cache_dtype, 'enable_dp_attention': self.sglang_enable_dp_attention, 'disable_custom_all_reduce': self.sglang_disable_custom_all_reduce, + 'speculative_algorithm': self.sglang_speculative_algorithm, + 'speculative_num_steps': self.sglang_speculative_num_steps, + 'speculative_eagle_topk': self.sglang_speculative_eagle_topk, + 'speculative_num_draft_tokens': self.sglang_speculative_num_draft_tokens, } if self.task_type == 'embedding': kwargs['task_type'] = 'embedding' diff --git a/swift/llm/infer/infer_engine/infer_engine.py b/swift/llm/infer/infer_engine/infer_engine.py index 4e1c903d17..86c9583e40 100644 --- a/swift/llm/infer/infer_engine/infer_engine.py +++ b/swift/llm/infer/infer_engine/infer_engine.py @@ -32,6 +32,7 @@ def _post_init(self, template=None): self.max_model_len = self.model_info.max_model_len self.task_type = self.model_info.task_type self.config = self.model_info.config + self.max_tokens_offset = 0 if template is None: ckpt_dir = get_ckpt_dir(self.model_dir, getattr(self, 'adapters', None)) logger.info('Create the default_template for the infer_engine') @@ -220,7 +221,7 @@ def set_default_max_tokens(self, request_config: RequestConfig, inputs: Dict[str max_model_len = 8192 logger.warning( 'The current model is unable to retrieve `max_model_len`. It is set to the default value of 8192.') - max_max_tokens = max_model_len - num_tokens + max_max_tokens = max_model_len - num_tokens + self.max_tokens_offset if max_tokens is None: request_config.max_tokens = max_max_tokens elif max_max_tokens < request_config.max_tokens: diff --git a/swift/llm/infer/infer_engine/sglang_engine.py b/swift/llm/infer/infer_engine/sglang_engine.py index dd78d0b651..13049d0ada 100644 --- a/swift/llm/infer/infer_engine/sglang_engine.py +++ b/swift/llm/infer/infer_engine/sglang_engine.py @@ -48,6 +48,10 @@ def __init__( kv_cache_dtype: str = 'auto', enable_dp_attention: bool = False, disable_custom_all_reduce: bool = True, + speculative_algorithm: Optional[str] = None, + speculative_num_steps: Optional[int] = None, + speculative_eagle_topk: Optional[int] = None, + speculative_num_draft_tokens: Optional[int] = None, log_level='error', engine_kwargs: Optional[Dict[str, Any]] = None, template: Optional[Template] = None, @@ -88,6 +92,10 @@ def __init__( kv_cache_dtype=kv_cache_dtype, enable_dp_attention=enable_dp_attention, disable_custom_all_reduce=disable_custom_all_reduce, + speculative_algorithm=speculative_algorithm, + speculative_num_steps=speculative_num_steps, + speculative_eagle_topk=speculative_eagle_topk, + speculative_num_draft_tokens=speculative_num_draft_tokens, log_level=log_level, skip_tokenizer_init=True, trust_remote_code=True, @@ -98,6 +106,7 @@ def __init__( self.server_args.is_embedding = True self.engine = sgl.Engine(server_args=self.server_args) self._load_generation_config() + self.max_tokens_offset = -speculative_num_draft_tokens def _load_generation_config(self) -> None: generation_config_path = os.path.join(self.model_dir, 'generation_config.json') diff --git a/swift/megatron/init.py b/swift/megatron/init.py index d2b7b7cb95..588f3066ab 100644 --- a/swift/megatron/init.py +++ b/swift/megatron/init.py @@ -388,6 +388,102 @@ def build_tokenizer(args): global_vars.build_tokenizer = build_tokenizer +def _patch_mtp(): + from megatron.core import InferenceParams + from megatron.core.transformer.multi_token_prediction import MultiTokenPredictionLayer + from megatron.core.packed_seq_params import PackedSeqParams + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + context: torch.Tensor = None, + context_mask: torch.Tensor = None, + rotary_pos_emb: torch.Tensor = None, + rotary_pos_cos: torch.Tensor = None, + rotary_pos_sin: torch.Tensor = None, + attention_bias: torch.Tensor = None, + inference_params: InferenceParams = None, + packed_seq_params: PackedSeqParams = None, + sequence_len_offset: torch.Tensor = None, + embedding=None, + ): + """ + Execute the forward pass through the Multi-Token Prediction (MTP) layer. + + Args: + input_ids (Tensor): Input token IDs . + position_ids (Tensor): Positional IDs of the input tokens. + hidden_states (Tensor): Hidden states tensor of shape [s, b, h] where s is the + sequence length, b is the batch size, and h is the hidden size. + attention_mask (Tensor): Boolean tensor of shape [1, 1, s, s] for masking + self-attention. + context (Tensor, optional): Context tensor for cross-attention, if applicable. + context_mask (Tensor, optional): Mask for cross-attention context, if applicable. + rotary_pos_emb (Tensor, optional): Rotary positional embeddings. + rotary_pos_cos (Tensor, optional): Cosine component of rotary positional embeddings. + rotary_pos_sin (Tensor, optional): Sine component of rotary positional embeddings. + sequence_len_offset (Tensor, optional): Offset for sequence length, if applicable. + embedding (Callable): The embedding module from gpt model to compute the decoder input. + + Returns: + Union[Tensor, Tuple[Tensor, Tensor]]: The output hidden states tensor of shape + [s, b, h], and optionally the updated context tensor if cross-attention is used. + """ + # TODO: Multimodal compatible; MTP initialization + # TODO: packed_seq_params offset + assert context is None, f'multi token prediction + cross attention is not yet supported.' + input_ids, position_ids, decoder_input, hidden_states = self._get_embeddings( + input_ids=input_ids, + position_ids=position_ids, + embedding=embedding, + hidden_states=hidden_states, + ) + packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' + apply_rope_fusion = self.config.apply_rope_fusion + self.config.apply_rope_fusion = False + if packed_seq and not self.config.apply_rope_fusion: + assert position_ids.shape[0] == 1, f'position_ids.shape: {position_ids.shape}' + rotary_pos_emb = rotary_pos_emb[position_ids[0]] + if self.config.recompute_granularity == 'full' and self.training: + hidden_states = self._checkpointed_forward( + self._proj_and_transformer_layer, + hidden_states=hidden_states, + decoder_input=decoder_input, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + attention_bias=attention_bias, + inference_params=inference_params, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + ) + else: + hidden_states = self._proj_and_transformer_layer( + hidden_states=hidden_states, + decoder_input=decoder_input, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + attention_bias=attention_bias, + inference_params=inference_params, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + ) + self.config.apply_rope_fusion = apply_rope_fusion + return hidden_states, input_ids, position_ids + + MultiTokenPredictionLayer.forward = forward + + def _patch_peft_ModulesToSaveWrapper(): if version.parse(peft.__version__) >= version.parse('0.16'): from peft.utils import other as peft_module @@ -686,6 +782,7 @@ def _patch_megatron(): _patch_build_train_valid_test_datasets() _patch_mrope() _patch_megatron_tokenizer() + _patch_mtp() logging.root.setLevel(logging_level) # revert logger level from swift.megatron import tuners # patch lora try: From 8d4fe50cd4ec56f3cbd7e6ca98490797b18ab049 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 19 Nov 2025 17:17:16 +0800 Subject: [PATCH 07/22] update --- swift/megatron/init.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/swift/megatron/init.py b/swift/megatron/init.py index 588f3066ab..d2797628b2 100644 --- a/swift/megatron/init.py +++ b/swift/megatron/init.py @@ -434,7 +434,7 @@ def forward( """ # TODO: Multimodal compatible; MTP initialization # TODO: packed_seq_params offset - assert context is None, f'multi token prediction + cross attention is not yet supported.' + assert context is None, 'multi token prediction + cross attention is not yet supported.' input_ids, position_ids, decoder_input, hidden_states = self._get_embeddings( input_ids=input_ids, position_ids=position_ids, From ed5646c5e767c62a553bb6ea7cb2e632c9d3800c Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 19 Nov 2025 17:29:04 +0800 Subject: [PATCH 08/22] update --- docs/source/Instruction/Supported-models-and-datasets.md | 4 +++- docs/source_en/Instruction/Supported-models-and-datasets.md | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/docs/source/Instruction/Supported-models-and-datasets.md b/docs/source/Instruction/Supported-models-and-datasets.md index 1b4e35058f..70cb341543 100644 --- a/docs/source/Instruction/Supported-models-and-datasets.md +++ b/docs/source/Instruction/Supported-models-and-datasets.md @@ -1133,6 +1133,7 @@ |-|default|huge dataset|-|pretrain, quality|[allenai/c4](https://huggingface.co/datasets/allenai/c4)| |[bespokelabs/Bespoke-Stratos-17k](https://modelscope.cn/datasets/bespokelabs/Bespoke-Stratos-17k)|default|16710|480.7±236.1, min=266, max=3556|chat, sft, cot, r1|[bespokelabs/Bespoke-Stratos-17k](https://huggingface.co/datasets/bespokelabs/Bespoke-Stratos-17k)| |-|default|huge dataset|-|pretrain, quality|[cerebras/SlimPajama-627B](https://huggingface.co/datasets/cerebras/SlimPajama-627B)| +|[clip-benchmark/wds_voc2007_multilabel](https://modelscope.cn/datasets/clip-benchmark/wds_voc2007_multilabel)|default|2501|112.0±0.0, min=112, max=112|multilabel, multi-modal|[clip-benchmark/wds_voc2007_multilabel](https://huggingface.co/datasets/clip-benchmark/wds_voc2007_multilabel)| |[codefuse-ai/CodeExercise-Python-27k](https://modelscope.cn/datasets/codefuse-ai/CodeExercise-Python-27k)|default|27224|337.3±154.2, min=90, max=2826|chat, coding, 🔥|-| |[codefuse-ai/Evol-instruction-66k](https://modelscope.cn/datasets/codefuse-ai/Evol-instruction-66k)|default|66862|440.1±208.4, min=46, max=2661|chat, coding, 🔥|-| |[damo/MSAgent-Bench](https://modelscope.cn/datasets/damo/MSAgent-Bench)|default
mini|638149|859.2±460.1, min=38, max=3479|chat, agent, multi-round|-| @@ -1160,6 +1161,7 @@ |[modelscope/clue](https://modelscope.cn/datasets/modelscope/clue)|cmnli|391783|81.6±16.0, min=54, max=157|text-generation, classification|[clue](https://huggingface.co/datasets/clue)| |[modelscope/coco_2014_caption](https://modelscope.cn/datasets/modelscope/coco_2014_caption)|train
validation|454617|389.6±68.4, min=70, max=587|chat, multi-modal, vision, 🔥|-| |[modelscope/gsm8k](https://modelscope.cn/datasets/modelscope/gsm8k)|main|7473|88.6±21.6, min=41, max=241|qa, math|-| +|[open-r1/DAPO-Math-17k-Processed](https://modelscope.cn/datasets/open-r1/DAPO-Math-17k-Processed)|all|17398|122.3±65.2, min=41, max=1517|math, rlvr|[open-r1/DAPO-Math-17k-Processed](https://huggingface.co/datasets/open-r1/DAPO-Math-17k-Processed)| |[open-r1/verifiable-coding-problems-python](https://modelscope.cn/datasets/open-r1/verifiable-coding-problems-python)|default|35735|559.0±255.2, min=74, max=6191|grpo, code|[open-r1/verifiable-coding-problems-python](https://huggingface.co/datasets/open-r1/verifiable-coding-problems-python)| |[open-r1/verifiable-coding-problems-python-10k](https://modelscope.cn/datasets/open-r1/verifiable-coding-problems-python-10k)|default|1800|581.6±233.4, min=136, max=2022|grpo, code|[open-r1/verifiable-coding-problems-python-10k](https://huggingface.co/datasets/open-r1/verifiable-coding-problems-python-10k)| |[open-r1/verifiable-coding-problems-python-10k_decontaminated](https://modelscope.cn/datasets/open-r1/verifiable-coding-problems-python-10k_decontaminated)|default|1574|575.7±234.3, min=136, max=2022|grpo, code|[open-r1/verifiable-coding-problems-python-10k_decontaminated](https://huggingface.co/datasets/open-r1/verifiable-coding-problems-python-10k_decontaminated)| @@ -1189,7 +1191,7 @@ |[swift/RedPajama-Data-V2](https://modelscope.cn/datasets/swift/RedPajama-Data-V2)|default|huge dataset|-|pretrain, quality|[togethercomputer/RedPajama-Data-V2](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-V2)| |[swift/ScienceQA](https://modelscope.cn/datasets/swift/ScienceQA)|default|16967|101.7±55.8, min=32, max=620|multi-modal, science, vqa, quality|[derek-thomas/ScienceQA](https://huggingface.co/datasets/derek-thomas/ScienceQA)| |[swift/SlimOrca](https://modelscope.cn/datasets/swift/SlimOrca)|default|517982|405.5±442.1, min=47, max=8312|quality, en|[Open-Orca/SlimOrca](https://huggingface.co/datasets/Open-Orca/SlimOrca)| -|[swift/TextCaps](https://modelscope.cn/datasets/swift/TextCaps)|default
emb|huge dataset|-|multi-modal, en, caption, quality|[HuggingFaceM4/TextCaps](https://huggingface.co/datasets/HuggingFaceM4/TextCaps)| +|[swift/TextCaps](https://modelscope.cn/datasets/swift/TextCaps)|default
emb
rerank|huge dataset|-|multi-modal, en, caption, quality|[HuggingFaceM4/TextCaps](https://huggingface.co/datasets/HuggingFaceM4/TextCaps)| |[swift/ToolBench](https://modelscope.cn/datasets/swift/ToolBench)|default|124345|2251.7±1039.8, min=641, max=9451|chat, agent, multi-round|-| |[swift/VQAv2](https://modelscope.cn/datasets/swift/VQAv2)|default|huge dataset|-|en, vqa, quality|[HuggingFaceM4/VQAv2](https://huggingface.co/datasets/HuggingFaceM4/VQAv2)| |[swift/VideoChatGPT](https://modelscope.cn/datasets/swift/VideoChatGPT)|Generic
Temporal
Consistency|3206|87.4±48.3, min=31, max=398|chat, multi-modal, video, 🔥|[lmms-lab/VideoChatGPT](https://huggingface.co/datasets/lmms-lab/VideoChatGPT)| diff --git a/docs/source_en/Instruction/Supported-models-and-datasets.md b/docs/source_en/Instruction/Supported-models-and-datasets.md index 808cec5a85..6d5059dd93 100644 --- a/docs/source_en/Instruction/Supported-models-and-datasets.md +++ b/docs/source_en/Instruction/Supported-models-and-datasets.md @@ -1134,6 +1134,7 @@ The table below introduces information about the datasets integrated with ms-swi |-|default|huge dataset|-|pretrain, quality|[allenai/c4](https://huggingface.co/datasets/allenai/c4)| |[bespokelabs/Bespoke-Stratos-17k](https://modelscope.cn/datasets/bespokelabs/Bespoke-Stratos-17k)|default|16710|480.7±236.1, min=266, max=3556|chat, sft, cot, r1|[bespokelabs/Bespoke-Stratos-17k](https://huggingface.co/datasets/bespokelabs/Bespoke-Stratos-17k)| |-|default|huge dataset|-|pretrain, quality|[cerebras/SlimPajama-627B](https://huggingface.co/datasets/cerebras/SlimPajama-627B)| +|[clip-benchmark/wds_voc2007_multilabel](https://modelscope.cn/datasets/clip-benchmark/wds_voc2007_multilabel)|default|2501|112.0±0.0, min=112, max=112|multilabel, multi-modal|[clip-benchmark/wds_voc2007_multilabel](https://huggingface.co/datasets/clip-benchmark/wds_voc2007_multilabel)| |[codefuse-ai/CodeExercise-Python-27k](https://modelscope.cn/datasets/codefuse-ai/CodeExercise-Python-27k)|default|27224|337.3±154.2, min=90, max=2826|chat, coding, 🔥|-| |[codefuse-ai/Evol-instruction-66k](https://modelscope.cn/datasets/codefuse-ai/Evol-instruction-66k)|default|66862|440.1±208.4, min=46, max=2661|chat, coding, 🔥|-| |[damo/MSAgent-Bench](https://modelscope.cn/datasets/damo/MSAgent-Bench)|default
mini|638149|859.2±460.1, min=38, max=3479|chat, agent, multi-round|-| @@ -1161,6 +1162,7 @@ The table below introduces information about the datasets integrated with ms-swi |[modelscope/clue](https://modelscope.cn/datasets/modelscope/clue)|cmnli|391783|81.6±16.0, min=54, max=157|text-generation, classification|[clue](https://huggingface.co/datasets/clue)| |[modelscope/coco_2014_caption](https://modelscope.cn/datasets/modelscope/coco_2014_caption)|train
validation|454617|389.6±68.4, min=70, max=587|chat, multi-modal, vision, 🔥|-| |[modelscope/gsm8k](https://modelscope.cn/datasets/modelscope/gsm8k)|main|7473|88.6±21.6, min=41, max=241|qa, math|-| +|[open-r1/DAPO-Math-17k-Processed](https://modelscope.cn/datasets/open-r1/DAPO-Math-17k-Processed)|all|17398|122.3±65.2, min=41, max=1517|math, rlvr|[open-r1/DAPO-Math-17k-Processed](https://huggingface.co/datasets/open-r1/DAPO-Math-17k-Processed)| |[open-r1/verifiable-coding-problems-python](https://modelscope.cn/datasets/open-r1/verifiable-coding-problems-python)|default|35735|559.0±255.2, min=74, max=6191|grpo, code|[open-r1/verifiable-coding-problems-python](https://huggingface.co/datasets/open-r1/verifiable-coding-problems-python)| |[open-r1/verifiable-coding-problems-python-10k](https://modelscope.cn/datasets/open-r1/verifiable-coding-problems-python-10k)|default|1800|581.6±233.4, min=136, max=2022|grpo, code|[open-r1/verifiable-coding-problems-python-10k](https://huggingface.co/datasets/open-r1/verifiable-coding-problems-python-10k)| |[open-r1/verifiable-coding-problems-python-10k_decontaminated](https://modelscope.cn/datasets/open-r1/verifiable-coding-problems-python-10k_decontaminated)|default|1574|575.7±234.3, min=136, max=2022|grpo, code|[open-r1/verifiable-coding-problems-python-10k_decontaminated](https://huggingface.co/datasets/open-r1/verifiable-coding-problems-python-10k_decontaminated)| @@ -1190,7 +1192,7 @@ The table below introduces information about the datasets integrated with ms-swi |[swift/RedPajama-Data-V2](https://modelscope.cn/datasets/swift/RedPajama-Data-V2)|default|huge dataset|-|pretrain, quality|[togethercomputer/RedPajama-Data-V2](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-V2)| |[swift/ScienceQA](https://modelscope.cn/datasets/swift/ScienceQA)|default|16967|101.7±55.8, min=32, max=620|multi-modal, science, vqa, quality|[derek-thomas/ScienceQA](https://huggingface.co/datasets/derek-thomas/ScienceQA)| |[swift/SlimOrca](https://modelscope.cn/datasets/swift/SlimOrca)|default|517982|405.5±442.1, min=47, max=8312|quality, en|[Open-Orca/SlimOrca](https://huggingface.co/datasets/Open-Orca/SlimOrca)| -|[swift/TextCaps](https://modelscope.cn/datasets/swift/TextCaps)|default
emb|huge dataset|-|multi-modal, en, caption, quality|[HuggingFaceM4/TextCaps](https://huggingface.co/datasets/HuggingFaceM4/TextCaps)| +|[swift/TextCaps](https://modelscope.cn/datasets/swift/TextCaps)|default
emb
rerank|huge dataset|-|multi-modal, en, caption, quality|[HuggingFaceM4/TextCaps](https://huggingface.co/datasets/HuggingFaceM4/TextCaps)| |[swift/ToolBench](https://modelscope.cn/datasets/swift/ToolBench)|default|124345|2251.7±1039.8, min=641, max=9451|chat, agent, multi-round|-| |[swift/VQAv2](https://modelscope.cn/datasets/swift/VQAv2)|default|huge dataset|-|en, vqa, quality|[HuggingFaceM4/VQAv2](https://huggingface.co/datasets/HuggingFaceM4/VQAv2)| |[swift/VideoChatGPT](https://modelscope.cn/datasets/swift/VideoChatGPT)|Generic
Temporal
Consistency|3206|87.4±48.3, min=31, max=398|chat, multi-modal, video, 🔥|[lmms-lab/VideoChatGPT](https://huggingface.co/datasets/lmms-lab/VideoChatGPT)| From 86b6979828fe7f1f865884f215c622265ed7e7dd Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sun, 23 Nov 2025 23:12:46 +0800 Subject: [PATCH 09/22] update --- .../Megatron-SWIFT/Command-line-parameters.md | 1 + .../Megatron-SWIFT/Command-line-parameters.md | 1 + examples/infer/sglang/mtp.sh | 13 ++++++++++ swift/llm/infer/infer_engine/sglang_engine.py | 3 ++- swift/megatron/init.py | 24 +++++++++++++------ swift/megatron/model/gpt_bridge.py | 9 +++++++ 6 files changed, 43 insertions(+), 8 deletions(-) create mode 100644 examples/infer/sglang/mtp.sh diff --git a/docs/source/Megatron-SWIFT/Command-line-parameters.md b/docs/source/Megatron-SWIFT/Command-line-parameters.md index 759a33c70b..d691705bd2 100644 --- a/docs/source/Megatron-SWIFT/Command-line-parameters.md +++ b/docs/source/Megatron-SWIFT/Command-line-parameters.md @@ -220,6 +220,7 @@ **MTP参数** - mtp_num_layers: 多token预测(MTP)层的数量。MTP将每个位置的预测范围扩展到多个未来token。此MTP实现使用D个顺序模块依次预测D个额外的token。默认为None。 + - 注意:mtp_num_layers的值,将不自动从config.json获取,需手动设置。你可以参考config.json中的`num_nextn_predict_layers`字段填写该值。使用mcore-bridge时,将优先从safetensors文件中加载MTP权重,若无法找到,则进行随机初始化。 - mtp_loss_scaling_factor: 多token预测(MTP)损失的缩放因子。我们计算所有深度上MTP损失的平均值,然后乘以该缩放因子得到总体MTP损失,它将作为一个额外的训练目标。默认为0.1。 **Tuner参数**: diff --git a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md index 0167dbe822..283a76f041 100644 --- a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md +++ b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md @@ -234,6 +234,7 @@ For guidance on selecting parallelization strategies, please refer to the [Train **MTP Parameters** - mtp_num_layers: Number of Multi-Token Prediction (MTP) layers. MTP extends the prediction scope at each position to multiple future tokens. This MTP implementation uses D sequential modules to sequentially predict D additional tokens. Default is None. + - Note: The value of mtp_num_layers will not be automatically retrieved from config.json and must be set manually. You can refer to the `num_nextn_predict_layers` field in config.json to fill in this value. When using mcore-bridge, MTP weights will be loaded from safetensors files first. If not found, random initialization will be performed. - mtp_loss_scaling_factor: Scaling factor of Multi-Token Prediction (MTP) loss. We compute the average of MTP losses across all depths, then multiply it by this scaling factor to obtain the overall MTP loss, which serves as an additional training objective. Default is 0.1. **Tuner Parameters**: diff --git a/examples/infer/sglang/mtp.sh b/examples/infer/sglang/mtp.sh new file mode 100644 index 0000000000..8582f43d94 --- /dev/null +++ b/examples/infer/sglang/mtp.sh @@ -0,0 +1,13 @@ +CUDA_VISIBLE_DEVICES=0,1,2,3 \ +swift infer \ + --model ZhipuAI/GLM-4.5-Air \ + --sglang_tp_size 4 \ + --infer_backend sglang \ + --val_dataset AI-ModelScope/alpaca-gpt4-data-zh#100 \ + --sglang_context_length 8192 \ + --max_new_tokens 2048 \ + --sglang_mem_fraction_static 0.7 \ + --sglang_speculative_algorithm EAGLE \ + --sglang_speculative_eagle_topk 1 \ + --sglang_speculative_num_steps 3 \ + --sglang_speculative_num_draft_tokens 4 diff --git a/swift/llm/infer/infer_engine/sglang_engine.py b/swift/llm/infer/infer_engine/sglang_engine.py index 13049d0ada..37de0f845e 100644 --- a/swift/llm/infer/infer_engine/sglang_engine.py +++ b/swift/llm/infer/infer_engine/sglang_engine.py @@ -106,7 +106,8 @@ def __init__( self.server_args.is_embedding = True self.engine = sgl.Engine(server_args=self.server_args) self._load_generation_config() - self.max_tokens_offset = -speculative_num_draft_tokens + if speculative_num_draft_tokens is not None: + self.max_tokens_offset = -speculative_num_draft_tokens def _load_generation_config(self) -> None: generation_config_path = os.path.join(self.model_dir, 'generation_config.json') diff --git a/swift/megatron/init.py b/swift/megatron/init.py index bc7f9240ac..0c74edd818 100644 --- a/swift/megatron/init.py +++ b/swift/megatron/init.py @@ -5,7 +5,8 @@ import subprocess import sys from contextlib import contextmanager -from copy import copy +from copy import copy, deepcopy +from functools import partial from typing import List, Optional, Tuple import peft @@ -444,12 +445,23 @@ def forward( packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' apply_rope_fusion = self.config.apply_rope_fusion self.config.apply_rope_fusion = False - if packed_seq and not self.config.apply_rope_fusion: - assert position_ids.shape[0] == 1, f'position_ids.shape: {position_ids.shape}' - rotary_pos_emb = rotary_pos_emb[position_ids[0]] + if packed_seq: + packed_seq_params = deepcopy(packed_seq_params) + tensor = packed_seq_params.cu_seqlens_q + cu_seqlens = torch.concat([tensor.new_zeros(1, ), tensor[1:] - 1, tensor[-1:]]) + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + packed_seq_params.cu_seqlens_q = packed_seq_params.cu_seqlens_kv = cu_seqlens + packed_seq_params.max_seqlen_q = packed_seq_params.max_seqlen_kv = max_seqlen + if not self.config.apply_rope_fusion: + assert position_ids.shape[0] == 1, f'position_ids.shape: {position_ids.shape}' + rotary_pos_emb = rotary_pos_emb[position_ids[0]] if self.config.recompute_granularity == 'full' and self.training: hidden_states = self._checkpointed_forward( - self._proj_and_transformer_layer, + partial( + self._proj_and_transformer_layer, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + ), hidden_states=hidden_states, decoder_input=decoder_input, attention_mask=attention_mask, @@ -460,8 +472,6 @@ def forward( rotary_pos_sin=rotary_pos_sin, attention_bias=attention_bias, inference_params=inference_params, - packed_seq_params=packed_seq_params, - sequence_len_offset=sequence_len_offset, ) else: hidden_states = self._proj_and_transformer_layer( diff --git a/swift/megatron/model/gpt_bridge.py b/swift/megatron/model/gpt_bridge.py index 63bd030acb..45af09d889 100644 --- a/swift/megatron/model/gpt_bridge.py +++ b/swift/megatron/model/gpt_bridge.py @@ -1066,6 +1066,15 @@ def _convert_mtp_layer(self, lm_model, hf_state_dict, hf_prefix: str, layer_idx: hf_prefix = f'{hf_prefix}{hf_layer_idx}.' if to_mcore: hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) + if len(hf_state_dict) == 0: + logger.info_if( + f'MTP Layer {mtp_layer.layer_number} safetensors weights not found, ' + 'this part will be randomly initialized.', + cond=is_last_rank()) + for param in mtp_layer.parameters(): + if param.ndim == 2: + mtp_layer.config.init_method(param.data) + return {} else: hf_state_dict = {} self._set_state_dict(lm_model, 'embedding.word_embeddings.weight', hf_state_dict, 'embed_tokens.weight', From 351be95fdac0b7368941b5db4a4113219ff3326f Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sun, 23 Nov 2025 23:49:39 +0800 Subject: [PATCH 10/22] update --- examples/megatron/lora/glm4_5_106b.sh | 7 +++++-- examples/megatron/lora/qwen3_235b.sh | 5 ++++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/examples/megatron/lora/glm4_5_106b.sh b/examples/megatron/lora/glm4_5_106b.sh index 69783aed4a..b7ecb23609 100644 --- a/examples/megatron/lora/glm4_5_106b.sh +++ b/examples/megatron/lora/glm4_5_106b.sh @@ -1,10 +1,13 @@ -# thinking -> non-thinking +# demo: thinking -> non-thinking # 4 * 70GiB; 40s/it PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \ NPROC_PER_NODE=4 \ CUDA_VISIBLE_DEVICES=0,1,2,3 \ megatron sft \ - --load GLM-4.5-Air-mcore \ + --model ZhipuAI/GLM-4.5-Air \ + --load_safetensors true \ + --save_safetensors true \ + --mtp_num_layers 1 \ --dataset 'swift/Chinese-Qwen3-235B-2507-Distill-data-110k-SFT' \ --load_from_cache_file true \ --train_type lora \ diff --git a/examples/megatron/lora/qwen3_235b.sh b/examples/megatron/lora/qwen3_235b.sh index fc4bada207..01a48f3fb9 100644 --- a/examples/megatron/lora/qwen3_235b.sh +++ b/examples/megatron/lora/qwen3_235b.sh @@ -5,9 +5,12 @@ PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \ NPROC_PER_NODE=8 \ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ megatron sft \ - --load Qwen3-235B-A22B-Instruct-2507-mcore \ + --model Qwen/Qwen3-235B-A22B-Instruct-2507 \ --dataset 'swift/Chinese-Qwen3-235B-2507-Distill-data-110k-SFT#2000' \ 'swift/self-cognition#1000' \ + --load_safetensors true \ + --save_safetensors true \ + --merge_lora false \ --load_from_cache_file true \ --train_type lora \ --lora_rank 8 \ From 3cb78625c6d177321f81e6aff92000544e2a5adc Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 24 Nov 2025 01:24:18 +0800 Subject: [PATCH 11/22] update --- swift/megatron/model/gpt/qwen3_next.py | 8 ++++++++ swift/megatron/model/gpt_bridge.py | 25 +++++++++++++++---------- 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/swift/megatron/model/gpt/qwen3_next.py b/swift/megatron/model/gpt/qwen3_next.py index e4603352d4..19e8e9c555 100644 --- a/swift/megatron/model/gpt/qwen3_next.py +++ b/swift/megatron/model/gpt/qwen3_next.py @@ -504,6 +504,7 @@ def get_qwen3_next_transformer_layer_spec(config, vp_stage=None): class Qwen3NextBridge(GPTBridge): + hf_mtp_prefix = 'mtp.layers' def _set_state_dict(self, mg_module, @@ -537,6 +538,13 @@ def _set_layer_attn(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: boo hf_state_dict = super()._set_layer_attn(mg_layer, hf_state_dict, layer_idx, to_mcore) return hf_state_dict + def _convert_mtp_extra(self, mtp_layer, hf_state_dict, to_mcore, origin_hf_state_dict): + hf_state_dict = self._remove_prefix(origin_hf_state_dict, 'mtp.') + for mg_key, key in zip(['enorm.weight', 'hnorm.weight', 'eh_proj.weight'], + ['pre_fc_norm_embedding.weight', 'pre_fc_norm_hidden.weight', 'fc.weight']): + self._set_state_dict(mtp_layer, mg_key, hf_state_dict, key, to_mcore) + self._set_state_dict(mtp_layer, 'final_layernorm.weight', hf_state_dict, 'norm.weight', to_mcore) + register_megatron_model( MegatronModelMeta( diff --git a/swift/megatron/model/gpt_bridge.py b/swift/megatron/model/gpt_bridge.py index 45af09d889..1484e4f51d 100644 --- a/swift/megatron/model/gpt_bridge.py +++ b/swift/megatron/model/gpt_bridge.py @@ -1057,6 +1057,11 @@ def _convert(self, mg_models, hf_state_dict, hf_prefix: str, to_mcore: bool, tqd hf_state_dict = self._convert_hf_state_dict(hf_state_dict, to_mcore) yield from list(self._add_prefix(hf_state_dict, hf_prefix).items()) + def _convert_mtp_extra(self, mtp_layer, hf_state_dict, to_mcore, origin_hf_state_dict): + for key in ['enorm.weight', 'hnorm.weight', 'eh_proj.weight']: + self._set_state_dict(mtp_layer, key, hf_state_dict, key, to_mcore) + self._set_state_dict(mtp_layer, 'final_layernorm.weight', hf_state_dict, 'shared_head.norm.weight', to_mcore) + def _convert_mtp_layer(self, lm_model, hf_state_dict, hf_prefix: str, layer_idx: int, to_mcore: bool): mtp_layer = lm_model.mtp.layers[layer_idx] if hasattr(lm_model, 'mtp') else None if self.hf_mtp_prefix == self.hf_layers_prefix: @@ -1065,6 +1070,7 @@ def _convert_mtp_layer(self, lm_model, hf_state_dict, hf_prefix: str, layer_idx: hf_layer_idx = layer_idx hf_prefix = f'{hf_prefix}{hf_layer_idx}.' if to_mcore: + origin_hf_state_dict = hf_state_dict hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) if len(hf_state_dict) == 0: logger.info_if( @@ -1076,21 +1082,20 @@ def _convert_mtp_layer(self, lm_model, hf_state_dict, hf_prefix: str, layer_idx: mtp_layer.config.init_method(param.data) return {} else: + origin_hf_state_dict = {} hf_state_dict = {} - self._set_state_dict(lm_model, 'embedding.word_embeddings.weight', hf_state_dict, 'embed_tokens.weight', - to_mcore) - self._set_state_dict(lm_model, 'output_layer.weight', hf_state_dict, 'shared_head.head.weight', to_mcore) - for key in ['enorm.weight', 'hnorm.weight', 'eh_proj.weight']: - self._set_state_dict(mtp_layer, key, hf_state_dict, key, to_mcore) - if hf_layer_idx >= len(self.hf_layers): - hf_layer_idx = -1 - hf_state_dict.update(self._set_layer_attn(mtp_layer.transformer_layer, hf_state_dict, hf_layer_idx, to_mcore)) - hf_state_dict.update(self._set_layer_mlp(mtp_layer.transformer_layer, hf_state_dict, hf_layer_idx, to_mcore)) - self._set_state_dict(mtp_layer, 'final_layernorm.weight', hf_state_dict, 'shared_head.norm.weight', to_mcore) + # Weights for shared parts are not stored, refer to GLM4.6 + # self._set_state_dict(lm_model, 'embedding.word_embeddings.weight', hf_state_dict, 'embed_tokens.weight', + # to_mcore) + # self._set_state_dict(lm_model, 'output_layer.weight', hf_state_dict, 'shared_head.head.weight', to_mcore) + self._convert_mtp_extra(mtp_layer, hf_state_dict, to_mcore, origin_hf_state_dict) + hf_state_dict.update(self._set_layer_attn(mtp_layer.transformer_layer, hf_state_dict, -1, to_mcore)) + hf_state_dict.update(self._set_layer_mlp(mtp_layer.transformer_layer, hf_state_dict, -1, to_mcore)) if to_mcore: hf_state_dict = {} else: hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix) + hf_state_dict.update(origin_hf_state_dict) return hf_state_dict def load_weights(self, mg_model, hf_model_dir: str, is_peft_format: bool = False, adapter_name: str = 'default'): From 9c621430318ec88235595857835a750ddf165ebe Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 24 Nov 2025 10:38:16 +0800 Subject: [PATCH 12/22] update --- examples/models/qwen3_next/mtp.sh | 59 ++++++++++++++++++++++++++ swift/megatron/init.py | 3 +- swift/megatron/model/gpt/qwen3_next.py | 2 +- swift/megatron/model/gpt_bridge.py | 4 +- 4 files changed, 64 insertions(+), 4 deletions(-) create mode 100644 examples/models/qwen3_next/mtp.sh diff --git a/examples/models/qwen3_next/mtp.sh b/examples/models/qwen3_next/mtp.sh new file mode 100644 index 0000000000..92b3612eb4 --- /dev/null +++ b/examples/models/qwen3_next/mtp.sh @@ -0,0 +1,59 @@ +# 8 * 60GiB, 10s/it + +PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \ +NPROC_PER_NODE=8 \ +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ +megatron sft \ + --model Qwen/Qwen3-Next-80B-A3B-Instruct \ + --load_safetensors true \ + --save_safetensors true \ + --mtp_num_layers 1 \ + --dataset 'swift/Chinese-Qwen3-235B-2507-Distill-data-110k-SFT#2000' \ + 'swift/self-cognition#1000' \ + --load_from_cache_file true \ + --train_type lora \ + --lora_rank 8 \ + --lora_alpha 32 \ + --target_modules all-linear \ + --expert_model_parallel_size 4 \ + --moe_permute_fusion true \ + --moe_grouped_gemm true \ + --moe_shared_expert_overlap true \ + --moe_aux_loss_coeff 1e-6 \ + --micro_batch_size 2 \ + --global_batch_size 16 \ + --recompute_granularity full \ + --recompute_method uniform \ + --recompute_num_layers 1 \ + --max_epochs 1 \ + --finetune true \ + --cross_entropy_loss_fusion true \ + --lr 1e-4 \ + --lr_warmup_fraction 0.05 \ + --min_lr 1e-5 \ + --save megatron_output/Qwen3-Next-80B-A3B-Instruct \ + --eval_interval 200 \ + --save_interval 200 \ + --max_length 2048 \ + --num_workers 8 \ + --dataset_num_proc 8 \ + --no_save_optim true \ + --no_save_rng true \ + --sequence_parallel true \ + --attention_backend flash \ + --model_author swift \ + --model_name swift-robot + + +# CUDA_VISIBLE_DEVICES=0,1,2,3 \ +# swift infer \ +# --model megatron_output/Qwen3-Next-80B-A3B-Instruct/vx-xxx/checkpoint-xxx \ +# --sglang_tp_size 4 \ +# --infer_backend sglang \ +# --sglang_context_length 8192 \ +# --max_new_tokens 2048 \ +# --sglang_mem_fraction_static 0.7 \ +# --sglang_speculative_algorithm NEXTN \ +# --sglang_speculative_eagle_topk 1 \ +# --sglang_speculative_num_steps 3 \ +# --sglang_speculative_num_draft_tokens 4 diff --git a/swift/megatron/init.py b/swift/megatron/init.py index 0c74edd818..2ceb713057 100644 --- a/swift/megatron/init.py +++ b/swift/megatron/init.py @@ -433,8 +433,7 @@ def forward( Union[Tensor, Tuple[Tensor, Tensor]]: The output hidden states tensor of shape [s, b, h], and optionally the updated context tensor if cross-attention is used. """ - # TODO: Multimodal compatible; MTP initialization - # TODO: packed_seq_params offset + # TODO: Multimodal compatible assert context is None, 'multi token prediction + cross attention is not yet supported.' input_ids, position_ids, decoder_input, hidden_states = self._get_embeddings( input_ids=input_ids, diff --git a/swift/megatron/model/gpt/qwen3_next.py b/swift/megatron/model/gpt/qwen3_next.py index 19e8e9c555..5ce1fc0a72 100644 --- a/swift/megatron/model/gpt/qwen3_next.py +++ b/swift/megatron/model/gpt/qwen3_next.py @@ -515,7 +515,7 @@ def _set_state_dict(self, *, offset: float = 0, is_expert: bool = False): - if 'layernorm' in mg_key or 'layer_norm_weight' in mg_key: + if 'layernorm' in mg_key or 'layer_norm_weight' in mg_key or 'enorm' in mg_key or 'hnorm' in mg_key: offset = 1 if to_mcore else -1 return super()._set_state_dict( mg_module, diff --git a/swift/megatron/model/gpt_bridge.py b/swift/megatron/model/gpt_bridge.py index 1484e4f51d..8a38688091 100644 --- a/swift/megatron/model/gpt_bridge.py +++ b/swift/megatron/model/gpt_bridge.py @@ -1058,9 +1058,11 @@ def _convert(self, mg_models, hf_state_dict, hf_prefix: str, to_mcore: bool, tqd yield from list(self._add_prefix(hf_state_dict, hf_prefix).items()) def _convert_mtp_extra(self, mtp_layer, hf_state_dict, to_mcore, origin_hf_state_dict): + hf_state_dict = hf_state_dict if to_mcore else {} for key in ['enorm.weight', 'hnorm.weight', 'eh_proj.weight']: self._set_state_dict(mtp_layer, key, hf_state_dict, key, to_mcore) self._set_state_dict(mtp_layer, 'final_layernorm.weight', hf_state_dict, 'shared_head.norm.weight', to_mcore) + return {} if to_mcore else hf_state_dict def _convert_mtp_layer(self, lm_model, hf_state_dict, hf_prefix: str, layer_idx: int, to_mcore: bool): mtp_layer = lm_model.mtp.layers[layer_idx] if hasattr(lm_model, 'mtp') else None @@ -1088,7 +1090,7 @@ def _convert_mtp_layer(self, lm_model, hf_state_dict, hf_prefix: str, layer_idx: # self._set_state_dict(lm_model, 'embedding.word_embeddings.weight', hf_state_dict, 'embed_tokens.weight', # to_mcore) # self._set_state_dict(lm_model, 'output_layer.weight', hf_state_dict, 'shared_head.head.weight', to_mcore) - self._convert_mtp_extra(mtp_layer, hf_state_dict, to_mcore, origin_hf_state_dict) + hf_state_dict.update(self._convert_mtp_extra(mtp_layer, hf_state_dict, to_mcore, origin_hf_state_dict)) hf_state_dict.update(self._set_layer_attn(mtp_layer.transformer_layer, hf_state_dict, -1, to_mcore)) hf_state_dict.update(self._set_layer_mlp(mtp_layer.transformer_layer, hf_state_dict, -1, to_mcore)) if to_mcore: From d1e7ce81a0dd2d3d03024f03fb175122c8a9a010 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 24 Nov 2025 10:40:06 +0800 Subject: [PATCH 13/22] update --- swift/megatron/model/gpt/qwen3_next.py | 1 + 1 file changed, 1 insertion(+) diff --git a/swift/megatron/model/gpt/qwen3_next.py b/swift/megatron/model/gpt/qwen3_next.py index 5ce1fc0a72..df87e0e453 100644 --- a/swift/megatron/model/gpt/qwen3_next.py +++ b/swift/megatron/model/gpt/qwen3_next.py @@ -544,6 +544,7 @@ def _convert_mtp_extra(self, mtp_layer, hf_state_dict, to_mcore, origin_hf_state ['pre_fc_norm_embedding.weight', 'pre_fc_norm_hidden.weight', 'fc.weight']): self._set_state_dict(mtp_layer, mg_key, hf_state_dict, key, to_mcore) self._set_state_dict(mtp_layer, 'final_layernorm.weight', hf_state_dict, 'norm.weight', to_mcore) + return {} if to_mcore else self._add_prefix(hf_state_dict, 'mtp.') register_megatron_model( From 65d0c66ab9e19e6ddd0a5e27d67bbec1164e76d9 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 24 Nov 2025 14:12:53 +0800 Subject: [PATCH 14/22] update --- .../Instruction/Command-line-parameters.md | 1 + .../Instruction/Command-line-parameters.md | 1 + swift/llm/argument/infer_args.py | 2 +- .../infer/infer_engine/grpo_vllm_engine.py | 2 ++ swift/llm/infer/infer_engine/vllm_engine.py | 5 ++- swift/megatron/init.py | 13 ++------ swift/megatron/model/gpt/qwen3_next.py | 3 +- swift/megatron/model/gpt_bridge.py | 9 +++--- swift/megatron/model/gpt_model.py | 4 ++- swift/megatron/model/mm_gpt_model.py | 4 +++ swift/megatron/train/__init__.py | 26 +++++++++++++-- swift/megatron/trainers/__init__.py | 32 ++++++++++++++++--- swift/megatron/trainers/base.py | 7 ++++ swift/trainers/arguments.py | 3 ++ 14 files changed, 85 insertions(+), 27 deletions(-) diff --git a/docs/source/Instruction/Command-line-parameters.md b/docs/source/Instruction/Command-line-parameters.md index af857745ca..791a44a4dc 100644 --- a/docs/source/Instruction/Command-line-parameters.md +++ b/docs/source/Instruction/Command-line-parameters.md @@ -397,6 +397,7 @@ Vera使用`target_modules`、`target_regex`、`modules_to_save`三个参数, - vllm_disable_custom_all_reduce: 禁用自定义的 all-reduce 内核,回退到 NCCL。为了稳定性,默认为`True`。 - vllm_enforce_eager: vllm使用pytorch eager模式还是建立cuda graph,默认为`False`。设置为True可以节约显存,但会影响效率。 - vllm_mm_processor_cache_gb: 多模态处理器缓存大小(GiB),用于缓存已处理的多模态输入(如图像、视频)避免重复处理。默认为`4`。设置为`0`可禁用缓存但会降低性能(不推荐)。仅对多模态模型生效。 +- vllm_speculative_config: 推测解码配置,传入json字符串。默认为None。 - vllm_disable_cascade_attn: 是否强制关闭V1引擎的cascade attention实现以防止潜在数值误差,默认为False,由vLLM内部逻辑决定是否使用。 - 🔥vllm_limit_mm_per_prompt: 控制vllm使用多图,默认为`None`。例如传入`--vllm_limit_mm_per_prompt '{"image": 5, "video": 2}'`。 - vllm_max_lora_rank: 默认为`16`。vllm对于lora支持的参数。 diff --git a/docs/source_en/Instruction/Command-line-parameters.md b/docs/source_en/Instruction/Command-line-parameters.md index 243c4d5d0f..b7737b54bf 100644 --- a/docs/source_en/Instruction/Command-line-parameters.md +++ b/docs/source_en/Instruction/Command-line-parameters.md @@ -404,6 +404,7 @@ Parameter meanings can be found in the [vllm documentation](https://docs.vllm.ai - vllm_disable_custom_all_reduce: Disables the custom all-reduce kernel and falls back to NCCL. For stability, the default is `True`. - vllm_enforce_eager: Determines whether vllm uses PyTorch eager mode or constructs a CUDA graph, default is `False`. Setting it to True can save memory but may affect efficiency. - vllm_mm_processor_cache_gb: The size (in GiB) of the multimodal processor cache, used to store processed multimodal inputs (e.g., images, videos) to avoid redundant processing. Default is 4. Setting it to 0 disables the cache but may degrade performance (not recommended). This option takes effect only for multimodal models. +- vllm_speculative_config: Speculative decoding configuration, passed as a JSON string. Default: None. - vllm_disable_cascade_attn: Whether to forcibly disable the V1 engine’s cascade-attention implementation to avoid potential numerical issues. Defaults to False; vLLM’s internal heuristics determine whether cascade attention is actually used. - 🔥vllm_limit_mm_per_prompt: Controls the use of multiple media in vllm, default is `None`. For example, you can pass in `--vllm_limit_mm_per_prompt '{"image": 5, "video": 2}'`. - vllm_max_lora_rank: Default is `16`. This is the parameter supported by vllm for lora. diff --git a/swift/llm/argument/infer_args.py b/swift/llm/argument/infer_args.py index e0cce69ffb..519c46f971 100644 --- a/swift/llm/argument/infer_args.py +++ b/swift/llm/argument/infer_args.py @@ -102,7 +102,7 @@ class InferArguments(MergeArguments, LmdeployArguments, SglangArguments, VllmArg ckpt_dir (Optional[str]): Directory to the checkpoint. Default is None. infer_backend (Literal): Backend to use for inference. Default is 'pt'. Allowed values are 'vllm', 'pt', 'lmdeploy'. - result_path (Optional[str]): Directory to store inference results. Default is None. + result_path (Optional[str]): Path to store inference results. Default is None. max_batch_size (int): Maximum batch size for the pt engine. Default is 1. val_dataset_sample (Optional[int]): Sample size for validation dataset. Default is None. reranker_use_activation (bool): reranker use activation after calculating. Default is True. diff --git a/swift/llm/infer/infer_engine/grpo_vllm_engine.py b/swift/llm/infer/infer_engine/grpo_vllm_engine.py index 18b626a505..cb0d2d56df 100644 --- a/swift/llm/infer/infer_engine/grpo_vllm_engine.py +++ b/swift/llm/infer/infer_engine/grpo_vllm_engine.py @@ -46,6 +46,7 @@ def __init__( disable_cascade_attn: bool = False, load_format: str = 'auto', mm_processor_cache_gb: Optional[float] = None, + speculative_config: Optional[Union[str, dict]] = None, # lora enable_lora: bool = False, max_loras: int = 1, @@ -80,6 +81,7 @@ def __init__( disable_cascade_attn=disable_cascade_attn, load_format=load_format, mm_processor_cache_gb=mm_processor_cache_gb, + speculative_config=speculative_config, enable_lora=enable_lora, max_loras=max_loras, max_lora_rank=max_lora_rank, diff --git a/swift/llm/infer/infer_engine/vllm_engine.py b/swift/llm/infer/infer_engine/vllm_engine.py index 2dc37c0ee0..e5c927cbb4 100644 --- a/swift/llm/infer/infer_engine/vllm_engine.py +++ b/swift/llm/infer/infer_engine/vllm_engine.py @@ -70,6 +70,7 @@ def __init__( disable_cascade_attn: bool = False, load_format: str = 'auto', mm_processor_cache_gb: Optional[float] = None, + speculative_config: Optional[Union[str, dict]] = None, # lora enable_lora: bool = False, max_loras: int = 1, @@ -131,6 +132,7 @@ def __init__( task=task_type, disable_cascade_attn=disable_cascade_attn, mm_processor_cache_gb=mm_processor_cache_gb, + speculative_config=speculative_config, **engine_kwargs, ) context = nullcontext() @@ -172,6 +174,7 @@ def _prepare_engine_kwargs( disable_cascade_attn: bool = False, load_format: str = 'auto', mm_processor_cache_gb: Optional[float] = None, + speculative_config: Optional[Union[str, dict]] = None, **engine_kwargs, ) -> None: if task == 'embedding': @@ -202,7 +205,7 @@ def _prepare_engine_kwargs( 'The current version of vLLM does not support `limit_mm_per_prompt`. Please upgrade vLLM.') for key in [ 'enable_expert_parallel', 'enable_sleep_mode', 'disable_cascade_attn', 'load_format', - 'mm_processor_cache_gb' + 'mm_processor_cache_gb', 'speculative_config' ]: if key in parameters: if locals()[key] is not None: diff --git a/swift/megatron/init.py b/swift/megatron/init.py index 2ceb713057..c7aa42dfdb 100644 --- a/swift/megatron/init.py +++ b/swift/megatron/init.py @@ -444,16 +444,9 @@ def forward( packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' apply_rope_fusion = self.config.apply_rope_fusion self.config.apply_rope_fusion = False - if packed_seq: - packed_seq_params = deepcopy(packed_seq_params) - tensor = packed_seq_params.cu_seqlens_q - cu_seqlens = torch.concat([tensor.new_zeros(1, ), tensor[1:] - 1, tensor[-1:]]) - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() - packed_seq_params.cu_seqlens_q = packed_seq_params.cu_seqlens_kv = cu_seqlens - packed_seq_params.max_seqlen_q = packed_seq_params.max_seqlen_kv = max_seqlen - if not self.config.apply_rope_fusion: - assert position_ids.shape[0] == 1, f'position_ids.shape: {position_ids.shape}' - rotary_pos_emb = rotary_pos_emb[position_ids[0]] + if packed_seq and not self.config.apply_rope_fusion: + assert position_ids.shape[0] == 1, f'position_ids.shape: {position_ids.shape}' + rotary_pos_emb = rotary_pos_emb[position_ids[0]] if self.config.recompute_granularity == 'full' and self.training: hidden_states = self._checkpointed_forward( partial( diff --git a/swift/megatron/model/gpt/qwen3_next.py b/swift/megatron/model/gpt/qwen3_next.py index df87e0e453..5950df2722 100644 --- a/swift/megatron/model/gpt/qwen3_next.py +++ b/swift/megatron/model/gpt/qwen3_next.py @@ -544,7 +544,8 @@ def _convert_mtp_extra(self, mtp_layer, hf_state_dict, to_mcore, origin_hf_state ['pre_fc_norm_embedding.weight', 'pre_fc_norm_hidden.weight', 'fc.weight']): self._set_state_dict(mtp_layer, mg_key, hf_state_dict, key, to_mcore) self._set_state_dict(mtp_layer, 'final_layernorm.weight', hf_state_dict, 'norm.weight', to_mcore) - return {} if to_mcore else self._add_prefix(hf_state_dict, 'mtp.') + if not to_mcore: + origin_hf_state_dict.update(self._add_prefix(hf_state_dict, 'mtp.')) register_megatron_model( diff --git a/swift/megatron/model/gpt_bridge.py b/swift/megatron/model/gpt_bridge.py index 8a38688091..f9da632687 100644 --- a/swift/megatron/model/gpt_bridge.py +++ b/swift/megatron/model/gpt_bridge.py @@ -1058,11 +1058,9 @@ def _convert(self, mg_models, hf_state_dict, hf_prefix: str, to_mcore: bool, tqd yield from list(self._add_prefix(hf_state_dict, hf_prefix).items()) def _convert_mtp_extra(self, mtp_layer, hf_state_dict, to_mcore, origin_hf_state_dict): - hf_state_dict = hf_state_dict if to_mcore else {} for key in ['enorm.weight', 'hnorm.weight', 'eh_proj.weight']: self._set_state_dict(mtp_layer, key, hf_state_dict, key, to_mcore) self._set_state_dict(mtp_layer, 'final_layernorm.weight', hf_state_dict, 'shared_head.norm.weight', to_mcore) - return {} if to_mcore else hf_state_dict def _convert_mtp_layer(self, lm_model, hf_state_dict, hf_prefix: str, layer_idx: int, to_mcore: bool): mtp_layer = lm_model.mtp.layers[layer_idx] if hasattr(lm_model, 'mtp') else None @@ -1090,9 +1088,10 @@ def _convert_mtp_layer(self, lm_model, hf_state_dict, hf_prefix: str, layer_idx: # self._set_state_dict(lm_model, 'embedding.word_embeddings.weight', hf_state_dict, 'embed_tokens.weight', # to_mcore) # self._set_state_dict(lm_model, 'output_layer.weight', hf_state_dict, 'shared_head.head.weight', to_mcore) - hf_state_dict.update(self._convert_mtp_extra(mtp_layer, hf_state_dict, to_mcore, origin_hf_state_dict)) - hf_state_dict.update(self._set_layer_attn(mtp_layer.transformer_layer, hf_state_dict, -1, to_mcore)) - hf_state_dict.update(self._set_layer_mlp(mtp_layer.transformer_layer, hf_state_dict, -1, to_mcore)) + self._convert_mtp_extra(mtp_layer, hf_state_dict, to_mcore, origin_hf_state_dict) + transformer_layer = None if mtp_layer is None else mtp_layer.transformer_layer + hf_state_dict.update(self._set_layer_attn(transformer_layer, hf_state_dict, -1, to_mcore)) + hf_state_dict.update(self._set_layer_mlp(transformer_layer, hf_state_dict, -1, to_mcore)) if to_mcore: hf_state_dict = {} else: diff --git a/swift/megatron/model/gpt_model.py b/swift/megatron/model/gpt_model.py index 19d14a7b47..e368c233fa 100644 --- a/swift/megatron/model/gpt_model.py +++ b/swift/megatron/model/gpt_model.py @@ -264,6 +264,8 @@ def forward( *, inference_params: Optional[BaseInferenceContext] = None, loss_mask: Optional[torch.Tensor] = None, + # Mask labels to be compatible with thd & MTP + mtp_labels: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: """Forward function of the GPT Model This function passes the input tensors @@ -310,7 +312,7 @@ def forward( hidden_states=hidden_states, input_ids=input_ids, position_ids=position_ids, - labels=labels, + labels=mtp_labels, rotary_pos_emb=rotary_pos_emb, rotary_pos_cos=rotary_pos_cos, rotary_pos_sin=rotary_pos_sin, diff --git a/swift/megatron/model/mm_gpt_model.py b/swift/megatron/model/mm_gpt_model.py index 8be3c36744..bee70d1f82 100644 --- a/swift/megatron/model/mm_gpt_model.py +++ b/swift/megatron/model/mm_gpt_model.py @@ -1,5 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from contextlib import contextmanager +from typing import Optional import megatron.core import torch @@ -87,6 +88,8 @@ def forward( labels: torch.Tensor = None, inference_params: InferenceParams = None, packed_seq_params: PackedSeqParams = None, + *, + mtp_labels: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: if decoder_input is not None: @@ -108,6 +111,7 @@ def forward( labels=labels, inference_params=inference_params, packed_seq_params=packed_seq_params, + mtp_labels=mtp_labels, **kwargs, ) diff --git a/swift/megatron/train/__init__.py b/swift/megatron/train/__init__.py index 1b091bd4a3..537a1489de 100644 --- a/swift/megatron/train/__init__.py +++ b/swift/megatron/train/__init__.py @@ -1,4 +1,24 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from .pt import megatron_pt_main -from .rlhf import megatron_rlhf_main -from .sft import megatron_sft_main +from typing import TYPE_CHECKING + +from swift.utils.import_utils import _LazyModule + +if TYPE_CHECKING: + from .pt import megatron_pt_main + from .rlhf import megatron_rlhf_main + from .sft import megatron_sft_main +else: + _import_structure = { + 'pt': ['megatron_pt_main'], + 'rlhf': ['megatron_rlhf_main'], + 'sft': ['megatron_sft_main'], + } + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/swift/megatron/trainers/__init__.py b/swift/megatron/trainers/__init__.py index 80cf16fe22..1f5ce04967 100644 --- a/swift/megatron/trainers/__init__.py +++ b/swift/megatron/trainers/__init__.py @@ -1,6 +1,28 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from .dpo_trainer import MegatronDPOTrainer -from .grpo_trainer import MegatronGRPOTrainer -from .kto_trainer import MegatronKTOTrainer -from .reward_trainer import MegatronRewardTrainer -from .trainer import MegatronTrainer +from typing import TYPE_CHECKING + +from swift.utils.import_utils import _LazyModule + +if TYPE_CHECKING: + from .dpo_trainer import MegatronDPOTrainer + from .grpo_trainer import MegatronGRPOTrainer + from .kto_trainer import MegatronKTOTrainer + from .reward_trainer import MegatronRewardTrainer + from .trainer import MegatronTrainer +else: + _import_structure = { + 'dpo_trainer': ['MegatronDPOTrainer'], + 'grpo_trainer': ['MegatronGRPOTrainer'], + 'kto_trainer': ['MegatronKTOTrainer'], + 'reward_trainer': ['MegatronRewardTrainer'], + 'trainer': ['MegatronTrainer'], + } + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index 11febfccd5..85903c6089 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -1029,6 +1029,13 @@ def _prepare_batch(self, data, vp_stage, num_samples=None): if args.padding_free and text_position_ids is not None: batch['packed_seq_params'] = get_packed_seq_params(text_position_ids) batch['packed_seq_params'].num_samples = num_samples + if args.mtp_num_layers and batch.get('labels') is not None: + cu_seqlens = batch['packed_seq_params'].cu_seqlens_q.clone() + mtp_labels = batch['labels'].clone() + for _ in range(args.mtp_num_layers): + mtp_labels[:, cu_seqlens[cu_seqlens < mtp_labels.shape[1]]] = -100 + cu_seqlens = cu_seqlens + 1 + batch['mtp_labels'] = mtp_labels # slice batch along sequence dimension for context parallelism batch = get_batch_on_this_cp_rank(batch) return batch diff --git a/swift/trainers/arguments.py b/swift/trainers/arguments.py index 42f6afdcdd..7874a60784 100644 --- a/swift/trainers/arguments.py +++ b/swift/trainers/arguments.py @@ -206,12 +206,14 @@ class VllmArguments: vllm_reasoning_parser: Optional[str] = None vllm_disable_cascade_attn: bool = False vllm_mm_processor_cache_gb: Optional[float] = None + vllm_speculative_config: Optional[Union[dict, str]] = None vllm_engine_kwargs: Optional[Union[dict, str]] = None # rollout vllm_data_parallel_size: int = 1 def __post_init__(self): self.vllm_limit_mm_per_prompt = json_parse_to_dict(self.vllm_limit_mm_per_prompt) + self.vllm_speculative_config = json_parse_to_dict(self.vllm_speculative_config) self.vllm_engine_kwargs = json_parse_to_dict(self.vllm_engine_kwargs) def get_vllm_engine_kwargs(self): @@ -237,6 +239,7 @@ def get_vllm_engine_kwargs(self): 'reasoning_parser': self.vllm_reasoning_parser, 'disable_cascade_attn': self.vllm_disable_cascade_attn, 'mm_processor_cache_gb': self.vllm_mm_processor_cache_gb, + 'speculative_config': self.vllm_speculative_config, 'num_labels': self.num_labels, 'engine_kwargs': self.vllm_engine_kwargs, } From 3220a8f7043dddb80413d3905b3528d7ee5c109d Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 24 Nov 2025 14:57:36 +0800 Subject: [PATCH 15/22] update --- examples/infer/vllm/mtp.sh | 10 ++++++++++ examples/models/qwen3_next/mtp.sh | 16 +++++++--------- 2 files changed, 17 insertions(+), 9 deletions(-) create mode 100644 examples/infer/vllm/mtp.sh diff --git a/examples/infer/vllm/mtp.sh b/examples/infer/vllm/mtp.sh new file mode 100644 index 0000000000..7d3b6a58a9 --- /dev/null +++ b/examples/infer/vllm/mtp.sh @@ -0,0 +1,10 @@ +CUDA_VISIBLE_DEVICES=0,1,2,3 \ +swift infer \ + --model Qwen/Qwen3-Next-80B-A3B-Instruct \ + --vllm_tensor_parallel_size 4 \ + --infer_backend vllm \ + --vllm_max_model_len 8192 \ + --val_dataset AI-ModelScope/alpaca-gpt4-data-zh#100 \ + --vllm_speculative_config '{"method":"qwen3_next_mtp","num_speculative_tokens":2}' \ + --vllm_gpu_memory_utilization 0.9 \ + --max_new_tokens 2048 diff --git a/examples/models/qwen3_next/mtp.sh b/examples/models/qwen3_next/mtp.sh index 92b3612eb4..a310f1a6d5 100644 --- a/examples/models/qwen3_next/mtp.sh +++ b/examples/models/qwen3_next/mtp.sh @@ -48,12 +48,10 @@ megatron sft \ # CUDA_VISIBLE_DEVICES=0,1,2,3 \ # swift infer \ # --model megatron_output/Qwen3-Next-80B-A3B-Instruct/vx-xxx/checkpoint-xxx \ -# --sglang_tp_size 4 \ -# --infer_backend sglang \ -# --sglang_context_length 8192 \ -# --max_new_tokens 2048 \ -# --sglang_mem_fraction_static 0.7 \ -# --sglang_speculative_algorithm NEXTN \ -# --sglang_speculative_eagle_topk 1 \ -# --sglang_speculative_num_steps 3 \ -# --sglang_speculative_num_draft_tokens 4 +# --vllm_tensor_parallel_size 4 \ +# --infer_backend vllm \ +# --vllm_max_model_len 8192 \ +# --val_dataset AI-ModelScope/alpaca-gpt4-data-zh#100 \ +# --vllm_gpu_memory_utilization 0.9 \ +# --vllm_speculative_config '{"method":"qwen3_next_mtp","num_speculative_tokens":2}' \ +# --max_new_tokens 2048 From 89e6627284866c21619e34cde379e47fbdc9b687 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 24 Nov 2025 15:05:35 +0800 Subject: [PATCH 16/22] update --- swift/trainers/arguments.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/swift/trainers/arguments.py b/swift/trainers/arguments.py index 7874a60784..a178d86841 100644 --- a/swift/trainers/arguments.py +++ b/swift/trainers/arguments.py @@ -212,8 +212,10 @@ class VllmArguments: vllm_data_parallel_size: int = 1 def __post_init__(self): - self.vllm_limit_mm_per_prompt = json_parse_to_dict(self.vllm_limit_mm_per_prompt) - self.vllm_speculative_config = json_parse_to_dict(self.vllm_speculative_config) + if self.vllm_limit_mm_per_prompt is not None: + self.vllm_limit_mm_per_prompt = json_parse_to_dict(self.vllm_limit_mm_per_prompt) + if self.vllm_speculative_config is not None: + self.vllm_speculative_config = json_parse_to_dict(self.vllm_speculative_config) self.vllm_engine_kwargs = json_parse_to_dict(self.vllm_engine_kwargs) def get_vllm_engine_kwargs(self): From 0df0c2f9296e92a47023fdc974c127f95cefbd1b Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 24 Nov 2025 15:46:05 +0800 Subject: [PATCH 17/22] update --- swift/megatron/model/mm_gpt_model.py | 2 ++ swift/megatron/trainers/base.py | 8 +++++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/swift/megatron/model/mm_gpt_model.py b/swift/megatron/model/mm_gpt_model.py index bee70d1f82..25a8244d01 100644 --- a/swift/megatron/model/mm_gpt_model.py +++ b/swift/megatron/model/mm_gpt_model.py @@ -40,6 +40,8 @@ def __init__(self, args = get_args() self.megatron_model_meta = get_megatron_model_meta(args.hf_model_type) self.visual = None + if args.mtp_num_layers: + raise ValueError('MTP currently does not support multimodal models.') if pre_process and self.megatron_model_meta.visual_cls is not None: self.visual = self.megatron_model_meta.visual_cls(config) diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index 85903c6089..3cd1b7d6ed 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -22,7 +22,7 @@ from megatron.core.transformer.multi_token_prediction import MTPLossLoggingHelper from megatron.core.utils import StragglerDetector from megatron.training import (checkpointing, ft_integration, get_args, get_model, get_tensorboard_writer, get_timers, - get_wandb_writer, is_last_rank, one_logger_utils, pretrain, print_rank_0, + get_wandb_writer, initialize, is_last_rank, one_logger_utils, pretrain, print_rank_0, print_rank_last, training) from megatron.training.checkpointing import load_checkpoint from megatron.training.theoretical_memory_usage import report_theoretical_memory @@ -81,6 +81,7 @@ def bridge(self): @contextmanager def _get_iters(self, train_dataset, val_dataset): origin_initialize_megatron = training.initialize_megatron + origin_validate_args = initialize.validate_args def initialize_megatron(*_args, **kwargs): res = origin_initialize_megatron(*_args, **kwargs) @@ -109,11 +110,16 @@ def initialize_megatron(*_args, **kwargs): logger.info(f'Setting args.eval_iters: {args.eval_iters}') return res + def validate_args(args, *_args, **kwargs): + return origin_validate_args(args, *_args, **kwargs) + training.initialize_megatron = initialize_megatron + initialize.validate_args = validate_args try: yield finally: training.initialize_megatron = origin_initialize_megatron + initialize.validate_args = origin_validate_args def new_cyclic_iter(self, iterable): args = get_args() From 78bad41f21fb62fbdb9dcf5408ae4abbd1ad8f97 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 24 Nov 2025 16:33:50 +0800 Subject: [PATCH 18/22] update --- swift/megatron/init.py | 5 +---- swift/megatron/model/gpt_model.py | 12 +++++++++++- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/swift/megatron/init.py b/swift/megatron/init.py index c7aa42dfdb..a5f8c6b234 100644 --- a/swift/megatron/init.py +++ b/swift/megatron/init.py @@ -442,9 +442,7 @@ def forward( hidden_states=hidden_states, ) packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' - apply_rope_fusion = self.config.apply_rope_fusion - self.config.apply_rope_fusion = False - if packed_seq and not self.config.apply_rope_fusion: + if packed_seq: assert position_ids.shape[0] == 1, f'position_ids.shape: {position_ids.shape}' rotary_pos_emb = rotary_pos_emb[position_ids[0]] if self.config.recompute_granularity == 'full' and self.training: @@ -480,7 +478,6 @@ def forward( packed_seq_params=packed_seq_params, sequence_len_offset=sequence_len_offset, ) - self.config.apply_rope_fusion = apply_rope_fusion return hidden_states, input_ids, position_ids MultiTokenPredictionLayer.forward = forward diff --git a/swift/megatron/model/gpt_model.py b/swift/megatron/model/gpt_model.py index e368c233fa..1fce9b7bbf 100644 --- a/swift/megatron/model/gpt_model.py +++ b/swift/megatron/model/gpt_model.py @@ -1,6 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from collections import OrderedDict from contextlib import contextmanager +from copy import deepcopy from typing import Any, Dict, Literal, Optional, Tuple import megatron.core @@ -140,8 +141,17 @@ def __init__( if (self.attention_scaling != 1 or position_embedding_type == 'mrope') and config.apply_rope_fusion: config.apply_rope_fusion = False - logger.warning('`apply_rope_fusion` does not support `attention_scaling`. ' + if self.attention_scaling != 1: + warning_string = 'attention_scaling' + else: + warning_string = 'mrope' + logger.warning(f'`apply_rope_fusion` does not support `{warning_string}`. ' f'Setting `config.apply_rope_fusion`: {config.apply_rope_fusion}') + if getattr(self, 'mtp', None) is not None: + for layer in self.mtp.layers: + attention = layer.transformer_layer.self_attention + attention.config = deepcopy(attention.config) + attention.config.apply_rope_fusion = False @contextmanager def _patch_apply_rotary_pos_emb(self): From a2ef030a4839a1cea350209e6abf605e30cb49b7 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 24 Nov 2025 16:35:55 +0800 Subject: [PATCH 19/22] update --- swift/megatron/init.py | 1 + 1 file changed, 1 insertion(+) diff --git a/swift/megatron/init.py b/swift/megatron/init.py index a5f8c6b234..3a95952260 100644 --- a/swift/megatron/init.py +++ b/swift/megatron/init.py @@ -443,6 +443,7 @@ def forward( ) packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' if packed_seq: + assert not self.transformer_layer.self_attention.config.apply_rope_fusion assert position_ids.shape[0] == 1, f'position_ids.shape: {position_ids.shape}' rotary_pos_emb = rotary_pos_emb[position_ids[0]] if self.config.recompute_granularity == 'full' and self.training: From 85e8429e0af2afcf336761be00263feb9533a858 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 24 Nov 2025 17:52:41 +0800 Subject: [PATCH 20/22] update --- docs/source/Customization/Custom-dataset.md | 2 +- .../Megatron-SWIFT/Command-line-parameters.md | 2 +- .../source_en/Customization/Custom-dataset.md | 2 +- .../Megatron-SWIFT/Command-line-parameters.md | 2 +- swift/megatron/model/gpt_model.py | 187 ++++++++++++++---- swift/megatron/model/model_provider.py | 8 +- 6 files changed, 155 insertions(+), 48 deletions(-) diff --git a/docs/source/Customization/Custom-dataset.md b/docs/source/Customization/Custom-dataset.md index 84fd506231..09b66666eb 100644 --- a/docs/source/Customization/Custom-dataset.md +++ b/docs/source/Customization/Custom-dataset.md @@ -29,7 +29,7 @@ query-response格式: ```jsonl {"system": "", "query": "", "response": "", "history": [["", ""]]} ``` -注意:以下字段会自动转成对应的system、query、response字段。 +注意:以下字段会自动转成对应的system、query、response字段。(solution字段会保留) - system: 'system', 'system_prompt'. - query: 'query', 'prompt', 'input', 'instruction', 'question', 'problem'. - response: 'response', 'answer', 'output', 'targets', 'target', 'answer_key', 'answers', 'solution', 'text', 'completion', 'content'. diff --git a/docs/source/Megatron-SWIFT/Command-line-parameters.md b/docs/source/Megatron-SWIFT/Command-line-parameters.md index d691705bd2..124c71cff1 100644 --- a/docs/source/Megatron-SWIFT/Command-line-parameters.md +++ b/docs/source/Megatron-SWIFT/Command-line-parameters.md @@ -219,7 +219,7 @@ - qk_pos_emb_head_dim: QK 投影中位置嵌入的维度。默认为None,自动从config.json读取。 **MTP参数** -- mtp_num_layers: 多token预测(MTP)层的数量。MTP将每个位置的预测范围扩展到多个未来token。此MTP实现使用D个顺序模块依次预测D个额外的token。默认为None。 +- mtp_num_layers: 多token预测(MTP)层的数量。MTP将每个位置的预测范围扩展到多个未来token。此MTP实现使用D个顺序模块依次预测D个额外的token。默认为None。(需要"megatron-core>=0.14") - 注意:mtp_num_layers的值,将不自动从config.json获取,需手动设置。你可以参考config.json中的`num_nextn_predict_layers`字段填写该值。使用mcore-bridge时,将优先从safetensors文件中加载MTP权重,若无法找到,则进行随机初始化。 - mtp_loss_scaling_factor: 多token预测(MTP)损失的缩放因子。我们计算所有深度上MTP损失的平均值,然后乘以该缩放因子得到总体MTP损失,它将作为一个额外的训练目标。默认为0.1。 diff --git a/docs/source_en/Customization/Custom-dataset.md b/docs/source_en/Customization/Custom-dataset.md index 69d49e4ed0..cae457d9eb 100644 --- a/docs/source_en/Customization/Custom-dataset.md +++ b/docs/source_en/Customization/Custom-dataset.md @@ -30,7 +30,7 @@ Query-Response format: ```jsonl {"system": "", "query": "", "response": "", "history": [["", ""]]} ``` -Note: The following fields will be automatically converted to the corresponding system, query, and response fields. +Note: The following fields will be automatically converted to the corresponding system, query, and response fields. (The 'solution' field will be retained) - system: 'system', 'system_prompt'. - query: 'query', 'prompt', 'input', 'instruction', 'question', 'problem'. - response: 'response', 'answer', 'output', 'targets', 'target', 'answer_key', 'answers', 'solution', 'text', 'completion', 'content'. diff --git a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md index 283a76f041..c9e11cd8b9 100644 --- a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md +++ b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md @@ -233,7 +233,7 @@ For guidance on selecting parallelization strategies, please refer to the [Train **MTP Parameters** -- mtp_num_layers: Number of Multi-Token Prediction (MTP) layers. MTP extends the prediction scope at each position to multiple future tokens. This MTP implementation uses D sequential modules to sequentially predict D additional tokens. Default is None. +- mtp_num_layers: Number of Multi-Token Prediction (MTP) layers. MTP extends the prediction scope at each position to multiple future tokens. This MTP implementation uses D sequential modules to sequentially predict D additional tokens. Default is None. (requires "megatron-core>=0.14") - Note: The value of mtp_num_layers will not be automatically retrieved from config.json and must be set manually. You can refer to the `num_nextn_predict_layers` field in config.json to fill in this value. When using mcore-bridge, MTP weights will be loaded from safetensors files first. If not found, random initialization will be performed. - mtp_loss_scaling_factor: Scaling factor of Multi-Token Prediction (MTP) loss. We compute the average of MTP losses across all depths, then multiply it by this scaling factor to obtain the overall MTP loss, which serves as an additional training objective. Default is 0.1. diff --git a/swift/megatron/model/gpt_model.py b/swift/megatron/model/gpt_model.py index 1fce9b7bbf..667ecc1738 100644 --- a/swift/megatron/model/gpt_model.py +++ b/swift/megatron/model/gpt_model.py @@ -6,6 +6,7 @@ import megatron.core import torch +from megatron.core import parallel_state from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk from megatron.core.dist_checkpointing.mapping import ShardedStateDict from megatron.core.extensions.transformer_engine import TELinear @@ -14,6 +15,7 @@ from megatron.core.models.gpt import GPTModel as McoreGPTModel from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region +from megatron.core.transformer.multi_token_prediction import MTPLossAutoScaler, MTPLossLoggingHelper, roll_tensor from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.utils import WrappedTensor, deprecate_inference_params @@ -316,53 +318,160 @@ def forward( args = get_args() labels = labels if args.task_type == 'causal_lm' else None - if mcore_013: - # MTP: https://github.com/NVIDIA/Megatron-LM/issues/1661 - return self._postprocess( - hidden_states=hidden_states, + # MTP: https://github.com/NVIDIA/Megatron-LM/issues/1661 + return self._postprocess( + hidden_states=hidden_states, + input_ids=input_ids, + position_ids=position_ids, + labels=labels, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + mtp_in_postprocess=self.mtp_process, + loss_mask=loss_mask, + decoder_input=decoder_input, + attention_mask=attention_mask, + inference_params=inference_params, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + runtime_gather_output=runtime_gather_output, + extra_block_kwargs=extra_block_kwargs, + inference_context=inference_context, + mtp_labels=mtp_labels, + ) + + def _postprocess( + self, + hidden_states, + input_ids, + position_ids, + labels, + rotary_pos_emb, + rotary_pos_cos, + rotary_pos_sin, + mtp_in_postprocess=None, + loss_mask=None, + decoder_input=None, + attention_mask=None, + inference_params=None, + packed_seq_params=None, + sequence_len_offset=None, + runtime_gather_output=None, + extra_block_kwargs=None, + inference_context=None, + mtp_labels=None, + ): + """Postprocesses decoder hidden states to generate logits or compute loss. + + Applies Multi-Token Prediction if enabled, generates output logits through + the output layer, and computes language model loss when labels are provided. + """ + in_inference_mode = inference_context is not None and not self.training + if in_inference_mode: + assert runtime_gather_output, 'Inference must always gather TP logits' + + # logits and loss + output_weight = None + if self.share_embeddings_and_output_weights: + output_weight = self.shared_embedding_or_output_weight() + + if mtp_in_postprocess: + hidden_states = self.mtp( input_ids=input_ids, position_ids=position_ids, - labels=mtp_labels, + hidden_states=hidden_states, + attention_mask=attention_mask, + inference_params=inference_params, rotary_pos_emb=rotary_pos_emb, rotary_pos_cos=rotary_pos_cos, rotary_pos_sin=rotary_pos_sin, - mtp_in_postprocess=self.mtp_process, - loss_mask=loss_mask, - decoder_input=decoder_input, - attention_mask=attention_mask, - inference_params=inference_params, packed_seq_params=packed_seq_params, sequence_len_offset=sequence_len_offset, - runtime_gather_output=runtime_gather_output, - extra_block_kwargs=extra_block_kwargs, - inference_context=inference_context, + embedding=self.embedding, + **(extra_block_kwargs or {}), ) - else: - if not self.post_process: - return hidden_states - - # logits and loss - output_weight = None - if self.share_embeddings_and_output_weights: - output_weight = self.shared_embedding_or_output_weight() - logits, _ = self.output_layer( - hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output) - if has_config_logger_enabled(self.config): - payload = OrderedDict({ - 'input_ids': input_ids, - 'position_ids': position_ids, - 'attention_mask': attention_mask, - 'decoder_input': decoder_input, - 'logits': logits, - }) - log_config_to_disk(self.config, payload, prefix='input_and_logits') - if labels is None: - # [s b h] => [b s h] - return logits.transpose(0, 1).contiguous() - - loss = self.compute_language_model_loss(labels, logits) - - return loss + + if not self.post_process: + return hidden_states + + if self.mtp_process: + mtp_labels = mtp_labels or labels.clone() + hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0) + hidden_states = hidden_states_list[0] + if loss_mask is None: + # if loss_mask is not provided, use all ones as loss_mask + loss_mask = torch.ones_like(mtp_labels) + for mtp_layer_number in range(self.config.mtp_num_layers): + # output + mtp_logits, _ = self.output_layer( + hidden_states_list[mtp_layer_number + 1], + weight=output_weight, + runtime_gather_output=runtime_gather_output, + ) + # Calc loss for the current Multi-Token Prediction (MTP) layers. + mtp_labels, _ = roll_tensor(mtp_labels, shifts=-1, dims=-1, cp_group=self.cp_group) + loss_mask, num_tokens = roll_tensor(loss_mask, shifts=-1, dims=-1, cp_group=self.cp_group) + mtp_loss = self.compute_language_model_loss(mtp_labels, mtp_logits) + mtp_loss = loss_mask * mtp_loss + if self.training: + # TODO(shifangx): remove the use of parallel_state here + # after moving loss logging to loss_func in pretrain_gpt.py + MTPLossLoggingHelper.save_loss_to_tracker( + torch.sum(mtp_loss) / num_tokens, + mtp_layer_number, + self.config.mtp_num_layers, + avg_group=parallel_state.get_data_parallel_group(with_context_parallel=True), + ) + mtp_loss_scale = self.config.mtp_loss_scaling_factor / self.config.mtp_num_layers + if self.config.calculate_per_token_loss: + hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss) + else: + hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss / num_tokens) + sequence_parallel_override = False + if in_inference_mode and inference_context.materialize_only_last_token_logits: + if inference_context.is_static_batching(): + hidden_states = hidden_states[-1:, :, :] + else: + if self.output_layer.sequence_parallel: + # Perform the sequence parallel gather here instead of after the output layer + # because we need to slice the last token logits from the full view of the + # packed logits across all requests. + # TODO(ksanthanam): Make the equivalent change in the `MambaModel` code after + # merging in !3722. + hidden_states = gather_from_sequence_parallel_region(hidden_states, group=self.pg_collection.tp) + self.output_layer.sequence_parallel = False + sequence_parallel_override = True + + # Reshape [B, 1, H] to [1, B, H] → extract each sample’s true last‐token hidden + # state ([B, H]) → unsqueeze back to [1, B, H] + # (so that the output layer, which expects S×B×H, receives only the final token) + hidden_states = inference_context.last_token_logits(hidden_states.squeeze(1).unsqueeze(0)).unsqueeze(1) + + logits, _ = self.output_layer(hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output) + + # Restore sequence parallel execution to the output layer if necessary. + if sequence_parallel_override: + assert (in_inference_mode and inference_context.is_dynamic_batching() + and inference_context.materialize_only_last_token_logits) + self.output_layer.sequence_parallel = True + + if has_config_logger_enabled(self.config): + payload = OrderedDict({ + 'input_ids': input_ids, + 'position_ids': position_ids, + 'attention_mask': attention_mask, + 'decoder_input': decoder_input, + 'logits': logits, + }) + log_config_to_disk(self.config, payload, prefix='input_and_logits') + + if labels is None: + # [s b h] => [b s h] + return logits.transpose(0, 1).contiguous() + + loss = self.compute_language_model_loss(labels, logits) + + return loss def get_input_tensor(self): return self.decoder.input_tensor diff --git a/swift/megatron/model/model_provider.py b/swift/megatron/model/model_provider.py index 997f53a231..84798aec49 100644 --- a/swift/megatron/model/model_provider.py +++ b/swift/megatron/model/model_provider.py @@ -117,10 +117,7 @@ def oom_observer(device, alloc, device_alloc, device_free): transformer_layer_spec = megatron_model_meta.get_transformer_layer_spec(config, vp_stage=vp_stage) else: if args.num_experts: - if mcore_013: - kwargs = {'qk_l2_norm': args.qk_l2_norm, 'vp_stage': vp_stage} - else: - kwargs = {} + kwargs = {'qk_l2_norm': args.qk_l2_norm, 'vp_stage': vp_stage} if mcore_013 else {} # Define the decoder block spec transformer_layer_spec = get_gpt_decoder_block_spec( config, use_transformer_engine=use_te, normalization=args.normalization, **kwargs) @@ -137,8 +134,9 @@ def oom_observer(device, alloc, device_alloc, device_free): transformer_layer_spec_for_mtp = _get_transformer_layer_spec(use_te, config) else: transformer_layer_spec_for_mtp = transformer_layer_spec + kwargs = {'vp_stage': vp_stage} if mcore_013 else {} mtp_block_spec = get_gpt_mtp_block_spec( - config, transformer_layer_spec_for_mtp, use_transformer_engine=use_te, vp_stage=vp_stage) + config, transformer_layer_spec_for_mtp, use_transformer_engine=use_te, **kwargs) if args.use_shared_expert_gate and args.num_experts and args.moe_shared_expert_intermediate_size: # qwen2_moe From e67237fc5b68562ea645452cc8ace980aa7624b7 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 24 Nov 2025 17:56:06 +0800 Subject: [PATCH 21/22] fix --- swift/megatron/model/gpt_model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/swift/megatron/model/gpt_model.py b/swift/megatron/model/gpt_model.py index 667ecc1738..f4c8f6ac39 100644 --- a/swift/megatron/model/gpt_model.py +++ b/swift/megatron/model/gpt_model.py @@ -395,7 +395,6 @@ def _postprocess( return hidden_states if self.mtp_process: - mtp_labels = mtp_labels or labels.clone() hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0) hidden_states = hidden_states_list[0] if loss_mask is None: From b4555387cd0f399d1ae09c5dd6b8488a24ee3c9e Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 24 Nov 2025 18:12:02 +0800 Subject: [PATCH 22/22] fix cp --- swift/megatron/trainers/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/swift/megatron/trainers/utils.py b/swift/megatron/trainers/utils.py index a9c956b45f..e5e376b014 100644 --- a/swift/megatron/trainers/utils.py +++ b/swift/megatron/trainers/utils.py @@ -110,7 +110,7 @@ def get_batch_on_this_cp_rank(batch: Dict[str, Any]): cp_size = mpu.get_context_parallel_world_size() if cp_size > 1: args = get_args() - keys = ['labels', 'attention_mask', 'position_ids', 'loss_scale'] + keys = ['labels', 'attention_mask', 'position_ids', 'loss_scale', 'mtp_labels'] if not args.is_multimodal: # Multimodal models will handle CP in input_embeds. keys.append('input_ids')