k_cache: be able to use IQ4_NL - just CUDA for now

This commit is contained in:
Iwan Kawrakow 2024-03-20 18:47:18 +02:00
parent d68030b820
commit 9711e1eed2
2 changed files with 64 additions and 0 deletions

View file

@ -1590,6 +1590,9 @@ static ggml_type kv_cache_type_from_str(const std::string & s) {
if (s == "q4_1") { if (s == "q4_1") {
return GGML_TYPE_Q4_1; return GGML_TYPE_Q4_1;
} }
if (s == "iq4_nl") {
return GGML_TYPE_IQ4_NL;
}
if (s == "q5_0") { if (s == "q5_0") {
return GGML_TYPE_Q5_0; return GGML_TYPE_Q5_0;
} }

View file

@ -6831,6 +6831,54 @@ static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) {
memcpy(dsti->qh, &qh, sizeof(qh)); memcpy(dsti->qh, &qh, sizeof(qh));
} }
static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) {
if (x <= val[0]) return 0;
if (x >= val[n-1]) return n-1;
int ml = 0, mu = n-1;
while (mu-ml > 1) {
int mav = (ml+mu)/2;
if (x < val[mav]) mu = mav; else ml = mav;
}
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
}
static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
const float * xi = (const float *) cxi;
block_iq4_nl * dsti = (block_iq4_nl *) cdsti;
float amax = 0.0f;
float vmax = 0.0f;
for (int j = 0; j < QK4_NL; ++j) {
const float v = xi[j];
if (amax < fabsf(v)) {
amax = fabsf(v);
vmax = v;
}
}
float d = vmax / kvalues_iq4nl[0];
const float id = d ? 1.0f/d : 0.0f;
float sumqx = 0, sumq2 = 0;
for (int j = 0; j < QK4_NL/2; ++j) {
const float x0 = xi[0 + j]*id;
const float x1 = xi[QK4_NL/2 + j]*id;
const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl, x0);
const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl, x1);
dsti->qs[j] = xi0 | (xi1 << 4);
const float v0 = kvalues_iq4nl[xi0];
const float v1 = kvalues_iq4nl[xi1];
const float w0 = xi[0 + j]*xi[0 + j];
const float w1 = xi[QK4_NL/2 + j]*xi[QK4_NL/2 + j];
sumqx += w0*v0*xi[j] + w1*v1*xi[QK4_NL/2 + j];
sumq2 += w0*v0*v0 + w1*v1*v1;
}
dsti->d = sumq2 > 0 ? sumqx/sumq2 : d;
}
template <cpy_kernel_t cpy_blck, int qk> template <cpy_kernel_t cpy_blck, int qk>
static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne, static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@ -8586,6 +8634,17 @@ static void ggml_cpy_f32_q5_1_cuda(
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} }
static void ggml_cpy_f32_iq4_nl_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
GGML_ASSERT(ne % QK4_NL == 0);
const int num_blocks = ne / QK4_NL;
cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
static void ggml_cpy_f16_f16_cuda( static void ggml_cpy_f16_f16_cuda(
const char * cx, char * cdst, const int ne, const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@ -10986,6 +11045,8 @@ static void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * s
ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) { } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) { } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {