Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2290,14 +2290,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params, int value) {
params.embd_normalize = value;
}
).set_examples({LLAMA_EXAMPLE_EMBEDDING}));
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_MTMD}));
add_opt(common_arg(
{"--embd-output-format"}, "FORMAT",
"empty = default, \"array\" = [[],[]...], \"json\" = openai style, \"json+\" = same \"json\" + cosine similarity matrix, \"raw\" = plain whitespace-delimited output (one embedding per line)",
[](common_params & params, const std::string & value) {
params.embd_out = value;
}
).set_examples({LLAMA_EXAMPLE_EMBEDDING}));
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_MTMD}));
add_opt(common_arg(
{"--embd-separator"}, "STRING",
"separator of embeddings (default \\n) for example \"<#sep#>\"",
Expand Down
165 changes: 164 additions & 1 deletion convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1521,7 +1521,9 @@ class MmprojModel(ModelBase):
preprocessor_config: dict[str, Any]
global_config: dict[str, Any]

n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth"]
# Prefer explicit "layers" (e.g. JinaCLIP),
# keep legacy keys for other models.
n_block_keys = ["layers", "n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth"]

has_vision_encoder: bool = True # by default
has_audio_encoder: bool = False
Expand Down Expand Up @@ -5557,6 +5559,13 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,

if lora_names := hparams.get("lora_adaptations"):
self._lora_names = lora_names

pe_type = (hparams.get("position_embedding_type") or "").lower()
rope_base = hparams.get("rotary_emb_base")
name_path = (hparams.get("_name_or_path") or "").lower()
is_vx = ("jina" in name_path and ("v2" in name_path or "v3" in name_path))
is_v3 = (pe_type == "rotary" or rope_base is not None) and is_vx
if is_v3 or self._lora_names:
self.model_arch = gguf.MODEL_ARCH.JINA_BERT_V3

super().__init__(dir_model, ftype, fname_out, hparams=hparams, **kwargs)
Expand Down Expand Up @@ -6779,6 +6788,160 @@ def set_vocab(self):
raise NotImplementedError(f'Tokenizer {tokenizer_class} is not supported for JinaBertModel')


@ModelBase.register("JinaCLIPVisionModel", "JinaCLIPModel")
class JinaCLIPVisionModel(MmprojModel):
"""JinaCLIP v2 Vision Encoder Model - handles vision component only"""
model_arch = gguf.MODEL_ARCH.MMPROJ

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

# Load config for vision encoder
config_path = self.dir_model / "config.json"
if not config_path.exists():
raise FileNotFoundError(
f"JinaCLIPVisionModel: missing config.json in {self.dir_model}. "
"Please ensure the original model config is present; default hyperparameter fallbacks are not used."
)
with open(config_path, encoding="utf-8") as f:
self.vision_config = json.load(f)

def get_vision_config(self) -> dict[str, Any] | None:
# For JinaCLIPVisionModel, the top-level AutoConfig dict is already
# the vision-only configuration.
return self.global_config

def set_vocab(self):
# Vision encoder doesn't need vocabulary
pass

def set_gguf_parameters(self):
cfg = self.vision_config

try:
width = int(cfg["width"]) # channel dim
head_width = int(cfg["head_width"]) # per-head dim
layers = int(cfg["layers"]) # block count
image_size = int(cfg["image_size"]) # input image size
patch_size = int(cfg["patch_size"]) # patch size
except KeyError as e:
raise KeyError(f"JinaCLIPVisionModel: missing key in config.json: {e}")

if width % head_width != 0:
raise ValueError(
f"JinaCLIPVisionModel: width ({width}) not divisible by head_width ({head_width})"
)
n_head = width // head_width

if "mlp_ratio" in cfg:
n_ff = int(width * float(cfg["mlp_ratio"]))
elif bool(cfg.get("naive_swiglu", False)):
n_ff = int((width * 8) // 3)
else:
raise ValueError("JinaCLIPVisionModel: unable to infer FFN size; please provide 'mlp_ratio' or set 'naive_swiglu' in config.json")

self.gguf_writer.add_clip_has_vision_encoder(True)
proj_dim = int(cfg.get("projection_dim", width))
self.gguf_writer.add_vision_projection_dim(proj_dim)

self.gguf_writer.add_vision_image_size(image_size)
self.gguf_writer.add_vision_patch_size(patch_size)
self.gguf_writer.add_vision_embedding_length(width)
self.gguf_writer.add_vision_block_count(layers)
self.gguf_writer.add_vision_head_count(n_head)
self.gguf_writer.add_vision_feed_forward_length(n_ff)

self.gguf_writer.add_vision_attention_layernorm_eps(float(cfg.get("layer_norm_eps", 1e-5)))

mean = self.preprocessor_config.get("image_mean", self.preprocessor_config.get("mean"))
std = self.preprocessor_config.get("image_std", self.preprocessor_config.get("std"))
if mean is None or std is None:
raise KeyError(
"JinaCLIPVisionModel: preprocessor_config missing image mean/std (expected keys: 'image_mean'/'image_std' or 'mean'/'std')"
)
self.gguf_writer.add_vision_image_mean(mean)
self.gguf_writer.add_vision_image_std(std)

self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.JINACLIP2)
self.gguf_writer.add_vision_use_silu(True)

def _strip_vm_prefix(self, name: str) -> str:
return name[len('vision_model.'):] if name.startswith('vision_model.') else name

def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", ".bias")) -> str:
if name.startswith('v.') or name.startswith('mm.'):
return name
return super().map_tensor_name(name, try_suffixes=try_suffixes)

def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
yielded_any = False
try:
for name, tensor in super().get_tensors():
yielded_any = True
yield name, tensor
except Exception as e:
logger.warning("mmproj(jinaclip): base get_tensors failed, falling back: %s", e)
if yielded_any:
return

candidates = [
self.dir_model / "pytorch_model.bin",
self.dir_model / "vision_model_weights.bin",
]
model_path = next((p for p in candidates if p.exists()), None)
if model_path is None:
raise FileNotFoundError(f"mmproj(jinaclip): no model weights found in {self.dir_model}")
try:
state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
except TypeError:
state_dict = torch.load(model_path, map_location="cpu")

for name, tensor in state_dict.items():
yield name, tensor

def _should_be_f32(self, gguf_name: str) -> bool:
patterns = (
".ln1.weight", ".ln1.bias",
".ln2.weight", ".ln2.bias",
".attn_ln.weight", ".attn_ln.bias",
".ffn_norm.weight", ".ffn_norm.bias",
"v.patch_embd.proj.bias",
)
return any(p in gguf_name for p in patterns)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused

src = name
if src.startswith('v.') or src.startswith('mm.'):
return [(src, data_torch)]

# Drop 'vision_model.' prefix if present
src_no_vm = self._strip_vm_prefix(src)

# Top-level direct mappings — use gguf constants directly for canonical names
if src_no_vm == 'cls_token':
base = gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_CLS]
return [(base, data_torch)]
if src_no_vm.startswith('patch_embed.proj.'):
suffix = src_no_vm.split('.')[-1]
base = gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH]
return [(f'{base}.{suffix}', data_torch)]
if src_no_vm == 'pos_embed':
pos_name = gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_POS] + '.weight'
return [(pos_name, data_torch)]
if src_no_vm.startswith('norm.'):
suffix = src_no_vm.split('.')[-1]
base = gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_POST_NORM]
return [(f'{base}.{suffix}', data_torch)]

try:
return [(self.map_tensor_name(name), data_torch)]
except Exception:
logger.debug("mmproj(jinaclip): skip unmapped tensor %s", name)
return []


@ModelBase.register("OpenELMForCausalLM")
class OpenELMModel(TextModel):
model_arch = gguf.MODEL_ARCH.OPENELM
Expand Down
13 changes: 13 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,9 +634,13 @@ class MODEL_TENSOR(IntEnum):
V_ENC_ATTN_O = auto()
V_ENC_ATTN_O_NORM = auto()
V_ENC_POST_ATTN_NORM = auto()
V_ENC_ATTN_LN = auto()
V_ENC_FFN_UP = auto()
V_ENC_FFN_GATE = auto()
V_ENC_FFN_DOWN = auto()
V_ENC_FFN_NORM = auto()
V_ENC_ATTN_Q_BIAS = auto()
V_ENC_ATTN_V_BIAS = auto()
V_LAYER_SCALE_1 = auto()
V_LAYER_SCALE_2 = auto()
V_PRE_NORM = auto()
Expand Down Expand Up @@ -1002,9 +1006,13 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.V_ENC_ATTN_O: "v.blk.{bid}.attn_out",
MODEL_TENSOR.V_ENC_ATTN_O_NORM: "v.blk.{bid}.attn_out_norm",
MODEL_TENSOR.V_ENC_POST_ATTN_NORM: "v.blk.{bid}.ln2",
MODEL_TENSOR.V_ENC_ATTN_LN: "v.blk.{bid}.attn_ln",
MODEL_TENSOR.V_ENC_FFN_UP: "v.blk.{bid}.ffn_up",
MODEL_TENSOR.V_ENC_FFN_GATE: "v.blk.{bid}.ffn_gate",
MODEL_TENSOR.V_ENC_FFN_DOWN: "v.blk.{bid}.ffn_down",
MODEL_TENSOR.V_ENC_FFN_NORM: "v.blk.{bid}.ffn_norm",
MODEL_TENSOR.V_ENC_ATTN_Q_BIAS: "v.blk.{bid}.attn_q.bias",
MODEL_TENSOR.V_ENC_ATTN_V_BIAS: "v.blk.{bid}.attn_v.bias",
MODEL_TENSOR.V_LAYER_SCALE_1: "v.blk.{bid}.ls1",
MODEL_TENSOR.V_LAYER_SCALE_2: "v.blk.{bid}.ls2",
MODEL_TENSOR.V_PRE_NORM: "v.pre_ln",
Expand Down Expand Up @@ -1080,9 +1088,13 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.V_ENC_ATTN_O,
MODEL_TENSOR.V_ENC_ATTN_O_NORM,
MODEL_TENSOR.V_ENC_POST_ATTN_NORM,
MODEL_TENSOR.V_ENC_ATTN_LN,
MODEL_TENSOR.V_ENC_FFN_UP,
MODEL_TENSOR.V_ENC_FFN_GATE,
MODEL_TENSOR.V_ENC_FFN_DOWN,
MODEL_TENSOR.V_ENC_FFN_NORM,
MODEL_TENSOR.V_ENC_ATTN_Q_BIAS,
MODEL_TENSOR.V_ENC_ATTN_V_BIAS,
MODEL_TENSOR.V_LAYER_SCALE_1,
MODEL_TENSOR.V_LAYER_SCALE_2,
MODEL_TENSOR.V_PRE_NORM,
Expand Down Expand Up @@ -3230,6 +3242,7 @@ class VisionProjectorType:
QWEN3VL = "qwen3vl_merger"
ULTRAVOX = "ultravox"
INTERNVL = "internvl"
JINACLIP2 = "jinaclip2"
QWEN2A = "qwen2a" # audio
QWEN25O = "qwen2.5o" # omni
VOXTRAL = "voxtral"
Expand Down
25 changes: 25 additions & 0 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -1243,6 +1243,7 @@ class TensorNameMap:
"vision_encoder.transformer.layers.{bid}.attention.wq", # pixtral
"visual.blocks.{bid}.attn.q", # qwen2vl, generated
"vision_tower.encoder.blocks.{bid}.wq", # kimi-vl, generated
"blocks.{bid}.attn.q_proj", # JinaCLIP v2 vision
),

MODEL_TENSOR.V_ENC_ATTN_Q_NORM: (
Expand All @@ -1260,6 +1261,7 @@ class TensorNameMap:
"vision_encoder.transformer.layers.{bid}.attention.wk", # pixtral
"visual.blocks.{bid}.attn.k", # qwen2vl, generated
"vision_tower.encoder.blocks.{bid}.wk", # kimi-vl, generated
"blocks.{bid}.attn.k_proj", # JinaCLIP v2 vision
),

MODEL_TENSOR.V_ENC_ATTN_K_NORM: (
Expand All @@ -1277,6 +1279,7 @@ class TensorNameMap:
"vision_encoder.transformer.layers.{bid}.attention.wv", # pixtral
"visual.blocks.{bid}.attn.v", # qwen2vl, generated
"vision_tower.encoder.blocks.{bid}.wv", # kimi-vl, generated
"blocks.{bid}.attn.v_proj", # JinaCLIP v2 vision
),

MODEL_TENSOR.V_ENC_INPUT_NORM: (
Expand All @@ -1291,6 +1294,7 @@ class TensorNameMap:
"visual.blocks.{bid}.norm1", # qwen2vl
"vision_tower.encoder.blocks.{bid}.norm0", # kimi-vl (norm0/norm1)
"model.vision.transformer.layers.{bid}.input_layernorm", # cogvlm
"blocks.{bid}.norm1", # JinaCLIP v2 vision
),

MODEL_TENSOR.V_ENC_ATTN_O: (
Expand All @@ -1306,6 +1310,7 @@ class TensorNameMap:
"visual.blocks.{bid}.attn.proj", # qwen2vl
"vision_tower.encoder.blocks.{bid}.wo", # kimi-vl
"model.vision.transformer.layers.{bid}.attention.dense", # cogvlm
"blocks.{bid}.attn.proj", # JinaCLIP v2 vision
),

MODEL_TENSOR.V_ENC_POST_ATTN_NORM: (
Expand All @@ -1320,6 +1325,11 @@ class TensorNameMap:
"visual.blocks.{bid}.norm2", # qwen2vl
"vision_tower.encoder.blocks.{bid}.norm1", # kimi-vl (norm0/norm1)
"model.vision.transformer.layers.{bid}.post_attention_layernorm", # cogvlm
"blocks.{bid}.norm2", # JinaCLIP v2 vision
),

MODEL_TENSOR.V_ENC_ATTN_LN: (
"blocks.{bid}.attn.inner_attn_ln", # JinaCLIP v2 vision
),

MODEL_TENSOR.V_ENC_FFN_UP: (
Expand All @@ -1335,12 +1345,14 @@ class TensorNameMap:
"visual.blocks.{bid}.mlp.linear_fc1", # qwen3vl
"vision_tower.encoder.blocks.{bid}.mlp.fc0", # kimi-vl (fc0/fc1)
"model.vision.transformer.layers.{bid}.mlp.fc1", # cogvlm
"blocks.{bid}.mlp.w2", # JinaCLIP v2 vision (up)
),

MODEL_TENSOR.V_ENC_FFN_GATE: (
"vision_tower.transformer.layers.{bid}.feed_forward.gate_proj", # pixtral-hf
"vision_encoder.transformer.layers.{bid}.feed_forward.w1", # pixtral
"visual.blocks.{bid}.mlp.gate_proj", # qwen2.5vl
"blocks.{bid}.mlp.w1", # JinaCLIP v2 vision
),

MODEL_TENSOR.V_ENC_FFN_DOWN: (
Expand All @@ -1356,6 +1368,11 @@ class TensorNameMap:
"visual.blocks.{bid}.mlp.linear_fc2", # qwen3vl
"vision_tower.encoder.blocks.{bid}.mlp.fc1", # kimi-vl (fc0/fc1)
"model.vision.transformer.layers.{bid}.mlp.fc2", # cogvlm
"blocks.{bid}.mlp.w3", # JinaCLIP v2 vision (down)
),

MODEL_TENSOR.V_ENC_FFN_NORM: (
"blocks.{bid}.mlp.ffn_ln", # JinaCLIP v2 vision
),

MODEL_TENSOR.V_LAYER_SCALE_1: (
Expand All @@ -1368,6 +1385,14 @@ class TensorNameMap:
"model.vision_tower.encoder.layer.{bid}.lambda_2", # Intern-S1
),

MODEL_TENSOR.V_ENC_ATTN_Q_BIAS: (
"blocks.{bid}.attn.q_bias", # JinaCLIP v2 vision
),

MODEL_TENSOR.V_ENC_ATTN_V_BIAS: (
"blocks.{bid}.attn.v_bias", # JinaCLIP v2 vision
),

MODEL_TENSOR.V_PRE_NORM: (
"vision_tower.vision_model.pre_layrnorm",
"vision_tower.ln_pre", # pixtral-hf
Expand Down
10 changes: 7 additions & 3 deletions tools/mtmd/clip-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
#define KEY_PROJ_SCALE_FACTOR "clip.vision.projector.scale_factor"
#define KEY_SPATIAL_MERGE_SIZE "clip.vision.spatial_merge_size"
#define KEY_IS_DEEPSTACK_LAYERS "clip.vision.is_deepstack_layers"
#define KEY_VISION_ROPE_THETA "clip.vision.rope_theta"

#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type"
#define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints"
Expand Down Expand Up @@ -69,14 +70,15 @@
#define TN_ATTN_Q "%s.blk.%d.attn_q.%s"
#define TN_ATTN_V "%s.blk.%d.attn_v.%s"
#define TN_ATTN_OUTPUT "%s.blk.%d.attn_out.%s"
#define TN_ATTN_LN "%s.blk.%d.attn_ln.%s" // inner attention LayerNorm
#define TN_ATTN_K_NORM "%s.blk.%d.attn_k_norm.%s"
#define TN_ATTN_Q_NORM "%s.blk.%d.attn_q_norm.%s"
#define TN_FFN_DOWN "%s.blk.%d.ffn_down.%s"
#define TN_FFN_GATE "%s.blk.%d.ffn_gate.%s"
#define TN_FFN_UP "%s.blk.%d.ffn_up.%s"
#define TN_FFN_GATE "%s.blk.%d.ffn_gate.%s"
#define TN_LN_1 "%s.blk.%d.ln1.%s" // layer norm
#define TN_LN_2 "%s.blk.%d.ln2.%s" // layer norm
#define TN_FFN_NORM "%s.blk.%d.ffn_norm.%s"
#define TN_LN_1 "%s.blk.%d.ln1.%s"
#define TN_LN_2 "%s.blk.%d.ln2.%s"
#define TN_LS_1 "%s.blk.%d.ls1.%s" // layer scale
#define TN_LS_2 "%s.blk.%d.ls2.%s" // layer scale
#define TN_LN_PRE "%s.pre_ln.%s"
Expand Down Expand Up @@ -151,6 +153,7 @@ enum projector_type {
PROJECTOR_TYPE_QWEN2A,
PROJECTOR_TYPE_QWEN25O, // will be replaced by QWEN2A or QWEN25VL depending on clip_ctx
PROJECTOR_TYPE_VOXTRAL,
PROJECTOR_TYPE_JINACLIP2, // JinaCLIP v2
PROJECTOR_TYPE_LFM2,
PROJECTOR_TYPE_KIMIVL,
PROJECTOR_TYPE_LIGHTONOCR,
Expand Down Expand Up @@ -180,6 +183,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
{ PROJECTOR_TYPE_LFM2, "lfm2"},
{ PROJECTOR_TYPE_KIMIVL, "kimivl"},
{ PROJECTOR_TYPE_LIGHTONOCR,"lightonocr"},
{ PROJECTOR_TYPE_JINACLIP2, "jinaclip2"},
{ PROJECTOR_TYPE_COGVLM, "cogvlm"},
{ PROJECTOR_TYPE_JANUS_PRO, "janus_pro"},
};
Expand Down
Loading
Loading