@@ -545,9 +545,9 @@ class CLIPEmbeddings : public GGMLBlock {
545545 int64_t vocab_size;
546546 int64_t num_positions;
547547
548- void init_params (struct ggml_context * ctx, std::map<std::string, enum ggml_type> & tensor_types, const std::string prefix = " " ) {
549- enum ggml_type token_wtype = GGML_TYPE_F32; // (tensor_types.find(prefix + "token_embedding.weight") != tensor_types.end()) ? tensor_types[prefix + "token_embedding.weight"] : GGML_TYPE_F32;
550- enum ggml_type position_wtype = GGML_TYPE_F32; // (tensor_types.find(prefix + "position_embedding.weight") != tensor_types.end()) ? tensor_types[prefix + "position_embedding.weight"] : GGML_TYPE_F32;
548+ void init_params (struct ggml_context * ctx, const String2GGMLType & tensor_types = {} , const std::string prefix = " " ) {
549+ enum ggml_type token_wtype = GGML_TYPE_F32;
550+ enum ggml_type position_wtype = GGML_TYPE_F32;
551551
552552 params[" token_embedding.weight" ] = ggml_new_tensor_2d (ctx, token_wtype, embed_dim, vocab_size);
553553 params[" position_embedding.weight" ] = ggml_new_tensor_2d (ctx, position_wtype, embed_dim, num_positions);
@@ -594,10 +594,10 @@ class CLIPVisionEmbeddings : public GGMLBlock {
594594 int64_t image_size;
595595 int64_t num_patches;
596596 int64_t num_positions;
597- void init_params (struct ggml_context * ctx, std::map<std::string, enum ggml_type> & tensor_types, const std::string prefix = " " ) {
598- enum ggml_type patch_wtype = GGML_TYPE_F16; // tensor_types.find(prefix + "patch_embedding.weight") != tensor_types.end() ? tensor_types[prefix + "patch_embedding.weight"] : GGML_TYPE_F16;
599- enum ggml_type class_wtype = GGML_TYPE_F32; // tensor_types.find(prefix + "class_embedding") != tensor_types.end() ? tensor_types[prefix + "class_embedding"] : GGML_TYPE_F32;
600- enum ggml_type position_wtype = GGML_TYPE_F32; // tensor_types.find(prefix + "position_embedding.weight") != tensor_types.end() ? tensor_types[prefix + "position_embedding.weight"] : GGML_TYPE_F32;
597+ void init_params (struct ggml_context * ctx, const String2GGMLType & tensor_types = {} , const std::string prefix = " " ) {
598+ enum ggml_type patch_wtype = GGML_TYPE_F16;
599+ enum ggml_type class_wtype = GGML_TYPE_F32;
600+ enum ggml_type position_wtype = GGML_TYPE_F32;
601601
602602 params[" patch_embedding.weight" ] = ggml_new_tensor_4d (ctx, patch_wtype, patch_size, patch_size, num_channels, embed_dim);
603603 params[" class_embedding" ] = ggml_new_tensor_1d (ctx, class_wtype, embed_dim);
@@ -657,9 +657,9 @@ enum CLIPVersion {
657657
658658class CLIPTextModel : public GGMLBlock {
659659protected:
660- void init_params (struct ggml_context * ctx, std::map<std::string, enum ggml_type> & tensor_types, const std::string prefix = " " ) {
660+ void init_params (struct ggml_context * ctx, const String2GGMLType & tensor_types = {} , const std::string prefix = " " ) {
661661 if (version == OPEN_CLIP_VIT_BIGG_14) {
662- enum ggml_type wtype = GGML_TYPE_F32; // tensor_types.find(prefix + "text_projection") != tensor_types.end() ? tensor_types[prefix + "text_projection"] : GGML_TYPE_F32;
662+ enum ggml_type wtype = GGML_TYPE_F32;
663663 params[" text_projection" ] = ggml_new_tensor_2d (ctx, wtype, projection_dim, hidden_size);
664664 }
665665 }
@@ -805,8 +805,8 @@ class CLIPProjection : public UnaryBlock {
805805 int64_t out_features;
806806 bool transpose_weight;
807807
808- void init_params (struct ggml_context * ctx, std::map<std::string, enum ggml_type> & tensor_types, const std::string prefix = " " ) {
809- enum ggml_type wtype = tensor_types. find (prefix + " weight" ) != tensor_types. end () ? tensor_types[prefix + " weight " ] : GGML_TYPE_F32;
808+ void init_params (struct ggml_context * ctx, const String2GGMLType & tensor_types = {} , const std::string prefix = " " ) {
809+ enum ggml_type wtype = get_type (prefix + " weight" , tensor_types, GGML_TYPE_F32) ;
810810 if (transpose_weight) {
811811 params[" weight" ] = ggml_new_tensor_2d (ctx, wtype, out_features, in_features);
812812 } else {
@@ -868,7 +868,7 @@ struct CLIPTextModelRunner : public GGMLRunner {
868868 CLIPTextModel model;
869869
870870 CLIPTextModelRunner (ggml_backend_t backend,
871- std::map<std::string, enum ggml_type> & tensor_types,
871+ const String2GGMLType & tensor_types,
872872 const std::string prefix,
873873 CLIPVersion version = OPENAI_CLIP_VIT_L_14,
874874 bool with_final_ln = true ,
0 commit comments