move mask to shared mem
This commit is contained in:
parent
2aedbb354e
commit
984928109c
2 changed files with 32 additions and 23 deletions
|
@ -3277,7 +3277,7 @@ static void ggml_metal_encode_node(
|
|||
// the shared memory needed for the simdgroups to load the KV cache
|
||||
// each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
|
||||
//
|
||||
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + nhalfs*(ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
|
||||
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + nhalfs*(2*ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
|
||||
|
||||
int64_t nsgmax = 2;
|
||||
|
||||
|
|
|
@ -2836,8 +2836,8 @@ kernel void kernel_flash_attn_ext(
|
|||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
const short nsg = ntg.y; // number of simdgroups
|
||||
|
||||
const int iq3 = tgpig[2];
|
||||
|
@ -2848,7 +2848,7 @@ kernel void kernel_flash_attn_ext(
|
|||
const short D8 = D/8;
|
||||
const short D16 = D/16;
|
||||
const short NW = N_SIMDWIDTH;
|
||||
const short SH = (C + Q); // shared memory per simdgroup in (half)
|
||||
const short SH = (2*C + Q); // shared memory per simdgroup in (half)
|
||||
|
||||
const short SF = sizeof(s_t)/sizeof(half);
|
||||
|
||||
|
@ -2933,9 +2933,6 @@ kernel void kernel_flash_attn_ext(
|
|||
simdgroup_load(mq[i], sq + i*8, T);
|
||||
}
|
||||
|
||||
// pointer to the mask
|
||||
device const half * mp = (device const half *) (mask + iq1*nb31);
|
||||
|
||||
const bool has_mask = mask != q;
|
||||
|
||||
float slope = 1.0f;
|
||||
|
@ -2958,6 +2955,26 @@ kernel void kernel_flash_attn_ext(
|
|||
break;
|
||||
}
|
||||
|
||||
if (has_mask) {
|
||||
// used to detect blocks full of -INF
|
||||
half smax = -INFINITY;
|
||||
|
||||
for (short j = 0; j < Q; ++j) {
|
||||
device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*nb31);
|
||||
|
||||
const half m = pm[ic + tiisg];
|
||||
|
||||
ss[j*TS + C + tiisg] = m;
|
||||
smax = max(smax, m);
|
||||
}
|
||||
|
||||
smax = simd_max(smax);
|
||||
|
||||
if (smax == -INFINITY) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Q*K^T
|
||||
{
|
||||
for (short cc = 0; cc < C/8; ++cc) {
|
||||
|
@ -3033,9 +3050,6 @@ kernel void kernel_flash_attn_ext(
|
|||
}
|
||||
}
|
||||
|
||||
// used to detect blocks full of -INF
|
||||
float smax = -INFINITY;
|
||||
|
||||
// online softmax
|
||||
{
|
||||
float ms[Q];
|
||||
|
@ -3052,10 +3066,10 @@ kernel void kernel_flash_attn_ext(
|
|||
|
||||
if (has_mask) {
|
||||
// mqk = mqk + mask*slope
|
||||
s += slope*mp[ic + j*nb31/sizeof(half) + tiisg]; // TODO: use ne30
|
||||
//s += slope*mp[ic + j*nb31/sizeof(half) + tiisg]; // TODO: use ne30
|
||||
s += slope*ss[j*TS + C + tiisg];
|
||||
}
|
||||
|
||||
smax = simd_max(max(smax, s));
|
||||
M[j] = simd_max(max(M[j], s));
|
||||
|
||||
ms[j] = exp(m - M[j]);
|
||||
|
@ -3069,19 +3083,14 @@ kernel void kernel_flash_attn_ext(
|
|||
|
||||
// create a QxQ diagonal matrix for rescaling the output
|
||||
if (tiisg < Q) {
|
||||
ss[tiisg*TS + C + tiisg] = ms[tiisg];
|
||||
ss[tiisg*TS + 2*C + tiisg] = ms[tiisg];
|
||||
}
|
||||
}
|
||||
|
||||
// skip -INF blocks
|
||||
if (smax == -INFINITY) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// O = diag(ms)*O
|
||||
{
|
||||
s8x8_t mm;
|
||||
simdgroup_load(mm, ss + C, TS, 0, false);
|
||||
simdgroup_load(mm, ss + 2*C, TS, 0, false);
|
||||
|
||||
#pragma unroll
|
||||
for (short i = 0; i < D8; ++i) {
|
||||
|
@ -3199,8 +3208,8 @@ kernel void kernel_flash_attn_ext(
|
|||
ss[j*TS + 0] = S;
|
||||
ss[j*TS + 1] = M;
|
||||
|
||||
ss[j*TS + C + j ] = ms0;
|
||||
ss[j*TS + C + j + sg*SH] = ms1;
|
||||
ss[j*TS + 2*C + j ] = ms0;
|
||||
ss[j*TS + 2*C + j + sg*SH] = ms1;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -3209,8 +3218,8 @@ kernel void kernel_flash_attn_ext(
|
|||
s8x8_t ms0;
|
||||
s8x8_t ms1;
|
||||
|
||||
simdgroup_load(ms0, ss + C, TS, 0, false);
|
||||
simdgroup_load(ms1, ss + C + sg*SH, TS, 0, false);
|
||||
simdgroup_load(ms0, ss + 2*C, TS, 0, false);
|
||||
simdgroup_load(ms1, ss + 2*C + sg*SH, TS, 0, false);
|
||||
|
||||
for (short i = 0; i < D8; ++i) {
|
||||
o8x8_t t;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue