wip2
This commit is contained in:
parent
01c7f11224
commit
2335086fd3
1 changed files with 52 additions and 28 deletions
|
@ -57,10 +57,14 @@ void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg
|
|||
const ushort mask0 = il ? 0x00F0 : 0x000F;
|
||||
const ushort mask1 = mask0 << 8;
|
||||
|
||||
float4x4 reg_f;
|
||||
|
||||
for (int i = 0; i < 8; i++) {
|
||||
reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md;
|
||||
reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md;
|
||||
reg_f[i/2][2*(i%2) + 0] = d1 * (qs[i] & mask0) + md;
|
||||
reg_f[i/2][2*(i%2) + 1] = d2 * (qs[i] & mask1) + md;
|
||||
}
|
||||
|
||||
reg = (type4x4) reg_f;
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
|
@ -72,10 +76,14 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg
|
|||
const ushort mask0 = il ? 0x00F0 : 0x000F;
|
||||
const ushort mask1 = mask0 << 8;
|
||||
|
||||
float4x4 reg_f;
|
||||
|
||||
for (int i = 0; i < 8; i++) {
|
||||
reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m;
|
||||
reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m;
|
||||
reg_f[i/2][2*(i%2) + 0] = ((qs[i] & mask0) * d1) + m;
|
||||
reg_f[i/2][2*(i%2) + 1] = ((qs[i] & mask1) * d2) + m;
|
||||
}
|
||||
|
||||
reg = (type4x4) reg_f;
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
|
@ -92,6 +100,8 @@ void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg
|
|||
const int gh_mv = il ? 12 : 0;
|
||||
const int gh_bk = il ? 0 : 4;
|
||||
|
||||
float4x4 reg_f;
|
||||
|
||||
for (int i = 0; i < 8; i++) {
|
||||
// extract the 5-th bits for x0 and x1
|
||||
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
|
||||
|
@ -101,9 +111,11 @@ void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg
|
|||
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
|
||||
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
|
||||
|
||||
reg[i/2][2*(i%2)+0] = d * x0 + md;
|
||||
reg[i/2][2*(i%2)+1] = d * x1 + md;
|
||||
reg_f[i/2][2*(i%2) + 0] = d * x0 + md;
|
||||
reg_f[i/2][2*(i%2) + 1] = d * x1 + md;
|
||||
}
|
||||
|
||||
reg = (type4x4) reg_f;
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
|
@ -120,6 +132,8 @@ void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg
|
|||
const int gh_mv = il ? 12 : 0;
|
||||
const int gh_bk = il ? 0 : 4;
|
||||
|
||||
float4x4 reg_f;
|
||||
|
||||
for (int i = 0; i < 8; i++) {
|
||||
// extract the 5-th bits for x0 and x1
|
||||
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
|
||||
|
@ -129,9 +143,11 @@ void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg
|
|||
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
|
||||
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
|
||||
|
||||
reg[i/2][2*(i%2)+0] = d * x0 + m;
|
||||
reg[i/2][2*(i%2)+1] = d * x1 + m;
|
||||
reg_f[i/2][2*(i%2) + 0] = d * x0 + m;
|
||||
reg_f[i/2][2*(i%2) + 1] = d * x1 + m;
|
||||
}
|
||||
|
||||
reg = (type4x4) reg_f;
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
|
@ -139,9 +155,13 @@ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg
|
|||
device const int8_t * qs = ((device const int8_t *)xb->qs);
|
||||
const half d = xb->d;
|
||||
|
||||
float4x4 reg_f;
|
||||
|
||||
for (int i = 0; i < 16; i++) {
|
||||
reg[i/4][i%4] = (qs[i + 16*il] * d);
|
||||
reg_f[i/4][i%4] = (qs[i + 16*il] * d);
|
||||
}
|
||||
|
||||
reg = (type4x4) reg_f;
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
|
@ -2768,6 +2788,7 @@ template<
|
|||
typename s_t, // attention accumulation types
|
||||
typename s8x8_t,
|
||||
typename o_t,
|
||||
typename o4_t,
|
||||
typename o8x8_t,
|
||||
typename block_q,
|
||||
short nl_k,
|
||||
|
@ -2835,6 +2856,7 @@ kernel void kernel_flash_attn_ext(
|
|||
threadgroup q_t * sq = (threadgroup q_t *) (shared + 0*D); // holds the query data
|
||||
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared + 0*D); // same as above but in q4_t
|
||||
threadgroup o_t * so = (threadgroup o_t *) (shared + 0*D); // reuse query data for accumulation
|
||||
threadgroup o4_t * so4 = (threadgroup o4_t *) (shared + 0*D); // reuse query data for accumulation
|
||||
threadgroup s_t * ss = (threadgroup s_t *) (shared + SF*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
|
||||
|
||||
threadgroup k_t * sk = (threadgroup k_t *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
|
||||
|
@ -2854,7 +2876,7 @@ kernel void kernel_flash_attn_ext(
|
|||
if (iq1 + j < ne01) {
|
||||
sq4[j*T4 + i] = (q4_t) q4[i];
|
||||
} else {
|
||||
sq4[j*T4 + i] = 0.0h;
|
||||
sq4[j*T4 + i] = (q4_t) (float4) 0.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2867,7 +2889,7 @@ kernel void kernel_flash_attn_ext(
|
|||
// zero out shared memory SH
|
||||
for (short j = 0; j < Q; ++j) {
|
||||
for (short i = tiisg; i < SH; i += NW) {
|
||||
ss[j*TS + i] = 0.0f;
|
||||
ss[j*TS + i] = (s_t) 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -3024,12 +3046,12 @@ kernel void kernel_flash_attn_ext(
|
|||
S[j] = S[j]*ms[j] + simd_sum(vs);
|
||||
|
||||
// the P matrix from the paper (Q rows, C columns)
|
||||
ss[j*TS + tiisg] = vs;
|
||||
ss[j*TS + tiisg] = (s_t) vs;
|
||||
}
|
||||
|
||||
// create a QxQ diagonal matrix for rescaling the output
|
||||
if (tiisg < Q) {
|
||||
ss[tiisg*TS + C + tiisg] = ms[tiisg];
|
||||
ss[tiisg*TS + C + tiisg] = (s_t) ms[tiisg];
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -3114,8 +3136,8 @@ kernel void kernel_flash_attn_ext(
|
|||
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
|
||||
for (short j = 0; j < Q; ++j) {
|
||||
if (tiisg == 0) {
|
||||
ss[j*TS + 0] = S[j];
|
||||
ss[j*TS + 1] = M[j];
|
||||
ss[j*TS + 0] = (s_t) S[j];
|
||||
ss[j*TS + 1] = (s_t) M[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -3139,11 +3161,11 @@ kernel void kernel_flash_attn_ext(
|
|||
// the first simdgroup accumulates the results from the other simdgroups
|
||||
if (sgitg == 0) {
|
||||
for (short j = 0; j < Q; ++j) {
|
||||
const float S0 = ss[j*TS + 0];
|
||||
const float S1 = ss[j*TS + sg*SH + 0];
|
||||
const float S0 = (s_t) ss[j*TS + 0];
|
||||
const float S1 = (s_t) ss[j*TS + sg*SH + 0];
|
||||
|
||||
const float M0 = ss[j*TS + 1];
|
||||
const float M1 = ss[j*TS + sg*SH + 1];
|
||||
const float M0 = (s_t) ss[j*TS + 1];
|
||||
const float M1 = (s_t) ss[j*TS + sg*SH + 1];
|
||||
|
||||
M = max(M0, M1);
|
||||
|
||||
|
@ -3153,11 +3175,11 @@ kernel void kernel_flash_attn_ext(
|
|||
S = S0*ms0 + S1*ms1;
|
||||
|
||||
if (tiisg == 0) {
|
||||
ss[j*TS + 0] = S;
|
||||
ss[j*TS + 1] = M;
|
||||
ss[j*TS + 0] = (s_t) S;
|
||||
ss[j*TS + 1] = (s_t) M;
|
||||
|
||||
ss[j*TS + C + j ] = ms0;
|
||||
ss[j*TS + C + j + sg*SH] = ms1;
|
||||
ss[j*TS + C + j ] = (s_t) ms0;
|
||||
ss[j*TS + C + j + sg*SH] = (s_t) ms1;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -3196,7 +3218,7 @@ kernel void kernel_flash_attn_ext(
|
|||
const float S = ss[j*TS + 0];
|
||||
|
||||
for (short i = tiisg; i < D4; i += NW) {
|
||||
dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S;
|
||||
dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) so4[j*T4 + i]/S;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -3213,7 +3235,7 @@ kernel void kernel_flash_attn_ext(
|
|||
half, half4x4, simdgroup_half8x8, \
|
||||
half, half4x4, simdgroup_half8x8, \
|
||||
half, simdgroup_half8x8, \
|
||||
half, simdgroup_half8x8
|
||||
half, half4, simdgroup_half8x8
|
||||
#else
|
||||
#define S_T float
|
||||
#define S4_T float4
|
||||
|
@ -3225,9 +3247,11 @@ kernel void kernel_flash_attn_ext(
|
|||
half, half4x4, simdgroup_half8x8, \
|
||||
half, half4x4, simdgroup_half8x8, \
|
||||
float, simdgroup_float8x8, \
|
||||
half, simdgroup_half8x8
|
||||
half, half4, simdgroup_half8x8
|
||||
#endif
|
||||
|
||||
// TOOD: static_assert
|
||||
|
||||
typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, 1, dequantize_f16, 64>) flash_attn_ext_t;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, 1, dequantize_f16, 64>;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue