ggml : add Flash Attention (#5021)
* ggml : add ggml_flash_attn_ext API * ggml : fix GQA support in ggml_flash_attn_ext * ggml : online attention (CPU) * metal : initial implementation * metal : f16 precision * metal : reduce branches * metal : specialize for head size * wip : 8 rows per simd group * wip : 4 rows per simd group * wip : template for rows per warp * metal : parallelize across KV size * metal : parallel reduce across heads * metal : efficient flash_attn_f16 implementation * metal : avoid redundant loads of the attention * metal : scale and mask in matrix form * metal : fix comment * llama : avoid ggml_cast, use F32 query * metal : add parallel reduce version (disabled) * metal : move output into local memory + optimize - the result from each simdgroup now stays in the registers - significantly reduced SRAM usage - more efficient skipping of -INF blocks - avoid simdgroup barrier in hot loop - add comments * metal : add tests, fix scaling, support C > 32 * metal : improve precision * ggml : fix f16 mad * metal : minor * metal : support Q > 8 * tests : add ATTN tests * metal : disable buffer allocation logs * tests : more * metal : faster inner loop for C == 32 * metal : fix array initialization * tests : ifdef * ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext * ggml : fix ggml_soft_max mask requirement * cuda : fix soft_max to use correct mask size * cuda : add flash_attn kernel (wip) * metal : optimize softmax for C > 32 * metal : optimize softmax * tests : minor fix * cuda : avoid zeroing fragments * tests : update dims * cuda : fix __hisinf() result check * cuda : avoid warp_reduce for smax * cuda : use int instead of int64_t Noticeably improves performance (thanks to Johannes) * cuda : make loops use the same loop values Thanks Johannes again for the tip * cuda : unroll some of the loops * cuda : avoid __hisinf branches * cuda : use half2 in softmax * cuda : switch to 1 warp for bs > 16 * cuda : speed-up reduce part of the kernel * cuda : unroll Q*K^T loop * cuda : fix -INF block check * cuda : simplify softmax * cuda : fix matrix names * cuda : minor * llama : adapt to F16 KQ_pos * llama : adapt new models to F16 KQ_mask * ggml : fix F16 store (ARM NEON) * llama : fix type of KQ_mask and KQ_pos * ggml : fix CPU soft_max * tests : add hs=256 * cuda : fix build * metal : improve perf via smaller int registers * cuda : adapt soft_max to F16 mask and pos * CUDA: faster FlashAttention, kernel for bs == 1 * 16 cols for Phi-2 * no vec for hs, no hs==256 ncols==32 for Volta * adjust kernel selection logic * 4 warps, 256 stride for all D * no ncols == 64 * Multiple parallel blocks for batch size 1 * fix compile warnings * fix excessive KQ_b loads * fix cmake build * fix KV cache padding, NaN from INFINITY (#6438) * llama : flash_attn cparam + fix defrag * server: support flash_attn param * server: bench: enable flash_attn param * CUDA: refactor host code, dyn. par. blocks * fix flash_attn_vec_f16 race condition * flush softmax exp below threshold to 0 * store temp KQ in registers * Calculate KQ as FP32 if KQV has GGML_PREC_F32 * Add __hgt2_mask implementation for CUDA 11 * fix KQ FP32 precision fpr parallel_blocks > 1 * llama-bench : add -fa,--flash-attn arg * metal : add BS=1 kernel for flash attention (#6508) * metal : add BS=1 kernel for flash attention (wip) * metal : support more than 1 warps * metal : opts * metal : opt * metal : switch to parallel reduce * metal : reduce registers * metal : simplify * metal : initial FA vec kernel * metal : use F32 attention accumulators * batched-bench : add fattn arg * llama : simplify llama_build_kv_store ggml-ci * llama : adapt build_olmo to changes * ggml : fix arm fp16 store on windows * metal : clean-up * metal : clean-up kernel code * metal : minor * tests : remove benchmarks ggml-ci * ggml : fix avx512 const correctness ggml-ci * ggml : fix soft_max with bias on CPU ggml-ci * common : print --flash-attn in help * ggml : fix num dimensions in ggml_flash_attn_ext * llama : force disable flash attention for incompatible models * ggml : ggml_soft_max support F16/F32 mask/pos ggml-ci * cuda : uint -> uint32_t * cuda : "constexpr dim3" -> "const dim3" ggml-ci * cuda : try to fix __hgt2_mask ggml-ci * ggml : add TODO's for F16/F32 mask/pos support in other backends * llama : replace bool need_kq_pos with use_alibi * llama : prep ALiBi support for BERT models ggml-ci * llama : fix n_batch requirements ggml-ci * cont * server : add help for --flash-attn arg * llama : disable FA for AMD * tests : remove TMP_ATTN_BENCH ggml-ci * llama : support save/load state with FA enabled ggml-ci * ci : add CUDA save-load-state tests ggml-ci * llama : llama_kv_cache_clear zeroes data + fix save-load seq ggml-ci * llama : fix copy-paste errors, add TODO * llama : disallow incompatible states * llama : update llama_state_get_size after v_trans field * metal : remove tmp log * llama : add static reminder for llama_state_get_size * metal : fix max nsg ggml-ci * ci : fix arg order ggml-ci --------- Co-authored-by: Johannes Gäßler <johannesg@5d6.de> Co-authored-by: Pierrick HYMBERT <pierrick.hymbert@gmail.com>
This commit is contained in:
parent
952d03dbea
commit
9c67c2773d
22 changed files with 2921 additions and 457 deletions
672
ggml-metal.metal
672
ggml-metal.metal
|
@ -352,11 +352,12 @@ kernel void kernel_sum_rows(
|
|||
dst_row[0] = row_sum;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_soft_max(
|
||||
device const float * src0,
|
||||
device const float * src1,
|
||||
device const float * src2,
|
||||
device float * dst,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device const char * src2,
|
||||
device char * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
|
@ -375,10 +376,10 @@ kernel void kernel_soft_max(
|
|||
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
||||
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
||||
|
||||
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||
device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr;
|
||||
device const float * ppos = src2 != src0 ? src2 : nullptr;
|
||||
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||
device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr;
|
||||
device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr;
|
||||
device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||
|
||||
float slope = 0.0f;
|
||||
|
||||
|
@ -456,11 +457,12 @@ kernel void kernel_soft_max(
|
|||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_soft_max_4(
|
||||
device const float * src0,
|
||||
device const float * src1,
|
||||
device const float * src2,
|
||||
device float * dst,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device const char * src2,
|
||||
device char * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
|
@ -479,10 +481,10 @@ kernel void kernel_soft_max_4(
|
|||
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
||||
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
||||
|
||||
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||
device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
|
||||
device const float4 * ppos = src2 != src0 ? (device const float4 *)(src2) : nullptr;
|
||||
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||
device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
|
||||
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr;
|
||||
device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr;
|
||||
device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
|
||||
|
||||
float slope = 0.0f;
|
||||
|
||||
|
@ -499,7 +501,7 @@ kernel void kernel_soft_max_4(
|
|||
float4 lmax4 = -INFINITY;
|
||||
|
||||
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
||||
lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f));
|
||||
lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)));
|
||||
}
|
||||
|
||||
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
||||
|
@ -525,7 +527,7 @@ kernel void kernel_soft_max_4(
|
|||
// parallel sum
|
||||
float4 lsum4 = 0.0f;
|
||||
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
||||
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)) - max_val);
|
||||
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f))) - max_val);
|
||||
lsum4 += exp_psrc4;
|
||||
pdst4[i00] = exp_psrc4;
|
||||
}
|
||||
|
@ -562,6 +564,14 @@ kernel void kernel_soft_max_4(
|
|||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_soft_max<float>) kernel_soft_max_t;
|
||||
typedef decltype(kernel_soft_max_4<float4>) kernel_soft_max_4_t;
|
||||
|
||||
template [[host_name("kernel_soft_max_f16")]] kernel kernel_soft_max_t kernel_soft_max<half>;
|
||||
template [[host_name("kernel_soft_max_f32")]] kernel kernel_soft_max_t kernel_soft_max<float>;
|
||||
template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<half4>;
|
||||
template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<float4>;
|
||||
|
||||
kernel void kernel_diag_mask_inf(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
|
@ -2084,6 +2094,632 @@ kernel void kernel_leaky_relu_f32(
|
|||
dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
|
||||
}
|
||||
|
||||
typedef void (flash_attn_ext_f16_t)(
|
||||
device const char * q,
|
||||
device const char * k,
|
||||
device const char * v,
|
||||
device const char * mask,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant int64_t & ne03,
|
||||
constant uint64_t & nb00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant uint64_t & nb03,
|
||||
constant int64_t & ne10,
|
||||
constant int64_t & ne11,
|
||||
constant int64_t & ne12,
|
||||
constant int64_t & ne13,
|
||||
constant uint64_t & nb10,
|
||||
constant uint64_t & nb11,
|
||||
constant uint64_t & nb12,
|
||||
constant uint64_t & nb13,
|
||||
constant int64_t & ne31,
|
||||
constant uint64_t & nb31,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant int64_t & ne2,
|
||||
constant int64_t & ne3,
|
||||
constant float & scale,
|
||||
threadgroup half * shared,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]);
|
||||
|
||||
// ref: https://arxiv.org/pdf/2307.08691.pdf
|
||||
template<int64_t D, int64_t Q = 8, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup
|
||||
kernel void kernel_flash_attn_ext_f16(
|
||||
device const char * q,
|
||||
device const char * k,
|
||||
device const char * v,
|
||||
device const char * mask,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant int64_t & ne03,
|
||||
constant uint64_t & nb00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant uint64_t & nb03,
|
||||
constant int64_t & ne10,
|
||||
constant int64_t & ne11,
|
||||
constant int64_t & ne12,
|
||||
constant int64_t & ne13,
|
||||
constant uint64_t & nb10,
|
||||
constant uint64_t & nb11,
|
||||
constant uint64_t & nb12,
|
||||
constant uint64_t & nb13,
|
||||
constant int64_t & ne31,
|
||||
constant uint64_t & nb31,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant int64_t & ne2,
|
||||
constant int64_t & ne3,
|
||||
constant float & scale,
|
||||
threadgroup half * shared [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
const short nsg = ntg.y; // number of simdgroups
|
||||
|
||||
const short iq3 = tgpig[2];
|
||||
const short iq2 = tgpig[1];
|
||||
const short iq1 = tgpig[0]*Q;
|
||||
|
||||
const short D4 = D/4;
|
||||
const short D8 = D/8;
|
||||
const short Q8 = Q/8;
|
||||
const short NW = N_SIMDWIDTH;
|
||||
const short SH = (C + Q); // shared memory per simdgroup in (half)
|
||||
|
||||
const short T = D + 2*nsg*SH; // shared memory size per query in (half)
|
||||
const short TF = T/2; // shared memory size per query in (float)
|
||||
const short T4 = T/4; // shared memory size per query in (half4)
|
||||
|
||||
threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
|
||||
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
|
||||
threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
|
||||
|
||||
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
|
||||
simdgroup_half8x8 lo[D8];
|
||||
|
||||
// load heads from Q to shared memory
|
||||
for (short j = sgitg; j < Q; j += nsg) {
|
||||
device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
|
||||
|
||||
for (short i = tiisg; i < D4; i += NW) {
|
||||
if (iq1 + j < ne01) {
|
||||
sq4[j*T4 + i] = (half4) q4[i];
|
||||
} else {
|
||||
sq4[j*T4 + i] = 0.0h;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// zero out lo
|
||||
for (short i = 0; i < D8; ++i) {
|
||||
lo[i] = make_filled_simdgroup_matrix<half, 8>(0.0h);
|
||||
}
|
||||
|
||||
// zero out shared memory SH
|
||||
for (short j = 0; j < Q; ++j) {
|
||||
for (short i = tiisg; i < SH; i += NW) {
|
||||
ss[j*TF + i] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
{
|
||||
float S[Q] = { [0 ... Q-1] = 0.0h };
|
||||
float M[Q] = { [0 ... Q-1] = -FLT_MAX/2 };
|
||||
|
||||
// assume K and V are same shape
|
||||
const short ne22 = ne12;
|
||||
const short ne23 = ne13;
|
||||
|
||||
const uint nb21 = nb11;
|
||||
const uint nb22 = nb12;
|
||||
const uint nb23 = nb13;
|
||||
|
||||
// broadcast
|
||||
const short rk2 = ne02/ne12;
|
||||
const short rk3 = ne03/ne13;
|
||||
|
||||
const short rv2 = ne02/ne22;
|
||||
const short rv3 = ne03/ne23;
|
||||
|
||||
// k indices
|
||||
const short ik2 = iq2/rk2;
|
||||
const short ik3 = iq3/rk3;
|
||||
|
||||
// v indices
|
||||
const short iv2 = iq2/rv2;
|
||||
const short iv3 = iq3/rv3;
|
||||
|
||||
// load the queries from shared memory into local memory
|
||||
simdgroup_half8x8 mq[D8];
|
||||
|
||||
for (short i = 0; i < D8; ++i) {
|
||||
simdgroup_load(mq[i], sq + i*8, T);
|
||||
}
|
||||
|
||||
// pointer to the mask
|
||||
device const half * mp = (device const half *) (mask + iq1*nb31);
|
||||
|
||||
// prepare diagonal scale matrix
|
||||
simdgroup_float8x8 mscale(scale);
|
||||
|
||||
// loop over the KV cache
|
||||
// each simdgroup handles blocks of Q rows and C columns
|
||||
for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
|
||||
const int ic = ic0 + C*sgitg;
|
||||
if (ic >= ne11) {
|
||||
break;
|
||||
}
|
||||
|
||||
// Q*K^T
|
||||
{
|
||||
for (short cc = 0; cc < C/8; ++cc) {
|
||||
simdgroup_float8x8 mqk = make_filled_simdgroup_matrix<float, 8>(0.h);
|
||||
|
||||
device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
|
||||
|
||||
for (short i = 0; i < D8; ++i) {
|
||||
simdgroup_half8x8 mk;
|
||||
simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose
|
||||
|
||||
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
|
||||
}
|
||||
|
||||
// mqk = mqk*scale + mask
|
||||
simdgroup_half8x8 mm;
|
||||
simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false);
|
||||
simdgroup_multiply_accumulate(mqk, mqk, mscale, mm);
|
||||
|
||||
simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
|
||||
}
|
||||
}
|
||||
|
||||
// used to detect blocks full of -INF
|
||||
float smax = -INFINITY;
|
||||
|
||||
// online softmax
|
||||
{
|
||||
float ms[Q];
|
||||
|
||||
for (short j = 0; j < Q; ++j) {
|
||||
const short p = tiisg;
|
||||
|
||||
const float m = M[j];
|
||||
const float s = ss[j*TF + p];
|
||||
|
||||
smax = simd_max(max(smax, s));
|
||||
M[j] = simd_max(max(M[j], s));
|
||||
|
||||
ms[j] = exp(m - M[j]);
|
||||
const float vs = exp(s - M[j]);
|
||||
|
||||
S[j] = S[j]*ms[j] + simd_sum(vs);
|
||||
|
||||
// the P matrix from the paper (Q rows, C columns)
|
||||
ss[j*TF + p] = vs;
|
||||
}
|
||||
|
||||
// create a QxQ diagonal matrix for rescaling the output
|
||||
if (tiisg < Q) {
|
||||
ss[tiisg*TF + C + tiisg] = ms[tiisg];
|
||||
}
|
||||
}
|
||||
|
||||
// skip -INF blocks
|
||||
if (smax == -INFINITY) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// O = diag(ms)*O
|
||||
{
|
||||
simdgroup_float8x8 mm;
|
||||
simdgroup_load(mm, ss + C, TF, 0, false);
|
||||
|
||||
for (short i = 0; i < D8; ++i) {
|
||||
simdgroup_multiply(lo[i], mm, lo[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// O = O + (Q*K^T)*V
|
||||
{
|
||||
for (short cc = 0; cc < C/8; ++cc) {
|
||||
device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
|
||||
|
||||
for (short i = 0; i < D8; ++i) {
|
||||
simdgroup_half8x8 mk;
|
||||
simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false);
|
||||
|
||||
simdgroup_float8x8 mv;
|
||||
simdgroup_load(mv, ss + 8*cc, TF, 0, false);
|
||||
|
||||
simdgroup_multiply_accumulate(lo[i], mv, mk, lo[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
|
||||
for (short j = 0; j < Q; ++j) {
|
||||
if (tiisg == 0) {
|
||||
ss[j*TF + 0] = S[j];
|
||||
ss[j*TF + 1] = M[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// reduce the warps sequentially
|
||||
for (short sg = 1; sg < nsg; ++sg) {
|
||||
float S = { 0.0h };
|
||||
float M = { -FLT_MAX/2 };
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// each simdgroup stores its output to shared memory, reusing sq
|
||||
if (sgitg == sg) {
|
||||
for (short i = 0; i < D8; ++i) {
|
||||
simdgroup_store(lo[i], sq + i*8, T, 0, false);
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// the first simdgroup accumulates the results from the other simdgroups
|
||||
if (sgitg == 0) {
|
||||
for (short j = 0; j < Q; ++j) {
|
||||
const float S0 = ss[j*TF + 0];
|
||||
const float S1 = ss[j*TF + sg*SH + 0];
|
||||
|
||||
const float M0 = ss[j*TF + 1];
|
||||
const float M1 = ss[j*TF + sg*SH + 1];
|
||||
|
||||
M = max(M0, M1);
|
||||
|
||||
const float ms0 = exp(M0 - M);
|
||||
const float ms1 = exp(M1 - M);
|
||||
|
||||
S = S0*ms0 + S1*ms1;
|
||||
|
||||
if (tiisg == 0) {
|
||||
ss[j*TF + 0] = S;
|
||||
ss[j*TF + 1] = M;
|
||||
|
||||
ss[j*TF + C + j ] = ms0;
|
||||
ss[j*TF + C + j + sg*SH] = ms1;
|
||||
}
|
||||
}
|
||||
|
||||
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
|
||||
{
|
||||
simdgroup_half8x8 t;
|
||||
simdgroup_float8x8 ms0;
|
||||
simdgroup_float8x8 ms1;
|
||||
|
||||
simdgroup_load(ms0, ss + C, TF, 0, false);
|
||||
simdgroup_load(ms1, ss + C + sg*SH, TF, 0, false);
|
||||
|
||||
for (short i = 0; i < D8; ++i) {
|
||||
simdgroup_load (t, sq + i*8, T, 0, false);
|
||||
simdgroup_multiply(t, ms1, t);
|
||||
|
||||
simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// store result to shared memory (reuse sq)
|
||||
if (sgitg == 0) {
|
||||
for (short i = 0; i < D8; ++i) {
|
||||
simdgroup_store(lo[i], sq + i*8, T, 0, false);
|
||||
}
|
||||
}
|
||||
|
||||
device float4 * dst4 = (device float4 *) dst;
|
||||
|
||||
// final rescale with 1/S and store to global memory
|
||||
if (sgitg == 0) {
|
||||
for (short j = 0; j < Q && iq1 + j < ne01; ++j) {
|
||||
const float S = ss[j*TF + 0];
|
||||
|
||||
for (short i = tiisg; i < D4; i += NW) {
|
||||
dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>;
|
||||
|
||||
template<int64_t D, int64_t Q = 1, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup
|
||||
kernel void kernel_flash_attn_ext_vec_f16(
|
||||
device const char * q,
|
||||
device const char * k,
|
||||
device const char * v,
|
||||
device const char * mask,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant int64_t & ne03,
|
||||
constant uint64_t & nb00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant uint64_t & nb03,
|
||||
constant int64_t & ne10,
|
||||
constant int64_t & ne11,
|
||||
constant int64_t & ne12,
|
||||
constant int64_t & ne13,
|
||||
constant uint64_t & nb10,
|
||||
constant uint64_t & nb11,
|
||||
constant uint64_t & nb12,
|
||||
constant uint64_t & nb13,
|
||||
constant int64_t & ne31,
|
||||
constant uint64_t & nb31,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant int64_t & ne2,
|
||||
constant int64_t & ne3,
|
||||
constant float & scale,
|
||||
threadgroup half * shared [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
const short nsg = ntg.y; // number of simdgroups
|
||||
|
||||
const short iq3 = tgpig[2];
|
||||
const short iq2 = tgpig[1];
|
||||
const short iq1 = tgpig[0];
|
||||
|
||||
const short D4 = D/4;
|
||||
const short NW = N_SIMDWIDTH;
|
||||
const short SH = (C + Q); // shared memory per simdgroup in (half)
|
||||
|
||||
const short T = D + 2*nsg*SH; // shared memory size per query in (half)
|
||||
|
||||
//threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
|
||||
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
|
||||
threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
|
||||
threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4
|
||||
threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + 1*T); // scratch buffer for the results
|
||||
|
||||
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
|
||||
half4 lo[D4/NW];
|
||||
|
||||
// load heads from Q to shared memory
|
||||
device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03));
|
||||
|
||||
for (short i = tiisg; i < D4; i += NW) {
|
||||
if (iq1 < ne01) {
|
||||
sq4[i] = (half4) q4[i];
|
||||
} else {
|
||||
sq4[i] = 0.0h;
|
||||
}
|
||||
}
|
||||
|
||||
// zero out lo
|
||||
for (short i = tiisg; i < D4; i += NW) {
|
||||
lo[i/NW] = 0.0h;
|
||||
}
|
||||
|
||||
// zero out shared memory SH
|
||||
for (short i = tiisg; i < SH/4; i += NW) {
|
||||
ss4[i] = 0.0h;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
{
|
||||
float S = { 0.0h };
|
||||
float M = { -FLT_MAX/2 };
|
||||
|
||||
// assume K and V are same shape
|
||||
const short ne22 = ne12;
|
||||
const short ne23 = ne13;
|
||||
|
||||
const uint nb21 = nb11;
|
||||
const uint nb22 = nb12;
|
||||
const uint nb23 = nb13;
|
||||
|
||||
// broadcast
|
||||
const short rk2 = ne02/ne12;
|
||||
const short rk3 = ne03/ne13;
|
||||
|
||||
const short rv2 = ne02/ne22;
|
||||
const short rv3 = ne03/ne23;
|
||||
|
||||
// k indices
|
||||
const short ik2 = iq2 / rk2;
|
||||
const short ik3 = iq3 / rk3;
|
||||
|
||||
// v indices
|
||||
const short iv2 = iq2 / rv2;
|
||||
const short iv3 = iq3 / rv3;
|
||||
|
||||
// load the queries from shared memory into local memory
|
||||
half4 mq[D4];
|
||||
|
||||
for (short ii = 0; ii < D4; ii += NW) {
|
||||
short i = ii + tiisg;
|
||||
mq[i] = sq4[i];
|
||||
}
|
||||
|
||||
// pointer to the mask
|
||||
device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31);
|
||||
|
||||
// loop over the KV cache
|
||||
// each simdgroup handles blocks of Q rows and C columns
|
||||
for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
|
||||
const int ic = ic0 + C*sgitg;
|
||||
if (ic >= ne11) {
|
||||
break;
|
||||
}
|
||||
|
||||
// Q*K^T
|
||||
{
|
||||
#pragma unroll
|
||||
for (short cc = 0; cc < C/4; ++cc) {
|
||||
float4 mqk = { 0.0h };
|
||||
|
||||
device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13));
|
||||
|
||||
#pragma unroll
|
||||
for (short ii = 0; ii < D4; ii += NW) {
|
||||
const short i = ii + tiisg;
|
||||
|
||||
half4x4 mk;
|
||||
mk[0] = pk4[i + 0*(nb11/8)];
|
||||
mk[1] = pk4[i + 1*(nb11/8)];
|
||||
mk[2] = pk4[i + 2*(nb11/8)];
|
||||
mk[3] = pk4[i + 3*(nb11/8)];
|
||||
|
||||
mqk += (float4) (mq[i] * mk);
|
||||
}
|
||||
|
||||
// reduce the results from the threads in the simdgroup
|
||||
mqk += simd_shuffle_down(mqk, 16);
|
||||
mqk += simd_shuffle_down(mqk, 8);
|
||||
mqk += simd_shuffle_down(mqk, 4);
|
||||
mqk += simd_shuffle_down(mqk, 2);
|
||||
mqk += simd_shuffle_down(mqk, 1);
|
||||
|
||||
// mqk = mqk*scale + mask
|
||||
if (tiisg == 0) {
|
||||
float4 mm = (float4) mp4[ic/4 + cc];
|
||||
mqk = mqk*scale + mm;
|
||||
|
||||
ss4[cc] = mqk;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// online softmax
|
||||
{
|
||||
const short p = tiisg;
|
||||
|
||||
const float m = M;
|
||||
const float s = ss[p];
|
||||
|
||||
M = simd_max(max(M, s));
|
||||
|
||||
const float ms = exp(m - M);
|
||||
const float vs = exp(s - M);
|
||||
|
||||
S = S*ms + simd_sum(vs);
|
||||
|
||||
// the P matrix from the paper (Q rows, C columns)
|
||||
ss[p] = vs;
|
||||
|
||||
// O = diag(ms)*O
|
||||
#pragma unroll
|
||||
for (short ii = 0; ii < D4; ii += NW) {
|
||||
const short i = ii + tiisg;
|
||||
lo[i/NW] *= ms;
|
||||
}
|
||||
}
|
||||
|
||||
// O = O + (Q*K^T)*V
|
||||
{
|
||||
#pragma unroll
|
||||
for (short cc = 0; cc < C/4; ++cc) {
|
||||
device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + 4*cc)*nb21 + iv2*nb22 + iv3*nb23));
|
||||
|
||||
#pragma unroll
|
||||
for (short ii = 0; ii < D4; ii += NW) {
|
||||
const short i = ii + tiisg;
|
||||
|
||||
lo[i/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0];
|
||||
lo[i/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1];
|
||||
lo[i/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2];
|
||||
lo[i/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
|
||||
if (tiisg == 0) {
|
||||
ss[0] = S;
|
||||
ss[1] = M;
|
||||
}
|
||||
}
|
||||
|
||||
// store results to shared memory
|
||||
for (short ii = 0; ii < D4; ii += NW) {
|
||||
short i = ii + tiisg;
|
||||
sr4[i] = lo[ii/NW];
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// parallel reduce
|
||||
for (short r = nsg/2; r > 0; r >>= 1) {
|
||||
if (sgitg < r) {
|
||||
const float S0 = ss[ 0];
|
||||
const float S1 = ss[r*SH + 0];
|
||||
|
||||
const float M0 = ss[ 1];
|
||||
const float M1 = ss[r*SH + 1];
|
||||
|
||||
const float M = max(M0, M1);
|
||||
|
||||
const float ms0 = exp(M0 - M);
|
||||
const float ms1 = exp(M1 - M);
|
||||
|
||||
const float S = S0*ms0 + S1*ms1;
|
||||
|
||||
if (tiisg == 0) {
|
||||
ss[0] = S;
|
||||
ss[1] = M;
|
||||
}
|
||||
|
||||
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
|
||||
for (short ii = 0; ii < D4; ii += NW) {
|
||||
short i = ii + tiisg;
|
||||
sr4[i] = sr4[i]*ms0 + sr4[i + r*D4]*ms1;
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
device float4 * dst4 = (device float4 *) dst;
|
||||
|
||||
// final rescale with 1/S and store to global memory
|
||||
if (sgitg == 0) {
|
||||
const float S = ss[0];
|
||||
|
||||
for (short ii = 0; ii < D4; ii += NW) {
|
||||
short i = ii + tiisg;
|
||||
dst4[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D4 + i] = (float4) sr4[i]/S;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
|
||||
|
||||
kernel void kernel_cpy_f16_f16(
|
||||
device const half * src0,
|
||||
device half * dst,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue