@@ -1008,6 +1008,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
10081008 if chkhsh == "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2" :
10091009 # ref: https://huggingface.co/THUDM/glm-4-9b-hf
10101010 res = "glm4"
1011+ if chkhsh == "9ca2dd618e8afaf09731a7cf6e2105b373ba6a1821559f258b272fe83e6eb902" :
1012+ # ref: https://huggingface.co/zai-org/GLM-4.5-Air
1013+ res = "glm4"
10111014 if chkhsh == "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35" :
10121015 # ref: https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0
10131016 res = "minerva-7b"
@@ -7026,6 +7029,139 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
70267029 return super ().modify_tensors (data_torch , name , bid )
70277030
70287031
7032+ @ModelBase .register ("Glm4MoeForCausalLM" )
7033+ class Glm4MoeModel (TextModel ):
7034+ model_arch = gguf .MODEL_ARCH .GLM4_MOE
7035+
7036+ def __init__ (self , * args , ** kwargs ):
7037+ super ().__init__ (* args , ** kwargs )
7038+ # GLM4_MOE has num_hidden_layers + 1 actual layers (including NextN layer)
7039+ self .block_count = self .hparams ["num_hidden_layers" ] + self .hparams .get ("num_nextn_predict_layers" , 0 )
7040+ self .tensor_map = gguf .get_tensor_name_map (self .model_arch , self .block_count )
7041+
7042+ def set_vocab (self ):
7043+ from transformers import AutoTokenizer
7044+
7045+ tokenizer = AutoTokenizer .from_pretrained (self .dir_model )
7046+ special_vocab = gguf .SpecialVocab (self .dir_model , load_merges = True )
7047+ tokens , toktypes , tokpre = self .get_vocab_base ()
7048+ self .gguf_writer .add_tokenizer_model ("gpt2" )
7049+ self .gguf_writer .add_tokenizer_pre (tokpre )
7050+ self .gguf_writer .add_token_list (tokens )
7051+ self .gguf_writer .add_token_types (toktypes )
7052+
7053+ # Special tokens
7054+ # Note: Using <|endoftext|> (151329) for eot causes endless generation
7055+ special_vocab ._set_special_token ("bos" , tokenizer .get_added_vocab ()["[gMASK]" ]) # 151331
7056+ special_vocab ._set_special_token ("eot" , tokenizer .get_added_vocab ()["<|user|>" ]) # 151336
7057+ special_vocab ._set_special_token ("unk" , tokenizer .get_added_vocab ()["<|endoftext|>" ]) # 151329
7058+ special_vocab ._set_special_token ("eom" , tokenizer .get_added_vocab ()["<|observation|>" ]) # 151338
7059+
7060+ # Patch broken chat template
7061+ if isinstance (special_vocab .chat_template , str ) and "visible_text(m.content).endswith" in special_vocab .chat_template :
7062+ special_vocab .chat_template = special_vocab .chat_template .replace (
7063+ """{{ visible_text(m.content) }}\n {{- '/nothink' if (enable_thinking is defined and not enable_thinking and not visible_text(m.content).endswith("/nothink")) else '' -}}""" ,
7064+ """{% set content = visible_text(m.content) %}{{ content }}\n {{- '/nothink' if (enable_thinking is defined and not enable_thinking and not content.endswith("/nothink")) else '' -}}""" )
7065+
7066+ special_vocab .add_to_gguf (self .gguf_writer )
7067+
7068+ def set_gguf_parameters (self ):
7069+ super ().set_gguf_parameters ()
7070+ if (rope_dim := self .hparams .get ("head_dim" )) is None :
7071+ rope_dim = (
7072+ self .hparams ["hidden_size" ] // self .hparams ["num_attention_heads" ]
7073+ )
7074+ self .gguf_writer .add_rope_dimension_count (
7075+ int (rope_dim * self .hparams .get ("partial_rotary_factor" , 0.5 ))
7076+ )
7077+
7078+ # MoE parameters - Use only routed expert count (shared experts handled separately)
7079+ if (n_routed_experts := self .hparams .get ("n_routed_experts" )) is not None :
7080+ self .gguf_writer .add_expert_count (n_routed_experts )
7081+ if (moe_intermediate_size := self .hparams .get ("moe_intermediate_size" )) is not None :
7082+ self .gguf_writer .add_expert_feed_forward_length (moe_intermediate_size )
7083+ if (n_shared_experts := self .hparams .get ("n_shared_experts" )) is not None :
7084+ self .gguf_writer .add_expert_shared_count (n_shared_experts )
7085+ if (first_k_dense_replace := self .hparams .get ("first_k_dense_replace" )) is not None :
7086+ self .gguf_writer .add_leading_dense_block_count (first_k_dense_replace )
7087+
7088+ # Expert gating function (sigmoid for GLM4_MOE)
7089+ self .gguf_writer .add_expert_gating_func (gguf .ExpertGatingFuncType .SIGMOID )
7090+
7091+ # Routed scaling factor
7092+ if (routed_scaling_factor := self .hparams .get ("routed_scaling_factor" )) is not None :
7093+ self .gguf_writer .add_expert_weights_scale (routed_scaling_factor )
7094+
7095+ # Normalise topk probabilities
7096+ if (norm_topk_prob := self .hparams .get ("norm_topk_prob" )) is not None :
7097+ self .gguf_writer .add_expert_weights_norm (norm_topk_prob )
7098+
7099+ # NextN/MTP prediction layers
7100+ if (num_nextn_predict_layers := self .hparams .get ("num_nextn_predict_layers" )) is not None :
7101+ self .gguf_writer .add_nextn_predict_layers (num_nextn_predict_layers )
7102+
7103+ _experts : list [dict [str , Tensor ]] | None = None
7104+
7105+ def modify_tensors (
7106+ self , data_torch : Tensor , name : str , bid : int | None
7107+ ) -> Iterable [tuple [str , Tensor ]]:
7108+ if name .startswith ("model.visual." ): # ignore visual part
7109+ return []
7110+ elif name .startswith ("model.language_model." ):
7111+ name = name .replace ("language_model." , "" ) # for multimodal variants
7112+
7113+ # Handle main token embedding (but not layer-specific NextN embeddings)
7114+ if name == "model.embed_tokens.weight" and ".layers." not in name :
7115+ return [(self .map_tensor_name ("token_embd.weight" ), data_torch )]
7116+
7117+ # Handle routed experts
7118+ if name .find ("mlp.experts" ) != - 1 :
7119+ n_experts = self .hparams ["n_routed_experts" ]
7120+ assert bid is not None
7121+
7122+ if self ._experts is None :
7123+ self ._experts = [{} for _ in range (self .block_count )]
7124+
7125+ self ._experts [bid ][name ] = data_torch
7126+
7127+ if len (self ._experts [bid ]) >= n_experts * 3 :
7128+ tensors : list [tuple [str , Tensor ]] = []
7129+
7130+ # merge the experts into a single 3d tensor
7131+ for w_name in ["down_proj" , "gate_proj" , "up_proj" ]:
7132+ datas : list [Tensor ] = []
7133+
7134+ for xid in range (n_experts ):
7135+ ename = f"model.layers.{ bid } .mlp.experts.{ xid } .{ w_name } .weight"
7136+ datas .append (self ._experts [bid ][ename ])
7137+ del self ._experts [bid ][ename ]
7138+
7139+ data_torch = torch .stack (datas , dim = 0 )
7140+
7141+ merged_name = f"model.layers.{ bid } .mlp.experts.{ w_name } .weight"
7142+
7143+ new_name = self .map_tensor_name (merged_name )
7144+ tensors .append ((new_name , data_torch ))
7145+ return tensors
7146+ else :
7147+ return []
7148+
7149+ if name .endswith ("e_score_correction_bias" ):
7150+ name = name .replace ("e_score_correction_bias" , "e_score_correction.bias" )
7151+
7152+ new_name = self .map_tensor_name (name )
7153+
7154+ return [(new_name , data_torch )]
7155+
7156+ def prepare_tensors (self ):
7157+ super ().prepare_tensors ()
7158+ if self ._experts is not None :
7159+ # flatten `list[dict[str, Tensor]]` into `list[str]`
7160+ experts = [k for d in self ._experts for k in d .keys ()]
7161+ if len (experts ) > 0 :
7162+ raise ValueError (f"Unprocessed experts: { experts } " )
7163+
7164+
70297165@ModelBase .register ("GlmForCausalLM" , "ChatGLMModel" , "ChatGLMForConditionalGeneration" )
70307166class ChatGLMModel (TextModel ):
70317167 model_arch = gguf .MODEL_ARCH .CHATGLM
0 commit comments