@@ -2679,13 +2679,18 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
26792679 bool use_quantized_src1 = false ;
26802680 int64_t src1_padded_num_cols = 0 , src1_padded_row_size = 0 , src1_quantized_size = 0 ;
26812681 if (ggml_is_quantized (src0_1->type ) && src0_1->type == src0_2->type && src1->ne [1 ] == 1 && src1->ne [3 ] == 1 ) {
2682- src1_padded_num_cols = GGML_PAD (src1->ne [0 ], MATRIX_ROW_PADDING);
2683- src1_padded_row_size = src1_padded_num_cols/ggml_blck_size (GGML_TYPE_Q8_1)*ggml_type_size (GGML_TYPE_Q8_1);
2684- src1_quantized_size = src1_padded_row_size*src1->ne [2 ] + get_mmq_x_max_host (ggml_cuda_info ().devices [ctx.device ].cc )*sizeof (block_q8_1_mmq);
2685- src1_quantized.alloc (src1_quantized_size);
2686- use_quantized_src1 = true ;
2682+ if (ggml_cuda_should_use_mmq (src0_1->type , ggml_cuda_info ().devices [ctx.device ].cc , src1->ne [2 ])) {
2683+ src1_padded_num_cols = GGML_PAD (src1->ne [0 ], MATRIX_ROW_PADDING);
2684+ src1_padded_row_size = src1_padded_num_cols/ggml_blck_size (GGML_TYPE_Q8_1)*ggml_type_size (GGML_TYPE_Q8_1);
2685+ src1_quantized_size = src1_padded_row_size*src1->ne [2 ] + get_mmq_x_max_host (ggml_cuda_info ().devices [ctx.device ].cc )*sizeof (block_q8_1_mmq);
2686+ src1_quantized.alloc (src1_quantized_size);
2687+ use_quantized_src1 = true ;
2688+ }
2689+ }
2690+ ggml_cuda_pool_alloc<char > src1_contiguous (ctx.pool ());
2691+ if (!use_quantized_src1) {
2692+ src1_contiguous.alloc (sizeof (float )*ggml_nelements (src1));
26872693 }
2688- ggml_cuda_pool_alloc<char > src1_contiguous (ctx.pool (), sizeof (float )*ggml_nelements (src1));
26892694 ggml_cuda_pool_alloc<char > dst_up_contiguous (ctx.pool (), sizeof (float )*ggml_nelements (dst));
26902695 ggml_cuda_pool_alloc<char > dst_gate_contiguous (ctx.pool (), sizeof (float )*ggml_nelements (dst));
26912696 ggml_cuda_pool_alloc<char > final_dst_contiguous (ctx.pool ());
@@ -2728,6 +2733,7 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
27282733 k_copy_src_to_contiguous<<<grid_dims, block_dims, 0 , stream>>> (
27292734 src1_original, src1_contiguous.get (), dev_row_mapping.get () + mapping_offset, ne10, ne11, nb11, nb12);
27302735 CUDA_CHECK (cudaGetLastError ());
2736+ src1_row.data = src1_contiguous.get ();
27312737 }
27322738
27332739 src0_1_row.data = src0_1_original + i02*nb02;
0 commit comments