remove obsolete code

This commit is contained in:
Johannes Gäßler 2024-05-31 20:12:00 +02:00
parent d8a0b87091
commit 05133280ab

View file

@ -598,27 +598,6 @@ static __global__ void flash_attn_combine_results(
dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator;
}
// Aliases for FATTN_VEC_CASE macro:
static constexpr ggml_type ggml_type_q4_0 = GGML_TYPE_Q4_0;
static constexpr ggml_type ggml_type_q4_1 = GGML_TYPE_Q4_1;
static constexpr ggml_type ggml_type_q5_0 = GGML_TYPE_Q5_0;
static constexpr ggml_type ggml_type_q5_1 = GGML_TYPE_Q5_1;
static constexpr ggml_type ggml_type_q8_0 = GGML_TYPE_Q8_0;
static constexpr ggml_type ggml_type_f16 = GGML_TYPE_F16;
typedef half f16;
typedef float f32;
#define FATTN_VEC_CASE(type_VKQ, D, type_K, type_V) \
if (Q->ne[0] == (D) && K->type == type_K && V->type == type_V) { \
constexpr int nwarps = (D)/WARP_SIZE; \
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_##type_VKQ< \
(D), cols_per_block, parallel_blocks, \
type_K, type_V>; \
launch_fattn<(D), parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block); \
return; \
} \
static void on_no_fattn_vec_case(const int D) {
if (D == 64) {
fprintf(stderr, "Unsupported KV type combination for head_size 64.\n");