move mask to shared mem
This commit is contained in:
parent
2aedbb354e
commit
984928109c
2 changed files with 32 additions and 23 deletions
|
@ -3277,7 +3277,7 @@ static void ggml_metal_encode_node(
|
||||||
// the shared memory needed for the simdgroups to load the KV cache
|
// the shared memory needed for the simdgroups to load the KV cache
|
||||||
// each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
|
// each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
|
||||||
//
|
//
|
||||||
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + nhalfs*(ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
|
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + nhalfs*(2*ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
|
||||||
|
|
||||||
int64_t nsgmax = 2;
|
int64_t nsgmax = 2;
|
||||||
|
|
||||||
|
|
|
@ -2836,8 +2836,8 @@ kernel void kernel_flash_attn_ext(
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
uint3 ntg[[threads_per_threadgroup]],
|
uint3 ntg[[threads_per_threadgroup]],
|
||||||
ushort tiisg[[thread_index_in_simdgroup]],
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
const short nsg = ntg.y; // number of simdgroups
|
const short nsg = ntg.y; // number of simdgroups
|
||||||
|
|
||||||
const int iq3 = tgpig[2];
|
const int iq3 = tgpig[2];
|
||||||
|
@ -2848,7 +2848,7 @@ kernel void kernel_flash_attn_ext(
|
||||||
const short D8 = D/8;
|
const short D8 = D/8;
|
||||||
const short D16 = D/16;
|
const short D16 = D/16;
|
||||||
const short NW = N_SIMDWIDTH;
|
const short NW = N_SIMDWIDTH;
|
||||||
const short SH = (C + Q); // shared memory per simdgroup in (half)
|
const short SH = (2*C + Q); // shared memory per simdgroup in (half)
|
||||||
|
|
||||||
const short SF = sizeof(s_t)/sizeof(half);
|
const short SF = sizeof(s_t)/sizeof(half);
|
||||||
|
|
||||||
|
@ -2933,9 +2933,6 @@ kernel void kernel_flash_attn_ext(
|
||||||
simdgroup_load(mq[i], sq + i*8, T);
|
simdgroup_load(mq[i], sq + i*8, T);
|
||||||
}
|
}
|
||||||
|
|
||||||
// pointer to the mask
|
|
||||||
device const half * mp = (device const half *) (mask + iq1*nb31);
|
|
||||||
|
|
||||||
const bool has_mask = mask != q;
|
const bool has_mask = mask != q;
|
||||||
|
|
||||||
float slope = 1.0f;
|
float slope = 1.0f;
|
||||||
|
@ -2958,6 +2955,26 @@ kernel void kernel_flash_attn_ext(
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (has_mask) {
|
||||||
|
// used to detect blocks full of -INF
|
||||||
|
half smax = -INFINITY;
|
||||||
|
|
||||||
|
for (short j = 0; j < Q; ++j) {
|
||||||
|
device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*nb31);
|
||||||
|
|
||||||
|
const half m = pm[ic + tiisg];
|
||||||
|
|
||||||
|
ss[j*TS + C + tiisg] = m;
|
||||||
|
smax = max(smax, m);
|
||||||
|
}
|
||||||
|
|
||||||
|
smax = simd_max(smax);
|
||||||
|
|
||||||
|
if (smax == -INFINITY) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Q*K^T
|
// Q*K^T
|
||||||
{
|
{
|
||||||
for (short cc = 0; cc < C/8; ++cc) {
|
for (short cc = 0; cc < C/8; ++cc) {
|
||||||
|
@ -3033,9 +3050,6 @@ kernel void kernel_flash_attn_ext(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// used to detect blocks full of -INF
|
|
||||||
float smax = -INFINITY;
|
|
||||||
|
|
||||||
// online softmax
|
// online softmax
|
||||||
{
|
{
|
||||||
float ms[Q];
|
float ms[Q];
|
||||||
|
@ -3052,10 +3066,10 @@ kernel void kernel_flash_attn_ext(
|
||||||
|
|
||||||
if (has_mask) {
|
if (has_mask) {
|
||||||
// mqk = mqk + mask*slope
|
// mqk = mqk + mask*slope
|
||||||
s += slope*mp[ic + j*nb31/sizeof(half) + tiisg]; // TODO: use ne30
|
//s += slope*mp[ic + j*nb31/sizeof(half) + tiisg]; // TODO: use ne30
|
||||||
|
s += slope*ss[j*TS + C + tiisg];
|
||||||
}
|
}
|
||||||
|
|
||||||
smax = simd_max(max(smax, s));
|
|
||||||
M[j] = simd_max(max(M[j], s));
|
M[j] = simd_max(max(M[j], s));
|
||||||
|
|
||||||
ms[j] = exp(m - M[j]);
|
ms[j] = exp(m - M[j]);
|
||||||
|
@ -3069,19 +3083,14 @@ kernel void kernel_flash_attn_ext(
|
||||||
|
|
||||||
// create a QxQ diagonal matrix for rescaling the output
|
// create a QxQ diagonal matrix for rescaling the output
|
||||||
if (tiisg < Q) {
|
if (tiisg < Q) {
|
||||||
ss[tiisg*TS + C + tiisg] = ms[tiisg];
|
ss[tiisg*TS + 2*C + tiisg] = ms[tiisg];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// skip -INF blocks
|
|
||||||
if (smax == -INFINITY) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// O = diag(ms)*O
|
// O = diag(ms)*O
|
||||||
{
|
{
|
||||||
s8x8_t mm;
|
s8x8_t mm;
|
||||||
simdgroup_load(mm, ss + C, TS, 0, false);
|
simdgroup_load(mm, ss + 2*C, TS, 0, false);
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (short i = 0; i < D8; ++i) {
|
for (short i = 0; i < D8; ++i) {
|
||||||
|
@ -3199,8 +3208,8 @@ kernel void kernel_flash_attn_ext(
|
||||||
ss[j*TS + 0] = S;
|
ss[j*TS + 0] = S;
|
||||||
ss[j*TS + 1] = M;
|
ss[j*TS + 1] = M;
|
||||||
|
|
||||||
ss[j*TS + C + j ] = ms0;
|
ss[j*TS + 2*C + j ] = ms0;
|
||||||
ss[j*TS + C + j + sg*SH] = ms1;
|
ss[j*TS + 2*C + j + sg*SH] = ms1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3209,8 +3218,8 @@ kernel void kernel_flash_attn_ext(
|
||||||
s8x8_t ms0;
|
s8x8_t ms0;
|
||||||
s8x8_t ms1;
|
s8x8_t ms1;
|
||||||
|
|
||||||
simdgroup_load(ms0, ss + C, TS, 0, false);
|
simdgroup_load(ms0, ss + 2*C, TS, 0, false);
|
||||||
simdgroup_load(ms1, ss + C + sg*SH, TS, 0, false);
|
simdgroup_load(ms1, ss + 2*C + sg*SH, TS, 0, false);
|
||||||
|
|
||||||
for (short i = 0; i < D8; ++i) {
|
for (short i = 0; i < D8; ++i) {
|
||||||
o8x8_t t;
|
o8x8_t t;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue