This commit is contained in:
Georgi Gerganov 2024-11-07 12:32:59 +02:00
parent dc2a27f2a2
commit 2aedbb354e
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -2879,7 +2879,7 @@ kernel void kernel_flash_attn_ext(
if (iq1 + j < ne01) {
sq4[j*T4 + i] = (q4_t) q4[i];
} else {
sq4[j*T4 + i] = (q4_t) (float4) 0.0f;
sq4[j*T4 + i] = (q4_t) 0.0f;
}
}
}
@ -2892,7 +2892,7 @@ kernel void kernel_flash_attn_ext(
// zero out shared memory SH
for (short j = 0; j < Q; ++j) {
for (short i = tiisg; i < SH; i += NW) {
ss[j*TS + i] = (s_t) 0.0f;
ss[j*TS + i] = 0.0f;
}
}
@ -3044,7 +3044,7 @@ kernel void kernel_flash_attn_ext(
const float m = M[j];
// scale and apply the logitcap / mask
float s = ((float)(ss[j*TS + tiisg]))*scale;
float s = ss[j*TS + tiisg]*scale;
if (logit_softcap != 0.0f) {
s = logit_softcap*precise::tanh(s);
@ -3064,12 +3064,12 @@ kernel void kernel_flash_attn_ext(
S[j] = S[j]*ms[j] + simd_sum(vs);
// the P matrix from the paper (Q rows, C columns)
ss[j*TS + tiisg] = (s_t) vs;
ss[j*TS + tiisg] = vs;
}
// create a QxQ diagonal matrix for rescaling the output
if (tiisg < Q) {
ss[tiisg*TS + C + tiisg] = (s_t) ms[tiisg];
ss[tiisg*TS + C + tiisg] = ms[tiisg];
}
}
@ -3157,8 +3157,8 @@ kernel void kernel_flash_attn_ext(
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
for (short j = 0; j < Q; ++j) {
if (tiisg == 0) {
ss[j*TS + 0] = (s_t) S[j];
ss[j*TS + 1] = (s_t) M[j];
ss[j*TS + 0] = S[j];
ss[j*TS + 1] = M[j];
}
}
}
@ -3196,18 +3196,16 @@ kernel void kernel_flash_attn_ext(
S = S0*ms0 + S1*ms1;
if (tiisg == 0) {
ss[j*TS + 0] = (s_t) S;
ss[j*TS + 1] = (s_t) M;
ss[j*TS + 0] = S;
ss[j*TS + 1] = M;
ss[j*TS + C + j ] = (s_t) ms0;
ss[j*TS + C + j + sg*SH] = (s_t) ms1;
ss[j*TS + C + j ] = ms0;
ss[j*TS + C + j + sg*SH] = ms1;
}
}
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
{
o8x8_t t;
s8x8_t ms0;
s8x8_t ms1;
@ -3215,6 +3213,8 @@ kernel void kernel_flash_attn_ext(
simdgroup_load(ms1, ss + C + sg*SH, TS, 0, false);
for (short i = 0; i < D8; ++i) {
o8x8_t t;
simdgroup_load (t, so + i*8, T, 0, false);
simdgroup_multiply(t, ms1, t);