diff --git a/ggml-cuda/fattn-common.cuh b/ggml-cuda/fattn-common.cuh index 7ef9b73dd..4bf03a49f 100644 --- a/ggml-cuda/fattn-common.cuh +++ b/ggml-cuda/fattn-common.cuh @@ -598,27 +598,6 @@ static __global__ void flash_attn_combine_results( dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator; } -// Aliases for FATTN_VEC_CASE macro: -static constexpr ggml_type ggml_type_q4_0 = GGML_TYPE_Q4_0; -static constexpr ggml_type ggml_type_q4_1 = GGML_TYPE_Q4_1; -static constexpr ggml_type ggml_type_q5_0 = GGML_TYPE_Q5_0; -static constexpr ggml_type ggml_type_q5_1 = GGML_TYPE_Q5_1; -static constexpr ggml_type ggml_type_q8_0 = GGML_TYPE_Q8_0; -static constexpr ggml_type ggml_type_f16 = GGML_TYPE_F16; - -typedef half f16; -typedef float f32; - -#define FATTN_VEC_CASE(type_VKQ, D, type_K, type_V) \ - if (Q->ne[0] == (D) && K->type == type_K && V->type == type_V) { \ - constexpr int nwarps = (D)/WARP_SIZE; \ - fattn_kernel_t fattn_kernel = flash_attn_vec_ext_##type_VKQ< \ - (D), cols_per_block, parallel_blocks, \ - type_K, type_V>; \ - launch_fattn<(D), parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block); \ - return; \ - } \ - static void on_no_fattn_vec_case(const int D) { if (D == 64) { fprintf(stderr, "Unsupported KV type combination for head_size 64.\n");