From 7ea1e821b7a9cd723514995a5331cd48d73c9600 Mon Sep 17 00:00:00 2001 From: liyang Date: Fri, 31 Oct 2025 16:28:58 +0800 Subject: [PATCH 1/4] address #16574; fold CLI into mtmd-cli; use ggml_rope_ext + bicubic;switch to 'jinaclip2'; fix converter constants --- common/arg.cpp | 4 +- convert_hf_to_gguf.py | 261 ++++++++++++++++++++++- gguf-py/gguf/constants.py | 1 + tools/mtmd/clip-impl.h | 10 +- tools/mtmd/clip.cpp | 423 +++++++++++++++++++++++++++++++++++-- tools/mtmd/clip.h | 1 + tools/mtmd/mtmd-cli.cpp | 113 +++++++++- tools/mtmd/mtmd-helper.cpp | 13 ++ tools/mtmd/mtmd-helper.h | 4 + tools/mtmd/mtmd.cpp | 107 ++++++++++ tools/mtmd/mtmd.h | 26 +++ 11 files changed, 939 insertions(+), 24 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 430ab45dfe26e..05c15f0d625e5 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2290,14 +2290,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params, int value) { params.embd_normalize = value; } - ).set_examples({LLAMA_EXAMPLE_EMBEDDING})); + ).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_MTMD})); add_opt(common_arg( {"--embd-output-format"}, "FORMAT", "empty = default, \"array\" = [[],[]...], \"json\" = openai style, \"json+\" = same \"json\" + cosine similarity matrix, \"raw\" = plain whitespace-delimited output (one embedding per line)", [](common_params & params, const std::string & value) { params.embd_out = value; } - ).set_examples({LLAMA_EXAMPLE_EMBEDDING})); + ).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_MTMD})); add_opt(common_arg( {"--embd-separator"}, "STRING", "separator of embeddings (default \\n) for example \"<#sep#>\"", diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 2b08013e1e457..08d3a93d63e42 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -5557,7 +5557,18 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, if lora_names := hparams.get("lora_adaptations"): self._lora_names = lora_names - self.model_arch = gguf.MODEL_ARCH.JINA_BERT_V3 + + try: + text_cfg = hparams.get("text_config", {}) if isinstance(hparams.get("text_config", {}), dict) else {} + pe_type = (text_cfg.get("position_embedding_type") or hparams.get("position_embedding_type") or "").lower() + rope_base = text_cfg.get("rotary_emb_base", hparams.get("rotary_emb_base")) + name_path = (hparams.get("_name_or_path") or "").lower() + is_vx = ("jina" in name_path and ("v2" in name_path or "v3" in name_path)) + is_v3 = (pe_type == "rotary" or rope_base is not None) and is_vx + if (is_v3) or self._lora_names: + self.model_arch = gguf.MODEL_ARCH.JINA_BERT_V3 + except Exception: + pass super().__init__(dir_model, ftype, fname_out, hparams=hparams, **kwargs) self._xlmroberta_tokenizer_init() @@ -6779,6 +6790,254 @@ def set_vocab(self): raise NotImplementedError(f'Tokenizer {tokenizer_class} is not supported for JinaBertModel') +@ModelBase.register("JinaCLIPVisionModel", "JinaCLIPModel") +class JinaCLIPVisionModel(MmprojModel): + """JinaCLIP v2 Vision Encoder Model - handles vision component only""" + model_arch = gguf.MODEL_ARCH.MMPROJ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Load config for vision encoder + config_path = self.dir_model / "config.json" + if not config_path.exists(): + raise FileNotFoundError( + f"JinaCLIPVisionModel: missing config.json in {self.dir_model}. " + "Please ensure the original model config is present; default hyperparameter fallbacks are not used." + ) + with open(config_path, encoding="utf-8") as f: + self.vision_config = json.load(f) + + def set_vocab(self): + # Vision encoder doesn't need vocabulary + pass + + def set_gguf_parameters(self): + cfg = self.vision_config + + try: + width = int(cfg["width"]) # channel dim + head_width = int(cfg["head_width"]) # per-head dim + layers = int(cfg["layers"]) # block count + image_size = int(cfg["image_size"]) # input image size + patch_size = int(cfg["patch_size"]) # patch size + except KeyError as e: + raise KeyError(f"JinaCLIPVisionModel: missing key in config.json: {e}") + + if width % head_width != 0: + raise ValueError( + f"JinaCLIPVisionModel: width ({width}) not divisible by head_width ({head_width})" + ) + n_head = width // head_width + + if "mlp_ratio" in cfg: + n_ff = int(width * float(cfg["mlp_ratio"])) + elif bool(cfg.get("naive_swiglu", False)): + n_ff = int((width * 8) // 3) + else: + raise ValueError("JinaCLIPVisionModel: unable to infer FFN size; please provide 'mlp_ratio' or set 'naive_swiglu' in config.json") + + self.gguf_writer.add_clip_has_vision_encoder(True) + proj_dim = int(cfg.get("projection_dim", width)) + self.gguf_writer.add_vision_projection_dim(proj_dim) + + self.gguf_writer.add_vision_image_size(image_size) + self.gguf_writer.add_vision_patch_size(patch_size) + self.gguf_writer.add_vision_embedding_length(width) + self.gguf_writer.add_vision_block_count(layers) + self.gguf_writer.add_vision_head_count(n_head) + self.gguf_writer.add_vision_feed_forward_length(n_ff) + + self.gguf_writer.add_vision_attention_layernorm_eps(float(cfg.get("layer_norm_eps", 1e-5))) + + mean = self.preprocessor_config.get("image_mean", self.preprocessor_config.get("mean")) + std = self.preprocessor_config.get("image_std", self.preprocessor_config.get("std")) + if mean is None or std is None: + raise KeyError( + "JinaCLIPVisionModel: preprocessor_config missing image mean/std (expected keys: 'image_mean'/'image_std' or 'mean'/'std')" + ) + self.gguf_writer.add_vision_image_mean(mean) + self.gguf_writer.add_vision_image_std(std) + + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.JINACLIP2) + self.gguf_writer.add_vision_use_silu(True) + + def _strip_vm_prefix(self, name: str) -> str: + return name[len('vision_model.'):] if name.startswith('vision_model.') else name + + def _map_block_tensor(self, layer: int, rest: str, data_torch: Tensor, name: str) -> list[tuple[str, Tensor]] | None: + parts = rest.split('.') + # layer norms + if rest.startswith('norm1.'): + suffix = parts[-1] + return [(f'v.blk.{layer}.ln1.{suffix}', data_torch)] + if rest.startswith('norm2.'): + suffix = parts[-1] + return [(f'v.blk.{layer}.ln2.{suffix}', data_torch)] + if rest.startswith('attn.inner_attn_ln.'): + suffix = parts[-1] + return [(f'v.blk.{layer}.attn_ln.{suffix}', data_torch)] + + # fused qkv + if rest == 'attn.qkv.weight': + w = data_torch + wdim = w.shape[0] + if wdim % 3 != 0: + logger.warning('mmproj(jinaclip): unexpected qkv weight shape %s for %s', tuple(w.shape), name) + d = wdim // 3 + q, k, v = w[0:d, :], w[d:2 * d, :], w[2 * d:, :] + return [ + (f'v.blk.{layer}.attn_q.weight', q), + (f'v.blk.{layer}.attn_k.weight', k), + (f'v.blk.{layer}.attn_v.weight', v), + ] + if rest == 'attn.qkv.bias': + b = data_torch + bdim = b.shape[0] + if bdim % 3 != 0: + logger.warning('mmproj(jinaclip): unexpected qkv bias shape %s for %s', tuple(b.shape), name) + d = bdim // 3 + qb, kb, vb = b[0:d], b[d:2 * d], b[2 * d:] + return [ + (f'v.blk.{layer}.attn_q.bias', qb), + (f'v.blk.{layer}.attn_k.bias', kb), + (f'v.blk.{layer}.attn_v.bias', vb), + ] + # separate q/v bias (some checkpoints) + if rest == 'attn.q_bias': + return [(f'v.blk.{layer}.attn_q.bias', data_torch)] + if rest == 'attn.v_bias': + return [(f'v.blk.{layer}.attn_v.bias', data_torch)] + + # separate projections + if rest.startswith('attn.q_proj.'): + suffix = parts[-1] + return [(f'v.blk.{layer}.attn_q.{suffix}', data_torch)] + if rest.startswith('attn.k_proj.'): + suffix = parts[-1] + return [(f'v.blk.{layer}.attn_k.{suffix}', data_torch)] + if rest.startswith('attn.v_proj.'): + suffix = parts[-1] + return [(f'v.blk.{layer}.attn_v.{suffix}', data_torch)] + if rest.startswith('attn.proj.'): + suffix = parts[-1] + return [(f'v.blk.{layer}.attn_out.{suffix}', data_torch)] + + # MLP + if rest.startswith('mlp.w1.'): + suffix = parts[-1] + return [(f'v.blk.{layer}.ffn_gate.{suffix}', data_torch)] + if rest.startswith('mlp.w2.'): + suffix = parts[-1] + return [(f'v.blk.{layer}.ffn_up.{suffix}', data_torch)] + if rest.startswith('mlp.w3.'): + suffix = parts[-1] + return [(f'v.blk.{layer}.ffn_down.{suffix}', data_torch)] + if rest.startswith('mlp.ffn_ln.'): + suffix = parts[-1] + return [(f'v.blk.{layer}.ffn_norm.{suffix}', data_torch)] + if rest.startswith('mlp.fc1.'): + suffix = parts[-1] + return [(f'v.blk.{layer}.ffn_up.{suffix}', data_torch)] + if rest.startswith('mlp.fc2.'): + suffix = parts[-1] + return [(f'v.blk.{layer}.ffn_down.{suffix}', data_torch)] + return None + + def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", ".bias")) -> str: + """Prefer base table-driven mapping; keep Jina-specific targets if already mapped; fallback to legacy mapper.""" + # Already a GGUF target name (e.g., "v.*" or "mm.*"): return as-is + if name.startswith('v.') or name.startswith('mm.'): + return name + # Try the base mapping first + try: + return super().map_tensor_name(name, try_suffixes=try_suffixes) + except Exception: + # Fallback to legacy Jina-specific mapper for any remaining edge keys + if hasattr(self, "_map_jinaclip_tensor_name"): + mapped = self._map_jinaclip_tensor_name(name) # type: ignore[attr-defined] + if mapped: + return mapped + return name + + def get_tensors(self) -> Iterator[tuple[str, Tensor]]: + yielded_any = False + try: + for name, tensor in super().get_tensors(): + yielded_any = True + yield name, tensor + except Exception as e: + logger.warning("mmproj(jinaclip): base get_tensors failed, falling back: %s", e) + if yielded_any: + return + + candidates = [ + self.dir_model / "pytorch_model.bin", + self.dir_model / "vision_model_weights.bin", + ] + model_path = next((p for p in candidates if p.exists()), None) + if model_path is None: + raise FileNotFoundError(f"mmproj(jinaclip): no model weights found in {self.dir_model}") + try: + state_dict = torch.load(model_path, map_location="cpu", weights_only=True) + except TypeError: + state_dict = torch.load(model_path, map_location="cpu") + + for name, tensor in state_dict.items(): + yield name, tensor + + def _should_be_f32(self, gguf_name: str) -> bool: + patterns = ( + ".ln1.weight", ".ln1.bias", + ".ln2.weight", ".ln2.bias", + ".attn_ln.weight", ".attn_ln.bias", + ".ffn_norm.weight", ".ffn_norm.bias", + "v.patch_embd.proj.bias", + ) + return any(p in gguf_name for p in patterns) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + src = name + if src.startswith('v.') or src.startswith('mm.'): + return [(src, data_torch)] + + # Drop 'vision_model.' prefix if present + src_no_vm = self._strip_vm_prefix(src) + + # Top-level direct mappings — use gguf constants directly for canonical names + if src_no_vm == 'cls_token': + base = gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_CLS] + return [(base, data_torch)] + if src_no_vm.startswith('patch_embed.proj.'): + suffix = src_no_vm.split('.')[-1] + base = gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + return [(f'{base}.{suffix}', data_torch)] + if src_no_vm == 'pos_embed': + pos_name = gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_POS] + '.weight' + return [(pos_name, data_torch)] + if src_no_vm.startswith('norm.'): + suffix = src_no_vm.split('.')[-1] + base = gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_POST_NORM] + return [(f'{base}.{suffix}', data_torch)] + + if src_no_vm.startswith('blocks.'): + parts = src_no_vm.split('.') + if len(parts) >= 3 and parts[1].isdigit(): + layer = int(parts[1]) + rest = '.'.join(parts[2:]) + mapped = self._map_block_tensor(layer, rest, data_torch, name) + if mapped is not None: + return mapped + + try: + return [(self.map_tensor_name(name), data_torch)] + except Exception: + logger.debug("mmproj(jinaclip): skip unmapped tensor %s", name) + return [] + + @ModelBase.register("OpenELMForCausalLM") class OpenELMModel(TextModel): model_arch = gguf.MODEL_ARCH.OPENELM diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 1cd0efad4a8f1..c38d1d11898d1 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -3230,6 +3230,7 @@ class VisionProjectorType: QWEN3VL = "qwen3vl_merger" ULTRAVOX = "ultravox" INTERNVL = "internvl" + JINACLIP2 = "jinaclip2" QWEN2A = "qwen2a" # audio QWEN25O = "qwen2.5o" # omni VOXTRAL = "voxtral" diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index cd47865bf4a78..d76602b50b763 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -40,6 +40,7 @@ #define KEY_PROJ_SCALE_FACTOR "clip.vision.projector.scale_factor" #define KEY_SPATIAL_MERGE_SIZE "clip.vision.spatial_merge_size" #define KEY_IS_DEEPSTACK_LAYERS "clip.vision.is_deepstack_layers" +#define KEY_VISION_ROPE_THETA "clip.vision.rope_theta" #define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type" #define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints" @@ -69,14 +70,15 @@ #define TN_ATTN_Q "%s.blk.%d.attn_q.%s" #define TN_ATTN_V "%s.blk.%d.attn_v.%s" #define TN_ATTN_OUTPUT "%s.blk.%d.attn_out.%s" +#define TN_ATTN_LN "%s.blk.%d.attn_ln.%s" // inner attention LayerNorm #define TN_ATTN_K_NORM "%s.blk.%d.attn_k_norm.%s" #define TN_ATTN_Q_NORM "%s.blk.%d.attn_q_norm.%s" #define TN_FFN_DOWN "%s.blk.%d.ffn_down.%s" #define TN_FFN_GATE "%s.blk.%d.ffn_gate.%s" #define TN_FFN_UP "%s.blk.%d.ffn_up.%s" -#define TN_FFN_GATE "%s.blk.%d.ffn_gate.%s" -#define TN_LN_1 "%s.blk.%d.ln1.%s" // layer norm -#define TN_LN_2 "%s.blk.%d.ln2.%s" // layer norm +#define TN_FFN_NORM "%s.blk.%d.ffn_norm.%s" +#define TN_LN_1 "%s.blk.%d.ln1.%s" +#define TN_LN_2 "%s.blk.%d.ln2.%s" #define TN_LS_1 "%s.blk.%d.ls1.%s" // layer scale #define TN_LS_2 "%s.blk.%d.ls2.%s" // layer scale #define TN_LN_PRE "%s.pre_ln.%s" @@ -151,6 +153,7 @@ enum projector_type { PROJECTOR_TYPE_QWEN2A, PROJECTOR_TYPE_QWEN25O, // will be replaced by QWEN2A or QWEN25VL depending on clip_ctx PROJECTOR_TYPE_VOXTRAL, + PROJECTOR_TYPE_JINACLIP2, // JinaCLIP v2 PROJECTOR_TYPE_LFM2, PROJECTOR_TYPE_KIMIVL, PROJECTOR_TYPE_LIGHTONOCR, @@ -180,6 +183,7 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_LFM2, "lfm2"}, { PROJECTOR_TYPE_KIMIVL, "kimivl"}, { PROJECTOR_TYPE_LIGHTONOCR,"lightonocr"}, + { PROJECTOR_TYPE_JINACLIP2, "jinaclip2"}, { PROJECTOR_TYPE_COGVLM, "cogvlm"}, { PROJECTOR_TYPE_JANUS_PRO, "janus_pro"}, }; diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index abdb778f7afb8..45df8d9f8c156 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -237,6 +237,9 @@ struct clip_layer { ggml_tensor * o_w = nullptr; ggml_tensor * o_b = nullptr; + ggml_tensor * attn_ln_w = nullptr; + ggml_tensor * attn_ln_b = nullptr; + ggml_tensor * k_norm = nullptr; ggml_tensor * q_norm = nullptr; @@ -251,6 +254,9 @@ struct clip_layer { ggml_tensor * ff_down_w = nullptr; ggml_tensor * ff_down_b = nullptr; + ggml_tensor * ffn_norm_w = nullptr; + ggml_tensor * ffn_norm_b = nullptr; + // layernorm 2 ggml_tensor * ln_2_w = nullptr; ggml_tensor * ln_2_b = nullptr; @@ -1788,6 +1794,157 @@ struct clip_graph { return gf; } + + ggml_cgraph * build_jina2() { + const int n_pos = n_patches + (model.class_embedding ? 1 : 0); + + GGML_ASSERT(n_patches_x == n_patches_y && "only square images supported"); + + int max_feature_layer = n_layer; + + ggml_tensor * inp = build_inp(); + + if (ctx->proj_type() == PROJECTOR_TYPE_JINACLIP2) { + ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos); + ggml_tensor * pos_w = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos); + ggml_set_name(pos_h, "pos_h"); + ggml_set_name(pos_w, "pos_w"); + ggml_set_input(pos_h); + ggml_set_input(pos_w); + ggml_build_forward_expand(gf, pos_h); + ggml_build_forward_expand(gf, pos_w); + + ggml_tensor * rope_c_first = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, d_head/2); + ggml_tensor * rope_c_second = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, d_head/2); + ggml_set_name(rope_c_first, "rope_c_first"); + ggml_set_name(rope_c_second, "rope_c_second"); + ggml_set_input(rope_c_first); + ggml_set_input(rope_c_second); + ggml_build_forward_expand(gf, rope_c_first); + ggml_build_forward_expand(gf, rope_c_second); + + } + if (model.class_embedding) { + inp = ggml_concat(ctx0, model.class_embedding, inp, 1); + } + + ggml_tensor * positions = nullptr; + + positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos); + ggml_set_name(positions, "positions"); + ggml_set_input(positions); + + inp = ggml_add(ctx0, inp, ggml_get_rows(ctx0, model.position_embeddings, positions)); + + ggml_tensor * inpL = inp; + + // pre-layernorm + if (model.pre_ln_w) { + inpL = build_norm(inpL, model.pre_ln_w, model.pre_ln_b, NORM_TYPE_NORMAL, eps, -1); + } + + // loop over layers + for (int il = 0; il < max_feature_layer; il++) { + auto & layer = model.layers[il]; + ggml_tensor * cur = inpL; // inpL = residual, cur = hidden_states + + cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, NORM_TYPE_NORMAL, eps, il); + + { + ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.q_w, cur); + if (layer.q_b) { + Qcur = ggml_add(ctx0, Qcur, layer.q_b); + } + + ggml_tensor * Kcur = ggml_mul_mat(ctx0, layer.k_w, cur); + if (layer.k_b) { + Kcur = ggml_add(ctx0, Kcur, layer.k_b); + } + + ggml_tensor * Vcur = ggml_mul_mat(ctx0, layer.v_w, cur); + if (layer.v_b) { + Vcur = ggml_add(ctx0, Vcur, layer.v_b); + } + Qcur = ggml_reshape_4d(ctx0, Qcur, d_head, n_head, n_pos, 1); + Kcur = ggml_reshape_4d(ctx0, Kcur, d_head, n_head, n_pos, 1); + Vcur = ggml_reshape_4d(ctx0, Vcur, d_head, n_head, n_pos, 1); + + ggml_tensor * Q_rope_in = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); + ggml_tensor * K_rope_in = ggml_permute(ctx0, Kcur, 0, 2, 1, 3); + + // Apply 2D RoPE position encoding for JinaCLIP (skip CLS). + ggml_tensor * Q_cls = + ggml_view_3d(ctx0, Q_rope_in, d_head, 1, n_head, Q_rope_in->nb[1], Q_rope_in->nb[2], 0); + ggml_tensor * Q_patches = ggml_view_3d(ctx0, Q_rope_in, d_head, n_pos - 1, n_head, Q_rope_in->nb[1], + Q_rope_in->nb[2], Q_rope_in->nb[1]); + + // Split K: CLS token (pos 0) + patch tokens (pos 1+) + ggml_tensor * K_cls = + ggml_view_3d(ctx0, K_rope_in, d_head, 1, n_head, K_rope_in->nb[1], K_rope_in->nb[2], 0); + ggml_tensor * K_patches = ggml_view_3d(ctx0, K_rope_in, d_head, n_pos - 1, n_head, K_rope_in->nb[1], + K_rope_in->nb[2], K_rope_in->nb[1]); + + int pt_seq_len = 16; // fallback pretrain length + if (hparams.patch_size > 0) { + int cand = (int) llroundf(224.0f / (float) hparams.patch_size); + if (cand > 0) { + pt_seq_len = cand; + } + } + const int hw_seq_len = static_cast(sqrtf(n_pos - 1)); // image grid size (excluding CLS) + Q_patches = build_jinaclip_rope(ctx0, ctx, Q_patches, pt_seq_len, hw_seq_len, hparams.rope_theta, true, + true, il); + K_patches = build_jinaclip_rope(ctx0, ctx, K_patches, pt_seq_len, hw_seq_len, hparams.rope_theta, true, + false, il); + GGML_ASSERT(Q_cls->ne[0] == Q_patches->ne[0]); + GGML_ASSERT(Q_cls->ne[2] == Q_patches->ne[2]); + GGML_ASSERT(Q_cls->ne[3] == Q_patches->ne[3]); + + // Recombine: CLS token + RoPE-processed patch tokens (seq dimension is 1 now) + ggml_tensor * Q_rope_out = ggml_concat(ctx0, Q_cls, Q_patches, 1); + ggml_tensor * K_rope_out = ggml_concat(ctx0, K_cls, K_patches, 1); + + Qcur = ggml_permute(ctx0, Q_rope_out, 0, 2, 1, 3); + Kcur = ggml_permute(ctx0, K_rope_out, 0, 2, 1, 3); + + cur = build_attn(layer.o_w, layer.o_b, Qcur, Kcur, Vcur, nullptr, + kq_scale, il); + } + + // re-add the layer input, e.g., residual + cur = ggml_add(ctx0, cur, inpL); + cb(inpL, "inp_after_attn", il); + + inpL = cur; + + cb(cur, "ffn_inp", il); + + // layernorm2 + cur = build_norm(cur, layer.ln_2_w, layer.ln_2_b, NORM_TYPE_NORMAL, eps, il); + cb(cur, "ffn_inp_normed", il); + // ffn + cur = build_ffn(cur, layer.ff_up_w, layer.ff_up_b, layer.ff_gate_w, layer.ff_gate_b, layer.ff_down_w, + layer.ff_down_b, hparams.ffn_op, il); + + // residual 2 + cur = ggml_add(ctx0, inpL, cur); + + inpL = cur; + } + + // post-layernorm + if (model.post_ln_w) { + inpL = build_norm(inpL, model.post_ln_w, model.post_ln_b, NORM_TYPE_NORMAL, eps, -1); + } + // final + + ggml_tensor * emb2d = ggml_reshape_2d(ctx0, inpL, inpL->ne[0], inpL->ne[1]); + ggml_tensor * cls = ggml_view_2d(ctx0, emb2d, emb2d->ne[0], /*rows=*/1, emb2d->nb[1], /*offset=*/0); + ggml_set_name(cls, "cls_view"); + ggml_build_forward_expand(gf, cls); + return gf; + } + // whisper encoder with custom projector ggml_cgraph * build_whisper_enc() { const int n_frames = img.nx; @@ -2215,15 +2372,15 @@ struct clip_graph { } ggml_tensor * build_ffn( - ggml_tensor * cur, - ggml_tensor * up, - ggml_tensor * up_b, - ggml_tensor * gate, - ggml_tensor * gate_b, - ggml_tensor * down, - ggml_tensor * down_b, - ffn_op_type type_op, - int il) const { + ggml_tensor * cur, + ggml_tensor * up, + ggml_tensor * up_b, + ggml_tensor * gate, + ggml_tensor * gate_b, + ggml_tensor * down, + ggml_tensor * down_b, + ffn_op_type type_op, + int il) const { ggml_tensor * tmp = up ? ggml_mul_mat(ctx0, up, cur) : cur; cb(tmp, "ffn_up", il); @@ -2281,6 +2438,18 @@ struct clip_graph { } break; } + ggml_tensor * ffn_norm_w = nullptr; + ggml_tensor * ffn_norm_b = nullptr; + if (il >= 0 && il < (int) model.layers.size()) { + auto & layer = model.layers[il]; + ffn_norm_w = layer.ffn_norm_w; + ffn_norm_b = layer.ffn_norm_b; + } + if (ffn_norm_w || ffn_norm_b) { + cur = build_norm(cur, ffn_norm_w, ffn_norm_b, NORM_TYPE_NORMAL, eps, il); + cb(cur, "ffn_norm", il); + } + if (down) { cur = ggml_mul_mat(ctx0, down, cur); } @@ -2291,20 +2460,21 @@ struct clip_graph { if (down_b) { cur = ggml_add(ctx0, cur, down_b); + cb(cur, "ffn_down_b", il); } return cur; } ggml_tensor * build_attn( - ggml_tensor * wo, - ggml_tensor * wo_b, - ggml_tensor * q_cur, - ggml_tensor * k_cur, - ggml_tensor * v_cur, - ggml_tensor * kq_mask, - float kq_scale, - int il) const { + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, + ggml_tensor * k_cur, + ggml_tensor * v_cur, + ggml_tensor * kq_mask, + float kq_scale, + int il) const { // these nodes are added to the graph together so that they are not reordered // by doing so, the number of splits in the graph is reduced ggml_build_forward_expand(gf, q_cur); @@ -2350,6 +2520,26 @@ struct clip_graph { cb(cur, "kqv_out", il); + ggml_tensor * attn_ln_w = nullptr; + ggml_tensor * attn_ln_b = nullptr; + if (il >= 0 && il < (int) model.layers.size()) { + auto & layer = model.layers[il]; + attn_ln_w = layer.attn_ln_w; + attn_ln_b = layer.attn_ln_b; + } + if (attn_ln_w && attn_ln_b) { + ggml_tensor * attn_ln_w_f32 = attn_ln_w; + ggml_tensor * attn_ln_b_f32 = attn_ln_b; + + if (attn_ln_w->type == GGML_TYPE_F16) { + attn_ln_w_f32 = ggml_cast(ctx0, attn_ln_w, GGML_TYPE_F32); + } + if (attn_ln_b->type == GGML_TYPE_F16) { + attn_ln_b_f32 = ggml_cast(ctx0, attn_ln_b, GGML_TYPE_F32); + } + cur = build_norm(cur, attn_ln_w_f32, attn_ln_b_f32, NORM_TYPE_NORMAL, hparams.eps, il); + } + if (wo) { cur = ggml_mul_mat(ctx0, wo, cur); } @@ -2361,6 +2551,87 @@ struct clip_graph { return cur; } + + ggml_tensor * build_jinaclip_rope(ggml_context * ctx0, + clip_ctx * /*ctx*/, + ggml_tensor * cur, + const int pt_seq_len, + const int ft_seq_len, + const float freq_base, + const bool has_cls_token = true, + const bool /*if_query*/ = true, + const int /*layer_id*/ = -1) { + (void) pt_seq_len; + (void) ft_seq_len; + const int64_t n_dim = cur->ne[0]; + const int64_t n_pos_patches = cur->ne[1]; + const int64_t n_head = cur->ne[2]; + + GGML_ASSERT(n_dim % 2 == 0); + const int64_t half = n_dim/2; + + ggml_tensor * pos_h_full = ggml_graph_get_tensor(gf, "pos_h"); + ggml_tensor * pos_w_full = ggml_graph_get_tensor(gf, "pos_w"); + GGML_ASSERT(pos_h_full && pos_w_full); + + const int64_t offset = has_cls_token ? 1 : 0; + ggml_tensor * pos_h = ggml_view_1d(ctx0, pos_h_full, n_pos_patches, offset * (int64_t)ggml_element_size(pos_h_full)); + ggml_tensor * pos_w = ggml_view_1d(ctx0, pos_w_full, n_pos_patches, offset * (int64_t)ggml_element_size(pos_w_full)); + ggml_tensor * pos_a = pos_h; + ggml_tensor * pos_b = pos_w; + + + ggml_tensor * first = ggml_view_3d(ctx0, cur, + half, n_head, n_pos_patches, + /*nb1 for head*/ cur->nb[2], + /*nb2 for seq */ cur->nb[1], + 0); + ggml_tensor * c_first = ggml_graph_get_tensor(gf, "rope_c_first"); + ggml_tensor * c_second = ggml_graph_get_tensor(gf, "rope_c_second"); + GGML_ASSERT(c_first && c_second); + + ggml_tensor * first_rot = ggml_rope_ext( + ctx0, + first, + pos_a, + c_first, + half, + 0, + 0, + freq_base, + 1.0f, + 0.0f, 1.0f, 0.0f, 0.0f); + first = ggml_view_3d(ctx0, first_rot, + half, n_pos_patches, n_head, + /*nb1 for seq */ first_rot->nb[2], + /*nb2 for head*/ first_rot->nb[1], + 0); + + ggml_tensor * second_hs = ggml_view_3d(ctx0, cur, + half, n_head, n_pos_patches, + /*nb1 for head*/ cur->nb[2], + /*nb2 for seq */ cur->nb[1], + /*offset*/ half * ggml_element_size(cur)); + ggml_tensor * second_rot = ggml_rope_ext( + ctx0, + second_hs, + pos_b, + c_second, + half, + 0, + 0, + freq_base, + 1.0f, + 0.0f, 1.0f, 0.0f, 0.0f); + ggml_tensor * second = ggml_view_3d(ctx0, second_rot, + half, n_pos_patches, n_head, + second_rot->nb[2], + second_rot->nb[1], + 0); + ggml_tensor * result = ggml_concat(ctx0, first, second, 0); + return result; + } + // implementation of the 2D RoPE without adding a new op in ggml // this is not efficient (use double the memory), but works on all backends // TODO: there was a more efficient which relies on ggml_view and ggml_rope_ext_inplace, but the rope inplace does not work well with non-contiguous tensors ; we should fix that and revert back to the original implementation in https://github.com/ggml-org/llama.cpp/pull/13065 @@ -2511,6 +2782,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 { res = graph.build_whisper_enc(); } break; + case PROJECTOR_TYPE_JINACLIP2: + { + res = graph.build_jina2(); + } break; case PROJECTOR_TYPE_KIMIVL: { res = graph.build_kimivl(); @@ -2838,6 +3113,11 @@ struct clip_model_loader { get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge, false); set_llava_uhd_res_candidates(model, 3); } break; + case PROJECTOR_TYPE_JINACLIP2: + { + hparams.rope_theta = 10000.0f; + get_f32(KEY_VISION_ROPE_THETA, hparams.rope_theta, /*required=*/false); + } break; case PROJECTOR_TYPE_ULTRAVOX: case PROJECTOR_TYPE_QWEN2A: case PROJECTOR_TYPE_VOXTRAL: @@ -2971,6 +3251,10 @@ struct clip_model_loader { layer.v_b = get_tensor(string_format(TN_ATTN_V, prefix, il, "bias"), false); layer.o_b = get_tensor(string_format(TN_ATTN_OUTPUT, prefix, il, "bias"), false); layer.qkv_b = get_tensor(string_format(TN_ATTN_QKV, prefix, il, "bias"), false); + + layer.attn_ln_w = get_tensor(string_format(TN_ATTN_LN, prefix, il, "weight"), false); + layer.attn_ln_b = get_tensor(string_format(TN_ATTN_LN, prefix, il, "bias"), false); + layer.ln_1_b = get_tensor(string_format(TN_LN_1, prefix, il, "bias"), false); layer.ln_2_b = get_tensor(string_format(TN_LN_2, prefix, il, "bias"), false); @@ -2982,6 +3266,8 @@ struct clip_model_loader { layer.ff_down_w = get_tensor(string_format(TN_FFN_DOWN, prefix, il, "weight")); layer.ff_down_b = get_tensor(string_format(TN_FFN_DOWN, prefix, il, "bias"), false); + layer.ffn_norm_w = get_tensor(string_format(TN_FFN_NORM, prefix, il, "weight"), false); + layer.ffn_norm_b = get_tensor(string_format(TN_FFN_NORM, prefix, il, "bias"), false); // qwen3vl deepstack layer layer.deepstack_norm_w = get_tensor(string_format(TN_DEEPSTACK_NORM, il, "weight"), false); @@ -3238,6 +3524,12 @@ struct clip_model_loader { model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight")); model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias")); } break; + case PROJECTOR_TYPE_JINACLIP2: + { + // JinaCLIP is a pure vision encoder without separate projection layers + // It only uses patch embedding projections + // No additional mm projection tensors are loaded for JinaCLIP2 + } break; default: GGML_ASSERT(false && "unknown projector type"); } @@ -4344,6 +4636,44 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str res_imgs->grid_y = inst.grid_size.height; } break; + case PROJECTOR_TYPE_JINACLIP2: + { + clip_image_u8 processed_image; + int sz = params.image_size; + + // 1) Preserve aspect ratio: resize so that the shorter side == sz (bicubic) + int in_w = img->nx; + int in_h = img->ny; + int out_w, out_h; + if (in_w <= 0 || in_h <= 0) { + LOG_ERR("%s: invalid input image size %dx%d\n", __func__, in_w, in_h); + return false; + } + if (in_w < in_h) { + out_w = sz; + out_h = std::max(1, (int) std::round((double) in_h * sz / in_w)); + } else { + out_h = sz; + out_w = std::max(1, (int) std::round((double) in_w * sz / in_h)); + } + + clip_image_u8 resized_keep_ratio; + img_tool::resize(*img, resized_keep_ratio, clip_image_size{out_w, out_h}, img_tool::RESIZE_ALGO_BICUBIC); + + // 2) Center-crop to sz x sz + int x0 = std::max(0, (resized_keep_ratio.nx - sz) / 2); + int y0 = std::max(0, (resized_keep_ratio.ny - sz) / 2); + int crop_w = std::min(sz, resized_keep_ratio.nx); + int crop_h = std::min(sz, resized_keep_ratio.ny); + + img_tool::crop(resized_keep_ratio, processed_image, x0, y0, crop_w, crop_h); + + // 3) Normalize + clip_image_f32_ptr img_f32(clip_image_f32_init()); + normalize_image_u8_to_f32(processed_image, *img_f32, params.image_mean, params.image_std); + res_imgs->entries.push_back(std::move(img_f32)); + } break; + case PROJECTOR_TYPE_LFM2: case PROJECTOR_TYPE_KIMIVL: { @@ -4490,6 +4820,10 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im { // do nothing } break; + case PROJECTOR_TYPE_JINACLIP2: + { + n_patches = 1; + } break; case PROJECTOR_TYPE_LDP: case PROJECTOR_TYPE_LDPV2: case PROJECTOR_TYPE_GLM_EDGE: @@ -4888,6 +5222,55 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } set_input_i32("positions", positions); } break; + case PROJECTOR_TYPE_JINACLIP2: + { + std::vector positions(n_pos); + for (int i = 0; i < n_pos; i++) { + positions[i] = i; + } + set_input_i32("positions", positions); + + const int n_patches = model.class_embedding ? (n_pos - 1) : n_pos; + const int n_patches_per_col = image_size_width / patch_size; + std::vector pos_data(n_pos, 0); + + for (int i = 0; i < n_patches; ++i) { + int idx = model.class_embedding ? (i + 1) : i; + pos_data[idx] = i / n_patches_per_col; + } + set_input_i32("pos_h", pos_data); + std::fill(pos_data.begin(), pos_data.end(), 0); + + for (int i = 0; i < n_patches; ++i) { + int idx = model.class_embedding ? (i + 1) : i; + pos_data[idx] = i % n_patches_per_col; + } + set_input_i32("pos_w", pos_data); + + int pt_seq_len = 16; + if (patch_size > 0) { + int cand = (int) llroundf(224.0f / (float) patch_size); + if (cand > 0) pt_seq_len = cand; + } + float s = (float) pt_seq_len / (float) n_patches_per_col; + int d_head_local = hparams.n_embd / hparams.n_head; + int half_local = d_head_local/2; + std::vector rope_c_first(half_local); + std::vector rope_c_second(half_local); + float odd = std::pow(hparams.rope_theta, (float)-2.0f / (float)d_head_local); + for (int k = 0; k < half_local; ++k) { + rope_c_first[k] = 1.0f / s; + rope_c_second[k] = 1.0f / (s * odd); + } + + auto t1 = ggml_graph_get_tensor(gf, "rope_c_first"); + auto t2 = ggml_graph_get_tensor(gf, "rope_c_second"); + GGML_ASSERT(t1 && (t1->flags & GGML_TENSOR_FLAG_INPUT)); + GGML_ASSERT(t2 && (t2->flags & GGML_TENSOR_FLAG_INPUT)); + ggml_backend_tensor_set(t1, rope_c_first.data(), 0, ggml_nbytes(t1)); + ggml_backend_tensor_set(t2, rope_c_second.data(), 0, ggml_nbytes(t2)); + + } break; case PROJECTOR_TYPE_MLP: case PROJECTOR_TYPE_MLP_NORM: case PROJECTOR_TYPE_LDP: @@ -4998,6 +5381,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { case PROJECTOR_TYPE_PIXTRAL: case PROJECTOR_TYPE_LIGHTONOCR: return ctx->model.mm_2_w->ne[1]; + case PROJECTOR_TYPE_JINACLIP2: + return ctx->model.hparams.projection_dim; case PROJECTOR_TYPE_MLP_NORM: return ctx->model.mm_3_b->ne[0]; case PROJECTOR_TYPE_MINICPMV: @@ -5059,6 +5444,10 @@ bool clip_is_gemma3(const struct clip_ctx * ctx) { return ctx->proj_type() == PROJECTOR_TYPE_GEMMA3; } +bool clip_is_jinaclip2(const struct clip_ctx * ctx) { + return ctx->proj_type() == PROJECTOR_TYPE_JINACLIP2; +} + bool clip_has_vision_encoder(const struct clip_ctx * ctx) { return ctx->model.modality == CLIP_MODALITY_VISION; } diff --git a/tools/mtmd/clip.h b/tools/mtmd/clip.h index c1442afe6b252..69648a0b0047c 100644 --- a/tools/mtmd/clip.h +++ b/tools/mtmd/clip.h @@ -104,6 +104,7 @@ bool clip_is_glm(const struct clip_ctx * ctx); bool clip_is_qwen2vl(const struct clip_ctx * ctx); bool clip_is_llava(const struct clip_ctx * ctx); bool clip_is_gemma3(const struct clip_ctx * ctx); +bool clip_is_jinaclip2(const struct clip_ctx * ctx); bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec); diff --git a/tools/mtmd/mtmd-cli.cpp b/tools/mtmd/mtmd-cli.cpp index bd20aad947e92..d7038e2548a71 100644 --- a/tools/mtmd/mtmd-cli.cpp +++ b/tools/mtmd/mtmd-cli.cpp @@ -39,7 +39,12 @@ static void show_additional_info(int /*argc*/, char ** argv) { LOG( "Experimental CLI for multimodal\n\n" "Usage: %s [options] -m --mmproj --image --audio