ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext

This commit is contained in:
Georgi Gerganov 2024-01-31 19:17:16 +02:00
parent 2ddc9bbef1
commit 8ad92dc1ec
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
7 changed files with 79 additions and 62 deletions

View file

@ -349,9 +349,9 @@ kernel void kernel_sum_rows(
}
kernel void kernel_soft_max(
device const float * src0,
device const float * src1,
device float * dst,
device const char * src0,
device const char * src1,
device char * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
@ -366,9 +366,9 @@ kernel void kernel_soft_max(
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr;
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
device const half * pmask = src1 != src0 ? (device const half *) src1 + i01*ne00 : nullptr;
device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
// parallel max
float lmax = -INFINITY;
@ -435,14 +435,14 @@ kernel void kernel_soft_max(
}
kernel void kernel_soft_max_4(
device const float * src0,
device const float * src1,
device float * dst,
device const char * src0,
device const char * src1,
device char * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant float & scale,
threadgroup float * buf [[threadgroup(0)]],
threadgroup float * buf [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]],
@ -452,15 +452,15 @@ kernel void kernel_soft_max_4(
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
device const half4 * pmask = src1 != src0 ? (device const half4 *) src1 + i01*ne00/4 : nullptr;
device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
// parallel max
float4 lmax4 = -INFINITY;
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f));
lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4) (pmask ? pmask[i00] : 0.0f));
}
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
@ -486,7 +486,7 @@ kernel void kernel_soft_max_4(
// parallel sum
float4 lsum4 = 0.0f;
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4) (pmask ? pmask[i00] : 0.0f)) - max_val);
lsum4 += exp_psrc4;
pdst4[i00] = exp_psrc4;
}
@ -2144,13 +2144,11 @@ kernel void kernel_flash_attn_ext_f16(
}
}
const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1;
// pointer to the mask
device const float * mp = (device const float *) (mask + (ir%ne31)*nb31);
device const half * mp = (device const half *) (mask + iq1*nb31);
// prepare diagonal scale matrix
simdgroup_float8x8 mscale(scale);
simdgroup_half8x8 mscale(scale);
// loop over the KV cache
// each simdgroup handles blocks of Q rows and C columns
@ -2176,8 +2174,8 @@ kernel void kernel_flash_attn_ext_f16(
// mqk = mqk*scale + mask
for (int64_t j = 0; j < Q8; ++j) {
simdgroup_float8x8 mm;
simdgroup_load(mm, mp + 8*j*(nb31/sizeof(float)) + ic + 8*cc, nb31/sizeof(float), 0, false);
simdgroup_half8x8 mm;
simdgroup_load(mm, mp + 8*j*(nb31/sizeof(half)) + ic + 8*cc, nb31/sizeof(half), 0, false);
simdgroup_multiply_accumulate(mqk[j], mqk[j], mscale, mm);
simdgroup_store(mqk[j], ss + 8*j*T + 8*cc, T, 0, false);