ggml : ggml_flash_attn_ext() support ALiBi (Metal)
This commit is contained in:
parent
166e60bf9b
commit
97c27f59f6
4 changed files with 105 additions and 47 deletions
|
@ -2058,13 +2058,16 @@ typedef void (flash_attn_ext_f16_t)(
|
|||
constant uint64_t & nb11,
|
||||
constant uint64_t & nb12,
|
||||
constant uint64_t & nb13,
|
||||
constant int64_t & ne31,
|
||||
constant uint64_t & nb31,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant int64_t & ne2,
|
||||
constant int64_t & ne3,
|
||||
constant float & scale,
|
||||
constant float & max_bias,
|
||||
constant float & m0,
|
||||
constant float & m1,
|
||||
constant uint32_t & n_head_log2,
|
||||
threadgroup half * shared,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
|
@ -2096,13 +2099,16 @@ kernel void kernel_flash_attn_ext_f16(
|
|||
constant uint64_t & nb11,
|
||||
constant uint64_t & nb12,
|
||||
constant uint64_t & nb13,
|
||||
constant int64_t & ne31,
|
||||
constant uint64_t & nb31,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant int64_t & ne2,
|
||||
constant int64_t & ne3,
|
||||
constant float & scale,
|
||||
constant float & max_bias,
|
||||
constant float & m0,
|
||||
constant float & m1,
|
||||
constant uint32_t & n_head_log2,
|
||||
threadgroup half * shared [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
|
@ -2199,6 +2205,19 @@ kernel void kernel_flash_attn_ext_f16(
|
|||
// prepare diagonal scale matrix
|
||||
simdgroup_float8x8 mscale(scale);
|
||||
|
||||
// prepare diagonal slope matrix
|
||||
simdgroup_float8x8 mslope(1.0f);
|
||||
|
||||
// ALiBi
|
||||
if (max_bias > 0.0f) {
|
||||
const short h = iq2;
|
||||
|
||||
const float base = h < n_head_log2 ? m0 : m1;
|
||||
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
||||
|
||||
mslope = simdgroup_float8x8(pow(base, exph));
|
||||
}
|
||||
|
||||
// loop over the KV cache
|
||||
// each simdgroup handles blocks of Q rows and C columns
|
||||
for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
|
||||
|
@ -2221,9 +2240,10 @@ kernel void kernel_flash_attn_ext_f16(
|
|||
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
|
||||
}
|
||||
|
||||
// mqk = mqk*scale + mask
|
||||
// mqk = mqk*scale + mask*slope
|
||||
simdgroup_half8x8 mm;
|
||||
simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false);
|
||||
simdgroup_multiply(mm, mslope, mm);
|
||||
simdgroup_multiply_accumulate(mqk, mqk, mscale, mm);
|
||||
|
||||
simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
|
||||
|
@ -2414,13 +2434,16 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|||
constant uint64_t & nb11,
|
||||
constant uint64_t & nb12,
|
||||
constant uint64_t & nb13,
|
||||
constant int64_t & ne31,
|
||||
constant uint64_t & nb31,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant int64_t & ne2,
|
||||
constant int64_t & ne3,
|
||||
constant float & scale,
|
||||
constant float & max_bias,
|
||||
constant float & m0,
|
||||
constant float & m1,
|
||||
constant uint32_t & n_head_log2,
|
||||
threadgroup half * shared [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
|
@ -2439,6 +2462,18 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|||
|
||||
const short T = D + 2*nsg*SH; // shared memory size per query in (half)
|
||||
|
||||
float slope = 1.0f;
|
||||
|
||||
// ALiBi
|
||||
if (max_bias > 0.0f) {
|
||||
const short h = iq2;
|
||||
|
||||
const float base = h < n_head_log2 ? m0 : m1;
|
||||
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
||||
|
||||
slope = pow(base, exp);
|
||||
}
|
||||
|
||||
//threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
|
||||
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
|
||||
threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
|
||||
|
@ -2545,10 +2580,10 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|||
mqk += simd_shuffle_down(mqk, 2);
|
||||
mqk += simd_shuffle_down(mqk, 1);
|
||||
|
||||
// mqk = mqk*scale + mask
|
||||
// mqk = mqk*scale + mask*slope
|
||||
if (tiisg == 0) {
|
||||
float4 mm = (float4) mp4[ic/4 + cc];
|
||||
mqk = mqk*scale + mm;
|
||||
mqk = mqk*scale + mm*slope;
|
||||
|
||||
ss4[cc] = mqk;
|
||||
}
|
||||
|
@ -2782,7 +2817,8 @@ kernel void kernel_cpy_f32_f16(
|
|||
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
||||
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
||||
|
||||
dst_data[i00] = src[0];
|
||||
// TODO: is there a better way to handle -INFINITY?
|
||||
dst_data[i00] = src[0] == -INFINITY ? -MAXHALF : src[0];
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue