ggml : full ALiBi support (#7192)

* ggml : full ALiBi support

* ggml : update ggml_soft_max_ext() CUDA, SYCL

* ggml : ggml_flash_attn_ext() support ALiBi (CPU)

* ggml : ggml_flash_attn_ext() support ALiBi (Metal)

* ggml : fix warning

* ggml : ggml_flash_attn_ext() support ALiBi (CUDA)

ggml-ci

* ggml : fix assert message

* vulkan : add dev notes

* ggml : require mask when using ALiBi

ggml-ci

* convert : fix convert for refact models
This commit is contained in:
Georgi Gerganov 2024-05-11 10:32:41 +03:00 committed by GitHub
parent e849648888
commit 9cb317f77e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 350 additions and 825 deletions

View file

@ -23,6 +23,10 @@ static __global__ void flash_attn_vec_ext_f16(
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
const float scale,
const float max_bias,
const float m0,
const float m1,
const uint32_t n_head_log2,
const int ne00,
const int ne01,
const int ne02,
@ -58,6 +62,18 @@ static __global__ void flash_attn_vec_ext_f16(
const int stride_KV = nb11 / sizeof(half);
const int stride_KV2 = nb11 / sizeof(half2);
half slopeh = __float2half(1.0f);
// ALiBi
if (max_bias > 0.0f) {
const int h = blockIdx.y;
const float base = h < n_head_log2 ? m0 : m1;
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
slopeh = __float2half(powf(base, exph));
}
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
constexpr int nwarps = D / WARP_SIZE;
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
@ -141,7 +157,7 @@ static __global__ void flash_attn_vec_ext_f16(
for (int j = 0; j < ncols; ++j) {
sum2[j] = warp_reduce_sum(sum2[j]);
half sum = __low2half(sum2[j]) + __high2half(sum2[j]);
sum += mask ? maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
if (ncols == 1) {
kqmax_new = ggml_cuda_hmax(kqmax_new, sum);
@ -249,6 +265,10 @@ static __global__ void flash_attn_ext_f16(
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
const float scale,
const float max_bias,
const float m0,
const float m1,
const uint32_t n_head_log2,
const int ne00,
const int ne01,
const int ne02,
@ -305,6 +325,20 @@ static __global__ void flash_attn_ext_f16(
const int stride_Q = nb01 / sizeof(float);
const int stride_KV = nb11 / sizeof(half);
half slopeh = __float2half(1.0f);
half2 slope2 = make_half2(1.0f, 1.0f);
// ALiBi
if (max_bias > 0.0f) {
const int h = blockIdx.y;
const float base = h < n_head_log2 ? m0 : m1;
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
slopeh = __float2half(powf(base, exph));
slope2 = make_half2(slopeh, slopeh);
}
frag_b Q_b[D/16][ncols/frag_n];
// A single buffer for temporarily holding tiles of KQ and VKQ parts:
@ -421,7 +455,7 @@ static __global__ void flash_attn_ext_f16(
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
const int k = k0 + threadIdx.x;
KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]);
}
KQ_max_new = warp_reduce_max(KQ_max_new);
@ -464,7 +498,7 @@ static __global__ void flash_attn_ext_f16(
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
const int k = k0 + threadIdx.x;
KQ2_tmp[k0/WARP_SIZE] += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
KQ2_tmp[k0/WARP_SIZE] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]);
}
KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
@ -710,8 +744,17 @@ template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_vec_
const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
const int shmem = 0;
float scale;
memcpy(&scale, KQV->op_params, sizeof(float));
float scale = 1.0f;
float max_bias = 0.0f;
memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
const uint32_t n_head = Q->ne[2];
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks>
<<<blocks_num, block_dim, shmem, main_stream>>> (
@ -720,7 +763,7 @@ template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_vec_
(const char *) V->data,
mask ? ((const char *) mask->data) : nullptr,
parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
scale,
scale, max_bias, m0, m1, n_head_log2,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
@ -761,8 +804,17 @@ template <int D, int cols_per_block, int nwarps, int parallel_blocks, typename K
const dim3 blocks_num(parallel_blocks*(Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]);
const int shmem = 0;
float scale;
memcpy(&scale, KQV->op_params, sizeof(float));
float scale = 1.0f;
float max_bias = 0.0f;
memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
const uint32_t n_head = Q->ne[2];
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>
<<<blocks_num, block_dim, shmem, main_stream>>> (
@ -771,7 +823,7 @@ template <int D, int cols_per_block, int nwarps, int parallel_blocks, typename K
(const char *) V->data,
mask ? ((const char *) mask->data) : nullptr,
(parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
scale,
scale, max_bias, m0, m1, n_head_log2,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
@ -837,7 +889,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
const int32_t precision = KQV->op_params[1];
const int32_t precision = KQV->op_params[2];
if (!fp16_mma_available(cc)) {
GGML_ASSERT(precision == GGML_PREC_DEFAULT);