Q3_K for QK_K = 64
This commit is contained in:
parent
8dba28c00a
commit
0099570f04
2 changed files with 50 additions and 18 deletions
11
ggml-metal.m
11
ggml-metal.m
|
@ -743,10 +743,17 @@ void ggml_metal_graph_compute(
|
||||||
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
|
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
}
|
}
|
||||||
else if (src0t == GGML_TYPE_Q3_K || src0t == GGML_TYPE_Q5_K) {
|
else if (src0t == GGML_TYPE_Q3_K) {
|
||||||
|
#ifdef GGML_QKK_64
|
||||||
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
|
#else
|
||||||
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
else if (src0t == GGML_TYPE_Q5_K) {
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
}
|
}
|
||||||
else if (src0t == GGML_TYPE_Q3_K || src0t == GGML_TYPE_Q6_K) {
|
else if (src0t == GGML_TYPE_Q6_K) {
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
} else {
|
} else {
|
||||||
[encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
|
[encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
|
||||||
|
|
|
@ -351,7 +351,7 @@ kernel void kernel_rms_norm(
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
// broadcast, simd group number is ntg / 32
|
// broadcast, simd group number is ntg / 32
|
||||||
for (int i = ntg / 32 / 2; i > 0; i /= 2) {
|
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
||||||
if (tpitg < i) {
|
if (tpitg < i) {
|
||||||
sum[tpitg] += sum[tpitg + i];
|
sum[tpitg] += sum[tpitg + i];
|
||||||
}
|
}
|
||||||
|
@ -1339,6 +1339,7 @@ kernel void kernel_mul_mat_q2_K_f32(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if QK_K == 256
|
||||||
kernel void kernel_mul_mat_q3_K_f32(
|
kernel void kernel_mul_mat_q3_K_f32(
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
|
@ -1363,8 +1364,6 @@ kernel void kernel_mul_mat_q3_K_f32(
|
||||||
|
|
||||||
float yl[16];
|
float yl[16];
|
||||||
|
|
||||||
#if QK_K == 256
|
|
||||||
|
|
||||||
const uint16_t kmask1 = 0x0303;
|
const uint16_t kmask1 = 0x0303;
|
||||||
const uint16_t kmask2 = 0x0f0f;
|
const uint16_t kmask2 = 0x0f0f;
|
||||||
|
|
||||||
|
@ -1392,7 +1391,7 @@ kernel void kernel_mul_mat_q3_K_f32(
|
||||||
const int q_offset = 32*ip + l0;
|
const int q_offset = 32*ip + l0;
|
||||||
const int y_offset = 128*ip + 32*il + l0;
|
const int y_offset = 128*ip + 32*il + l0;
|
||||||
|
|
||||||
const int step = sizeof(block_q3_K) * nb;
|
const int step = sizeof(block_q3_K) * nb / 2;
|
||||||
|
|
||||||
device const float * y1 = yy + ix*QK_K + y_offset;
|
device const float * y1 = yy + ix*QK_K + y_offset;
|
||||||
|
|
||||||
|
@ -1434,17 +1433,47 @@ kernel void kernel_mul_mat_q3_K_f32(
|
||||||
sumf1[row] += d * scales[1];
|
sumf1[row] += d * scales[1];
|
||||||
sumf2[row] += d;
|
sumf2[row] += d;
|
||||||
|
|
||||||
q += step/2;
|
q += step;
|
||||||
h += step/2;
|
h += step;
|
||||||
a += step/2;
|
a += step;
|
||||||
dh += step/2;
|
dh += step;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
y1 += 2 * QK_K;
|
y1 += 2 * QK_K;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for (int row = 0; row < 2; ++row) {
|
||||||
|
const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift);
|
||||||
|
const float tot = simd_sum(sumf);
|
||||||
|
if (tiisg == 0) {
|
||||||
|
dst[r1*ne0 + first_row + row] = tot;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
#else
|
#else
|
||||||
|
kernel void kernel_mul_mat_q3_K_f32(
|
||||||
|
device const void * src0,
|
||||||
|
device const float * src1,
|
||||||
|
device float * dst,
|
||||||
|
constant int64_t & ne00,
|
||||||
|
constant int64_t & ne10,
|
||||||
|
constant int64_t & ne0,
|
||||||
|
constant int64_t & ne1,
|
||||||
|
uint2 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
|
|
||||||
|
const int nb = ne00/QK_K;
|
||||||
|
|
||||||
|
const int64_t r0 = tgpig.x;
|
||||||
|
const int64_t r1 = tgpig.y;
|
||||||
|
|
||||||
|
const int row = 2 * r0 + sgitg;
|
||||||
|
|
||||||
|
device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb;
|
||||||
|
device const float * yy = (device const float *) src1 + r1*ne10;
|
||||||
const int ix = tiisg/4;
|
const int ix = tiisg/4;
|
||||||
const int il = 4 * (tiisg%4);// 0, 4, 8, 12
|
const int il = 4 * (tiisg%4);// 0, 4, 8, 12
|
||||||
const int im = il/8; // 0, 0, 1, 1
|
const int im = il/8; // 0, 0, 1, 1
|
||||||
|
@ -1481,17 +1510,14 @@ kernel void kernel_mul_mat_q3_K_f32(
|
||||||
}
|
}
|
||||||
const float sumf = sum1[0] + sum1[1] * 1.f/4.f + sum1[2] * 1.f/16.f + sum1[3] * 1.f/64.f +
|
const float sumf = sum1[0] + sum1[1] * 1.f/4.f + sum1[2] * 1.f/16.f + sum1[3] * 1.f/64.f +
|
||||||
(sum2[0] + sum2[1] * 1.f/4.f + sum2[2] * 1.f/16.f + sum2[3] * 1.f/64.f) * 1.f/256.f;
|
(sum2[0] + sum2[1] * 1.f/4.f + sum2[2] * 1.f/16.f + sum2[3] * 1.f/64.f) * 1.f/256.f;
|
||||||
#endif
|
|
||||||
|
|
||||||
for (int row = 0; row < 2; ++row) {
|
const float tot = simd_sum(sumf);
|
||||||
const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift);
|
if (tiisg == 0) {
|
||||||
const float tot = simd_sum(sumf);
|
dst[r1*ne0 + row] = tot;
|
||||||
if (tiisg == 0) {
|
|
||||||
dst[r1*ne0 + first_row + row] = tot;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
#if QK_K == 256
|
#if QK_K == 256
|
||||||
kernel void kernel_mul_mat_q4_K_f32(
|
kernel void kernel_mul_mat_q4_K_f32(
|
||||||
|
@ -1789,7 +1815,6 @@ kernel void kernel_mul_mat_q5_K_f32(
|
||||||
|
|
||||||
for (int i = ix; i < nb; i += 8) {
|
for (int i = ix; i < nb; i += 8) {
|
||||||
|
|
||||||
float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
|
||||||
for (int l = 0; l < 4; ++l) {
|
for (int l = 0; l < 4; ++l) {
|
||||||
yl[l+0] = y[l+ 0];
|
yl[l+0] = y[l+ 0];
|
||||||
yl[l+4] = y[l+16];
|
yl[l+4] = y[l+16];
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue