ggml : ggml_flash_attn_ext() support ALiBi (CUDA)

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-05-10 14:50:28 +03:00
parent f7055d31c5
commit 865af990cc
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 66 additions and 14 deletions

View file

@ -30,9 +30,9 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst
const int h = rowx/nrows_y; // head index
const float base = h < n_head_log2 ? m0 : m1;
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
slope = powf(base, exp);
slope = powf(base, exph);
}
extern __shared__ float data_soft_max_f32[];
@ -133,8 +133,8 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons
const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
const uint32_t n_head_kv = nrows_x/nrows_y;
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
const uint32_t n_head = nrows_x/nrows_y;
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);