CPU/CUDA: Gemma 2 FlashAttention support (#8542)

* CPU/CUDA: Gemma 2 FlashAttention support

* apply logit_softcap to scale in kernel

* disable logit softcapping tests on Metal

* remove metal check
This commit is contained in:
Johannes Gäßler 2024-08-24 21:34:59 +02:00 committed by GitHub
parent 8f824ffe8e
commit e11bd856d5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 319 additions and 79 deletions

View file

@ -1,7 +1,7 @@
#include "common.cuh"
#include "fattn-common.cuh"
template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V> // D == head size
template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(D, 1)
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
@ -17,6 +17,7 @@ static __global__ void flash_attn_vec_ext_f32(
const float m0,
const float m1,
const uint32_t n_head_log2,
const float logit_softcap,
const int ne00,
const int ne01,
const int ne02,
@ -40,6 +41,12 @@ static __global__ void flash_attn_vec_ext_f32(
const int ne1,
const int ne2,
const int ne3) {
// Skip unused kernel variants for faster compilation:
if (use_logit_softcap && !(D == 128 || D == 256)) {
NO_DEVICE_CODE;
return;
}
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
constexpr vec_dot_KQ_f32_t vec_dot_KQ = get_vec_dot_KQ_f32<D>(type_K);
@ -180,6 +187,11 @@ static __global__ void flash_attn_vec_ext_f32(
for (int j = 0; j < ncols; ++j) {
float sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_f2[j], Q_i32[j], Q_ds[j]);
sum = warp_reduce_sum(sum);
if (use_logit_softcap) {
sum = logit_softcap*tanhf(sum);
}
sum += mask ? slope*__half2float(maskh[j*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum);
@ -267,10 +279,10 @@ static __global__ void flash_attn_vec_ext_f32(
}
}
template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V>
template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
constexpr int nwarps = D/WARP_SIZE;
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks, type_K, type_V>;
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>;
constexpr bool need_f16_K = D != 128;
constexpr bool need_f16_V = D != 128 && D != 64;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
@ -278,44 +290,78 @@ void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx,
template <int D, ggml_type type_K, ggml_type type_V>
void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_tensor * Q = dst->src[0];
ggml_tensor * K = dst->src[1];
ggml_tensor * V = dst->src[2];
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];
GGML_ASSERT(K->type == type_K);
GGML_ASSERT(V->type == type_V);
float logit_softcap;
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
if (Q->ne[1] == 1) {
constexpr int cols_per_block = 1;
constexpr int parallel_blocks = 4;
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
}
return;
}
if (Q->ne[1] == 2) {
constexpr int cols_per_block = 2;
constexpr int parallel_blocks = 4;
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
}
return;
}
if (Q->ne[1] <= 4) {
constexpr int cols_per_block = 4;
constexpr int parallel_blocks = 4;
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
}
return;
}
if (Q->ne[1] <= 8) {
constexpr int cols_per_block = 8;
constexpr int parallel_blocks = 4;
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
}
return;
}
constexpr int cols_per_block = 8;
constexpr int parallel_blocks = 1;
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
}
}
#define DECL_FATTN_VEC_F32_CASE(D, type_K, type_V) \