Optimizing Q4_K on metal

The first token always takes longer, I guess because
the metal kernel is being jit-compiled.
So, using n = 128 to measure time.

At this point Q4_K takes 29.5 ms / token
compared to 27.2 ms / token for Q4_0.
Quite a bit better than the initial attempt,
but still not good enough.
This commit is contained in:
Iwan Kawrakow 2023-06-06 20:07:37 +03:00
parent 6f8f39fbaf
commit 1e903f6b82
4 changed files with 27 additions and 31 deletions

View file

@ -1,18 +0,0 @@
---
Checks: >
bugprone-*,
-bugprone-easily-swappable-parameters,
-bugprone-implicit-widening-of-multiplication-result,
-bugprone-narrowing-conversions,
readability-*,
-readability-avoid-unconditional-preprocessor-if,
-readability-function-cognitive-complexity,
-readability-identifier-length,
-readability-implicit-bool-conversion,
-readability-magic-numbers,
-readability-uppercase-literal-suffix,
clang-analyzer-*,
-clang-analyzer-security.insecureAPI.DeprecatedOrUnsafeBufferHandling,
performance-*,
portability-*,
FormatStyle: none

View file

@ -41,8 +41,8 @@ endif
# keep standard at C11 and C++11 # keep standard at C11 and C++11
# -Ofast tends to produce faster code, but may not be available for some compilers. # -Ofast tends to produce faster code, but may not be available for some compilers.
#OPT = -Ofast OPT = -Ofast
OPT = -O3 #OPT = -O3
CFLAGS = -I. $(OPT) -std=c11 -fPIC CFLAGS = -I. $(OPT) -std=c11 -fPIC
CXXFLAGS = -I. -I./examples $(OPT) -std=c++11 -fPIC CXXFLAGS = -I. -I./examples $(OPT) -std=c++11 -fPIC
LDFLAGS = LDFLAGS =

View file

@ -526,8 +526,8 @@ void ggml_metal_graph_compute(
GGML_ASSERT(ne02 == 1); GGML_ASSERT(ne02 == 1);
GGML_ASSERT(ne12 == 1); GGML_ASSERT(ne12 == 1);
nth0 = 2; nth0 = 4;
nth1 = 32; nth1 = 16;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_k_f32]; [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_k_f32];
} break; } break;
case GGML_TYPE_F16: case GGML_TYPE_F16:

View file

@ -1,3 +1,4 @@
// 34.7 ms / token
#include <metal_stdlib> #include <metal_stdlib>
using namespace metal; using namespace metal;
@ -50,6 +51,19 @@ static inline uchar2 get_scale_min_k4(int j, device const uint8_t * q) {
} }
return r; return r;
} }
static inline uchar4 get_scale_min_k4_2(int j, device const uint8_t * q) {
uchar4 r;
if (j < 4) {
r[0] = q[j+0] & 63; r[1] = q[j+4] & 63;
r[2] = q[j+1] & 63; r[3] = q[j+5] & 63;
} else {
r[0] = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
r[1] = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
r[2] = (q[j+5] & 0xF) | ((q[j-3] >> 6) << 4);
r[3] = (q[j+5] >> 4) | ((q[j+1] >> 6) << 4);
}
return r;
}
static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, int k) { static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, int k) {
assert(k % QK_K == 0); assert(k % QK_K == 0);
@ -412,10 +426,10 @@ kernel void kernel_mul_mat_q4_k_f32(
const uint nth = tptg.x*tptg.y; const uint nth = tptg.x*tptg.y;
const uint ith = tptg.y*tpitg.x + tpitg.y; const uint ith = tptg.y*tpitg.x + tpitg.y;
const int tid = tpitg.y; const int tid = tpitg.y; // 0...16
const int il = tid/8; const int il = tid/4; // 0...3
const int ir = tid%8; const int ir = tid%4; // 0...3
const int n = 4; const int n = 8;
const int is = 2*il; const int is = 2*il;
sum[ith] = 0.0f; sum[ith] = 0.0f;
@ -430,14 +444,14 @@ kernel void kernel_mul_mat_q4_k_f32(
const float dall = (float)((x + i)->d); const float dall = (float)((x + i)->d);
const float dmin = (float)((x + i)->dmin); const float dmin = (float)((x + i)->dmin);
const uchar2 sc1 = get_scale_min_k4(is, scales); const uchar4 sc = get_scale_min_k4_2(is, scales);
const float d1 = dall * sc1[0]; const float m1 = dmin * sc1[1];
const uchar2 sc2 = get_scale_min_k4(is+1, scales);
const float d2 = dall * sc2[0]; const float m2 = dmin * sc2[1];
float4 s = {0.f, 0.f, 0.f, 0.f};
for (int l = 0; l < n; ++l) { for (int l = 0; l < n; ++l) {
sumf += y[l] * (d1 * (q[l] & 0xF) - m1) + y[l+32] * (d2 * (q[l] >> 4) - m2); s[0] += y[l] * (q[l] & 0xF); s[1] += y[l];
s[2] += y[l+32] * (q[l] >> 4); s[3] += y[l+32];
} }
sumf += dall * (s[0] * sc[0] + s[2] * sc[2]) - dmin * (s[1] * sc[1] + s[3] * sc[3]);
} }
sum[ith] = sumf; sum[ith] = sumf;