metal: improvement for Q4_K driver
This commit is contained in:
parent
804c78dcc9
commit
aa4b7d29a2
1 changed files with 49 additions and 21 deletions
|
@ -972,25 +972,42 @@ template <typename addr_uint16_p,typename addr_block_q_p, typename type4x4>
|
|||
class q4_K_driver {
|
||||
public:
|
||||
uint16_t d_mask1, d_mask2, m_mask1, mask1, mask2;
|
||||
float coef1, coef2, sumy;
|
||||
float coef1, coef2, sumy1, sumy2;
|
||||
uint16_t d_loc1, d_loc2, m_loc1, m_loc2, q_offset;
|
||||
|
||||
void init(int il) {
|
||||
q_offset = (il/4) * 16 + 4 * (il%4);
|
||||
d_mask1 = il < 8 ? 0x3F3F : 0x0F0F; d_mask2 = il < 8 ? 0x0000 : 0xC0C0;
|
||||
d_loc1 = il < 8 ? il/4 : il/4 + 2; d_loc2 = il < 8 ? il/4 : il/4 - 2;
|
||||
m_mask1 = il < 8 ? 0x3F3F : 0xF0F0;
|
||||
m_loc1 = il/4 + 2; m_loc2 = il/4;
|
||||
}
|
||||
|
||||
void get_scales(addr_block_q_p xb, int il, thread float & dl1, thread float & ml1, thread float & dl2, thread float & ml2) {
|
||||
#if QK_K == 256
|
||||
const float d = (float)(xb->d);
|
||||
const float min = (float)(xb->dmin);
|
||||
addr_uint16_p sc = (addr_uint16_p)xb->scales;
|
||||
uint16_t d_int = (sc[d_loc1] & d_mask1) | ((sc[d_loc2] & d_mask2) >> 2);
|
||||
uint16_t m_int = il < 8 ? (sc[m_loc1] & m_mask1) : ((sc[m_loc1] & m_mask1) >> 4);
|
||||
m_int = m_int | ((sc[m_loc2] & d_mask2) >> 2);
|
||||
dl1 = as_type<uchar2>(d_int)[0] * d, ml1 = as_type<uchar2>(m_int)[0] * min;
|
||||
dl2 = as_type<uchar2>(d_int)[1] * d, ml2 = as_type<uchar2>(m_int)[1] * min;
|
||||
#else
|
||||
dl1 = (float)(xb->d[0]) * (xb->scales[0]&0xF); dl2 = (float)(xb->d[0]) * (xb->scales[1]&0xF);
|
||||
ml1 = (float)(xb->d[1]) * (xb->scales[0]>>4); ml2 = (float)(xb->d[1]) * (xb->scales[1]>>4);
|
||||
#endif
|
||||
}
|
||||
|
||||
void get_scales2(addr_block_q_p xb, int il, thread float & dl, thread float & ml) {
|
||||
q_offset = (il/4) * 16 + 8 * (il&1);
|
||||
mask1 = (il%4) < 2 ? 0x000F : 0x00F0; mask2 = mask1 << 8;
|
||||
coef1 = (il%4) < 2 ? 1.f : 1/16.f; coef2 = coef1 / 256.f;
|
||||
#if QK_K == 256
|
||||
d_mask1 = il < 8 ? 63 : 0x0F; d_mask2 = il < 8 ? 0 : 192;
|
||||
d_loc1 = il < 8 ? il/2 : 4 + il/2; d_loc2 = il < 8 ? il/2 : il/2 - 4;
|
||||
m_mask1 = il < 8 ? 63 : 0xF0;
|
||||
m_loc1 = il/2 + 4; m_loc2 = il/2;
|
||||
mask1 = (il%4) < 2 ? 0x000F : 0x00F0; mask2 = mask1 << 8;
|
||||
coef1 = (il%4) < 2 ? 1.f : 1/16.f; coef2 = coef1 / 256.f;
|
||||
#if QK_K == 256
|
||||
q_offset = (il/4) * 16 + 8 * (il&1);
|
||||
#else
|
||||
q_offset = 8 * (il&1);
|
||||
#endif
|
||||
}
|
||||
|
||||
void get_scales(addr_block_q_p xb, int il, thread float & dl, thread float & ml) {
|
||||
#if QK_K == 256
|
||||
const float d = (float)(xb->d);
|
||||
const float min = (float)(xb->dmin);
|
||||
uint16_t d_int = (xb->scales[d_loc1] & d_mask1) | ((xb->scales[d_loc2] & d_mask2) >> 2);
|
||||
|
@ -1004,23 +1021,34 @@ class q4_K_driver {
|
|||
}
|
||||
|
||||
void inner_product_pre(int il, thread float4x4 & yl){
|
||||
fix_y_v2(coef1, coef2, sumy, yl);
|
||||
sumy1 = 0.f; sumy2 = 0.f;
|
||||
for (int i = 0; i < 8; i += 2) {
|
||||
sumy1 += yl[i/4 ][i%4]; sumy1 += yl[i/4 ][i%4+1];
|
||||
sumy2 += yl[2+i/4][i%4]; sumy2 += yl[2+i/4][i%4+1];
|
||||
yl[i/4 ][i%4 ] = yl[i/4][i%4];
|
||||
yl[i/4 ][i%4+1] = 1/256.f * yl[i/4][i%4+1];
|
||||
yl[i/4+2][i%4 ] = 1/16.f * yl[2+i/4][i%4];
|
||||
yl[i/4+2][i%4+1] = 1/4096.f * yl[2+i/4][i%4+1];
|
||||
}
|
||||
}
|
||||
|
||||
void inner_product(addr_block_q_p xb, int il, thread float4x4 & yl, thread float & sum){
|
||||
float dl, ml;
|
||||
get_scales(xb, il, dl, ml);
|
||||
float dl1, ml1, dl2, ml2;
|
||||
float sum2 = 0.f;
|
||||
get_scales(xb, il, dl1, ml1, dl2, ml2);
|
||||
addr_uint16_p q = (addr_uint16_p)xb->qs + q_offset;
|
||||
for (int i = 0; i < 16; i += 2) {
|
||||
sum += yl[i/4][i%4] * (q[i/2] & mask1);
|
||||
sum += yl[i/4][i%4+1] * (q[i/2] & mask2);
|
||||
for (int i = 0; i < 8; i += 2) {
|
||||
sum += yl[i/4 ][i%4 ] * ((q[i/2]&0x000F));
|
||||
sum += yl[i/4 ][i%4+1] * ((q[i/2]&0x0F00));
|
||||
sum2 += yl[i/4+2][i%4 ] * ((q[i/2]&0x00F0));
|
||||
sum2 += yl[i/4+2][i%4+1] * ((q[i/2]&0xF000));
|
||||
}
|
||||
sum = dl * sum - ml * sumy;
|
||||
sum = dl1 * sum - ml1 * sumy1 + dl2 * sum2 - ml2 * sumy2;
|
||||
}
|
||||
|
||||
void dequantize(addr_block_q_p xb, int il, thread type4x4 & reg) {
|
||||
float dl, ml;
|
||||
get_scales(xb, il, dl, ml);
|
||||
get_scales2(xb, il, dl, ml);
|
||||
addr_uint16_p q = (addr_uint16_p)xb->qs + q_offset;
|
||||
for (int i = 0; i < 16; i += 2) {
|
||||
reg[i/4][i%4] = coef1 * dl * (q[i/2] & mask1) - ml;
|
||||
|
@ -1465,7 +1493,7 @@ template [[host_name("kernel_mul_mv_q4_1_f32")]] kernel mat_mv_t kernel_mat_mv<b
|
|||
template [[host_name("kernel_mul_mv_q8_0_f32")]] kernel mat_mv_t kernel_mat_mv<block_q8_0, N_DST, N_SIMDGROUP, 2, 8, q8_0_driver>;
|
||||
template [[host_name("kernel_mul_mv_q2_K_f32")]] kernel mat_mv_t kernel_mat_mv<block_q2_K, N_DST, N_SIMDGROUP, QK_NL, 8, q2_K_driver>;
|
||||
template [[host_name("kernel_mul_mv_q3_K_f32")]] kernel mat_mv_t kernel_mat_mv<block_q3_K, N_DST, N_SIMDGROUP, QK_NL, 8, q3_K_driver>;
|
||||
template [[host_name("kernel_mul_mv_q4_K_f32")]] kernel mat_mv_t kernel_mat_mv<block_q4_K, N_DST, N_SIMDGROUP, QK_NL, 8, q4_K_driver>;
|
||||
template [[host_name("kernel_mul_mv_q4_K_f32")]] kernel mat_mv_t kernel_mat_mv<block_q4_K, N_DST, N_SIMDGROUP, QK_NL, 32, q4_K_driver>;
|
||||
template [[host_name("kernel_mul_mv_q5_K_f32")]] kernel mat_mv_t kernel_mat_mv<block_q5_K, N_DST, N_SIMDGROUP, QK_NL, 8, q5_K_driver>;
|
||||
#if QK_K == 256
|
||||
template [[host_name("kernel_mul_mv_q6_K_f32")]] kernel mat_mv_t kernel_mat_mv<block_q6_K, N_DST, N_SIMDGROUP, QK_NL, 64, q6_K_driver>;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue