diff --git a/ggml-cuda/fattn-vec-f16.cu b/ggml-cuda/fattn-vec-f16.cu index fa6750364..71581509c 100644 --- a/ggml-cuda/fattn-vec-f16.cu +++ b/ggml-cuda/fattn-vec-f16.cu @@ -382,37 +382,38 @@ void launch_fattn_vec_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * void ggml_cuda_flash_attn_ext_vec_f16_no_mma(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; const int32_t precision = KQV->op_params[2]; GGML_ASSERT(precision == GGML_PREC_DEFAULT); - // if (Q->ne[1] == 1) { - // constexpr int cols_per_block = 1; - // constexpr int parallel_blocks = 4; - // launch_fattn_vec_f16_64_128(ctx, dst); - // return; - // } + if (Q->ne[1] == 1) { + constexpr int cols_per_block = 1; + constexpr int parallel_blocks = 4; + launch_fattn_vec_f16_64_128(ctx, dst); + return; + } - // if (Q->ne[1] == 2) { - // constexpr int cols_per_block = 2; - // constexpr int parallel_blocks = 4; - // launch_fattn_vec_f16_64_128(ctx, dst); - // return; - // } + if (Q->ne[1] == 2) { + constexpr int cols_per_block = 2; + constexpr int parallel_blocks = 4; + launch_fattn_vec_f16_64_128(ctx, dst); + return; + } - // if (Q->ne[1] <= 4) { - // constexpr int cols_per_block = 4; - // constexpr int parallel_blocks = 4; - // launch_fattn_vec_f16_64_128(ctx, dst); - // return; - // } + if (Q->ne[1] <= 4) { + constexpr int cols_per_block = 4; + constexpr int parallel_blocks = 4; + launch_fattn_vec_f16_64_128(ctx, dst); + return; + } - // if (Q->ne[1] <= 8) { - // constexpr int cols_per_block = 8; - // constexpr int parallel_blocks = 4; - // launch_fattn_vec_f16_64_128(ctx, dst); - // return; - // } + if (Q->ne[1] <= 8) { + constexpr int cols_per_block = 8; + constexpr int parallel_blocks = 4; + launch_fattn_vec_f16_64_128(ctx, dst); + return; + } constexpr int cols_per_block = 8; constexpr int parallel_blocks = 1;