CUDA: deduplicate FlashAttention code (#7352)

This commit is contained in:
Johannes Gäßler 2024-05-18 12:36:25 +02:00 committed by GitHub
parent cb42c29427
commit 133d99c599
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 316 additions and 654 deletions

View file

@ -477,6 +477,17 @@ static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -
typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
static __device__ __forceinline__ float get_alibi_slope(
const float max_bias, const uint32_t h, const uint32_t n_head_log2, const float m0, const float m1
) {
if (max_bias <= 0.0f) {
return 1.0f;
}
const float base = h < n_head_log2 ? m0 : m1;
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
return powf(base, exph);
}
//////////////////////