wip 5
This commit is contained in:
parent
dc2a27f2a2
commit
2aedbb354e
1 changed files with 13 additions and 13 deletions
|
@ -2879,7 +2879,7 @@ kernel void kernel_flash_attn_ext(
|
||||||
if (iq1 + j < ne01) {
|
if (iq1 + j < ne01) {
|
||||||
sq4[j*T4 + i] = (q4_t) q4[i];
|
sq4[j*T4 + i] = (q4_t) q4[i];
|
||||||
} else {
|
} else {
|
||||||
sq4[j*T4 + i] = (q4_t) (float4) 0.0f;
|
sq4[j*T4 + i] = (q4_t) 0.0f;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2892,7 +2892,7 @@ kernel void kernel_flash_attn_ext(
|
||||||
// zero out shared memory SH
|
// zero out shared memory SH
|
||||||
for (short j = 0; j < Q; ++j) {
|
for (short j = 0; j < Q; ++j) {
|
||||||
for (short i = tiisg; i < SH; i += NW) {
|
for (short i = tiisg; i < SH; i += NW) {
|
||||||
ss[j*TS + i] = (s_t) 0.0f;
|
ss[j*TS + i] = 0.0f;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3044,7 +3044,7 @@ kernel void kernel_flash_attn_ext(
|
||||||
const float m = M[j];
|
const float m = M[j];
|
||||||
|
|
||||||
// scale and apply the logitcap / mask
|
// scale and apply the logitcap / mask
|
||||||
float s = ((float)(ss[j*TS + tiisg]))*scale;
|
float s = ss[j*TS + tiisg]*scale;
|
||||||
|
|
||||||
if (logit_softcap != 0.0f) {
|
if (logit_softcap != 0.0f) {
|
||||||
s = logit_softcap*precise::tanh(s);
|
s = logit_softcap*precise::tanh(s);
|
||||||
|
@ -3064,12 +3064,12 @@ kernel void kernel_flash_attn_ext(
|
||||||
S[j] = S[j]*ms[j] + simd_sum(vs);
|
S[j] = S[j]*ms[j] + simd_sum(vs);
|
||||||
|
|
||||||
// the P matrix from the paper (Q rows, C columns)
|
// the P matrix from the paper (Q rows, C columns)
|
||||||
ss[j*TS + tiisg] = (s_t) vs;
|
ss[j*TS + tiisg] = vs;
|
||||||
}
|
}
|
||||||
|
|
||||||
// create a QxQ diagonal matrix for rescaling the output
|
// create a QxQ diagonal matrix for rescaling the output
|
||||||
if (tiisg < Q) {
|
if (tiisg < Q) {
|
||||||
ss[tiisg*TS + C + tiisg] = (s_t) ms[tiisg];
|
ss[tiisg*TS + C + tiisg] = ms[tiisg];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3157,8 +3157,8 @@ kernel void kernel_flash_attn_ext(
|
||||||
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
|
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
|
||||||
for (short j = 0; j < Q; ++j) {
|
for (short j = 0; j < Q; ++j) {
|
||||||
if (tiisg == 0) {
|
if (tiisg == 0) {
|
||||||
ss[j*TS + 0] = (s_t) S[j];
|
ss[j*TS + 0] = S[j];
|
||||||
ss[j*TS + 1] = (s_t) M[j];
|
ss[j*TS + 1] = M[j];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -3196,18 +3196,16 @@ kernel void kernel_flash_attn_ext(
|
||||||
S = S0*ms0 + S1*ms1;
|
S = S0*ms0 + S1*ms1;
|
||||||
|
|
||||||
if (tiisg == 0) {
|
if (tiisg == 0) {
|
||||||
ss[j*TS + 0] = (s_t) S;
|
ss[j*TS + 0] = S;
|
||||||
ss[j*TS + 1] = (s_t) M;
|
ss[j*TS + 1] = M;
|
||||||
|
|
||||||
ss[j*TS + C + j ] = (s_t) ms0;
|
ss[j*TS + C + j ] = ms0;
|
||||||
ss[j*TS + C + j + sg*SH] = (s_t) ms1;
|
ss[j*TS + C + j + sg*SH] = ms1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
|
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
|
||||||
{
|
{
|
||||||
o8x8_t t;
|
|
||||||
|
|
||||||
s8x8_t ms0;
|
s8x8_t ms0;
|
||||||
s8x8_t ms1;
|
s8x8_t ms1;
|
||||||
|
|
||||||
|
@ -3215,6 +3213,8 @@ kernel void kernel_flash_attn_ext(
|
||||||
simdgroup_load(ms1, ss + C + sg*SH, TS, 0, false);
|
simdgroup_load(ms1, ss + C + sg*SH, TS, 0, false);
|
||||||
|
|
||||||
for (short i = 0; i < D8; ++i) {
|
for (short i = 0; i < D8; ++i) {
|
||||||
|
o8x8_t t;
|
||||||
|
|
||||||
simdgroup_load (t, so + i*8, T, 0, false);
|
simdgroup_load (t, so + i*8, T, 0, false);
|
||||||
simdgroup_multiply(t, ms1, t);
|
simdgroup_multiply(t, ms1, t);
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue