@@ -5479,7 +5479,18 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,
54795479
54805480 if lora_names := hparams .get ("lora_adaptations" ):
54815481 self ._lora_names = lora_names
5482- self .model_arch = gguf .MODEL_ARCH .JINA_BERT_V3
5482+
5483+ try :
5484+ text_cfg = hparams .get ("text_config" , {}) if isinstance (hparams .get ("text_config" , {}), dict ) else {}
5485+ pe_type = (text_cfg .get ("position_embedding_type" ) or hparams .get ("position_embedding_type" ) or "" ).lower ()
5486+ rope_base = text_cfg .get ("rotary_emb_base" , hparams .get ("rotary_emb_base" ))
5487+ name_path = (hparams .get ("_name_or_path" ) or "" ).lower ()
5488+ is_vx = ("jina" in name_path and ("v2" in name_path or "v3" in name_path ))
5489+ is_v3 = (pe_type == "rotary" or rope_base is not None ) and is_vx
5490+ if (is_v3 ) or self ._lora_names :
5491+ self .model_arch = gguf .MODEL_ARCH .JINA_BERT_V3
5492+ except Exception :
5493+ pass
54835494
54845495 super ().__init__ (dir_model , ftype , fname_out , hparams = hparams , ** kwargs )
54855496 self ._xlmroberta_tokenizer_init ()
@@ -6701,6 +6712,254 @@ def set_vocab(self):
67016712 raise NotImplementedError (f'Tokenizer { tokenizer_class } is not supported for JinaBertModel' )
67026713
67036714
6715+ @ModelBase .register ("JinaCLIPVisionModel" , "JinaCLIPModel" )
6716+ class JinaCLIPVisionModel (MmprojModel ):
6717+ """JinaCLIP v2 Vision Encoder Model - handles vision component only"""
6718+ model_arch = gguf .MODEL_ARCH .MMPROJ
6719+
6720+ def __init__ (self , * args , ** kwargs ):
6721+ super ().__init__ (* args , ** kwargs )
6722+
6723+ # Load config for vision encoder
6724+ config_path = self .dir_model / "config.json"
6725+ if not config_path .exists ():
6726+ raise FileNotFoundError (
6727+ f"JinaCLIPVisionModel: missing config.json in { self .dir_model } . "
6728+ "Please ensure the original model config is present; default hyperparameter fallbacks are not used."
6729+ )
6730+ with open (config_path , encoding = "utf-8" ) as f :
6731+ self .vision_config = json .load (f )
6732+
6733+ def set_vocab (self ):
6734+ # Vision encoder doesn't need vocabulary
6735+ pass
6736+
6737+ def set_gguf_parameters (self ):
6738+ cfg = self .vision_config
6739+
6740+ try :
6741+ width = int (cfg ["width" ]) # channel dim
6742+ head_width = int (cfg ["head_width" ]) # per-head dim
6743+ layers = int (cfg ["layers" ]) # block count
6744+ image_size = int (cfg ["image_size" ]) # input image size
6745+ patch_size = int (cfg ["patch_size" ]) # patch size
6746+ except KeyError as e :
6747+ raise KeyError (f"JinaCLIPVisionModel: missing key in config.json: { e } " )
6748+
6749+ if width % head_width != 0 :
6750+ raise ValueError (
6751+ f"JinaCLIPVisionModel: width ({ width } ) not divisible by head_width ({ head_width } )"
6752+ )
6753+ n_head = width // head_width
6754+
6755+ if "mlp_ratio" in cfg :
6756+ n_ff = int (width * float (cfg ["mlp_ratio" ]))
6757+ elif bool (cfg .get ("naive_swiglu" , False )):
6758+ n_ff = int ((width * 8 ) // 3 )
6759+ else :
6760+ raise ValueError ("JinaCLIPVisionModel: unable to infer FFN size; please provide 'mlp_ratio' or set 'naive_swiglu' in config.json" )
6761+
6762+ self .gguf_writer .add_clip_has_vision_encoder (True )
6763+ proj_dim = int (cfg .get ("projection_dim" , width ))
6764+ self .gguf_writer .add_vision_projection_dim (proj_dim )
6765+
6766+ self .gguf_writer .add_vision_image_size (image_size )
6767+ self .gguf_writer .add_vision_patch_size (patch_size )
6768+ self .gguf_writer .add_vision_embedding_length (width )
6769+ self .gguf_writer .add_vision_block_count (layers )
6770+ self .gguf_writer .add_vision_head_count (n_head )
6771+ self .gguf_writer .add_vision_feed_forward_length (n_ff )
6772+
6773+ self .gguf_writer .add_vision_attention_layernorm_eps (float (cfg .get ("layer_norm_eps" , 1e-5 )))
6774+
6775+ mean = self .preprocessor_config .get ("image_mean" , self .preprocessor_config .get ("mean" ))
6776+ std = self .preprocessor_config .get ("image_std" , self .preprocessor_config .get ("std" ))
6777+ if mean is None or std is None :
6778+ raise KeyError (
6779+ "JinaCLIPVisionModel: preprocessor_config missing image mean/std (expected keys: 'image_mean'/'image_std' or 'mean'/'std')"
6780+ )
6781+ self .gguf_writer .add_vision_image_mean (mean )
6782+ self .gguf_writer .add_vision_image_std (std )
6783+
6784+ self .gguf_writer .add_clip_projector_type (gguf .VisionProjectorType .JINACLIP2 )
6785+ self .gguf_writer .add_vision_use_silu (True )
6786+
6787+ def _strip_vm_prefix (self , name : str ) -> str :
6788+ return name [len ('vision_model.' ):] if name .startswith ('vision_model.' ) else name
6789+
6790+ def _map_block_tensor (self , layer : int , rest : str , data_torch : Tensor , name : str ) -> list [tuple [str , Tensor ]] | None :
6791+ parts = rest .split ('.' )
6792+ # layer norms
6793+ if rest .startswith ('norm1.' ):
6794+ suffix = parts [- 1 ]
6795+ return [(f'v.blk.{ layer } .ln1.{ suffix } ' , data_torch )]
6796+ if rest .startswith ('norm2.' ):
6797+ suffix = parts [- 1 ]
6798+ return [(f'v.blk.{ layer } .ln2.{ suffix } ' , data_torch )]
6799+ if rest .startswith ('attn.inner_attn_ln.' ):
6800+ suffix = parts [- 1 ]
6801+ return [(f'v.blk.{ layer } .attn_ln.{ suffix } ' , data_torch )]
6802+
6803+ # fused qkv
6804+ if rest == 'attn.qkv.weight' :
6805+ w = data_torch
6806+ wdim = w .shape [0 ]
6807+ if wdim % 3 != 0 :
6808+ logger .warning ('mmproj(jinaclip): unexpected qkv weight shape %s for %s' , tuple (w .shape ), name )
6809+ d = wdim // 3
6810+ q , k , v = w [0 :d , :], w [d :2 * d , :], w [2 * d :, :]
6811+ return [
6812+ (f'v.blk.{ layer } .attn_q.weight' , q ),
6813+ (f'v.blk.{ layer } .attn_k.weight' , k ),
6814+ (f'v.blk.{ layer } .attn_v.weight' , v ),
6815+ ]
6816+ if rest == 'attn.qkv.bias' :
6817+ b = data_torch
6818+ bdim = b .shape [0 ]
6819+ if bdim % 3 != 0 :
6820+ logger .warning ('mmproj(jinaclip): unexpected qkv bias shape %s for %s' , tuple (b .shape ), name )
6821+ d = bdim // 3
6822+ qb , kb , vb = b [0 :d ], b [d :2 * d ], b [2 * d :]
6823+ return [
6824+ (f'v.blk.{ layer } .attn_q.bias' , qb ),
6825+ (f'v.blk.{ layer } .attn_k.bias' , kb ),
6826+ (f'v.blk.{ layer } .attn_v.bias' , vb ),
6827+ ]
6828+ # separate q/v bias (some checkpoints)
6829+ if rest == 'attn.q_bias' :
6830+ return [(f'v.blk.{ layer } .attn_q.bias' , data_torch )]
6831+ if rest == 'attn.v_bias' :
6832+ return [(f'v.blk.{ layer } .attn_v.bias' , data_torch )]
6833+
6834+ # separate projections
6835+ if rest .startswith ('attn.q_proj.' ):
6836+ suffix = parts [- 1 ]
6837+ return [(f'v.blk.{ layer } .attn_q.{ suffix } ' , data_torch )]
6838+ if rest .startswith ('attn.k_proj.' ):
6839+ suffix = parts [- 1 ]
6840+ return [(f'v.blk.{ layer } .attn_k.{ suffix } ' , data_torch )]
6841+ if rest .startswith ('attn.v_proj.' ):
6842+ suffix = parts [- 1 ]
6843+ return [(f'v.blk.{ layer } .attn_v.{ suffix } ' , data_torch )]
6844+ if rest .startswith ('attn.proj.' ):
6845+ suffix = parts [- 1 ]
6846+ return [(f'v.blk.{ layer } .attn_out.{ suffix } ' , data_torch )]
6847+
6848+ # MLP
6849+ if rest .startswith ('mlp.w1.' ):
6850+ suffix = parts [- 1 ]
6851+ return [(f'v.blk.{ layer } .ffn_gate.{ suffix } ' , data_torch )]
6852+ if rest .startswith ('mlp.w2.' ):
6853+ suffix = parts [- 1 ]
6854+ return [(f'v.blk.{ layer } .ffn_up.{ suffix } ' , data_torch )]
6855+ if rest .startswith ('mlp.w3.' ):
6856+ suffix = parts [- 1 ]
6857+ return [(f'v.blk.{ layer } .ffn_down.{ suffix } ' , data_torch )]
6858+ if rest .startswith ('mlp.ffn_ln.' ):
6859+ suffix = parts [- 1 ]
6860+ return [(f'v.blk.{ layer } .ffn_norm.{ suffix } ' , data_torch )]
6861+ if rest .startswith ('mlp.fc1.' ):
6862+ suffix = parts [- 1 ]
6863+ return [(f'v.blk.{ layer } .ffn_up.{ suffix } ' , data_torch )]
6864+ if rest .startswith ('mlp.fc2.' ):
6865+ suffix = parts [- 1 ]
6866+ return [(f'v.blk.{ layer } .ffn_down.{ suffix } ' , data_torch )]
6867+ return None
6868+
6869+ def map_tensor_name (self , name : str , try_suffixes : Sequence [str ] = (".weight" , ".bias" )) -> str :
6870+ """Prefer base table-driven mapping; keep Jina-specific targets if already mapped; fallback to legacy mapper."""
6871+ # Already a GGUF target name (e.g., "v.*" or "mm.*"): return as-is
6872+ if name .startswith ('v.' ) or name .startswith ('mm.' ):
6873+ return name
6874+ # Try the base mapping first
6875+ try :
6876+ return super ().map_tensor_name (name , try_suffixes = try_suffixes )
6877+ except Exception :
6878+ # Fallback to legacy Jina-specific mapper for any remaining edge keys
6879+ if hasattr (self , "_map_jinaclip_tensor_name" ):
6880+ mapped = self ._map_jinaclip_tensor_name (name ) # type: ignore[attr-defined]
6881+ if mapped :
6882+ return mapped
6883+ return name
6884+
6885+ def get_tensors (self ) -> Iterator [tuple [str , Tensor ]]:
6886+ yielded_any = False
6887+ try :
6888+ for name , tensor in super ().get_tensors ():
6889+ yielded_any = True
6890+ yield name , tensor
6891+ except Exception as e :
6892+ logger .warning ("mmproj(jinaclip): base get_tensors failed, falling back: %s" , e )
6893+ if yielded_any :
6894+ return
6895+
6896+ candidates = [
6897+ self .dir_model / "pytorch_model.bin" ,
6898+ self .dir_model / "vision_model_weights.bin" ,
6899+ ]
6900+ model_path = next ((p for p in candidates if p .exists ()), None )
6901+ if model_path is None :
6902+ raise FileNotFoundError (f"mmproj(jinaclip): no model weights found in { self .dir_model } " )
6903+ try :
6904+ state_dict = torch .load (model_path , map_location = "cpu" , weights_only = True )
6905+ except TypeError :
6906+ state_dict = torch .load (model_path , map_location = "cpu" )
6907+
6908+ for name , tensor in state_dict .items ():
6909+ yield name , tensor
6910+
6911+ def _should_be_f32 (self , gguf_name : str ) -> bool :
6912+ patterns = (
6913+ ".ln1.weight" , ".ln1.bias" ,
6914+ ".ln2.weight" , ".ln2.bias" ,
6915+ ".attn_ln.weight" , ".attn_ln.bias" ,
6916+ ".ffn_norm.weight" , ".ffn_norm.bias" ,
6917+ "v.patch_embd.proj.bias" ,
6918+ )
6919+ return any (p in gguf_name for p in patterns )
6920+
6921+ def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
6922+ del bid # unused
6923+
6924+ src = name
6925+ if src .startswith ('v.' ) or src .startswith ('mm.' ):
6926+ return [(src , data_torch )]
6927+
6928+ # Drop 'vision_model.' prefix if present
6929+ src_no_vm = self ._strip_vm_prefix (src )
6930+
6931+ # Top-level direct mappings — use gguf constants directly for canonical names
6932+ if src_no_vm == 'cls_token' :
6933+ base = gguf .TENSOR_NAMES [gguf .MODEL_TENSOR .V_ENC_EMBD_CLS ]
6934+ return [(base , data_torch )]
6935+ if src_no_vm .startswith ('patch_embed.proj.' ):
6936+ suffix = src_no_vm .split ('.' )[- 1 ]
6937+ base = gguf .TENSOR_NAMES [gguf .MODEL_TENSOR .V_ENC_EMBD_PATCH ]
6938+ return [(f'{ base } .{ suffix } ' , data_torch )]
6939+ if src_no_vm == 'pos_embed' :
6940+ pos_name = gguf .TENSOR_NAMES [gguf .MODEL_TENSOR .V_ENC_EMBD_POS ] + '.weight'
6941+ return [(pos_name , data_torch )]
6942+ if src_no_vm .startswith ('norm.' ):
6943+ suffix = src_no_vm .split ('.' )[- 1 ]
6944+ base = gguf .TENSOR_NAMES [gguf .MODEL_TENSOR .V_POST_NORM ]
6945+ return [(f'{ base } .{ suffix } ' , data_torch )]
6946+
6947+ if src_no_vm .startswith ('blocks.' ):
6948+ parts = src_no_vm .split ('.' )
6949+ if len (parts ) >= 3 and parts [1 ].isdigit ():
6950+ layer = int (parts [1 ])
6951+ rest = '.' .join (parts [2 :])
6952+ mapped = self ._map_block_tensor (layer , rest , data_torch , name )
6953+ if mapped is not None :
6954+ return mapped
6955+
6956+ try :
6957+ return [(self .map_tensor_name (name ), data_torch )]
6958+ except Exception :
6959+ logger .debug ("mmproj(jinaclip): skip unmapped tensor %s" , name )
6960+ return []
6961+
6962+
67046963@ModelBase .register ("OpenELMForCausalLM" )
67056964class OpenELMModel (TextModel ):
67066965 model_arch = gguf .MODEL_ARCH .OPENELM
0 commit comments