ggml : ggml_flash_attn_ext() support ALiBi (CUDA)

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-05-10 14:50:28 +03:00
parent f7055d31c5
commit 865af990cc
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 66 additions and 14 deletions

View file

@ -23,6 +23,10 @@ static __global__ void flash_attn_vec_ext_f16(
float * __restrict__ dst, float * __restrict__ dst,
float2 * __restrict__ dst_meta, float2 * __restrict__ dst_meta,
const float scale, const float scale,
const float max_bias,
const float m0,
const float m1,
const uint32_t n_head_log2,
const int ne00, const int ne00,
const int ne01, const int ne01,
const int ne02, const int ne02,
@ -58,6 +62,18 @@ static __global__ void flash_attn_vec_ext_f16(
const int stride_KV = nb11 / sizeof(half); const int stride_KV = nb11 / sizeof(half);
const int stride_KV2 = nb11 / sizeof(half2); const int stride_KV2 = nb11 / sizeof(half2);
half slopeh = __float2half(1.0f);
// ALiBi
if (max_bias > 0.0f) {
const int h = blockIdx.y;
const float base = h < n_head_log2 ? m0 : m1;
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
slopeh = __float2half(powf(base, exph));
}
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
constexpr int nwarps = D / WARP_SIZE; constexpr int nwarps = D / WARP_SIZE;
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
@ -141,7 +157,7 @@ static __global__ void flash_attn_vec_ext_f16(
for (int j = 0; j < ncols; ++j) { for (int j = 0; j < ncols; ++j) {
sum2[j] = warp_reduce_sum(sum2[j]); sum2[j] = warp_reduce_sum(sum2[j]);
half sum = __low2half(sum2[j]) + __high2half(sum2[j]); half sum = __low2half(sum2[j]) + __high2half(sum2[j]);
sum += mask ? maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f); sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
if (ncols == 1) { if (ncols == 1) {
kqmax_new = ggml_cuda_hmax(kqmax_new, sum); kqmax_new = ggml_cuda_hmax(kqmax_new, sum);
@ -249,6 +265,10 @@ static __global__ void flash_attn_ext_f16(
float * __restrict__ dst, float * __restrict__ dst,
float2 * __restrict__ dst_meta, float2 * __restrict__ dst_meta,
const float scale, const float scale,
const float max_bias,
const float m0,
const float m1,
const uint32_t n_head_log2,
const int ne00, const int ne00,
const int ne01, const int ne01,
const int ne02, const int ne02,
@ -305,6 +325,20 @@ static __global__ void flash_attn_ext_f16(
const int stride_Q = nb01 / sizeof(float); const int stride_Q = nb01 / sizeof(float);
const int stride_KV = nb11 / sizeof(half); const int stride_KV = nb11 / sizeof(half);
half slopeh = __float2half(1.0f);
half2 slope2 = make_half2(1.0f, 1.0f);
// ALiBi
if (max_bias > 0.0f) {
const int h = blockIdx.y;
const float base = h < n_head_log2 ? m0 : m1;
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
slopeh = __float2half(powf(base, exph));
slope2 = make_half2(slopeh, slopeh);
}
frag_b Q_b[D/16][ncols/frag_n]; frag_b Q_b[D/16][ncols/frag_n];
// A single buffer for temporarily holding tiles of KQ and VKQ parts: // A single buffer for temporarily holding tiles of KQ and VKQ parts:
@ -421,7 +455,7 @@ static __global__ void flash_attn_ext_f16(
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
const int k = k0 + threadIdx.x; const int k = k0 + threadIdx.x;
KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f; KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]); KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]);
} }
KQ_max_new = warp_reduce_max(KQ_max_new); KQ_max_new = warp_reduce_max(KQ_max_new);
@ -464,7 +498,7 @@ static __global__ void flash_attn_ext_f16(
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
const int k = k0 + threadIdx.x; const int k = k0 + threadIdx.x;
KQ2_tmp[k0/WARP_SIZE] += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); KQ2_tmp[k0/WARP_SIZE] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]); KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]);
} }
KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
@ -710,8 +744,17 @@ template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_vec_
const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]); const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
const int shmem = 0; const int shmem = 0;
float scale; float scale = 1.0f;
memcpy(&scale, KQV->op_params, sizeof(float)); float max_bias = 0.0f;
memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
const uint32_t n_head = Q->ne[2];
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks> flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks>
<<<blocks_num, block_dim, shmem, main_stream>>> ( <<<blocks_num, block_dim, shmem, main_stream>>> (
@ -720,7 +763,7 @@ template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_vec_
(const char *) V->data, (const char *) V->data,
mask ? ((const char *) mask->data) : nullptr, mask ? ((const char *) mask->data) : nullptr,
parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
scale, scale, max_bias, m0, m1, n_head_log2,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
@ -761,8 +804,17 @@ template <int D, int cols_per_block, int nwarps, int parallel_blocks, typename K
const dim3 blocks_num(parallel_blocks*(Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]); const dim3 blocks_num(parallel_blocks*(Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]);
const int shmem = 0; const int shmem = 0;
float scale; float scale = 1.0f;
memcpy(&scale, KQV->op_params, sizeof(float)); float max_bias = 0.0f;
memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
const uint32_t n_head = Q->ne[2];
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t> flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>
<<<blocks_num, block_dim, shmem, main_stream>>> ( <<<blocks_num, block_dim, shmem, main_stream>>> (
@ -771,7 +823,7 @@ template <int D, int cols_per_block, int nwarps, int parallel_blocks, typename K
(const char *) V->data, (const char *) V->data,
mask ? ((const char *) mask->data) : nullptr, mask ? ((const char *) mask->data) : nullptr,
(parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
scale, scale, max_bias, m0, m1, n_head_log2,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
@ -837,7 +889,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
const int32_t precision = KQV->op_params[1]; const int32_t precision = KQV->op_params[2];
if (!fp16_mma_available(cc)) { if (!fp16_mma_available(cc)) {
GGML_ASSERT(precision == GGML_PREC_DEFAULT); GGML_ASSERT(precision == GGML_PREC_DEFAULT);

View file

@ -30,9 +30,9 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst
const int h = rowx/nrows_y; // head index const int h = rowx/nrows_y; // head index
const float base = h < n_head_log2 ? m0 : m1; const float base = h < n_head_log2 ? m0 : m1;
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
slope = powf(base, exp); slope = powf(base, exph);
} }
extern __shared__ float data_soft_max_f32[]; extern __shared__ float data_soft_max_f32[];
@ -133,8 +133,8 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons
const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float); const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted."); static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
const uint32_t n_head_kv = nrows_x/nrows_y; const uint32_t n_head = nrows_x/nrows_y;
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv)); const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);