move mask to shared mem

This commit is contained in:
Georgi Gerganov 2024-11-07 13:12:10 +02:00
parent 2aedbb354e
commit 984928109c
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 32 additions and 23 deletions

View file

@ -3277,7 +3277,7 @@ static void ggml_metal_encode_node(
// the shared memory needed for the simdgroups to load the KV cache
// each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
//
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + nhalfs*(ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + nhalfs*(2*ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
int64_t nsgmax = 2;

View file

@ -2836,8 +2836,8 @@ kernel void kernel_flash_attn_ext(
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
const short nsg = ntg.y; // number of simdgroups
const int iq3 = tgpig[2];
@ -2848,7 +2848,7 @@ kernel void kernel_flash_attn_ext(
const short D8 = D/8;
const short D16 = D/16;
const short NW = N_SIMDWIDTH;
const short SH = (C + Q); // shared memory per simdgroup in (half)
const short SH = (2*C + Q); // shared memory per simdgroup in (half)
const short SF = sizeof(s_t)/sizeof(half);
@ -2933,9 +2933,6 @@ kernel void kernel_flash_attn_ext(
simdgroup_load(mq[i], sq + i*8, T);
}
// pointer to the mask
device const half * mp = (device const half *) (mask + iq1*nb31);
const bool has_mask = mask != q;
float slope = 1.0f;
@ -2958,6 +2955,26 @@ kernel void kernel_flash_attn_ext(
break;
}
if (has_mask) {
// used to detect blocks full of -INF
half smax = -INFINITY;
for (short j = 0; j < Q; ++j) {
device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*nb31);
const half m = pm[ic + tiisg];
ss[j*TS + C + tiisg] = m;
smax = max(smax, m);
}
smax = simd_max(smax);
if (smax == -INFINITY) {
continue;
}
}
// Q*K^T
{
for (short cc = 0; cc < C/8; ++cc) {
@ -3033,9 +3050,6 @@ kernel void kernel_flash_attn_ext(
}
}
// used to detect blocks full of -INF
float smax = -INFINITY;
// online softmax
{
float ms[Q];
@ -3052,10 +3066,10 @@ kernel void kernel_flash_attn_ext(
if (has_mask) {
// mqk = mqk + mask*slope
s += slope*mp[ic + j*nb31/sizeof(half) + tiisg]; // TODO: use ne30
//s += slope*mp[ic + j*nb31/sizeof(half) + tiisg]; // TODO: use ne30
s += slope*ss[j*TS + C + tiisg];
}
smax = simd_max(max(smax, s));
M[j] = simd_max(max(M[j], s));
ms[j] = exp(m - M[j]);
@ -3069,19 +3083,14 @@ kernel void kernel_flash_attn_ext(
// create a QxQ diagonal matrix for rescaling the output
if (tiisg < Q) {
ss[tiisg*TS + C + tiisg] = ms[tiisg];
ss[tiisg*TS + 2*C + tiisg] = ms[tiisg];
}
}
// skip -INF blocks
if (smax == -INFINITY) {
continue;
}
// O = diag(ms)*O
{
s8x8_t mm;
simdgroup_load(mm, ss + C, TS, 0, false);
simdgroup_load(mm, ss + 2*C, TS, 0, false);
#pragma unroll
for (short i = 0; i < D8; ++i) {
@ -3199,8 +3208,8 @@ kernel void kernel_flash_attn_ext(
ss[j*TS + 0] = S;
ss[j*TS + 1] = M;
ss[j*TS + C + j ] = ms0;
ss[j*TS + C + j + sg*SH] = ms1;
ss[j*TS + 2*C + j ] = ms0;
ss[j*TS + 2*C + j + sg*SH] = ms1;
}
}
@ -3209,8 +3218,8 @@ kernel void kernel_flash_attn_ext(
s8x8_t ms0;
s8x8_t ms1;
simdgroup_load(ms0, ss + C, TS, 0, false);
simdgroup_load(ms1, ss + C + sg*SH, TS, 0, false);
simdgroup_load(ms0, ss + 2*C, TS, 0, false);
simdgroup_load(ms1, ss + 2*C + sg*SH, TS, 0, false);
for (short i = 0; i < D8; ++i) {
o8x8_t t;