k_cache: be able to use Q5_1 on CODA

This commit is contained in:
Iwan Kawrakow 2024-03-20 17:33:14 +02:00
parent 5e09ce41a8
commit 5d8822b096

View file

@ -6783,9 +6783,9 @@ static __device__ void cpy_blck_f32_q5_0(const char * cxi, char * cdsti) {
dsti->d = d; dsti->d = d;
uint32_t qh = 0; uint32_t qh = 0;
for (int j = 0; j < QK4_0/2; ++j) { for (int j = 0; j < QK5_0/2; ++j) {
const float x0 = xi[0 + j]*id; const float x0 = xi[0 + j]*id;
const float x1 = xi[QK4_0/2 + j]*id; const float x1 = xi[QK5_0/2 + j]*id;
const uint8_t xi0 = min(31, (int8_t)(x0 + 16.5f)); const uint8_t xi0 = min(31, (int8_t)(x0 + 16.5f));
const uint8_t xi1 = min(31, (int8_t)(x1 + 16.5f)); const uint8_t xi1 = min(31, (int8_t)(x1 + 16.5f));
@ -6797,6 +6797,40 @@ static __device__ void cpy_blck_f32_q5_0(const char * cxi, char * cdsti) {
memcpy(dsti->qh, &qh, sizeof(qh)); memcpy(dsti->qh, &qh, sizeof(qh));
} }
static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) {
const float * xi = (const float *) cxi;
block_q5_1 * dsti = (block_q5_1 *) cdsti;
float min = xi[0];
float max = xi[0];
for (int j = 1; j < QK5_1; ++j) {
const float v = xi[j];
min = v < min ? v : min;
max = v > max ? v : max;
}
const float d = (max - min) / 31;
const float id = d ? 1.0f/d : 0.0f;
dsti->dm.x = d;
dsti->dm.y = min;
uint32_t qh = 0;
for (int j = 0; j < QK5_1/2; ++j) {
const float x0 = (xi[0 + j] - min)*id;
const float x1 = (xi[QK5_1/2 + j] - min)*id;
const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
dsti->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
}
memcpy(dsti->qh, &qh, sizeof(qh));
}
template <cpy_kernel_t cpy_blck, int qk> template <cpy_kernel_t cpy_blck, int qk>
static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne, static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@ -8541,6 +8575,17 @@ static void ggml_cpy_f32_q5_0_cuda(
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} }
static void ggml_cpy_f32_q5_1_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
GGML_ASSERT(ne % QK5_1 == 0);
const int num_blocks = ne / QK5_1;
cpy_f32_q<cpy_blck_f32_q5_1, QK5_1><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
static void ggml_cpy_f16_f16_cuda( static void ggml_cpy_f16_f16_cuda(
const char * cx, char * cdst, const int ne, const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@ -10941,6 +10986,8 @@ static void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * s
ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) { } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {