diff --git a/ggml-metal.metal b/ggml-metal.metal index 2a6fcd616..afc8977d6 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -972,25 +972,42 @@ template class q4_K_driver { public: uint16_t d_mask1, d_mask2, m_mask1, mask1, mask2; - float coef1, coef2, sumy; + float coef1, coef2, sumy1, sumy2; uint16_t d_loc1, d_loc2, m_loc1, m_loc2, q_offset; void init(int il) { + q_offset = (il/4) * 16 + 4 * (il%4); + d_mask1 = il < 8 ? 0x3F3F : 0x0F0F; d_mask2 = il < 8 ? 0x0000 : 0xC0C0; + d_loc1 = il < 8 ? il/4 : il/4 + 2; d_loc2 = il < 8 ? il/4 : il/4 - 2; + m_mask1 = il < 8 ? 0x3F3F : 0xF0F0; + m_loc1 = il/4 + 2; m_loc2 = il/4; + } + + void get_scales(addr_block_q_p xb, int il, thread float & dl1, thread float & ml1, thread float & dl2, thread float & ml2) { + #if QK_K == 256 + const float d = (float)(xb->d); + const float min = (float)(xb->dmin); + addr_uint16_p sc = (addr_uint16_p)xb->scales; + uint16_t d_int = (sc[d_loc1] & d_mask1) | ((sc[d_loc2] & d_mask2) >> 2); + uint16_t m_int = il < 8 ? (sc[m_loc1] & m_mask1) : ((sc[m_loc1] & m_mask1) >> 4); + m_int = m_int | ((sc[m_loc2] & d_mask2) >> 2); + dl1 = as_type(d_int)[0] * d, ml1 = as_type(m_int)[0] * min; + dl2 = as_type(d_int)[1] * d, ml2 = as_type(m_int)[1] * min; +#else + dl1 = (float)(xb->d[0]) * (xb->scales[0]&0xF); dl2 = (float)(xb->d[0]) * (xb->scales[1]&0xF); + ml1 = (float)(xb->d[1]) * (xb->scales[0]>>4); ml2 = (float)(xb->d[1]) * (xb->scales[1]>>4); +#endif + } + + void get_scales2(addr_block_q_p xb, int il, thread float & dl, thread float & ml) { + q_offset = (il/4) * 16 + 8 * (il&1); + mask1 = (il%4) < 2 ? 0x000F : 0x00F0; mask2 = mask1 << 8; + coef1 = (il%4) < 2 ? 1.f : 1/16.f; coef2 = coef1 / 256.f; +#if QK_K == 256 d_mask1 = il < 8 ? 63 : 0x0F; d_mask2 = il < 8 ? 0 : 192; d_loc1 = il < 8 ? il/2 : 4 + il/2; d_loc2 = il < 8 ? il/2 : il/2 - 4; m_mask1 = il < 8 ? 63 : 0xF0; m_loc1 = il/2 + 4; m_loc2 = il/2; - mask1 = (il%4) < 2 ? 0x000F : 0x00F0; mask2 = mask1 << 8; - coef1 = (il%4) < 2 ? 1.f : 1/16.f; coef2 = coef1 / 256.f; -#if QK_K == 256 - q_offset = (il/4) * 16 + 8 * (il&1); -#else - q_offset = 8 * (il&1); -#endif - } - - void get_scales(addr_block_q_p xb, int il, thread float & dl, thread float & ml) { -#if QK_K == 256 const float d = (float)(xb->d); const float min = (float)(xb->dmin); uint16_t d_int = (xb->scales[d_loc1] & d_mask1) | ((xb->scales[d_loc2] & d_mask2) >> 2); @@ -1004,23 +1021,34 @@ class q4_K_driver { } void inner_product_pre(int il, thread float4x4 & yl){ - fix_y_v2(coef1, coef2, sumy, yl); + sumy1 = 0.f; sumy2 = 0.f; + for (int i = 0; i < 8; i += 2) { + sumy1 += yl[i/4 ][i%4]; sumy1 += yl[i/4 ][i%4+1]; + sumy2 += yl[2+i/4][i%4]; sumy2 += yl[2+i/4][i%4+1]; + yl[i/4 ][i%4 ] = yl[i/4][i%4]; + yl[i/4 ][i%4+1] = 1/256.f * yl[i/4][i%4+1]; + yl[i/4+2][i%4 ] = 1/16.f * yl[2+i/4][i%4]; + yl[i/4+2][i%4+1] = 1/4096.f * yl[2+i/4][i%4+1]; + } } void inner_product(addr_block_q_p xb, int il, thread float4x4 & yl, thread float & sum){ - float dl, ml; - get_scales(xb, il, dl, ml); + float dl1, ml1, dl2, ml2; + float sum2 = 0.f; + get_scales(xb, il, dl1, ml1, dl2, ml2); addr_uint16_p q = (addr_uint16_p)xb->qs + q_offset; - for (int i = 0; i < 16; i += 2) { - sum += yl[i/4][i%4] * (q[i/2] & mask1); - sum += yl[i/4][i%4+1] * (q[i/2] & mask2); + for (int i = 0; i < 8; i += 2) { + sum += yl[i/4 ][i%4 ] * ((q[i/2]&0x000F)); + sum += yl[i/4 ][i%4+1] * ((q[i/2]&0x0F00)); + sum2 += yl[i/4+2][i%4 ] * ((q[i/2]&0x00F0)); + sum2 += yl[i/4+2][i%4+1] * ((q[i/2]&0xF000)); } - sum = dl * sum - ml * sumy; + sum = dl1 * sum - ml1 * sumy1 + dl2 * sum2 - ml2 * sumy2; } void dequantize(addr_block_q_p xb, int il, thread type4x4 & reg) { float dl, ml; - get_scales(xb, il, dl, ml); + get_scales2(xb, il, dl, ml); addr_uint16_p q = (addr_uint16_p)xb->qs + q_offset; for (int i = 0; i < 16; i += 2) { reg[i/4][i%4] = coef1 * dl * (q[i/2] & mask1) - ml; @@ -1465,7 +1493,7 @@ template [[host_name("kernel_mul_mv_q4_1_f32")]] kernel mat_mv_t kernel_mat_mv; template [[host_name("kernel_mul_mv_q2_K_f32")]] kernel mat_mv_t kernel_mat_mv; template [[host_name("kernel_mul_mv_q3_K_f32")]] kernel mat_mv_t kernel_mat_mv; -template [[host_name("kernel_mul_mv_q4_K_f32")]] kernel mat_mv_t kernel_mat_mv; +template [[host_name("kernel_mul_mv_q4_K_f32")]] kernel mat_mv_t kernel_mat_mv; template [[host_name("kernel_mul_mv_q5_K_f32")]] kernel mat_mv_t kernel_mat_mv; #if QK_K == 256 template [[host_name("kernel_mul_mv_q6_K_f32")]] kernel mat_mv_t kernel_mat_mv;