Skip to content

Commit 10d9096

Browse files
authored
[bugfix] compat deepseek-v3 mcore 0.13.0 (#6510)
1 parent f6883d1 commit 10d9096

File tree

5 files changed

+23
-6
lines changed

5 files changed

+23
-6
lines changed

docs/source/Megatron-SWIFT/Command-line-parameters.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,10 +185,12 @@
185185
- moe_ffn_hidden_size: 每个专家的前馈网络(ffn)的隐藏层大小。默认为None,自动从config.json读取。若未读取到且`num_experts`不为None,则设置为ffn_hidden_size。
186186
- moe_shared_expert_intermediate_size: 共享专家的总FFN隐藏层大小。如果有多个共享专家,它应等于 `num_shared_experts * ffn_size_of_each_shared_expert`。 默认为None。自动从config.json读取。
187187
- moe_router_topk: 每个token路由到的专家数量。默认为None。自动从config.json读取。
188+
- moe_router_num_groups: 将专家分成的组数,用于组限制路由。参考DeepSeek-V2和DeepSeek-V3。默认为None。自动从config.json读取。
189+
- moe_router_group_topk: 组限制路由中选择的组数。默认为None。自动从config.json读取。
188190
- moe_router_pre_softmax: 为MoE启用预softmax路由,这意味着softmax会在top-k选择之前进行。默认为None。自动从config.json读取。
189191
- 🔥moe_router_dtype: 用于路由计算和专家输出加权平均的数据类型。可选为'none', 'fp32'、'fp64',这增强了数值稳定性,尤其是在专家数量较多时。与`moe_permute_fusion`一起使用时,性能影响可以忽略不计。默认为'fp32'。'none'代表不改变数据类型。
190192
- moe_router_score_function: MoE TopK 路由的评分函数。可以为 "softmax" 或 "sigmoid"。默认为None,从config.json中读取。
191-
- moe_router_bias_update_rate: 在无辅助损失负载均衡策略中,专家偏置的更新速率。专家偏置根据每个专家在全局批次中被分配的 token 数量进行更新,对于分配到的 token 较少的专家,偏置会增加;对于分配到的 token 较多的专家,偏置会减少。默认值 1e-3,与 DeepSeekV3 中使用的值相同
193+
- moe_router_bias_update_rate: 在无辅助损失负载均衡策略中,专家偏置的更新速率。专家偏置根据每个专家在全局批次中被分配的 token 数量进行更新,对于分配到的 token 较少的专家,偏置会增加;对于分配到的 token 较多的专家,偏置会减少。默认为None,从config.json中读取
192194
- moe_router_enable_expert_bias: 在无辅助损失负载均衡策略中,带有动态专家偏置的 TopK 路由。路由决策基于路由分数与专家偏置之和。详情请参见:https://arxiv.org/abs/2408.15664。默认为None,自动从config.json读取。
193195
- moe_router_topk_scaling_factor: 默认为None。从config.json中读取。
194196
- moe_router_load_balancing_type: 确定路由器的负载均衡策略。可选项为"aux_loss"、"seq_aux_loss"、"sinkhorn"、"none"。默认值为 None。从config.json中读取。

docs/source_en/Megatron-SWIFT/Command-line-parameters.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,10 +197,12 @@ For guidance on selecting parallelization strategies, please refer to the [Train
197197
- moe_ffn_hidden_size: Hidden layer size of the feedforward network (ffn) for each expert. Default is None and will be automatically read from config.json. If not found and `num_experts` is not None, it will be set to ffn_hidden_size.
198198
- moe_shared_expert_intermediate_size: The total FFN hidden layer size for shared experts. If there are multiple shared experts, it should equal `num_shared_experts * ffn_size_of_each_shared_expert`. Default is None. Automatically read from config.json.
199199
- moe_router_topk: The number of experts each token is routed to. Default is None. Automatically read from config.json.
200+
- moe_router_num_groups: Number of groups to divide experts into for group-limited routing. Refers to DeepSeek-V2 and DeepSeek-V3. Default is None. Automatically read from config.json.
201+
- moe_router_group_topk: Number of selected groups for group-limited routing. Default is None. Automatically read from config.json.
200202
- moe_router_pre_softmax: Enable pre-softmax routing for MoE, meaning that softmax will be applied before top-k selection. Default is None. Automatically read from config.json.
201203
- 🔥moe_router_dtype: Data type used for routing computation and expert output weighted averaging. Options are 'none', 'fp32', and 'fp64', which enhances numerical stability, especially when the number of experts is large. When used together with `moe_permute_fusion`, the performance impact is negligible. Default is 'fp32'. 'none' means no change to data type.
202204
- moe_router_score_function: Scoring function for MoE TopK routing. Can be "softmax" or "sigmoid". Default is None and is read from config.json.
203-
- moe_router_bias_update_rate: Update rate of expert bias in the auxiliary-loss-free load balancing strategy. Expert bias is updated based on the number of tokens each expert is assigned in the global batch: bias increases for experts assigned fewer tokens, and decreases for those assigned more tokens. Default is 1e-3, same as used in DeepSeekV3.
205+
- moe_router_bias_update_rate: Update rate of expert bias in the auxiliary-loss-free load balancing strategy. Expert bias is updated based on the number of tokens each expert is assigned in the global batch: bias increases for experts assigned fewer tokens, and decreases for those assigned more tokens. Default is None and is read from config.json.
204206
- moe_router_enable_expert_bias: TopK routing with dynamic expert bias in the auxiliary-loss-free load balancing strategy. Routing decisions are based on the sum of routing scores and expert bias. See details at: https://arxiv.org/abs/2408.15664. Default is None and is automatically read from config.json.
205207
- moe_router_topk_scaling_factor: Default is None. This parameter is read from config.json.
206208
- moe_router_load_balancing_type: Determines the router’s load balancing strategy. Options are "aux_loss", "seq_aux_loss", "sinkhorn", and "none". Default is None and is read from config.json.

swift/megatron/argument/megatron_args.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,10 +284,12 @@ class MegatronArguments(ExtraMegatronArguments):
284284
moe_shared_expert_intermediate_size: Optional[int] = None
285285

286286
moe_router_topk: Optional[int] = None
287+
moe_router_num_groups: Optional[int] = None
288+
moe_router_group_topk: Optional[int] = None
287289
moe_router_pre_softmax: Optional[bool] = None
288290
moe_router_dtype: Literal['none', 'fp32', 'fp64'] = 'fp32'
289291
moe_router_score_function: Literal['sigmoid', 'softmax'] = None
290-
moe_router_bias_update_rate: float = 1e-3
292+
moe_router_bias_update_rate: Optional[float] = None
291293
moe_router_enable_expert_bias: Optional[bool] = None
292294
moe_router_topk_scaling_factor: Optional[float] = None
293295
moe_router_load_balancing_type: Literal['aux_loss', 'seq_aux_loss', 'sinkhorn', 'none'] = None

swift/megatron/model/gpt_bridge.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22
from copy import copy
33
from typing import Optional
44

5+
import megatron.core
56
import torch
67
import torch.distributed as dist
78
import torch.nn.functional as F
89
from megatron.core import mpu
910
from megatron.training import get_args
11+
from packaging import version
1012
from peft.utils import ModulesToSaveWrapper
1113
from tqdm import tqdm
1214
from transformers.modeling_utils import custom_object_save
@@ -41,6 +43,7 @@ def __init__(self, disable_tqmd: bool = False):
4143
self._init_meta_hf_model()
4244
self.hf_layers = deep_getattr(self.hf_model, self.hf_layers_prefix)
4345
self.module_mapping = {}
46+
self.megatron_core_014 = version.parse(megatron.core.__version__) >= version.parse('0.14.0rc0')
4447
megatron_model_meta = get_megatron_model_meta(self.args.hf_model_type)
4548
if self.args.is_multimodal and megatron_model_meta.visual_cls is not None:
4649
self.module_mapping = megatron_model_meta.visual_cls.module_mapping
@@ -64,8 +67,7 @@ def _init_meta_hf_model(self):
6467
self.hf_model, self.processor = get_model_tokenizer(
6568
self.args.model_dir, model_type=self.args.hf_model_type, return_dummy_model=True)
6669

67-
@staticmethod
68-
def _get_tp_split_dim(mg_key: Optional[str]) -> Optional[int]:
70+
def _get_tp_split_dim(self, mg_key: Optional[str]) -> Optional[int]:
6971
if mg_key is None:
7072
return
7173
# ColumnLinear
@@ -78,6 +80,9 @@ def _get_tp_split_dim(mg_key: Optional[str]) -> Optional[int]:
7880
'linear_q_up_proj',
7981
'linear_kv_up_proj'
8082
}
83+
if not self.megatron_core_014:
84+
# https://github.com/NVIDIA/Megatron-LM/commit/720c8b40d8e7e2de1dd303d792f29093101c5e72
85+
dim0_keys.update({'linear_q_down_proj', 'linear_kv_down_proj'})
8186
# RowLinear
8287
dim1_keys = {'linear_proj', 'linear_fc2'}
8388
if 'lora_A' not in mg_key and 'lora_B' not in mg_key:
@@ -856,6 +861,9 @@ def _set_mla_attn_state(
856861
to_mcore)
857862
self._set_state_dict(mg_attn, 'linear_kv_up_proj.weight', hf_state_dict, 'kv_b_proj.weight', to_mcore)
858863
if self.args.qk_layernorm:
864+
if self.args.q_lora_rank is not None:
865+
self._set_state_dict(mg_attn, 'linear_q_up_proj.layer_norm_weight', hf_state_dict,
866+
'q_a_layernorm.weight', to_mcore)
859867
self._set_state_dict(mg_attn, 'linear_kv_up_proj.layer_norm_weight', hf_state_dict, 'kv_a_layernorm.weight',
860868
to_mcore)
861869
if to_mcore:

swift/megatron/utils/config.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,16 @@
2424
# moe
2525
'moe_ffn_hidden_size': ['moe_intermediate_size'],
2626
'moe_shared_expert_intermediate_size': ['shared_expert_intermediate_size'],
27-
'moe_router_topk': ['num_experts_per_tok', 'n_group', 'moe_topk', 'moe_k'],
27+
'moe_router_topk': ['num_experts_per_tok', 'moe_topk', 'moe_k'],
28+
'moe_router_num_groups': ['n_group'],
29+
'moe_router_group_topk': ['topk_group'],
2830
'num_experts': ['num_experts', 'n_routed_experts', 'moe_num_experts'],
2931
'moe_router_pre_softmax': ['norm_topk_prob'],
3032
# deepseek
3133
'q_lora_rank': ['q_lora_rank'],
3234
'kv_lora_rank': ['kv_lora_rank'],
3335
'moe_router_score_function': ['scoring_func'],
36+
'moe_router_bias_update_rate': ['aux_loss_alpha'],
3437
'qk_head_dim': ['qk_nope_head_dim'],
3538
'qk_pos_emb_head_dim': ['qk_rope_head_dim'],
3639
'moe_router_topk_scaling_factor': ['routed_scaling_factor'],

0 commit comments

Comments
 (0)