ggml : full ALiBi support (#7192)

* ggml : full ALiBi support

* ggml : update ggml_soft_max_ext() CUDA, SYCL

* ggml : ggml_flash_attn_ext() support ALiBi (CPU)

* ggml : ggml_flash_attn_ext() support ALiBi (Metal)

* ggml : fix warning

* ggml : ggml_flash_attn_ext() support ALiBi (CUDA)

ggml-ci

* ggml : fix assert message

* vulkan : add dev notes

* ggml : require mask when using ALiBi

ggml-ci

* convert : fix convert for refact models
This commit is contained in:
Georgi Gerganov 2024-05-11 10:32:41 +03:00 committed by GitHub
parent e849648888
commit 9cb317f77e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 350 additions and 825 deletions

View file

@ -356,7 +356,6 @@ template<typename T>
kernel void kernel_soft_max(
device const char * src0,
device const char * src1,
device const char * src2,
device char * dst,
constant int64_t & ne00,
constant int64_t & ne01,
@ -378,10 +377,9 @@ kernel void kernel_soft_max(
device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr;
device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr;
device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
float slope = 0.0f;
float slope = 1.0f;
// ALiBi
if (max_bias > 0.0f) {
@ -397,7 +395,7 @@ kernel void kernel_soft_max(
float lmax = -INFINITY;
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f));
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
}
// find the max value in the block
@ -422,7 +420,7 @@ kernel void kernel_soft_max(
// parallel sum
float lsum = 0.0f;
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)) - max_val);
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
lsum += exp_psrc0;
pdst[i00] = exp_psrc0;
}
@ -461,7 +459,6 @@ template<typename T>
kernel void kernel_soft_max_4(
device const char * src0,
device const char * src1,
device const char * src2,
device char * dst,
constant int64_t & ne00,
constant int64_t & ne01,
@ -483,10 +480,9 @@ kernel void kernel_soft_max_4(
device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr;
device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr;
device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
float slope = 0.0f;
float slope = 1.0f;
if (max_bias > 0.0f) {
const int64_t h = i02;
@ -501,7 +497,7 @@ kernel void kernel_soft_max_4(
float4 lmax4 = -INFINITY;
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)));
lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
}
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
@ -527,7 +523,7 @@ kernel void kernel_soft_max_4(
// parallel sum
float4 lsum4 = 0.0f;
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f))) - max_val);
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
lsum4 += exp_psrc4;
pdst4[i00] = exp_psrc4;
}
@ -1595,60 +1591,6 @@ kernel void kernel_mul_mv_f16_f32_l4(
}
}
kernel void kernel_alibi_f32(
device const float * src0,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant uint64_t & nb0,
constant uint64_t & nb1,
constant uint64_t & nb2,
constant uint64_t & nb3,
constant float & m0,
constant float & m1,
constant int & n_heads_log2_floor,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i03 = tgpig[2];
const int64_t i02 = tgpig[1];
const int64_t i01 = tgpig[0];
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
const int64_t i3 = n / (ne2*ne1*ne0);
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
//const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
const int64_t k = i3*ne3 + i2;
float m_k;
if (k < n_heads_log2_floor) {
m_k = pow(m0, k + 1);
} else {
m_k = pow(m1, 2 * (k - n_heads_log2_floor) + 1);
}
device char * dst_row = (device char *) dst + i3*nb3 + i2*nb2 + i1*nb1;
device const char * src_row = (device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01;
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
const float src_v = *(device float *)(src_row + i00*nb00);
device float * dst_v = (device float *)(dst_row + i00*nb0);
*dst_v = i00 * m_k + src_v;
}
}
static float rope_yarn_ramp(const float low, const float high, const int i0) {
const float y = (i0 / 2 - low) / max(0.001f, high - low);
return 1.0f - min(1.0f, max(0.0f, y));
@ -2116,13 +2058,16 @@ typedef void (flash_attn_ext_f16_t)(
constant uint64_t & nb11,
constant uint64_t & nb12,
constant uint64_t & nb13,
constant int64_t & ne31,
constant uint64_t & nb31,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant float & scale,
constant float & max_bias,
constant float & m0,
constant float & m1,
constant uint32_t & n_head_log2,
threadgroup half * shared,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
@ -2154,13 +2099,16 @@ kernel void kernel_flash_attn_ext_f16(
constant uint64_t & nb11,
constant uint64_t & nb12,
constant uint64_t & nb13,
constant int64_t & ne31,
constant uint64_t & nb31,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant float & scale,
constant float & max_bias,
constant float & m0,
constant float & m1,
constant uint32_t & n_head_log2,
threadgroup half * shared [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
@ -2257,6 +2205,19 @@ kernel void kernel_flash_attn_ext_f16(
// prepare diagonal scale matrix
simdgroup_float8x8 mscale(scale);
// prepare diagonal slope matrix
simdgroup_float8x8 mslope(1.0f);
// ALiBi
if (max_bias > 0.0f) {
const short h = iq2;
const float base = h < n_head_log2 ? m0 : m1;
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
mslope = simdgroup_float8x8(pow(base, exph));
}
// loop over the KV cache
// each simdgroup handles blocks of Q rows and C columns
for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
@ -2279,9 +2240,10 @@ kernel void kernel_flash_attn_ext_f16(
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
}
// mqk = mqk*scale + mask
// mqk = mqk*scale + mask*slope
simdgroup_half8x8 mm;
simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false);
simdgroup_multiply(mm, mslope, mm);
simdgroup_multiply_accumulate(mqk, mqk, mscale, mm);
simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
@ -2472,13 +2434,16 @@ kernel void kernel_flash_attn_ext_vec_f16(
constant uint64_t & nb11,
constant uint64_t & nb12,
constant uint64_t & nb13,
constant int64_t & ne31,
constant uint64_t & nb31,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant float & scale,
constant float & max_bias,
constant float & m0,
constant float & m1,
constant uint32_t & n_head_log2,
threadgroup half * shared [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
@ -2497,6 +2462,18 @@ kernel void kernel_flash_attn_ext_vec_f16(
const short T = D + 2*nsg*SH; // shared memory size per query in (half)
float slope = 1.0f;
// ALiBi
if (max_bias > 0.0f) {
const short h = iq2;
const float base = h < n_head_log2 ? m0 : m1;
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
slope = pow(base, exp);
}
//threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
@ -2603,10 +2580,10 @@ kernel void kernel_flash_attn_ext_vec_f16(
mqk += simd_shuffle_down(mqk, 2);
mqk += simd_shuffle_down(mqk, 1);
// mqk = mqk*scale + mask
// mqk = mqk*scale + mask*slope
if (tiisg == 0) {
float4 mm = (float4) mp4[ic/4 + cc];
mqk = mqk*scale + mm;
mqk = mqk*scale + mm*slope;
ss4[cc] = mqk;
}
@ -2840,7 +2817,8 @@ kernel void kernel_cpy_f32_f16(
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
dst_data[i00] = src[0];
// TODO: is there a better way to handle -INFINITY?
dst_data[i00] = src[0] == -INFINITY ? -MAXHALF : src[0];
}
}