From 1e903f6b821232d607514038d4bede8099cadc53 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Tue, 6 Jun 2023 20:07:37 +0300 Subject: [PATCH] Optimizing Q4_K on metal The first token always takes longer, I guess because the metal kernel is being jit-compiled. So, using n = 128 to measure time. At this point Q4_K takes 29.5 ms / token compared to 27.2 ms / token for Q4_0. Quite a bit better than the initial attempt, but still not good enough. --- .clang-tidy | 18 ------------------ Makefile | 4 ++-- ggml-metal.m | 4 ++-- ggml-metal.metal | 32 +++++++++++++++++++++++--------- 4 files changed, 27 insertions(+), 31 deletions(-) delete mode 100644 .clang-tidy diff --git a/.clang-tidy b/.clang-tidy deleted file mode 100644 index 1a42b9abc..000000000 --- a/.clang-tidy +++ /dev/null @@ -1,18 +0,0 @@ ---- -Checks: > - bugprone-*, - -bugprone-easily-swappable-parameters, - -bugprone-implicit-widening-of-multiplication-result, - -bugprone-narrowing-conversions, - readability-*, - -readability-avoid-unconditional-preprocessor-if, - -readability-function-cognitive-complexity, - -readability-identifier-length, - -readability-implicit-bool-conversion, - -readability-magic-numbers, - -readability-uppercase-literal-suffix, - clang-analyzer-*, - -clang-analyzer-security.insecureAPI.DeprecatedOrUnsafeBufferHandling, - performance-*, - portability-*, -FormatStyle: none diff --git a/Makefile b/Makefile index 0205f1959..40733b34c 100644 --- a/Makefile +++ b/Makefile @@ -41,8 +41,8 @@ endif # keep standard at C11 and C++11 # -Ofast tends to produce faster code, but may not be available for some compilers. -#OPT = -Ofast -OPT = -O3 +OPT = -Ofast +#OPT = -O3 CFLAGS = -I. $(OPT) -std=c11 -fPIC CXXFLAGS = -I. -I./examples $(OPT) -std=c++11 -fPIC LDFLAGS = diff --git a/ggml-metal.m b/ggml-metal.m index d5af5dc30..9e2454391 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -526,8 +526,8 @@ void ggml_metal_graph_compute( GGML_ASSERT(ne02 == 1); GGML_ASSERT(ne12 == 1); - nth0 = 2; - nth1 = 32; + nth0 = 4; + nth1 = 16; [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_k_f32]; } break; case GGML_TYPE_F16: diff --git a/ggml-metal.metal b/ggml-metal.metal index 8b4531521..c82d4717e 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1,3 +1,4 @@ +// 34.7 ms / token #include using namespace metal; @@ -50,6 +51,19 @@ static inline uchar2 get_scale_min_k4(int j, device const uint8_t * q) { } return r; } +static inline uchar4 get_scale_min_k4_2(int j, device const uint8_t * q) { + uchar4 r; + if (j < 4) { + r[0] = q[j+0] & 63; r[1] = q[j+4] & 63; + r[2] = q[j+1] & 63; r[3] = q[j+5] & 63; + } else { + r[0] = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4); + r[1] = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4); + r[2] = (q[j+5] & 0xF) | ((q[j-3] >> 6) << 4); + r[3] = (q[j+5] >> 4) | ((q[j+1] >> 6) << 4); + } + return r; +} static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, int k) { assert(k % QK_K == 0); @@ -412,10 +426,10 @@ kernel void kernel_mul_mat_q4_k_f32( const uint nth = tptg.x*tptg.y; const uint ith = tptg.y*tpitg.x + tpitg.y; - const int tid = tpitg.y; - const int il = tid/8; - const int ir = tid%8; - const int n = 4; + const int tid = tpitg.y; // 0...16 + const int il = tid/4; // 0...3 + const int ir = tid%4; // 0...3 + const int n = 8; const int is = 2*il; sum[ith] = 0.0f; @@ -430,14 +444,14 @@ kernel void kernel_mul_mat_q4_k_f32( const float dall = (float)((x + i)->d); const float dmin = (float)((x + i)->dmin); - const uchar2 sc1 = get_scale_min_k4(is, scales); - const float d1 = dall * sc1[0]; const float m1 = dmin * sc1[1]; - const uchar2 sc2 = get_scale_min_k4(is+1, scales); - const float d2 = dall * sc2[0]; const float m2 = dmin * sc2[1]; + const uchar4 sc = get_scale_min_k4_2(is, scales); + float4 s = {0.f, 0.f, 0.f, 0.f}; for (int l = 0; l < n; ++l) { - sumf += y[l] * (d1 * (q[l] & 0xF) - m1) + y[l+32] * (d2 * (q[l] >> 4) - m2); + s[0] += y[l] * (q[l] & 0xF); s[1] += y[l]; + s[2] += y[l+32] * (q[l] >> 4); s[3] += y[l+32]; } + sumf += dall * (s[0] * sc[0] + s[2] * sc[2]) - dmin * (s[1] * sc[1] + s[3] * sc[3]); } sum[ith] = sumf;