CPU/CUDA: Gemma 2 FlashAttention support (#8542)
* CPU/CUDA: Gemma 2 FlashAttention support * apply logit_softcap to scale in kernel * disable logit softcapping tests on Metal * remove metal check
This commit is contained in:
parent
8f824ffe8e
commit
e11bd856d5
12 changed files with 319 additions and 79 deletions
|
@ -22,6 +22,7 @@ typedef void (* fattn_kernel_t)(
|
|||
const float m0,
|
||||
const float m1,
|
||||
const uint32_t n_head_log2,
|
||||
const float logit_softcap,
|
||||
const int ne00,
|
||||
const int ne01,
|
||||
const int ne02,
|
||||
|
@ -657,11 +658,17 @@ void launch_fattn(
|
|||
const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
|
||||
const int shmem = 0;
|
||||
|
||||
float scale = 1.0f;
|
||||
float max_bias = 0.0f;
|
||||
float scale = 1.0f;
|
||||
float max_bias = 0.0f;
|
||||
float logit_softcap = 0.0f;
|
||||
|
||||
memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
|
||||
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
|
||||
memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
|
||||
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
|
||||
memcpy(&logit_softcap, (float *) KQV->op_params + 2, sizeof(float));
|
||||
|
||||
if (logit_softcap != 0.0f) {
|
||||
scale /= logit_softcap;
|
||||
}
|
||||
|
||||
const uint32_t n_head = Q->ne[2];
|
||||
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
||||
|
@ -675,7 +682,7 @@ void launch_fattn(
|
|||
V_data,
|
||||
mask ? ((const char *) mask->data) : nullptr,
|
||||
(parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
|
||||
scale, max_bias, m0, m1, n_head_log2,
|
||||
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
|
||||
#define FATTN_KQ_STRIDE_TILE_F16 64
|
||||
|
||||
template<int D, int ncols, int nwarps, int parallel_blocks> // D == head size
|
||||
template<int D, int ncols, int nwarps, int parallel_blocks, bool use_logit_softcap> // D == head size
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
|
@ -20,6 +20,7 @@ static __global__ void flash_attn_tile_ext_f16(
|
|||
const float m0,
|
||||
const float m1,
|
||||
const uint32_t n_head_log2,
|
||||
const float logit_softcap,
|
||||
const int ne00,
|
||||
const int ne01,
|
||||
const int ne02,
|
||||
|
@ -44,6 +45,12 @@ static __global__ void flash_attn_tile_ext_f16(
|
|||
const int ne2,
|
||||
const int ne3) {
|
||||
#ifdef FP16_AVAILABLE
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
|
||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||
|
||||
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
|
||||
|
@ -154,7 +161,13 @@ static __global__ void flash_attn_tile_ext_f16(
|
|||
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
|
||||
const int j_KQ = j_KQ_0 + threadIdx.y;
|
||||
|
||||
half sum = __low2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]) + __high2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
|
||||
half sum;
|
||||
if (use_logit_softcap) {
|
||||
const float2 tmp = __half22float2(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
|
||||
sum = logit_softcap * tanhf(tmp.x + tmp.y);
|
||||
} else {
|
||||
sum = __low2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]) + __high2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
|
||||
}
|
||||
sum += mask ? slopeh*maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
|
||||
|
||||
kqmax_new[j_KQ_0/nwarps] = ggml_cuda_hmax(kqmax_new[j_KQ_0/nwarps], sum);
|
||||
|
@ -270,20 +283,20 @@ static __global__ void flash_attn_tile_ext_f16(
|
|||
#endif // FP16_AVAILABLE
|
||||
}
|
||||
|
||||
template <int cols_per_block, int parallel_blocks>
|
||||
template <int cols_per_block, int parallel_blocks, bool use_logit_softcap>
|
||||
void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
switch (Q->ne[0]) {
|
||||
case 64: {
|
||||
constexpr int D = 64;
|
||||
constexpr int nwarps = 8;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
|
||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||
} break;
|
||||
case 128: {
|
||||
constexpr int D = 128;
|
||||
constexpr int nwarps = 8;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
|
||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||
} break;
|
||||
default: {
|
||||
|
@ -296,24 +309,45 @@ void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_ten
|
|||
const ggml_tensor * KQV = dst;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
|
||||
const int32_t precision = KQV->op_params[2];
|
||||
const int32_t precision = KQV->op_params[3];
|
||||
GGML_ASSERT(precision == GGML_PREC_DEFAULT);
|
||||
|
||||
float logit_softcap;
|
||||
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
||||
|
||||
if (Q->ne[1] <= 16) {
|
||||
constexpr int cols_per_block = 16;
|
||||
constexpr int parallel_blocks = 4;
|
||||
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] <= 32) {
|
||||
constexpr int cols_per_block = 32;
|
||||
constexpr int parallel_blocks = 4;
|
||||
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr int cols_per_block = 32;
|
||||
constexpr int parallel_blocks = 1;
|
||||
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
|
||||
#define FATTN_KQ_STRIDE_TILE_F32 32
|
||||
|
||||
template<int D, int ncols, int nwarps, int parallel_blocks> // D == head size
|
||||
template<int D, int ncols, int nwarps, int parallel_blocks, bool use_logit_softcap> // D == head size
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
|
@ -20,6 +20,7 @@ static __global__ void flash_attn_tile_ext_f32(
|
|||
const float m0,
|
||||
const float m1,
|
||||
const uint32_t n_head_log2,
|
||||
const float logit_softcap,
|
||||
const int ne00,
|
||||
const int ne01,
|
||||
const int ne02,
|
||||
|
@ -43,6 +44,12 @@ static __global__ void flash_attn_tile_ext_f32(
|
|||
const int ne1,
|
||||
const int ne2,
|
||||
const int ne3) {
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
|
||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||
|
||||
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
|
||||
|
@ -151,6 +158,10 @@ static __global__ void flash_attn_tile_ext_f32(
|
|||
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
|
||||
const int j_KQ = j_KQ_0 + threadIdx.y;
|
||||
|
||||
if (use_logit_softcap) {
|
||||
sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] = logit_softcap * tanhf(sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
|
||||
}
|
||||
|
||||
sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
|
||||
|
||||
kqmax_new[j_KQ_0/nwarps] = fmaxf(kqmax_new[j_KQ_0/nwarps], sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
|
||||
|
@ -267,20 +278,20 @@ static __global__ void flash_attn_tile_ext_f32(
|
|||
}
|
||||
}
|
||||
|
||||
template <int cols_per_block, int parallel_blocks>
|
||||
template <int cols_per_block, int parallel_blocks, bool use_logit_softcap>
|
||||
void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
switch (Q->ne[0]) {
|
||||
case 64: {
|
||||
constexpr int D = 64;
|
||||
constexpr int nwarps = 8;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks>;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
|
||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||
} break;
|
||||
case 128: {
|
||||
constexpr int D = 128;
|
||||
constexpr int nwarps = 8;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks>;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
|
||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||
} break;
|
||||
default: {
|
||||
|
@ -290,23 +301,45 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
|
|||
}
|
||||
|
||||
void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * KQV = dst;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
|
||||
float logit_softcap;
|
||||
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
||||
|
||||
if (Q->ne[1] <= 16) {
|
||||
constexpr int cols_per_block = 16;
|
||||
constexpr int parallel_blocks = 4;
|
||||
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] <= 32) {
|
||||
constexpr int cols_per_block = 32;
|
||||
constexpr int parallel_blocks = 4;
|
||||
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr int cols_per_block = 32;
|
||||
constexpr int parallel_blocks = 1;
|
||||
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
#include "common.cuh"
|
||||
#include "fattn-common.cuh"
|
||||
|
||||
template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V> // D == head size
|
||||
template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
__launch_bounds__(D, 1)
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
|
@ -17,6 +17,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
|||
const float m0,
|
||||
const float m1,
|
||||
const uint32_t n_head_log2,
|
||||
const float logit_softcap,
|
||||
const int ne00,
|
||||
const int ne01,
|
||||
const int ne02,
|
||||
|
@ -41,6 +42,12 @@ static __global__ void flash_attn_vec_ext_f16(
|
|||
const int ne2,
|
||||
const int ne3) {
|
||||
#ifdef FP16_AVAILABLE
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
|
||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||
|
||||
constexpr vec_dot_KQ_f16_t vec_dot_KQ = get_vec_dot_KQ_f16<D>(type_K);
|
||||
|
@ -190,6 +197,11 @@ static __global__ void flash_attn_vec_ext_f16(
|
|||
for (int j = 0; j < ncols; ++j) {
|
||||
half sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_h2[j], Q_i32[j], Q_ds[j]);
|
||||
sum = warp_reduce_sum(sum);
|
||||
|
||||
if (use_logit_softcap) {
|
||||
sum = logit_softcap*tanhf(sum);
|
||||
}
|
||||
|
||||
sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
|
||||
|
||||
if (ncols == 1) {
|
||||
|
@ -286,10 +298,10 @@ static __global__ void flash_attn_vec_ext_f16(
|
|||
#endif // FP16_AVAILABLE
|
||||
}
|
||||
|
||||
template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V>
|
||||
template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
|
||||
void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
constexpr int nwarps = D/WARP_SIZE;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks, type_K, type_V>;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>;
|
||||
constexpr bool need_f16_K = D != 128;
|
||||
constexpr bool need_f16_V = D != 128 && D != 64;
|
||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
|
||||
|
@ -297,48 +309,81 @@ void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx,
|
|||
|
||||
template <int D, ggml_type type_K, ggml_type type_V>
|
||||
void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
ggml_tensor * KQV = dst;
|
||||
ggml_tensor * Q = dst->src[0];
|
||||
ggml_tensor * K = dst->src[1];
|
||||
ggml_tensor * V = dst->src[2];
|
||||
const ggml_tensor * KQV = dst;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
|
||||
const int32_t precision = KQV->op_params[2];
|
||||
const int32_t precision = KQV->op_params[3];
|
||||
GGML_ASSERT(precision == GGML_PREC_DEFAULT);
|
||||
|
||||
GGML_ASSERT(K->type == type_K);
|
||||
GGML_ASSERT(V->type == type_V);
|
||||
|
||||
float logit_softcap;
|
||||
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
||||
|
||||
if (Q->ne[1] == 1) {
|
||||
constexpr int cols_per_block = 1;
|
||||
constexpr int parallel_blocks = 4;
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] == 2) {
|
||||
constexpr int cols_per_block = 2;
|
||||
constexpr int parallel_blocks = 4;
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] <= 4) {
|
||||
constexpr int cols_per_block = 4;
|
||||
constexpr int parallel_blocks = 4;
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] <= 8) {
|
||||
constexpr int cols_per_block = 8;
|
||||
constexpr int parallel_blocks = 4;
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr int cols_per_block = 8;
|
||||
constexpr int parallel_blocks = 1;
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
}
|
||||
}
|
||||
|
||||
#define DECL_FATTN_VEC_F16_CASE(D, type_K, type_V) \
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
#include "common.cuh"
|
||||
#include "fattn-common.cuh"
|
||||
|
||||
template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V> // D == head size
|
||||
template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
__launch_bounds__(D, 1)
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
|
@ -17,6 +17,7 @@ static __global__ void flash_attn_vec_ext_f32(
|
|||
const float m0,
|
||||
const float m1,
|
||||
const uint32_t n_head_log2,
|
||||
const float logit_softcap,
|
||||
const int ne00,
|
||||
const int ne01,
|
||||
const int ne02,
|
||||
|
@ -40,6 +41,12 @@ static __global__ void flash_attn_vec_ext_f32(
|
|||
const int ne1,
|
||||
const int ne2,
|
||||
const int ne3) {
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
|
||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||
|
||||
constexpr vec_dot_KQ_f32_t vec_dot_KQ = get_vec_dot_KQ_f32<D>(type_K);
|
||||
|
@ -180,6 +187,11 @@ static __global__ void flash_attn_vec_ext_f32(
|
|||
for (int j = 0; j < ncols; ++j) {
|
||||
float sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_f2[j], Q_i32[j], Q_ds[j]);
|
||||
sum = warp_reduce_sum(sum);
|
||||
|
||||
if (use_logit_softcap) {
|
||||
sum = logit_softcap*tanhf(sum);
|
||||
}
|
||||
|
||||
sum += mask ? slope*__half2float(maskh[j*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
|
||||
|
||||
kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum);
|
||||
|
@ -267,10 +279,10 @@ static __global__ void flash_attn_vec_ext_f32(
|
|||
}
|
||||
}
|
||||
|
||||
template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V>
|
||||
template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
|
||||
void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
constexpr int nwarps = D/WARP_SIZE;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks, type_K, type_V>;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>;
|
||||
constexpr bool need_f16_K = D != 128;
|
||||
constexpr bool need_f16_V = D != 128 && D != 64;
|
||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
|
||||
|
@ -278,44 +290,78 @@ void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx,
|
|||
|
||||
template <int D, ggml_type type_K, ggml_type type_V>
|
||||
void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
ggml_tensor * Q = dst->src[0];
|
||||
ggml_tensor * K = dst->src[1];
|
||||
ggml_tensor * V = dst->src[2];
|
||||
const ggml_tensor * KQV = dst;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
|
||||
GGML_ASSERT(K->type == type_K);
|
||||
GGML_ASSERT(V->type == type_V);
|
||||
|
||||
float logit_softcap;
|
||||
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
||||
|
||||
if (Q->ne[1] == 1) {
|
||||
constexpr int cols_per_block = 1;
|
||||
constexpr int parallel_blocks = 4;
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] == 2) {
|
||||
constexpr int cols_per_block = 2;
|
||||
constexpr int parallel_blocks = 4;
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] <= 4) {
|
||||
constexpr int cols_per_block = 4;
|
||||
constexpr int parallel_blocks = 4;
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] <= 8) {
|
||||
constexpr int cols_per_block = 8;
|
||||
constexpr int parallel_blocks = 4;
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr int cols_per_block = 8;
|
||||
constexpr int parallel_blocks = 1;
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||
}
|
||||
}
|
||||
|
||||
#define DECL_FATTN_VEC_F32_CASE(D, type_K, type_V) \
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
#endif // FP16_MMA_AVAILABLE
|
||||
|
||||
// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
|
||||
template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t>
|
||||
template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t, bool use_logit_softcap>
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
|
@ -22,6 +22,7 @@ static __global__ void flash_attn_ext_f16(
|
|||
const float m0,
|
||||
const float m1,
|
||||
const uint32_t n_head_log2,
|
||||
const float logit_softcap,
|
||||
const int ne00,
|
||||
const int ne01,
|
||||
const int ne02,
|
||||
|
@ -46,6 +47,12 @@ static __global__ void flash_attn_ext_f16(
|
|||
const int ne2,
|
||||
const int ne3) {
|
||||
#ifdef FP16_MMA_AVAILABLE
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
|
||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||
|
||||
const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
|
||||
|
@ -85,6 +92,8 @@ static __global__ void flash_attn_ext_f16(
|
|||
const half slopeh = __float2half(slopef);
|
||||
const half2 slope2 = make_half2(slopef, slopef);
|
||||
|
||||
const half2 logit_softcap_2 = make_half2(logit_softcap, logit_softcap);
|
||||
|
||||
frag_b Q_b[D/16][ncols/frag_n];
|
||||
|
||||
// A single buffer for temporarily holding tiles of KQ and VKQ parts:
|
||||
|
@ -194,6 +203,10 @@ static __global__ void flash_attn_ext_f16(
|
|||
const int k = k0 + threadIdx.x;
|
||||
|
||||
KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k];
|
||||
|
||||
if (use_logit_softcap) {
|
||||
KQ_f_tmp[k0/WARP_SIZE] = logit_softcap*tanhf(KQ_f_tmp[k0/WARP_SIZE]);
|
||||
}
|
||||
}
|
||||
|
||||
float KQ_max_new = KQ_max_f[j0/nwarps];
|
||||
|
@ -237,6 +250,15 @@ static __global__ void flash_attn_ext_f16(
|
|||
const int k = k0 + threadIdx.x;
|
||||
|
||||
KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k];
|
||||
|
||||
if (use_logit_softcap) {
|
||||
// There is no dedicated tangens hyperbolicus function for half2.
|
||||
KQ2_tmp[k0/WARP_SIZE] = h2exp(KQ2_tmp[k0/WARP_SIZE]*make_half2(2.0f, 2.0f));
|
||||
KQ2_tmp[k0/WARP_SIZE] = (KQ2_tmp[k0/WARP_SIZE] - make_half2(1.0f, 1.0f))
|
||||
/(KQ2_tmp[k0/WARP_SIZE] + make_half2(1.0f, 1.0f));
|
||||
|
||||
KQ2_tmp[k0/WARP_SIZE] *= logit_softcap_2;
|
||||
}
|
||||
}
|
||||
|
||||
half2 KQ_max_new = KQ_max_h2[j0/nwarps];
|
||||
|
@ -427,7 +449,8 @@ static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed.");
|
|||
|
||||
template <int D, int cols_per_block, typename KQ_acc_t>
|
||||
void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * KQV = dst;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
|
||||
constexpr int nwarps = 4;
|
||||
|
||||
|
@ -435,20 +458,50 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
|
|||
const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
|
||||
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
|
||||
|
||||
float logit_softcap;
|
||||
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
||||
|
||||
if (4*blocks_num_pb1 < 2*nsm) {
|
||||
constexpr int parallel_blocks = 4;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
|
||||
fattn_kernel_t fattn_kernel;
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
fattn_kernel = flash_attn_ext_f16<
|
||||
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
fattn_kernel = flash_attn_ext_f16<
|
||||
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
||||
}
|
||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||
return;
|
||||
}
|
||||
if (2*blocks_num_pb1 < 2*nsm) {
|
||||
constexpr int parallel_blocks = 2;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
|
||||
fattn_kernel_t fattn_kernel;
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
fattn_kernel = flash_attn_ext_f16<
|
||||
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
fattn_kernel = flash_attn_ext_f16<
|
||||
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
||||
}
|
||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||
return;
|
||||
}
|
||||
constexpr int parallel_blocks = 1;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
|
||||
fattn_kernel_t fattn_kernel;
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
fattn_kernel = flash_attn_ext_f16<
|
||||
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
fattn_kernel = flash_attn_ext_f16<
|
||||
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
||||
}
|
||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||
}
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g
|
|||
const ggml_tensor * KQV = dst;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
|
||||
const int32_t precision = KQV->op_params[2];
|
||||
const int32_t precision = KQV->op_params[3];
|
||||
|
||||
if (precision != GGML_PREC_DEFAULT) {
|
||||
if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
|
||||
|
@ -301,7 +301,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|||
|
||||
ggml_cuda_set_device(ctx.device);
|
||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||
const int32_t precision = KQV->op_params[2];
|
||||
const int32_t precision = KQV->op_params[3];
|
||||
|
||||
// On AMD the tile kernels perform poorly, use the vec kernel instead:
|
||||
if (cc >= CC_OFFSET_AMD) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue