wip 5
This commit is contained in:
parent
dc2a27f2a2
commit
2aedbb354e
1 changed files with 13 additions and 13 deletions
|
@ -2879,7 +2879,7 @@ kernel void kernel_flash_attn_ext(
|
|||
if (iq1 + j < ne01) {
|
||||
sq4[j*T4 + i] = (q4_t) q4[i];
|
||||
} else {
|
||||
sq4[j*T4 + i] = (q4_t) (float4) 0.0f;
|
||||
sq4[j*T4 + i] = (q4_t) 0.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2892,7 +2892,7 @@ kernel void kernel_flash_attn_ext(
|
|||
// zero out shared memory SH
|
||||
for (short j = 0; j < Q; ++j) {
|
||||
for (short i = tiisg; i < SH; i += NW) {
|
||||
ss[j*TS + i] = (s_t) 0.0f;
|
||||
ss[j*TS + i] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -3044,7 +3044,7 @@ kernel void kernel_flash_attn_ext(
|
|||
const float m = M[j];
|
||||
|
||||
// scale and apply the logitcap / mask
|
||||
float s = ((float)(ss[j*TS + tiisg]))*scale;
|
||||
float s = ss[j*TS + tiisg]*scale;
|
||||
|
||||
if (logit_softcap != 0.0f) {
|
||||
s = logit_softcap*precise::tanh(s);
|
||||
|
@ -3064,12 +3064,12 @@ kernel void kernel_flash_attn_ext(
|
|||
S[j] = S[j]*ms[j] + simd_sum(vs);
|
||||
|
||||
// the P matrix from the paper (Q rows, C columns)
|
||||
ss[j*TS + tiisg] = (s_t) vs;
|
||||
ss[j*TS + tiisg] = vs;
|
||||
}
|
||||
|
||||
// create a QxQ diagonal matrix for rescaling the output
|
||||
if (tiisg < Q) {
|
||||
ss[tiisg*TS + C + tiisg] = (s_t) ms[tiisg];
|
||||
ss[tiisg*TS + C + tiisg] = ms[tiisg];
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -3157,8 +3157,8 @@ kernel void kernel_flash_attn_ext(
|
|||
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
|
||||
for (short j = 0; j < Q; ++j) {
|
||||
if (tiisg == 0) {
|
||||
ss[j*TS + 0] = (s_t) S[j];
|
||||
ss[j*TS + 1] = (s_t) M[j];
|
||||
ss[j*TS + 0] = S[j];
|
||||
ss[j*TS + 1] = M[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -3196,18 +3196,16 @@ kernel void kernel_flash_attn_ext(
|
|||
S = S0*ms0 + S1*ms1;
|
||||
|
||||
if (tiisg == 0) {
|
||||
ss[j*TS + 0] = (s_t) S;
|
||||
ss[j*TS + 1] = (s_t) M;
|
||||
ss[j*TS + 0] = S;
|
||||
ss[j*TS + 1] = M;
|
||||
|
||||
ss[j*TS + C + j ] = (s_t) ms0;
|
||||
ss[j*TS + C + j + sg*SH] = (s_t) ms1;
|
||||
ss[j*TS + C + j ] = ms0;
|
||||
ss[j*TS + C + j + sg*SH] = ms1;
|
||||
}
|
||||
}
|
||||
|
||||
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
|
||||
{
|
||||
o8x8_t t;
|
||||
|
||||
s8x8_t ms0;
|
||||
s8x8_t ms1;
|
||||
|
||||
|
@ -3215,6 +3213,8 @@ kernel void kernel_flash_attn_ext(
|
|||
simdgroup_load(ms1, ss + C + sg*SH, TS, 0, false);
|
||||
|
||||
for (short i = 0; i < D8; ++i) {
|
||||
o8x8_t t;
|
||||
|
||||
simdgroup_load (t, so + i*8, T, 0, false);
|
||||
simdgroup_multiply(t, ms1, t);
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue