metal : minor

This commit is contained in:
Georgi Gerganov 2024-01-28 22:23:40 +02:00
parent 0ad44baf33
commit 134c81c78d
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -2127,15 +2127,14 @@ kernel void kernel_flash_attn_ext_f16(
for (int64_t ic = C*sgitg; ic < ne11; ic += C*nsg) {
// Q*K^T
{
simdgroup_half8x8 mk;
for (int cc = 0; cc < C/8; ++cc) {
simdgroup_half8x8 mqk = make_filled_simdgroup_matrix<half, Q>(0.h);
device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
for (int64_t i = 0; i < D8; ++i) {
simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true);
simdgroup_half8x8 mk;
simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
}
@ -2192,7 +2191,6 @@ kernel void kernel_flash_attn_ext_f16(
// O = diag(ms)*O
{
simdgroup_half8x8 mm;
simdgroup_load(mm, ss + C, T, 0, false);
for (int64_t i = 0; i < D8; ++i) {
@ -2202,8 +2200,6 @@ kernel void kernel_flash_attn_ext_f16(
// O = O + (Q*K^T)*V
{
simdgroup_half8x8 mv;
for (int cc = 0; cc < C/8; ++cc) {
simdgroup_half8x8 mp;
simdgroup_load(mp, ss + 8*cc, T, 0, false);
@ -2211,6 +2207,7 @@ kernel void kernel_flash_attn_ext_f16(
for (int64_t i = 0; i < D8; ++i) {
device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
simdgroup_half8x8 mv;
simdgroup_load(mv, pv + i*8, nb21/sizeof(half), 0, false);
simdgroup_multiply_accumulate(lo[i], mp, mv, lo[i]);