fix commented-out kernel variants

This commit is contained in:
Johannes Gäßler 2024-05-26 20:14:55 +02:00
parent 462add6a01
commit 3194a01058

View file

@ -382,37 +382,38 @@ void launch_fattn_vec_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
void ggml_cuda_flash_attn_ext_vec_f16_no_mma(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_flash_attn_ext_vec_f16_no_mma(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * KQV = dst; const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
const int32_t precision = KQV->op_params[2]; const int32_t precision = KQV->op_params[2];
GGML_ASSERT(precision == GGML_PREC_DEFAULT); GGML_ASSERT(precision == GGML_PREC_DEFAULT);
// if (Q->ne[1] == 1) { if (Q->ne[1] == 1) {
// constexpr int cols_per_block = 1; constexpr int cols_per_block = 1;
// constexpr int parallel_blocks = 4; constexpr int parallel_blocks = 4;
// launch_fattn_vec_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst); launch_fattn_vec_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
// return; return;
// } }
// if (Q->ne[1] == 2) { if (Q->ne[1] == 2) {
// constexpr int cols_per_block = 2; constexpr int cols_per_block = 2;
// constexpr int parallel_blocks = 4; constexpr int parallel_blocks = 4;
// launch_fattn_vec_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst); launch_fattn_vec_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
// return; return;
// } }
// if (Q->ne[1] <= 4) { if (Q->ne[1] <= 4) {
// constexpr int cols_per_block = 4; constexpr int cols_per_block = 4;
// constexpr int parallel_blocks = 4; constexpr int parallel_blocks = 4;
// launch_fattn_vec_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst); launch_fattn_vec_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
// return; return;
// } }
// if (Q->ne[1] <= 8) { if (Q->ne[1] <= 8) {
// constexpr int cols_per_block = 8; constexpr int cols_per_block = 8;
// constexpr int parallel_blocks = 4; constexpr int parallel_blocks = 4;
// launch_fattn_vec_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst); launch_fattn_vec_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
// return; return;
// } }
constexpr int cols_per_block = 8; constexpr int cols_per_block = 8;
constexpr int parallel_blocks = 1; constexpr int parallel_blocks = 1;