@@ -64,7 +64,7 @@ struct FuseModule : public GGMLBlock {
6464 auto prompt_embeds0 = ggml_cont (ctx, ggml_permute (ctx, prompt_embeds, 2 , 0 , 1 , 3 ));
6565 auto id_embeds0 = ggml_cont (ctx, ggml_permute (ctx, id_embeds, 2 , 0 , 1 , 3 ));
6666 // concat is along dim 2
67- auto stacked_id_embeds = ggml_concat (ctx, prompt_embeds0, id_embeds0);
67+ auto stacked_id_embeds = ggml_concat (ctx, prompt_embeds0, id_embeds0, 2 );
6868 stacked_id_embeds = ggml_cont (ctx, ggml_permute (ctx, stacked_id_embeds, 1 , 2 , 0 , 3 ));
6969
7070 // stacked_id_embeds = mlp1.forward(ctx, stacked_id_embeds);
@@ -102,12 +102,12 @@ struct FuseModule : public GGMLBlock {
102102
103103 stacked_id_embeds = ggml_cont (ctx, ggml_permute (ctx, stacked_id_embeds, 0 , 2 , 1 , 3 ));
104104 if (left && right) {
105- stacked_id_embeds = ggml_concat (ctx, left, stacked_id_embeds);
106- stacked_id_embeds = ggml_concat (ctx, stacked_id_embeds, right);
105+ stacked_id_embeds = ggml_concat (ctx, left, stacked_id_embeds, 2 );
106+ stacked_id_embeds = ggml_concat (ctx, stacked_id_embeds, right, 2 );
107107 } else if (left) {
108- stacked_id_embeds = ggml_concat (ctx, left, stacked_id_embeds);
108+ stacked_id_embeds = ggml_concat (ctx, left, stacked_id_embeds, 2 );
109109 } else if (right) {
110- stacked_id_embeds = ggml_concat (ctx, stacked_id_embeds, right);
110+ stacked_id_embeds = ggml_concat (ctx, stacked_id_embeds, right, 2 );
111111 }
112112 stacked_id_embeds = ggml_cont (ctx, ggml_permute (ctx, stacked_id_embeds, 0 , 2 , 1 , 3 ));
113113 class_tokens_mask = ggml_cont (ctx, ggml_transpose (ctx, class_tokens_mask));
@@ -146,7 +146,7 @@ struct PhotoMakerIDEncoderBlock : public CLIPVisionModelProjection {
146146 id_embeds = ggml_cont (ctx, ggml_permute (ctx, id_embeds, 2 , 0 , 1 , 3 ));
147147 id_embeds_2 = ggml_cont (ctx, ggml_permute (ctx, id_embeds_2, 2 , 0 , 1 , 3 ));
148148
149- id_embeds = ggml_concat (ctx, id_embeds, id_embeds_2); // [batch_size, seq_length, 1, 2048] check whether concat at dim 2 is right
149+ id_embeds = ggml_concat (ctx, id_embeds, id_embeds_2, 2 ); // [batch_size, seq_length, 1, 2048] check whether concat at dim 2 is right
150150 id_embeds = ggml_cont (ctx, ggml_permute (ctx, id_embeds, 1 , 2 , 0 , 3 ));
151151
152152 struct ggml_tensor * updated_prompt_embeds = fuse_module->forward (ctx,
0 commit comments