Add ability to use Q5_0, Q5_1, and IQ4_NL for quantized K cache (#6183)

* k_cache: be able to use Q5_0

* k_cache: be able to use Q5_1 on CODA

* k_cache: be able to use Q5_0 on Metal

* k_cache: be able to use Q5_1 on Metal

* k_cache: be able to use IQ4_NL - just CUDA for now

* k_cache: be able to use IQ4_NL on Metal

* k_cache: add newly added supported types to llama-bench and CUDA supports_op

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow 2024-03-21 08:27:57 +01:00 committed by GitHub
parent c5b8595e3f
commit 76aa30a263
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 424 additions and 15 deletions

View file

@ -6757,6 +6757,123 @@ static __device__ void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
}
}
static __device__ void cpy_blck_f32_q5_0(const char * cxi, char * cdsti) {
const float * xi = (const float *) cxi;
block_q5_0 * dsti = (block_q5_0 *) cdsti;
float amax = 0.0f;
float vmax = 0.0f;
for (int j = 0; j < QK5_0; ++j) {
const float v = xi[j];
if (amax < fabsf(v)) {
amax = fabsf(v);
vmax = v;
}
}
const float d = vmax / -16;
const float id = d ? 1.0f/d : 0.0f;
dsti->d = d;
uint32_t qh = 0;
for (int j = 0; j < QK5_0/2; ++j) {
const float x0 = xi[0 + j]*id;
const float x1 = xi[QK5_0/2 + j]*id;
const uint8_t xi0 = min(31, (int8_t)(x0 + 16.5f));
const uint8_t xi1 = min(31, (int8_t)(x1 + 16.5f));
dsti->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
}
memcpy(dsti->qh, &qh, sizeof(qh));
}
static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) {
const float * xi = (const float *) cxi;
block_q5_1 * dsti = (block_q5_1 *) cdsti;
float min = xi[0];
float max = xi[0];
for (int j = 1; j < QK5_1; ++j) {
const float v = xi[j];
min = v < min ? v : min;
max = v > max ? v : max;
}
const float d = (max - min) / 31;
const float id = d ? 1.0f/d : 0.0f;
dsti->dm.x = d;
dsti->dm.y = min;
uint32_t qh = 0;
for (int j = 0; j < QK5_1/2; ++j) {
const float x0 = (xi[0 + j] - min)*id;
const float x1 = (xi[QK5_1/2 + j] - min)*id;
const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
dsti->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
}
memcpy(dsti->qh, &qh, sizeof(qh));
}
static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) {
if (x <= val[0]) return 0;
if (x >= val[n-1]) return n-1;
int ml = 0, mu = n-1;
while (mu-ml > 1) {
int mav = (ml+mu)/2;
if (x < val[mav]) mu = mav; else ml = mav;
}
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
}
static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
const float * xi = (const float *) cxi;
block_iq4_nl * dsti = (block_iq4_nl *) cdsti;
float amax = 0.0f;
float vmax = 0.0f;
for (int j = 0; j < QK4_NL; ++j) {
const float v = xi[j];
if (amax < fabsf(v)) {
amax = fabsf(v);
vmax = v;
}
}
float d = vmax / kvalues_iq4nl[0];
const float id = d ? 1.0f/d : 0.0f;
float sumqx = 0, sumq2 = 0;
for (int j = 0; j < QK4_NL/2; ++j) {
const float x0 = xi[0 + j]*id;
const float x1 = xi[QK4_NL/2 + j]*id;
const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl, x0);
const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl, x1);
dsti->qs[j] = xi0 | (xi1 << 4);
const float v0 = kvalues_iq4nl[xi0];
const float v1 = kvalues_iq4nl[xi1];
const float w0 = xi[0 + j]*xi[0 + j];
const float w1 = xi[QK4_NL/2 + j]*xi[QK4_NL/2 + j];
sumqx += w0*v0*xi[j] + w1*v1*xi[QK4_NL/2 + j];
sumq2 += w0*v0*v0 + w1*v1*v1;
}
dsti->d = sumq2 > 0 ? sumqx/sumq2 : d;
}
template <cpy_kernel_t cpy_blck, int qk>
static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@ -8490,6 +8607,39 @@ static void ggml_cpy_f32_q4_1_cuda(
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
static void ggml_cpy_f32_q5_0_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
GGML_ASSERT(ne % QK5_0 == 0);
const int num_blocks = ne / QK5_0;
cpy_f32_q<cpy_blck_f32_q5_0, QK5_0><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
static void ggml_cpy_f32_q5_1_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
GGML_ASSERT(ne % QK5_1 == 0);
const int num_blocks = ne / QK5_1;
cpy_f32_q<cpy_blck_f32_q5_1, QK5_1><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
static void ggml_cpy_f32_iq4_nl_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
GGML_ASSERT(ne % QK4_NL == 0);
const int num_blocks = ne / QK4_NL;
cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
static void ggml_cpy_f16_f16_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@ -10888,6 +11038,12 @@ static void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * s
ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
@ -11304,6 +11460,15 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_1) {
return true;
}
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_0) {
return true;
}
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_1) {
return true;
}
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) {
return true;
}
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
return true;
}