k_cache: be able to use Q5_1 on Metal
This commit is contained in:
parent
fef4a23e2c
commit
d68030b820
2 changed files with 76 additions and 3 deletions
|
@ -174,7 +174,7 @@ enum ggml_metal_kernel_type {
|
||||||
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
|
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
|
||||||
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
|
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
|
||||||
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
|
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
|
||||||
//GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
|
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
|
||||||
GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
|
GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
|
||||||
GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
|
GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_CONCAT,
|
GGML_METAL_KERNEL_TYPE_CONCAT,
|
||||||
|
@ -599,7 +599,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
|
||||||
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
|
||||||
|
@ -740,6 +740,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
||||||
case GGML_TYPE_Q4_0:
|
case GGML_TYPE_Q4_0:
|
||||||
case GGML_TYPE_Q4_1:
|
case GGML_TYPE_Q4_1:
|
||||||
case GGML_TYPE_Q5_0:
|
case GGML_TYPE_Q5_0:
|
||||||
|
case GGML_TYPE_Q5_1:
|
||||||
return true;
|
return true;
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
|
@ -2438,7 +2439,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||||
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
|
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
|
||||||
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
|
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
|
||||||
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break;
|
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break;
|
||||||
//case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break;
|
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break;
|
||||||
default: GGML_ASSERT(false && "not implemented");
|
default: GGML_ASSERT(false && "not implemented");
|
||||||
};
|
};
|
||||||
} break;
|
} break;
|
||||||
|
|
|
@ -2461,6 +2461,78 @@ kernel void kernel_cpy_f32_q5_0(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
kernel void kernel_cpy_f32_q5_1(
|
||||||
|
device const float * src0,
|
||||||
|
device void * dst,
|
||||||
|
constant int64_t & ne00,
|
||||||
|
constant int64_t & ne01,
|
||||||
|
constant int64_t & ne02,
|
||||||
|
constant int64_t & ne03,
|
||||||
|
constant uint64_t & nb00,
|
||||||
|
constant uint64_t & nb01,
|
||||||
|
constant uint64_t & nb02,
|
||||||
|
constant uint64_t & nb03,
|
||||||
|
constant int64_t & ne0,
|
||||||
|
constant int64_t & ne1,
|
||||||
|
constant int64_t & ne2,
|
||||||
|
constant int64_t & ne3,
|
||||||
|
constant uint64_t & nb0,
|
||||||
|
constant uint64_t & nb1,
|
||||||
|
constant uint64_t & nb2,
|
||||||
|
constant uint64_t & nb3,
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
|
uint3 ntg[[threads_per_threadgroup]]) {
|
||||||
|
const int64_t i03 = tgpig[2];
|
||||||
|
const int64_t i02 = tgpig[1];
|
||||||
|
const int64_t i01 = tgpig[0];
|
||||||
|
|
||||||
|
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||||
|
|
||||||
|
const int64_t i3 = n / (ne2*ne1*ne0);
|
||||||
|
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
||||||
|
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
||||||
|
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK5_1;
|
||||||
|
|
||||||
|
device block_q5_1 * dst_data = (device block_q5_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||||
|
|
||||||
|
for (int64_t i00 = tpitg.x*QK5_0; i00 < ne00; i00 += ntg.x*QK5_0) {
|
||||||
|
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
||||||
|
|
||||||
|
float max = src[0];
|
||||||
|
float min = src[0];
|
||||||
|
|
||||||
|
for (int j = 1; j < QK5_1; j++) {
|
||||||
|
const float v = src[j];
|
||||||
|
min = v < min ? v : min;
|
||||||
|
max = v > max ? v : max;
|
||||||
|
}
|
||||||
|
|
||||||
|
const float d = (max - min) / 31;
|
||||||
|
const float id = d ? 1.0f/d : 0.0f;
|
||||||
|
|
||||||
|
dst_data[i00/QK5_1].d = d;
|
||||||
|
dst_data[i00/QK5_1].m = min;
|
||||||
|
|
||||||
|
uint32_t qh = 0;
|
||||||
|
for (int j = 0; j < QK5_1/2; ++j) {
|
||||||
|
const float x0 = (src[0 + j] - min)*id;
|
||||||
|
const float x1 = (src[QK5_1/2 + j] - min)*id;
|
||||||
|
|
||||||
|
const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
|
||||||
|
const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
|
||||||
|
|
||||||
|
dst_data[i00/QK5_1].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
|
||||||
|
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
|
||||||
|
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
|
||||||
|
}
|
||||||
|
thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
|
||||||
|
for (int j = 0; j < 4; ++j) {
|
||||||
|
dst_data[i00/QK5_1].qh[j] = qh8[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
kernel void kernel_concat(
|
kernel void kernel_concat(
|
||||||
device const char * src0,
|
device const char * src0,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue