metal : minor
This commit is contained in:
parent
0ad44baf33
commit
134c81c78d
1 changed files with 3 additions and 6 deletions
|
@ -2127,15 +2127,14 @@ kernel void kernel_flash_attn_ext_f16(
|
|||
for (int64_t ic = C*sgitg; ic < ne11; ic += C*nsg) {
|
||||
// Q*K^T
|
||||
{
|
||||
simdgroup_half8x8 mk;
|
||||
|
||||
for (int cc = 0; cc < C/8; ++cc) {
|
||||
simdgroup_half8x8 mqk = make_filled_simdgroup_matrix<half, Q>(0.h);
|
||||
|
||||
device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
|
||||
|
||||
for (int64_t i = 0; i < D8; ++i) {
|
||||
simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true);
|
||||
simdgroup_half8x8 mk;
|
||||
simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose
|
||||
|
||||
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
|
||||
}
|
||||
|
@ -2192,7 +2191,6 @@ kernel void kernel_flash_attn_ext_f16(
|
|||
// O = diag(ms)*O
|
||||
{
|
||||
simdgroup_half8x8 mm;
|
||||
|
||||
simdgroup_load(mm, ss + C, T, 0, false);
|
||||
|
||||
for (int64_t i = 0; i < D8; ++i) {
|
||||
|
@ -2202,8 +2200,6 @@ kernel void kernel_flash_attn_ext_f16(
|
|||
|
||||
// O = O + (Q*K^T)*V
|
||||
{
|
||||
simdgroup_half8x8 mv;
|
||||
|
||||
for (int cc = 0; cc < C/8; ++cc) {
|
||||
simdgroup_half8x8 mp;
|
||||
simdgroup_load(mp, ss + 8*cc, T, 0, false);
|
||||
|
@ -2211,6 +2207,7 @@ kernel void kernel_flash_attn_ext_f16(
|
|||
for (int64_t i = 0; i < D8; ++i) {
|
||||
device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
|
||||
|
||||
simdgroup_half8x8 mv;
|
||||
simdgroup_load(mv, pv + i*8, nb21/sizeof(half), 0, false);
|
||||
|
||||
simdgroup_multiply_accumulate(lo[i], mp, mv, lo[i]);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue