[metal-kernel] add flash_attn_ext_scalar_f16 implementation
This commit is contained in:
parent
841713e1e4
commit
9e62e7e10e
1 changed files with 288 additions and 0 deletions
|
@ -2799,6 +2799,294 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
||||||
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
|
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
|
||||||
//template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
|
//template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
|
||||||
|
|
||||||
|
half dequantize_load_f16(device const half *xb, short il) {
|
||||||
|
return xb[il];
|
||||||
|
}
|
||||||
|
|
||||||
|
half dequantize_load_q8_0(device const block_q8_0 *xb, short il) {
|
||||||
|
device const block_q8_0 *xb_ = &xb[il / QK8_0];
|
||||||
|
return xb_->d * xb_->qs[il % QK8_0];
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename block_q, half (*dequantize_load)(device const block_q* xb, short il), int64_t D, int64_t Q = 1, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup
|
||||||
|
kernel void kernel_flash_attn_ext_scalar_f16(
|
||||||
|
device const char * q,
|
||||||
|
device const char * k,
|
||||||
|
device const char * v,
|
||||||
|
device const char * mask,
|
||||||
|
device float * dst,
|
||||||
|
constant int64_t & ne01,
|
||||||
|
constant int64_t & ne02,
|
||||||
|
constant int64_t & ne03,
|
||||||
|
constant uint64_t & nb01,
|
||||||
|
constant uint64_t & nb02,
|
||||||
|
constant uint64_t & nb03,
|
||||||
|
constant int64_t & ne11,
|
||||||
|
constant int64_t & ne12,
|
||||||
|
constant int64_t & ne13,
|
||||||
|
constant uint64_t & nb11,
|
||||||
|
constant uint64_t & nb12,
|
||||||
|
constant uint64_t & nb13,
|
||||||
|
constant uint64_t & nb21,
|
||||||
|
constant uint64_t & nb22,
|
||||||
|
constant uint64_t & nb23,
|
||||||
|
constant uint64_t & nb31,
|
||||||
|
constant int64_t & ne1,
|
||||||
|
constant int64_t & ne2,
|
||||||
|
constant float & scale,
|
||||||
|
constant float & max_bias,
|
||||||
|
constant float & m0,
|
||||||
|
constant float & m1,
|
||||||
|
constant uint32_t & n_head_log2,
|
||||||
|
constant float & logit_softcap,
|
||||||
|
threadgroup half * shared [[threadgroup(0)]],
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
|
uint3 ntg[[threads_per_threadgroup]],
|
||||||
|
ushort tiisg[[thread_index_in_simdgroup]],
|
||||||
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
|
const short nsg = ntg.y; // number of simdgroups
|
||||||
|
|
||||||
|
const short iq3 = tgpig[2];
|
||||||
|
const short iq2 = tgpig[1];
|
||||||
|
const short iq1 = tgpig[0];
|
||||||
|
|
||||||
|
const short NW = N_SIMDWIDTH;
|
||||||
|
const short SH = (C + Q); // shared memory per simdgroup in (half)
|
||||||
|
|
||||||
|
const short T = D + 2*nsg*SH; // shared memory size per query in (half)
|
||||||
|
|
||||||
|
float slope = 1.0f;
|
||||||
|
|
||||||
|
// ALiBi
|
||||||
|
if (max_bias > 0.0f) {
|
||||||
|
const uint32_t h = iq2;
|
||||||
|
|
||||||
|
const float base = h < n_head_log2 ? m0 : m1;
|
||||||
|
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
||||||
|
|
||||||
|
slope = pow(base, exp);
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
|
||||||
|
threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
|
||||||
|
threadgroup half * sr = (threadgroup half *) (shared + sgitg*D + 1*T); // scratch buffer for the results
|
||||||
|
|
||||||
|
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
|
||||||
|
half lo[D/NW];
|
||||||
|
|
||||||
|
// load heads from Q to shared memory
|
||||||
|
device const float * q_ = (device const float *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03));
|
||||||
|
|
||||||
|
for (short i = tiisg; i < D; i += NW) {
|
||||||
|
if (iq1 < ne01) {
|
||||||
|
sq[i] = (half) q_[i];
|
||||||
|
} else {
|
||||||
|
sq[i] = 0.0h;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// zero out lo
|
||||||
|
for (short i = tiisg; i < D; i += NW) {
|
||||||
|
lo[i/NW] = 0.0h;
|
||||||
|
}
|
||||||
|
|
||||||
|
// zero out shared memory SH
|
||||||
|
for (short i = tiisg; i < SH; i += NW) {
|
||||||
|
ss[i] = 0.0h;
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
{
|
||||||
|
float S = { 0.0h };
|
||||||
|
float M = { -FLT_MAX/2 };
|
||||||
|
|
||||||
|
// assume K and V are same shape
|
||||||
|
const short ne22 = ne12;
|
||||||
|
const short ne23 = ne13;
|
||||||
|
|
||||||
|
// broadcast
|
||||||
|
const short rk2 = ne02/ne12;
|
||||||
|
const short rk3 = ne03/ne13;
|
||||||
|
|
||||||
|
const short rv2 = ne02/ne22;
|
||||||
|
const short rv3 = ne03/ne23;
|
||||||
|
|
||||||
|
// k indices
|
||||||
|
const short ik2 = iq2 / rk2;
|
||||||
|
const short ik3 = iq3 / rk3;
|
||||||
|
|
||||||
|
// v indices
|
||||||
|
const short iv2 = iq2 / rv2;
|
||||||
|
const short iv3 = iq3 / rv3;
|
||||||
|
|
||||||
|
// load the queries from shared memory into local memory
|
||||||
|
half mq[D];
|
||||||
|
|
||||||
|
for (short ii = 0; ii < D; ii += NW) {
|
||||||
|
short i = ii + tiisg;
|
||||||
|
mq[i] = sq[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// pointer to the mask
|
||||||
|
device const half * mp = (device const half *) (mask + iq1*nb31);
|
||||||
|
|
||||||
|
// loop over the KV cache
|
||||||
|
// each simdgroup handles blocks of Q rows and C columns
|
||||||
|
for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
|
||||||
|
const int ic = ic0 + C*sgitg;
|
||||||
|
if (ic >= ne11) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Q*K^T
|
||||||
|
{
|
||||||
|
// #pragma unroll
|
||||||
|
for (short cc = 0; cc < C; ++cc) {
|
||||||
|
float mqk = 0.0;
|
||||||
|
|
||||||
|
device const block_q * pk = (device const block_q *) ((device const char *) k + ((ic + cc)*nb11 + ik2*nb12 + ik3*nb13));
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (short ii = 0; ii < D; ii += NW) {
|
||||||
|
const short i = ii + tiisg;
|
||||||
|
mqk += mq[i] * dequantize_load(pk, i);
|
||||||
|
}
|
||||||
|
|
||||||
|
// reduce the results from the threads in the simdgroup
|
||||||
|
mqk += simd_shuffle_down(mqk, 16);
|
||||||
|
mqk += simd_shuffle_down(mqk, 8);
|
||||||
|
mqk += simd_shuffle_down(mqk, 4);
|
||||||
|
mqk += simd_shuffle_down(mqk, 2);
|
||||||
|
mqk += simd_shuffle_down(mqk, 1);
|
||||||
|
|
||||||
|
// mqk = mqk*scale + mask*slope
|
||||||
|
if (tiisg == 0) {
|
||||||
|
mqk *= scale;
|
||||||
|
|
||||||
|
if (logit_softcap != 0.0f) {
|
||||||
|
mqk = logit_softcap*precise::tanh(mqk);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (mask != q) {
|
||||||
|
mqk += (mp[ic + cc])*slope;
|
||||||
|
}
|
||||||
|
|
||||||
|
ss[cc] = mqk;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// online softmax
|
||||||
|
{
|
||||||
|
const short p = tiisg;
|
||||||
|
|
||||||
|
const float m = M;
|
||||||
|
const float s = ss[p];
|
||||||
|
|
||||||
|
M = simd_max(max(M, s));
|
||||||
|
|
||||||
|
const float ms = exp(m - M);
|
||||||
|
const float vs = exp(s - M);
|
||||||
|
|
||||||
|
S = S*ms + simd_sum(vs);
|
||||||
|
|
||||||
|
// the P matrix from the paper (Q rows, C columns)
|
||||||
|
ss[p] = vs;
|
||||||
|
|
||||||
|
// O = diag(ms)*O
|
||||||
|
#pragma unroll
|
||||||
|
for (short ii = 0; ii < D; ii += NW) {
|
||||||
|
const short i = ii + tiisg;
|
||||||
|
lo[i/NW] *= ms;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// O = O + (Q*K^T)*V
|
||||||
|
{
|
||||||
|
// #pragma unroll
|
||||||
|
for (short cc = 0; cc < C; ++cc) {
|
||||||
|
device const block_q * pv = (device const block_q *) ((device const char *) v + ((ic + cc)*nb21 + iv2*nb22 + iv3*nb23));
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (short ii = 0; ii < D; ii += NW) {
|
||||||
|
const short i = ii + tiisg;
|
||||||
|
|
||||||
|
lo[i/NW] += dequantize_load(pv, i) * ss[cc];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
|
||||||
|
if (tiisg == 0) {
|
||||||
|
ss[0] = S;
|
||||||
|
ss[1] = M;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// store results to shared memory
|
||||||
|
for (short ii = 0; ii < D; ii += NW) {
|
||||||
|
short i = ii + tiisg;
|
||||||
|
sr[i] = lo[ii/NW];
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// parallel reduce
|
||||||
|
for (short r = nsg/2; r > 0; r >>= 1) {
|
||||||
|
if (sgitg < r) {
|
||||||
|
const float S0 = ss[ 0];
|
||||||
|
const float S1 = ss[r*SH + 0];
|
||||||
|
|
||||||
|
const float M0 = ss[ 1];
|
||||||
|
const float M1 = ss[r*SH + 1];
|
||||||
|
|
||||||
|
const float M = max(M0, M1);
|
||||||
|
|
||||||
|
const float ms0 = exp(M0 - M);
|
||||||
|
const float ms1 = exp(M1 - M);
|
||||||
|
|
||||||
|
const float S = S0*ms0 + S1*ms1;
|
||||||
|
|
||||||
|
if (tiisg == 0) {
|
||||||
|
ss[0] = S;
|
||||||
|
ss[1] = M;
|
||||||
|
}
|
||||||
|
|
||||||
|
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
|
||||||
|
for (short ii = 0; ii < D; ii += NW) {
|
||||||
|
short i = ii + tiisg;
|
||||||
|
sr[i] = sr[i]*ms0 + sr[i + r*D]*ms1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
}
|
||||||
|
|
||||||
|
// final rescale with 1/S and store to global memory
|
||||||
|
if (sgitg == 0) {
|
||||||
|
const float S = ss[0];
|
||||||
|
|
||||||
|
for (short ii = 0; ii < D; ii += NW) {
|
||||||
|
short i = ii + tiisg;
|
||||||
|
dst[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D + i] = sr[i]/S;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template [[host_name("kernel_flash_attn_ext_scalar_f16_h32")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_scalar_f16<half, dequantize_load_f16, 32>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_scalar_f16_h64")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_scalar_f16<half, dequantize_load_f16, 64>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_scalar_f16_h96")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_scalar_f16<half, dequantize_load_f16, 96>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_scalar_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_scalar_f16<half, dequantize_load_f16, 128>;
|
||||||
|
|
||||||
|
template [[host_name("kernel_flash_attn_ext_scalar_q8_0_h32")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_scalar_f16<block_q8_0, dequantize_load_q8_0, 32>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_scalar_q8_0_h64")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_scalar_f16<block_q8_0, dequantize_load_q8_0, 64>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_scalar_q8_0_h96")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_scalar_f16<block_q8_0, dequantize_load_q8_0, 96>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_scalar_q8_0_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_scalar_f16<block_q8_0, dequantize_load_q8_0, 128>;
|
||||||
|
|
||||||
template<typename T0, typename T1>
|
template<typename T0, typename T1>
|
||||||
kernel void kernel_cpy(
|
kernel void kernel_cpy(
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue