@@ -678,17 +678,25 @@ void launch_fattn(
678678) {
679679 constexpr int ncols = ncols1 * ncols2;
680680
681+ const bool is_mla = DV == 512 ; // TODO better parameterization
682+
681683 const ggml_tensor * Q = dst->src [0 ];
682684 const ggml_tensor * K = dst->src [1 ];
683685 const ggml_tensor * V = dst->src [2 ];
684686
687+ GGML_ASSERT (V || is_mla);
688+
685689 const ggml_tensor * mask = dst->src [3 ];
686690
687691 ggml_tensor * KQV = dst;
688692
689693 GGML_ASSERT (Q->type == GGML_TYPE_F32);
690694 GGML_ASSERT (KQV->type == GGML_TYPE_F32);
691695
696+ GGML_ASSERT ( Q->nb [0 ] == ggml_element_size (Q));
697+ GGML_ASSERT ( K->nb [0 ] == ggml_element_size (K));
698+ GGML_ASSERT (!V || V->nb [0 ] == ggml_element_size (V));
699+
692700 GGML_ASSERT (!mask || mask->type == GGML_TYPE_F16);
693701 GGML_ASSERT (!mask || mask->ne [1 ] >= GGML_PAD (Q->ne [1 ], 16 ) &&
694702 " the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big" );
@@ -713,10 +721,10 @@ void launch_fattn(
713721 size_t nb12 = K->nb [2 ];
714722 size_t nb13 = K->nb [3 ];
715723
716- const char * V_data = (const char *) V->data ;
717- size_t nb21 = V->nb [1 ];
718- size_t nb22 = V->nb [2 ];
719- size_t nb23 = V->nb [3 ];
724+ const char * V_data = V ? (const char *) V->data : nullptr ;
725+ size_t nb21 = V ? V ->nb [1 ] : nb11 ;
726+ size_t nb22 = V ? V ->nb [2 ] : nb12 ;
727+ size_t nb23 = V ? V ->nb [3 ] : nb13 ;
720728
721729 if (need_f16_K && K->type != GGML_TYPE_F16) {
722730 GGML_ASSERT (ggml_is_contiguously_allocated (K));
@@ -733,7 +741,7 @@ void launch_fattn(
733741 nb13 = nb13*bs*sizeof (half)/ts;
734742 }
735743
736- if (need_f16_V && V->type != GGML_TYPE_F16) {
744+ if (V && need_f16_V && V->type != GGML_TYPE_F16) {
737745 GGML_ASSERT (ggml_is_contiguously_allocated (V));
738746 V_f16.alloc (ggml_nelements (V));
739747 to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda (V->type );
0 commit comments