From aa237447594151d2e3da6a829e32321870c06fc8 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 11 Jun 2023 22:45:24 +0300 Subject: [PATCH] metal : Q3_K cleanup --- ggml-metal.m | 2 +- ggml-metal.metal | 67 +++++++++++++----------------------------------- 2 files changed, 19 insertions(+), 50 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 4ad9bca1c..ae56ce2e1 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -579,7 +579,7 @@ void ggml_metal_graph_compute( GGML_ASSERT(ne12 == 1); nth0 = 4; - nth1 = 16; + nth1 = 32; [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_k_f32]; } break; case GGML_TYPE_Q3_K: diff --git a/ggml-metal.metal b/ggml-metal.metal index 433319bb0..09e12a879 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -724,10 +724,9 @@ static void dequantize_row_q3_k(device const block_q3_k * x, device float * y, i assert(k % QK_K == 0); const int nb = k / QK_K; - const uint32_t kmask1 = 0x03030303; - const uint32_t kmask2 = 0x0f0f0f0f; + const uint16_t kmask1 = 0x0303; + const uint16_t kmask2 = 0x0f0f; - //uint32_t aux[4]; uint16_t aux[8]; thread const int8_t * scales = (thread const int8_t*)aux; @@ -739,12 +738,6 @@ static void dequantize_row_q3_k(device const block_q3_k * x, device float * y, i device const uint8_t * h = x[i].hmask; uint8_t m = 1; - //device const uint32_t * a = (device const uint32_t *)x[i].scales; - //aux[0] = (a[0] & kmask2) | (((a[2] >> 0) & kmask1) << 4); - //aux[1] = (a[1] & kmask2) | (((a[2] >> 2) & kmask1) << 4); - //aux[2] = ((a[0] >> 4) & kmask2) | (((a[2] >> 4) & kmask1) << 4); - //aux[3] = ((a[1] >> 4) & kmask2) | (((a[2] >> 6) & kmask1) << 4); - device const uint16_t * a = (device const uint16_t *)x[i].scales; aux[0] = (a[0] & kmask2) | (((a[4] >> 0) & kmask1) << 4); aux[1] = (a[1] & kmask2) | (((a[5] >> 0) & kmask1) << 4); @@ -1061,8 +1054,6 @@ kernel void kernel_mul_mat_q3_k_f32( uint2 tpitg[[thread_position_in_threadgroup]], uint2 tptg[[threads_per_threadgroup]]) { - //const uint32_t kmask1 = 0x03030303; - //const uint32_t kmask2 = 0x0f0f0f0f; const uint16_t kmask1 = 0x0303; const uint16_t kmask2 = 0x0f0f; @@ -1090,28 +1081,16 @@ kernel void kernel_mul_mat_q3_k_f32( const uint8_t m = 1 << (4*ip + il); const int shift = 2*il; - const int is = 8*ip + 2*il; - // is = 0 -> shift1 = 0, shift2 = 0, accessing aux[0] -> (a[0] & kmask2) | (((a[4] >> 0) & kmask1) << 4); ip = 0, il = 0 - // is = 2 -> shift1 = 0, shift2 = 0, accessing aux[1] -> (a[1] & kmask2) | (((a[5] >> 0) & kmask1) << 4); ip = 0, il = 1 - // is = 4 -> shift1 = 0, shift2 = 2, accessing aux[2] -> (a[2] & kmask2) | (((a[4] >> 2) & kmask1) << 4); ip = 0, il = 2 - // is = 6 -> shift1 = 0, shift2 = 2, accessing aux[3] -> (a[3] & kmask2) | (((a[5] >> 2) & kmask1) << 4); ip = 0, il = 3 - // is = 8 -> shift1 = 4, shift2 = 4, accessing aux[4] -> ((a[0] >> 4) & kmask2) | (((a[4] >> 4) & kmask1) << 4); ip = 1, il = 0 - // is =10 -> shift1 = 4, shift2 = 4, accessing aux[5] -> ((a[1] >> 4) & kmask2) | (((a[5] >> 4) & kmask1) << 4); ip = 1, il = 1 - // is =12 -> shift1 = 4, shift2 = 6, accessing aux[6] -> ((a[2] >> 4) & kmask2) | (((a[4] >> 6) & kmask1) << 4); ip = 1, il = 2 - // is =14 -> shift1 = 4, shift2 = 6, accessing aux[7] -> ((a[3] >> 4) & kmask2) | (((a[5] >> 6) & kmask1) << 4); ip = 1, il = 3 - - const int s_shift1 = 4*ip; - const int s_shift2 = s_shift1 + 2*(il/2); + const uint16_t s_shift1 = 4*ip; + const uint16_t s_shift2 = s_shift1 + 2*(il/2); const int ik = 4 + (il%2); - uint16_t aux[8]; - thread const int8_t * scales = (thread const int8_t*)aux; - const int q_offset = 32*ip + l0; const int y_offset = 128*ip + 32*il + l0; - float sumf = 0; + //float sumf = 0; + float sumf1 = 0, sumf2 = 0; for (int i = tpitg.x; i < nb; i += tptg.x) { const float d_all = (float)(x[i].d); @@ -1125,36 +1104,26 @@ kernel void kernel_mul_mat_q3_k_f32( float s = 0; for (int l = 0; l < n; ++l) { - s += y[l+ 0] * ((int8_t)((q[l+ 0] >> shift) & 3) - ((h[l+ 0] & m) ? 0 : 4)); + s += y[l+ 0] * ((int8_t)((q[l+ 0] >> shift) & m3) - ((h[l+ 0] & m) ? 0 : m4)); } - sumf += s * d_all * (scales[0] - 32); + float d = d_all * s; + sumf1 += d * scales[0]; + sumf2 += d; + //sumf += d_all * s * (scales[0] - 32); s = 0; for (int l = 0; l < n; ++l) { - s += y[l+16] * ((int8_t)((q[l+16] >> shift) & 3) - ((h[l+16] & m) ? 0 : 4)); + s += y[l+16] * ((int8_t)((q[l+16] >> shift) & m3) - ((h[l+16] & m) ? 0 : m4)); } - sumf += s * d_all * (scales[1] - 32); - - //const float d1 = d_all * (scales[0] - 32); - //for (int l = 0; l < n; ++l) { - // sumf += y[l+ 0] * d1 * ((int8_t)((q[l+ 0] >> shift) & 3) - ((h[l+ 0] & m) ? 0 : 4)); - //} - - //const float d2 = d_all * (scales[1] - 32); - //for (int l = 0; l < n; ++l) { - // sumf += y[l+16] * d2 * ((int8_t)((q[l+16] >> shift) & 3) - ((h[l+16] & m) ? 0 : 4)); - //} - - //float2 s = {0.f, 0.f}; - //for (int l = 0; l < n; ++l) { - // s[0] += y[l+ 0] * ((int8_t)((q[l+ 0] >> shift) & 3) - ((h[l+ 0] & m) ? 0 : 4)); - // s[1] += y[l+16] * ((int8_t)((q[l+16] >> shift) & 3) - ((h[l+16] & m) ? 0 : 4)); - //} - //sumf += d_all * (s[0] * (scales[0] - 32) + s[1] * (scales[1] - 32)); + d = d_all * s; + sumf1 += d * scales[1]; + sumf2 += d; + //sumf += d_all * s * (scales[1] - 32); } - sum[ith] = sumf; + //sum[ith] = sumf; + sum[ith] = sumf1 - 32.f*sumf2; // // Accumulate the sum from all threads in the threadgroup