Optimizing Q4_K on metal
The first token always takes longer, I guess because the metal kernel is being jit-compiled. So, using n = 128 to measure time. At this point Q4_K takes 29.5 ms / token compared to 27.2 ms / token for Q4_0. Quite a bit better than the initial attempt, but still not good enough.
This commit is contained in:
parent
6f8f39fbaf
commit
1e903f6b82
4 changed files with 27 additions and 31 deletions
18
.clang-tidy
18
.clang-tidy
|
@ -1,18 +0,0 @@
|
||||||
---
|
|
||||||
Checks: >
|
|
||||||
bugprone-*,
|
|
||||||
-bugprone-easily-swappable-parameters,
|
|
||||||
-bugprone-implicit-widening-of-multiplication-result,
|
|
||||||
-bugprone-narrowing-conversions,
|
|
||||||
readability-*,
|
|
||||||
-readability-avoid-unconditional-preprocessor-if,
|
|
||||||
-readability-function-cognitive-complexity,
|
|
||||||
-readability-identifier-length,
|
|
||||||
-readability-implicit-bool-conversion,
|
|
||||||
-readability-magic-numbers,
|
|
||||||
-readability-uppercase-literal-suffix,
|
|
||||||
clang-analyzer-*,
|
|
||||||
-clang-analyzer-security.insecureAPI.DeprecatedOrUnsafeBufferHandling,
|
|
||||||
performance-*,
|
|
||||||
portability-*,
|
|
||||||
FormatStyle: none
|
|
4
Makefile
4
Makefile
|
@ -41,8 +41,8 @@ endif
|
||||||
|
|
||||||
# keep standard at C11 and C++11
|
# keep standard at C11 and C++11
|
||||||
# -Ofast tends to produce faster code, but may not be available for some compilers.
|
# -Ofast tends to produce faster code, but may not be available for some compilers.
|
||||||
#OPT = -Ofast
|
OPT = -Ofast
|
||||||
OPT = -O3
|
#OPT = -O3
|
||||||
CFLAGS = -I. $(OPT) -std=c11 -fPIC
|
CFLAGS = -I. $(OPT) -std=c11 -fPIC
|
||||||
CXXFLAGS = -I. -I./examples $(OPT) -std=c++11 -fPIC
|
CXXFLAGS = -I. -I./examples $(OPT) -std=c++11 -fPIC
|
||||||
LDFLAGS =
|
LDFLAGS =
|
||||||
|
|
|
@ -526,8 +526,8 @@ void ggml_metal_graph_compute(
|
||||||
GGML_ASSERT(ne02 == 1);
|
GGML_ASSERT(ne02 == 1);
|
||||||
GGML_ASSERT(ne12 == 1);
|
GGML_ASSERT(ne12 == 1);
|
||||||
|
|
||||||
nth0 = 2;
|
nth0 = 4;
|
||||||
nth1 = 32;
|
nth1 = 16;
|
||||||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_k_f32];
|
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_k_f32];
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
// 34.7 ms / token
|
||||||
#include <metal_stdlib>
|
#include <metal_stdlib>
|
||||||
|
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
@ -50,6 +51,19 @@ static inline uchar2 get_scale_min_k4(int j, device const uint8_t * q) {
|
||||||
}
|
}
|
||||||
return r;
|
return r;
|
||||||
}
|
}
|
||||||
|
static inline uchar4 get_scale_min_k4_2(int j, device const uint8_t * q) {
|
||||||
|
uchar4 r;
|
||||||
|
if (j < 4) {
|
||||||
|
r[0] = q[j+0] & 63; r[1] = q[j+4] & 63;
|
||||||
|
r[2] = q[j+1] & 63; r[3] = q[j+5] & 63;
|
||||||
|
} else {
|
||||||
|
r[0] = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
|
||||||
|
r[1] = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
|
||||||
|
r[2] = (q[j+5] & 0xF) | ((q[j-3] >> 6) << 4);
|
||||||
|
r[3] = (q[j+5] >> 4) | ((q[j+1] >> 6) << 4);
|
||||||
|
}
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, int k) {
|
static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, int k) {
|
||||||
assert(k % QK_K == 0);
|
assert(k % QK_K == 0);
|
||||||
|
@ -412,10 +426,10 @@ kernel void kernel_mul_mat_q4_k_f32(
|
||||||
const uint nth = tptg.x*tptg.y;
|
const uint nth = tptg.x*tptg.y;
|
||||||
const uint ith = tptg.y*tpitg.x + tpitg.y;
|
const uint ith = tptg.y*tpitg.x + tpitg.y;
|
||||||
|
|
||||||
const int tid = tpitg.y;
|
const int tid = tpitg.y; // 0...16
|
||||||
const int il = tid/8;
|
const int il = tid/4; // 0...3
|
||||||
const int ir = tid%8;
|
const int ir = tid%4; // 0...3
|
||||||
const int n = 4;
|
const int n = 8;
|
||||||
const int is = 2*il;
|
const int is = 2*il;
|
||||||
|
|
||||||
sum[ith] = 0.0f;
|
sum[ith] = 0.0f;
|
||||||
|
@ -430,14 +444,14 @@ kernel void kernel_mul_mat_q4_k_f32(
|
||||||
const float dall = (float)((x + i)->d);
|
const float dall = (float)((x + i)->d);
|
||||||
const float dmin = (float)((x + i)->dmin);
|
const float dmin = (float)((x + i)->dmin);
|
||||||
|
|
||||||
const uchar2 sc1 = get_scale_min_k4(is, scales);
|
const uchar4 sc = get_scale_min_k4_2(is, scales);
|
||||||
const float d1 = dall * sc1[0]; const float m1 = dmin * sc1[1];
|
|
||||||
const uchar2 sc2 = get_scale_min_k4(is+1, scales);
|
|
||||||
const float d2 = dall * sc2[0]; const float m2 = dmin * sc2[1];
|
|
||||||
|
|
||||||
|
float4 s = {0.f, 0.f, 0.f, 0.f};
|
||||||
for (int l = 0; l < n; ++l) {
|
for (int l = 0; l < n; ++l) {
|
||||||
sumf += y[l] * (d1 * (q[l] & 0xF) - m1) + y[l+32] * (d2 * (q[l] >> 4) - m2);
|
s[0] += y[l] * (q[l] & 0xF); s[1] += y[l];
|
||||||
|
s[2] += y[l+32] * (q[l] >> 4); s[3] += y[l+32];
|
||||||
}
|
}
|
||||||
|
sumf += dall * (s[0] * sc[0] + s[2] * sc[2]) - dmin * (s[1] * sc[1] + s[3] * sc[3]);
|
||||||
|
|
||||||
}
|
}
|
||||||
sum[ith] = sumf;
|
sum[ith] = sumf;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue