metal : add BS=1 kernel for flash attention (#6508)
* metal : add BS=1 kernel for flash attention (wip) * metal : support more than 1 warps * metal : opts * metal : opt * metal : switch to parallel reduce * metal : reduce registers * metal : simplify * metal : initial FA vec kernel
This commit is contained in:
parent
260cdb2d08
commit
105332cc17
2 changed files with 361 additions and 32 deletions
119
ggml-metal.m
119
ggml-metal.m
|
@ -183,6 +183,8 @@ enum ggml_metal_kernel_type {
|
|||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
|
||||
|
@ -621,12 +623,14 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
|
||||
|
@ -2563,19 +2567,32 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
|
||||
switch (ne00) {
|
||||
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
|
||||
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
|
||||
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
|
||||
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
|
||||
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
|
||||
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
|
||||
default:
|
||||
{
|
||||
GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
|
||||
GGML_METAL_LOG_ERROR("add template specialization for this size\n");
|
||||
GGML_ASSERT(false && "add template specialization for this size");
|
||||
}
|
||||
if (ne01 > 1 || (ne00%128 != 0)) {
|
||||
switch (ne00) {
|
||||
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
|
||||
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
|
||||
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
|
||||
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
|
||||
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
|
||||
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
|
||||
default:
|
||||
{
|
||||
GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
|
||||
GGML_METAL_LOG_ERROR("add template specialization for this size\n");
|
||||
GGML_ASSERT(false && "add template specialization for this size");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
switch (ne00) {
|
||||
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
|
||||
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
|
||||
default:
|
||||
{
|
||||
GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
|
||||
GGML_METAL_LOG_ERROR("add template specialization for this size\n");
|
||||
GGML_ASSERT(false && "add template specialization for this size");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: extend if necessary
|
||||
|
@ -2609,24 +2626,62 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26];
|
||||
[encoder setBytes:&scale length:sizeof( float) atIndex:27];
|
||||
|
||||
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
|
||||
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
|
||||
// half8x8 kernel
|
||||
if (ne01 > 1 || (ne00%128 != 0)) {
|
||||
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
|
||||
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
|
||||
|
||||
GGML_ASSERT(nqptg <= 32);
|
||||
GGML_ASSERT(nqptg % 8 == 0);
|
||||
GGML_ASSERT(ncpsg % 32 == 0);
|
||||
GGML_ASSERT(nqptg <= 32);
|
||||
GGML_ASSERT(nqptg % 8 == 0);
|
||||
GGML_ASSERT(ncpsg % 32 == 0);
|
||||
|
||||
// simdgroups per threadgroup (a.k.a. warps)
|
||||
// for small batches use more simdgroups (needs more tests, to confirm if it's worth it)
|
||||
const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)) : 4;
|
||||
// simdgroups per threadgroup (a.k.a. warps)
|
||||
// for small batches use more simdgroups (needs more tests, to confirm if it's worth it)
|
||||
const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)) : 4;
|
||||
|
||||
const size_t smem = nqptg*(ne00 + nsg*(ncpsg + nqptg))*(sizeof(float)/2);
|
||||
const size_t smem = nqptg*(ne00 + nsg*(ncpsg + nqptg))*(sizeof(float)/2);
|
||||
|
||||
//printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
|
||||
GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
|
||||
[encoder setThreadgroupMemoryLength:smem atIndex:0];
|
||||
//printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
|
||||
GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
|
||||
[encoder setThreadgroupMemoryLength:smem atIndex:0];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
||||
} else {
|
||||
// half1x4 kernel
|
||||
const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !!
|
||||
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
|
||||
|
||||
GGML_ASSERT(nqptg <= 32);
|
||||
GGML_ASSERT(nqptg % 1 == 0);
|
||||
GGML_ASSERT(ncpsg % 32 == 0);
|
||||
|
||||
// simdgroups per threadgroup (a.k.a. warps)
|
||||
// for small batches use more simdgroups (needs more tests, to confirm if it's worth it)
|
||||
const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32));
|
||||
|
||||
int64_t nsg = 1;
|
||||
while (nsg <= nsgt) {
|
||||
nsg *= 2;
|
||||
}
|
||||
nsg /= 2;
|
||||
|
||||
// require power of 2
|
||||
//{
|
||||
// int64_t nsgm = 1;
|
||||
// while (nsgm < nsg) {
|
||||
// nsgm *= 2;
|
||||
// }
|
||||
// GGML_ASSERT(nsg == nsgm);
|
||||
//}
|
||||
|
||||
const size_t smem = (nqptg*(ne00 + nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2);
|
||||
|
||||
//printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
|
||||
GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
|
||||
[encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_DUP:
|
||||
case GGML_OP_CPY:
|
||||
|
|
274
ggml-metal.metal
274
ggml-metal.metal
|
@ -2494,6 +2494,280 @@ template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f
|
|||
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 8, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256, 8, 32>;
|
||||
|
||||
#define HALF_MAX_HALF half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
|
||||
|
||||
template<int64_t D, int64_t C> // head size, queries per threadgroup, cache items per threadgroup
|
||||
kernel void kernel_flash_attn_ext_vec_f16(
|
||||
device const char * q,
|
||||
device const char * k,
|
||||
device const char * v,
|
||||
device const char * mask,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant int64_t & ne03,
|
||||
constant uint64_t & nb00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant uint64_t & nb03,
|
||||
constant int64_t & ne10,
|
||||
constant int64_t & ne11,
|
||||
constant int64_t & ne12,
|
||||
constant int64_t & ne13,
|
||||
constant uint64_t & nb10,
|
||||
constant uint64_t & nb11,
|
||||
constant uint64_t & nb12,
|
||||
constant uint64_t & nb13,
|
||||
constant int64_t & ne31,
|
||||
constant uint64_t & nb31,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant int64_t & ne2,
|
||||
constant int64_t & ne3,
|
||||
constant float & scale,
|
||||
threadgroup half * shared [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
const short nsg = ntg.y; // number of simdgroups
|
||||
|
||||
const short iq3 = tgpig[2];
|
||||
const short iq2 = tgpig[1];
|
||||
const short iq1 = tgpig[0];
|
||||
|
||||
const short D4 = D/4;
|
||||
const short D8 = D/8;
|
||||
const short NW = N_SIMDWIDTH;
|
||||
const short SH = (C + 1); // shared memory per simdgroup in (half)
|
||||
|
||||
const short T = D + nsg*SH; // shared memory size per query in (half)
|
||||
const short T4 = T/4; // shared memory size per query in (half4)
|
||||
|
||||
threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
|
||||
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
|
||||
threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
|
||||
threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*SH + 1*D); // same as above but in half4
|
||||
threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + 1*T); // scratch buffer for the results
|
||||
|
||||
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
|
||||
half4 lo[D4/NW];
|
||||
|
||||
// load heads from Q to shared memory
|
||||
device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03));
|
||||
|
||||
for (short i = tiisg; i < D4; i += NW) {
|
||||
if (iq1 < ne01) {
|
||||
sq4[i] = (half4) q4[i];
|
||||
} else {
|
||||
sq4[i] = 0.0h;
|
||||
}
|
||||
}
|
||||
|
||||
// zero out lo
|
||||
for (short i = tiisg; i < D4; i += NW) {
|
||||
lo[i/NW] = 0.0h;
|
||||
}
|
||||
|
||||
// zero out shared memory SH
|
||||
for (short i = tiisg; i < SH/4; i += NW) {
|
||||
ss4[i] = 0.0h;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
{
|
||||
half S = { 0.0h };
|
||||
half M = { -HALF_MAX_HALF };
|
||||
|
||||
// assume K and V are same shape
|
||||
const short ne22 = ne12;
|
||||
const short ne23 = ne13;
|
||||
|
||||
const uint nb21 = nb11;
|
||||
const uint nb22 = nb12;
|
||||
const uint nb23 = nb13;
|
||||
|
||||
// broadcast
|
||||
const short rk2 = ne02/ne12;
|
||||
const short rk3 = ne03/ne13;
|
||||
|
||||
const short rv2 = ne02/ne22;
|
||||
const short rv3 = ne03/ne23;
|
||||
|
||||
// k indices
|
||||
const short ik2 = iq2 / rk2;
|
||||
const short ik3 = iq3 / rk3;
|
||||
|
||||
// v indices
|
||||
const short iv2 = iq2 / rv2;
|
||||
const short iv3 = iq3 / rv3;
|
||||
|
||||
// load the queries from shared memory into local memory
|
||||
half4 mq[D4];
|
||||
|
||||
for (short ii = 0; ii < D4; ii += NW) {
|
||||
short i = ii + tiisg;
|
||||
mq[i] = sq4[i];
|
||||
}
|
||||
|
||||
// pointer to the mask
|
||||
device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31);
|
||||
|
||||
// loop over the KV cache
|
||||
// each simdgroup handles blocks of Q rows and C columns
|
||||
for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
|
||||
const int ic = ic0 + C*sgitg;
|
||||
if (ic >= ne11) {
|
||||
break;
|
||||
}
|
||||
|
||||
// Q*K^T
|
||||
{
|
||||
#pragma unroll
|
||||
for (short cc = 0; cc < C/4; ++cc) {
|
||||
half4 mqk = { 0.0h };
|
||||
|
||||
device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13));
|
||||
|
||||
#pragma unroll
|
||||
for (short ii = 0; ii < D4; ii += NW) {
|
||||
const short i = ii + tiisg;
|
||||
|
||||
half4x4 mk;
|
||||
mk[0] = pk4[i + 0*(nb11/8)];
|
||||
mk[1] = pk4[i + 1*(nb11/8)];
|
||||
mk[2] = pk4[i + 2*(nb11/8)];
|
||||
mk[3] = pk4[i + 3*(nb11/8)];
|
||||
|
||||
mqk += mq[i] * mk;
|
||||
}
|
||||
|
||||
// reduce the results from the threads in the simdgroup
|
||||
mqk += simd_shuffle_down(mqk, 16);
|
||||
mqk += simd_shuffle_down(mqk, 8);
|
||||
mqk += simd_shuffle_down(mqk, 4);
|
||||
mqk += simd_shuffle_down(mqk, 2);
|
||||
mqk += simd_shuffle_down(mqk, 1);
|
||||
|
||||
// mqk = mqk*scale + mask
|
||||
if (tiisg == 0) {
|
||||
half4 mm = mp4[ic/4 + cc];
|
||||
mqk = mqk*scale + mm;
|
||||
|
||||
ss4[cc] = mqk;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// online softmax
|
||||
{
|
||||
const short p = tiisg;
|
||||
|
||||
const half m = M;
|
||||
const half s = ss[p];
|
||||
|
||||
M = simd_max(max(M, s));
|
||||
|
||||
const half ms = exp(m - M);
|
||||
const half vs = exp(s - M);
|
||||
|
||||
S = S*ms + simd_sum(vs);
|
||||
|
||||
// the P matrix from the paper (Q rows, C columns)
|
||||
ss[p] = vs;
|
||||
|
||||
// O = diag(ms)*O
|
||||
#pragma unroll
|
||||
for (short ii = 0; ii < D4; ii += NW) {
|
||||
const short i = ii + tiisg;
|
||||
lo[i/NW] *= ms;
|
||||
}
|
||||
}
|
||||
|
||||
// O = O + (Q*K^T)*V
|
||||
{
|
||||
#pragma unroll
|
||||
for (short cc = 0; cc < C/4; ++cc) {
|
||||
device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + 4*cc)*nb21 + iv2*nb22 + iv3*nb23));
|
||||
|
||||
#pragma unroll
|
||||
for (short ii = 0; ii < D4; ii += NW) {
|
||||
const short i = ii + tiisg;
|
||||
lo[i/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0];
|
||||
lo[i/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1];
|
||||
lo[i/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2];
|
||||
lo[i/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
|
||||
if (tiisg == 0) {
|
||||
ss[0] = S;
|
||||
ss[1] = M;
|
||||
}
|
||||
}
|
||||
|
||||
// store results to shared memory
|
||||
for (short ii = 0; ii < D4; ii += NW) {
|
||||
short i = ii + tiisg;
|
||||
sr4[i] = lo[ii/NW];
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// parallel reduce
|
||||
for (short r = nsg/2; r > 0; r >>= 1) {
|
||||
if (sgitg < r) {
|
||||
const half S0 = ss[ 0];
|
||||
const half S1 = ss[r*SH + 0];
|
||||
|
||||
const half M0 = ss[ 1];
|
||||
const half M1 = ss[r*SH + 1];
|
||||
|
||||
const half M = max(M0, M1);
|
||||
|
||||
const half ms0 = exp(M0 - M);
|
||||
const half ms1 = exp(M1 - M);
|
||||
|
||||
const half S = S0*ms0 + S1*ms1;
|
||||
|
||||
if (tiisg == 0) {
|
||||
ss[0] = S;
|
||||
ss[1] = M;
|
||||
}
|
||||
|
||||
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
|
||||
for (short ii = 0; ii < D4; ii += NW) {
|
||||
short i = ii + tiisg;
|
||||
sr4[i] = sr4[i]*ms0 + sr4[i + r*D4]*ms1;
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
device float4 * dst4 = (device float4 *) dst;
|
||||
|
||||
// final rescale with 1/S and store to global memory
|
||||
if (sgitg == 0) {
|
||||
const half S = ss[0];
|
||||
|
||||
for (short ii = 0; ii < D4; ii += NW) {
|
||||
short i = ii + tiisg;
|
||||
dst4[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D4 + i] = (float4) sr4[i]/S;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128, 32>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256, 32>;
|
||||
|
||||
kernel void kernel_cpy_f16_f16(
|
||||
device const half * src0,
|
||||
device half * dst,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue