@@ -835,6 +835,10 @@ void launch_fattn(
835835 GGML_ASSERT (Q->type == GGML_TYPE_F32);
836836 GGML_ASSERT (KQV->type == GGML_TYPE_F32);
837837
838+ GGML_ASSERT ( Q->nb [0 ] == ggml_element_size (Q));
839+ GGML_ASSERT ( K->nb [0 ] == ggml_element_size (K));
840+ GGML_ASSERT (!V || V->nb [0 ] == ggml_element_size (V));
841+
838842 GGML_ASSERT (!mask || mask->type == GGML_TYPE_F16);
839843 GGML_ASSERT (!mask || mask->ne [1 ] >= GGML_PAD (Q->ne [1 ], 16 ) &&
840844 " the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big" );
@@ -859,10 +863,10 @@ void launch_fattn(
859863 size_t nb12 = K->nb [2 ];
860864 size_t nb13 = K->nb [3 ];
861865
862- const char * V_data = (const char *) V->data ;
863- size_t nb21 = V->nb [1 ];
864- size_t nb22 = V->nb [2 ];
865- size_t nb23 = V->nb [3 ];
866+ const char * V_data = V ? (const char *) V->data : nullptr ;
867+ size_t nb21 = V ? V ->nb [1 ] : nb11 ;
868+ size_t nb22 = V ? V ->nb [2 ] : nb12 ;
869+ size_t nb23 = V ? V ->nb [3 ] : nb13 ;
866870
867871 if (need_f16_K && K->type != GGML_TYPE_F16) {
868872 K_f16.alloc (ggml_nelements (K));
@@ -878,7 +882,8 @@ void launch_fattn(
878882 nb13 = nb13*bs*sizeof (half)/ts;
879883 }
880884
881- if (need_f16_V && V->type != GGML_TYPE_F16) {
885+ if (V && need_f16_V && V->type != GGML_TYPE_F16) {
886+ // GGML_ASSERT(ggml_is_contiguous(V));
882887 V_f16.alloc (ggml_nelements (V));
883888 to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda (V->type );
884889 to_fp16 (V_data, V_f16.ptr , 1 , ggml_nelements (V), main_stream);
0 commit comments