This commit is contained in:
Georgi Gerganov 2024-11-06 22:04:07 +02:00
parent 01c7f11224
commit 2335086fd3
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -57,10 +57,14 @@ void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg
const ushort mask0 = il ? 0x00F0 : 0x000F; const ushort mask0 = il ? 0x00F0 : 0x000F;
const ushort mask1 = mask0 << 8; const ushort mask1 = mask0 << 8;
for (int i=0;i<8;i++) { float4x4 reg_f;
reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md;
reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md; for (int i = 0; i < 8; i++) {
reg_f[i/2][2*(i%2) + 0] = d1 * (qs[i] & mask0) + md;
reg_f[i/2][2*(i%2) + 1] = d2 * (qs[i] & mask1) + md;
} }
reg = (type4x4) reg_f;
} }
template <typename type4x4> template <typename type4x4>
@ -72,10 +76,14 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg
const ushort mask0 = il ? 0x00F0 : 0x000F; const ushort mask0 = il ? 0x00F0 : 0x000F;
const ushort mask1 = mask0 << 8; const ushort mask1 = mask0 << 8;
for (int i=0;i<8;i++) { float4x4 reg_f;
reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m;
reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m; for (int i = 0; i < 8; i++) {
reg_f[i/2][2*(i%2) + 0] = ((qs[i] & mask0) * d1) + m;
reg_f[i/2][2*(i%2) + 1] = ((qs[i] & mask1) * d2) + m;
} }
reg = (type4x4) reg_f;
} }
template <typename type4x4> template <typename type4x4>
@ -92,6 +100,8 @@ void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg
const int gh_mv = il ? 12 : 0; const int gh_mv = il ? 12 : 0;
const int gh_bk = il ? 0 : 4; const int gh_bk = il ? 0 : 4;
float4x4 reg_f;
for (int i = 0; i < 8; i++) { for (int i = 0; i < 8; i++) {
// extract the 5-th bits for x0 and x1 // extract the 5-th bits for x0 and x1
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10; const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
@ -101,9 +111,11 @@ void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0); const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
reg[i/2][2*(i%2)+0] = d * x0 + md; reg_f[i/2][2*(i%2) + 0] = d * x0 + md;
reg[i/2][2*(i%2)+1] = d * x1 + md; reg_f[i/2][2*(i%2) + 1] = d * x1 + md;
} }
reg = (type4x4) reg_f;
} }
template <typename type4x4> template <typename type4x4>
@ -120,6 +132,8 @@ void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg
const int gh_mv = il ? 12 : 0; const int gh_mv = il ? 12 : 0;
const int gh_bk = il ? 0 : 4; const int gh_bk = il ? 0 : 4;
float4x4 reg_f;
for (int i = 0; i < 8; i++) { for (int i = 0; i < 8; i++) {
// extract the 5-th bits for x0 and x1 // extract the 5-th bits for x0 and x1
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10; const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
@ -129,9 +143,11 @@ void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0); const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
reg[i/2][2*(i%2)+0] = d * x0 + m; reg_f[i/2][2*(i%2) + 0] = d * x0 + m;
reg[i/2][2*(i%2)+1] = d * x1 + m; reg_f[i/2][2*(i%2) + 1] = d * x1 + m;
} }
reg = (type4x4) reg_f;
} }
template <typename type4x4> template <typename type4x4>
@ -139,9 +155,13 @@ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg
device const int8_t * qs = ((device const int8_t *)xb->qs); device const int8_t * qs = ((device const int8_t *)xb->qs);
const half d = xb->d; const half d = xb->d;
float4x4 reg_f;
for (int i = 0; i < 16; i++) { for (int i = 0; i < 16; i++) {
reg[i/4][i%4] = (qs[i + 16*il] * d); reg_f[i/4][i%4] = (qs[i + 16*il] * d);
} }
reg = (type4x4) reg_f;
} }
template <typename type4x4> template <typename type4x4>
@ -2768,6 +2788,7 @@ template<
typename s_t, // attention accumulation types typename s_t, // attention accumulation types
typename s8x8_t, typename s8x8_t,
typename o_t, typename o_t,
typename o4_t,
typename o8x8_t, typename o8x8_t,
typename block_q, typename block_q,
short nl_k, short nl_k,
@ -2835,6 +2856,7 @@ kernel void kernel_flash_attn_ext(
threadgroup q_t * sq = (threadgroup q_t *) (shared + 0*D); // holds the query data threadgroup q_t * sq = (threadgroup q_t *) (shared + 0*D); // holds the query data
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared + 0*D); // same as above but in q4_t threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared + 0*D); // same as above but in q4_t
threadgroup o_t * so = (threadgroup o_t *) (shared + 0*D); // reuse query data for accumulation threadgroup o_t * so = (threadgroup o_t *) (shared + 0*D); // reuse query data for accumulation
threadgroup o4_t * so4 = (threadgroup o4_t *) (shared + 0*D); // reuse query data for accumulation
threadgroup s_t * ss = (threadgroup s_t *) (shared + SF*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix threadgroup s_t * ss = (threadgroup s_t *) (shared + SF*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
threadgroup k_t * sk = (threadgroup k_t *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory threadgroup k_t * sk = (threadgroup k_t *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
@ -2854,7 +2876,7 @@ kernel void kernel_flash_attn_ext(
if (iq1 + j < ne01) { if (iq1 + j < ne01) {
sq4[j*T4 + i] = (q4_t) q4[i]; sq4[j*T4 + i] = (q4_t) q4[i];
} else { } else {
sq4[j*T4 + i] = 0.0h; sq4[j*T4 + i] = (q4_t) (float4) 0.0f;
} }
} }
} }
@ -2867,7 +2889,7 @@ kernel void kernel_flash_attn_ext(
// zero out shared memory SH // zero out shared memory SH
for (short j = 0; j < Q; ++j) { for (short j = 0; j < Q; ++j) {
for (short i = tiisg; i < SH; i += NW) { for (short i = tiisg; i < SH; i += NW) {
ss[j*TS + i] = 0.0f; ss[j*TS + i] = (s_t) 0.0f;
} }
} }
@ -3024,12 +3046,12 @@ kernel void kernel_flash_attn_ext(
S[j] = S[j]*ms[j] + simd_sum(vs); S[j] = S[j]*ms[j] + simd_sum(vs);
// the P matrix from the paper (Q rows, C columns) // the P matrix from the paper (Q rows, C columns)
ss[j*TS + tiisg] = vs; ss[j*TS + tiisg] = (s_t) vs;
} }
// create a QxQ diagonal matrix for rescaling the output // create a QxQ diagonal matrix for rescaling the output
if (tiisg < Q) { if (tiisg < Q) {
ss[tiisg*TS + C + tiisg] = ms[tiisg]; ss[tiisg*TS + C + tiisg] = (s_t) ms[tiisg];
} }
} }
@ -3114,8 +3136,8 @@ kernel void kernel_flash_attn_ext(
// these are needed for reducing the results from the simdgroups (reuse the ss buffer) // these are needed for reducing the results from the simdgroups (reuse the ss buffer)
for (short j = 0; j < Q; ++j) { for (short j = 0; j < Q; ++j) {
if (tiisg == 0) { if (tiisg == 0) {
ss[j*TS + 0] = S[j]; ss[j*TS + 0] = (s_t) S[j];
ss[j*TS + 1] = M[j]; ss[j*TS + 1] = (s_t) M[j];
} }
} }
} }
@ -3139,11 +3161,11 @@ kernel void kernel_flash_attn_ext(
// the first simdgroup accumulates the results from the other simdgroups // the first simdgroup accumulates the results from the other simdgroups
if (sgitg == 0) { if (sgitg == 0) {
for (short j = 0; j < Q; ++j) { for (short j = 0; j < Q; ++j) {
const float S0 = ss[j*TS + 0]; const float S0 = (s_t) ss[j*TS + 0];
const float S1 = ss[j*TS + sg*SH + 0]; const float S1 = (s_t) ss[j*TS + sg*SH + 0];
const float M0 = ss[j*TS + 1]; const float M0 = (s_t) ss[j*TS + 1];
const float M1 = ss[j*TS + sg*SH + 1]; const float M1 = (s_t) ss[j*TS + sg*SH + 1];
M = max(M0, M1); M = max(M0, M1);
@ -3153,11 +3175,11 @@ kernel void kernel_flash_attn_ext(
S = S0*ms0 + S1*ms1; S = S0*ms0 + S1*ms1;
if (tiisg == 0) { if (tiisg == 0) {
ss[j*TS + 0] = S; ss[j*TS + 0] = (s_t) S;
ss[j*TS + 1] = M; ss[j*TS + 1] = (s_t) M;
ss[j*TS + C + j ] = ms0; ss[j*TS + C + j ] = (s_t) ms0;
ss[j*TS + C + j + sg*SH] = ms1; ss[j*TS + C + j + sg*SH] = (s_t) ms1;
} }
} }
@ -3196,7 +3218,7 @@ kernel void kernel_flash_attn_ext(
const float S = ss[j*TS + 0]; const float S = ss[j*TS + 0];
for (short i = tiisg; i < D4; i += NW) { for (short i = tiisg; i < D4; i += NW) {
dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S; dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) so4[j*T4 + i]/S;
} }
} }
} }
@ -3213,7 +3235,7 @@ kernel void kernel_flash_attn_ext(
half, half4x4, simdgroup_half8x8, \ half, half4x4, simdgroup_half8x8, \
half, half4x4, simdgroup_half8x8, \ half, half4x4, simdgroup_half8x8, \
half, simdgroup_half8x8, \ half, simdgroup_half8x8, \
half, simdgroup_half8x8 half, half4, simdgroup_half8x8
#else #else
#define S_T float #define S_T float
#define S4_T float4 #define S4_T float4
@ -3225,9 +3247,11 @@ kernel void kernel_flash_attn_ext(
half, half4x4, simdgroup_half8x8, \ half, half4x4, simdgroup_half8x8, \
half, half4x4, simdgroup_half8x8, \ half, half4x4, simdgroup_half8x8, \
float, simdgroup_float8x8, \ float, simdgroup_float8x8, \
half, simdgroup_half8x8 half, half4, simdgroup_half8x8
#endif #endif
// TOOD: static_assert
typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, 1, dequantize_f16, 64>) flash_attn_ext_t; typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, 1, dequantize_f16, 64>) flash_attn_ext_t;
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, 1, dequantize_f16, 64>; template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, 1, dequantize_f16, 64>;