metal: improvement for Q4_K driver

This commit is contained in:
lshzh-ww 2023-08-30 23:05:22 -04:00
parent 804c78dcc9
commit aa4b7d29a2

View file

@ -972,25 +972,42 @@ template <typename addr_uint16_p,typename addr_block_q_p, typename type4x4>
class q4_K_driver {
public:
uint16_t d_mask1, d_mask2, m_mask1, mask1, mask2;
float coef1, coef2, sumy;
float coef1, coef2, sumy1, sumy2;
uint16_t d_loc1, d_loc2, m_loc1, m_loc2, q_offset;
void init(int il) {
q_offset = (il/4) * 16 + 4 * (il%4);
d_mask1 = il < 8 ? 0x3F3F : 0x0F0F; d_mask2 = il < 8 ? 0x0000 : 0xC0C0;
d_loc1 = il < 8 ? il/4 : il/4 + 2; d_loc2 = il < 8 ? il/4 : il/4 - 2;
m_mask1 = il < 8 ? 0x3F3F : 0xF0F0;
m_loc1 = il/4 + 2; m_loc2 = il/4;
}
void get_scales(addr_block_q_p xb, int il, thread float & dl1, thread float & ml1, thread float & dl2, thread float & ml2) {
#if QK_K == 256
const float d = (float)(xb->d);
const float min = (float)(xb->dmin);
addr_uint16_p sc = (addr_uint16_p)xb->scales;
uint16_t d_int = (sc[d_loc1] & d_mask1) | ((sc[d_loc2] & d_mask2) >> 2);
uint16_t m_int = il < 8 ? (sc[m_loc1] & m_mask1) : ((sc[m_loc1] & m_mask1) >> 4);
m_int = m_int | ((sc[m_loc2] & d_mask2) >> 2);
dl1 = as_type<uchar2>(d_int)[0] * d, ml1 = as_type<uchar2>(m_int)[0] * min;
dl2 = as_type<uchar2>(d_int)[1] * d, ml2 = as_type<uchar2>(m_int)[1] * min;
#else
dl1 = (float)(xb->d[0]) * (xb->scales[0]&0xF); dl2 = (float)(xb->d[0]) * (xb->scales[1]&0xF);
ml1 = (float)(xb->d[1]) * (xb->scales[0]>>4); ml2 = (float)(xb->d[1]) * (xb->scales[1]>>4);
#endif
}
void get_scales2(addr_block_q_p xb, int il, thread float & dl, thread float & ml) {
q_offset = (il/4) * 16 + 8 * (il&1);
mask1 = (il%4) < 2 ? 0x000F : 0x00F0; mask2 = mask1 << 8;
coef1 = (il%4) < 2 ? 1.f : 1/16.f; coef2 = coef1 / 256.f;
#if QK_K == 256
d_mask1 = il < 8 ? 63 : 0x0F; d_mask2 = il < 8 ? 0 : 192;
d_loc1 = il < 8 ? il/2 : 4 + il/2; d_loc2 = il < 8 ? il/2 : il/2 - 4;
m_mask1 = il < 8 ? 63 : 0xF0;
m_loc1 = il/2 + 4; m_loc2 = il/2;
mask1 = (il%4) < 2 ? 0x000F : 0x00F0; mask2 = mask1 << 8;
coef1 = (il%4) < 2 ? 1.f : 1/16.f; coef2 = coef1 / 256.f;
#if QK_K == 256
q_offset = (il/4) * 16 + 8 * (il&1);
#else
q_offset = 8 * (il&1);
#endif
}
void get_scales(addr_block_q_p xb, int il, thread float & dl, thread float & ml) {
#if QK_K == 256
const float d = (float)(xb->d);
const float min = (float)(xb->dmin);
uint16_t d_int = (xb->scales[d_loc1] & d_mask1) | ((xb->scales[d_loc2] & d_mask2) >> 2);
@ -1004,23 +1021,34 @@ class q4_K_driver {
}
void inner_product_pre(int il, thread float4x4 & yl){
fix_y_v2(coef1, coef2, sumy, yl);
sumy1 = 0.f; sumy2 = 0.f;
for (int i = 0; i < 8; i += 2) {
sumy1 += yl[i/4 ][i%4]; sumy1 += yl[i/4 ][i%4+1];
sumy2 += yl[2+i/4][i%4]; sumy2 += yl[2+i/4][i%4+1];
yl[i/4 ][i%4 ] = yl[i/4][i%4];
yl[i/4 ][i%4+1] = 1/256.f * yl[i/4][i%4+1];
yl[i/4+2][i%4 ] = 1/16.f * yl[2+i/4][i%4];
yl[i/4+2][i%4+1] = 1/4096.f * yl[2+i/4][i%4+1];
}
}
void inner_product(addr_block_q_p xb, int il, thread float4x4 & yl, thread float & sum){
float dl, ml;
get_scales(xb, il, dl, ml);
float dl1, ml1, dl2, ml2;
float sum2 = 0.f;
get_scales(xb, il, dl1, ml1, dl2, ml2);
addr_uint16_p q = (addr_uint16_p)xb->qs + q_offset;
for (int i = 0; i < 16; i += 2) {
sum += yl[i/4][i%4] * (q[i/2] & mask1);
sum += yl[i/4][i%4+1] * (q[i/2] & mask2);
for (int i = 0; i < 8; i += 2) {
sum += yl[i/4 ][i%4 ] * ((q[i/2]&0x000F));
sum += yl[i/4 ][i%4+1] * ((q[i/2]&0x0F00));
sum2 += yl[i/4+2][i%4 ] * ((q[i/2]&0x00F0));
sum2 += yl[i/4+2][i%4+1] * ((q[i/2]&0xF000));
}
sum = dl * sum - ml * sumy;
sum = dl1 * sum - ml1 * sumy1 + dl2 * sum2 - ml2 * sumy2;
}
void dequantize(addr_block_q_p xb, int il, thread type4x4 & reg) {
float dl, ml;
get_scales(xb, il, dl, ml);
get_scales2(xb, il, dl, ml);
addr_uint16_p q = (addr_uint16_p)xb->qs + q_offset;
for (int i = 0; i < 16; i += 2) {
reg[i/4][i%4] = coef1 * dl * (q[i/2] & mask1) - ml;
@ -1465,7 +1493,7 @@ template [[host_name("kernel_mul_mv_q4_1_f32")]] kernel mat_mv_t kernel_mat_mv<b
template [[host_name("kernel_mul_mv_q8_0_f32")]] kernel mat_mv_t kernel_mat_mv<block_q8_0, N_DST, N_SIMDGROUP, 2, 8, q8_0_driver>;
template [[host_name("kernel_mul_mv_q2_K_f32")]] kernel mat_mv_t kernel_mat_mv<block_q2_K, N_DST, N_SIMDGROUP, QK_NL, 8, q2_K_driver>;
template [[host_name("kernel_mul_mv_q3_K_f32")]] kernel mat_mv_t kernel_mat_mv<block_q3_K, N_DST, N_SIMDGROUP, QK_NL, 8, q3_K_driver>;
template [[host_name("kernel_mul_mv_q4_K_f32")]] kernel mat_mv_t kernel_mat_mv<block_q4_K, N_DST, N_SIMDGROUP, QK_NL, 8, q4_K_driver>;
template [[host_name("kernel_mul_mv_q4_K_f32")]] kernel mat_mv_t kernel_mat_mv<block_q4_K, N_DST, N_SIMDGROUP, QK_NL, 32, q4_K_driver>;
template [[host_name("kernel_mul_mv_q5_K_f32")]] kernel mat_mv_t kernel_mat_mv<block_q5_K, N_DST, N_SIMDGROUP, QK_NL, 8, q5_K_driver>;
#if QK_K == 256
template [[host_name("kernel_mul_mv_q6_K_f32")]] kernel mat_mv_t kernel_mat_mv<block_q6_K, N_DST, N_SIMDGROUP, QK_NL, 64, q6_K_driver>;