Added OpenCL DMMV kernels

This commit is contained in:
Concedo 2023-06-12 19:31:09 +08:00 committed by 0cc4m
parent f558e4c297
commit 6e20827f93

View file

@ -100,15 +100,13 @@ struct __attribute__((packed)) block_q6_K
half d; half d;
}; };
__kernel void convert_fp16_to_fp32(__global half *x, __global float *y) __kernel void convert_fp16_to_fp32(__global half* x, __global float* y) {
{
const uint i = get_global_id(0); const uint i = get_global_id(0);
y[i] = vload_half(0, &x[i]); y[i] = vload_half(0, &x[i]);
} }
void dequantize_q4_0(__global const struct block_q4_0 *x, const int ib, const int iqs, float *v0, float *v1) void dequantize_q4_0(__global const struct block_q4_0* x, const int ib, const int iqs, float* v0, float* v1) {
{
const float d = vload_half(0, &x[ib].d); const float d = vload_half(0, &x[ib].d);
const uint8_t vui = x[ib].qs[iqs]; const uint8_t vui = x[ib].qs[iqs];
@ -118,8 +116,8 @@ void dequantize_q4_0(__global const struct block_q4_0 *x, const int ib, const in
*v0 = (vi0 - 8)*d; *v0 = (vi0 - 8)*d;
*v1 = (vi1 - 8)*d; *v1 = (vi1 - 8)*d;
} void dequantize_q4_1(__global const struct block_q4_1 *x, const int ib, const int iqs, float *v0, float *v1) }
{ void dequantize_q4_1(__global const struct block_q4_1* x, const int ib, const int iqs, float* v0, float* v1) {
const float d = vload_half(0, &x[ib].d); const float d = vload_half(0, &x[ib].d);
const float m = vload_half(0, &x[ib].m); const float m = vload_half(0, &x[ib].m);
@ -130,8 +128,8 @@ void dequantize_q4_0(__global const struct block_q4_0 *x, const int ib, const in
*v0 = vi0*d + m; *v0 = vi0*d + m;
*v1 = vi1*d + m; *v1 = vi1*d + m;
} void dequantize_q5_0(__global const struct block_q5_0 *x, const int ib, const int iqs, float *v0, float *v1) }
{ void dequantize_q5_0(__global const struct block_q5_0* x, const int ib, const int iqs, float* v0, float* v1) {
const float d = vload_half(0, &x[ib].d); const float d = vload_half(0, &x[ib].d);
uint32_t qh = x[ib].qh; uint32_t qh = x[ib].qh;
@ -144,8 +142,8 @@ void dequantize_q4_0(__global const struct block_q4_0 *x, const int ib, const in
*v0 = x0*d; *v0 = x0*d;
*v1 = x1*d; *v1 = x1*d;
} void dequantize_q5_1(__global const struct block_q5_1 *x, const int ib, const int iqs, float *v0, float *v1) }
{ void dequantize_q5_1(__global const struct block_q5_1* x, const int ib, const int iqs, float* v0, float* v1) {
const float d = vload_half(0, &x[ib].d); const float d = vload_half(0, &x[ib].d);
const float m = vload_half(0, &x[ib].m); const float m = vload_half(0, &x[ib].m);
@ -159,8 +157,8 @@ void dequantize_q4_0(__global const struct block_q4_0 *x, const int ib, const in
*v0 = x0*d + m; *v0 = x0*d + m;
*v1 = x1*d + m; *v1 = x1*d + m;
} void dequantize_q8_0(__global const struct block_q8_0 *x, const int ib, const int iqs, float *v0, float *v1) }
{ void dequantize_q8_0(__global const struct block_q8_0* x, const int ib, const int iqs, float* v0, float* v1) {
const float d = vload_half(0, &x[ib].d); const float d = vload_half(0, &x[ib].d);
const int8_t vi0 = x[ib].qs[iqs + 0]; const int8_t vi0 = x[ib].qs[iqs + 0];
@ -168,8 +166,8 @@ void dequantize_q4_0(__global const struct block_q4_0 *x, const int ib, const in
*v0 = vi0*d; *v0 = vi0*d;
*v1 = vi1*d; *v1 = vi1*d;
} void convert_f16(__global half *x, const int ib, const int iqs, float *v0, float *v1) }
{ void convert_f16(__global half* x, const int ib, const int iqs, float* v0, float* v1){
*v0 = vload_half(0, &x[ib + 0]); *v0 = vload_half(0, &x[ib + 0]);
*v1 = vload_half(0, &x[ib + 1]); *v1 = vload_half(0, &x[ib + 1]);
} }
@ -397,6 +395,95 @@ void vec_dot_q3_K(__global const struct block_q3_K* x, const int ib, const int i
} }
void vec_dot_q4_K(__global const struct block_q4_K* x, const int ib, const int iqs, const __global float *yy, float *result) {
const int j = iqs / 64; // j is in 0...3
const int ir = (iqs - 64*j)/2; // ir is in 0...28 in steps of 4
const int is = 2*j; // is is in 0...6 in steps of 2
__global const float * y = yy + 64*j + ir;
__global const uint8_t * q = x[ib].qs + 32*j + ir;
const float dall = vload_half(0, &x[ib].d);
const float dmin = vload_half(0, &x[ib].dmin);
uint8_t sc, m;
get_scale_min_k4(is + 0, x[ib].scales, &sc, &m);
const float d1 = dall * sc;
const float m1 = dmin * m;
get_scale_min_k4(is + 1, x[ib].scales, &sc, &m);
const float d2 = dall * sc;
const float m2 = dmin * m;
float sum = 0;
for (int k = 0; k < 4; ++k) {
sum += y[k + 0] * (d1 * (q[k] & 0xF) - m1);
sum += y[k + 32] * (d2 * (q[k] >> 4) - m2);
}
*result = sum;
}
void vec_dot_q5_K(__global const struct block_q5_K* x, const int ib, const int iqs, const __global float *yy, float *result) {
const int j = iqs / 64;
const int ir = (iqs - 64*j)/2;
const int is = 2*j;
__global const float * y = yy + 64*j + ir;
__global const uint8_t * ql = x[ib].qs + 32*j + ir;
__global const uint8_t * qh = x[ib].qh + ir;
const float dall = vload_half(0, &x[ib].d);
const float dmin = vload_half(0, &x[ib].dmin);
uint8_t sc, m;
get_scale_min_k4(is + 0, x[ib].scales, &sc, &m);
const float d1 = dall * sc;
const float m1 = dmin * m;
get_scale_min_k4(is + 1, x[ib].scales, &sc, &m);
const float d2 = dall * sc;
const float m2 = dmin * m;
uint8_t hm = 1 << is;
float sum = 0;
for (int k = 0; k < 4; ++k) {
sum += y[k + 0] * (d1 * ((ql[k] & 0xF) + (qh[k] & hm ? 16 : 0)) - m1);
}
hm <<= 1;
for (int k = 0; k < 4; ++k) {
sum += y[k + 32] * (d2 * ((ql[k] >> 4) + (qh[k] & hm ? 16 : 0)) - m2);
}
*result = sum;
}
void vec_dot_q6_K(__global const struct block_q6_K* x, const int ib, const int iqs, const __global float *yy, float *result) {
const int ip = iqs / 128; // 0 or 1
const int il = (iqs - 128*ip)/8; // 0...15
const int is = 8*ip;
__global const float * y = yy + 128*ip + il;
const float d = vload_half(0, &x[ib].d);
__global const uint8_t * ql = x[ib].ql + 64*ip + il;
const uint8_t * qh = x[ib].qh + 32*ip + il;
__global const int8_t * sc = x[ib].scales + is;
*result = y[ 0] * d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh[ 0] >> 0) & 3) << 4)) - 32)
+ y[ 32] * d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh[ 0] >> 2) & 3) << 4)) - 32)
+ y[ 64] * d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh[ 0] >> 4) & 3) << 4)) - 32)
+ y[ 96] * d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh[ 0] >> 6) & 3) << 4)) - 32)
+ y[ 16] * d * sc[1] * ((int8_t)((ql[16] & 0xF) | (((qh[16] >> 0) & 3) << 4)) - 32)
+ y[ 48] * d * sc[3] * ((int8_t)((ql[48] & 0xF) | (((qh[16] >> 2) & 3) << 4)) - 32)
+ y[ 80] * d * sc[5] * ((int8_t)((ql[16] >> 4) | (((qh[16] >> 4) & 3) << 4)) - 32)
+ y[112] * d * sc[7] * ((int8_t)((ql[48] >> 4) | (((qh[16] >> 6) & 3) << 4)) - 32);
}
); );
@ -566,9 +653,12 @@ std::array<std::string, 2> mul_str_values = {
"mul_f32", "float" "mul_f32", "float"
}; };
std::array<std::string, 6> dmmv_k_str_values = { std::array<std::string, 15> dmmv_k_str_values = {
"dequantize_mul_mat_vec_q2_K", "struct block_q2_K", "vec_dot_q2_K", "dequantize_mul_mat_vec_q2_K", "struct block_q2_K", "vec_dot_q2_K",
"dequantize_mul_mat_vec_q3_K", "struct block_q3_K", "vec_dot_q3_K", "dequantize_mul_mat_vec_q3_K", "struct block_q3_K", "vec_dot_q3_K",
"dequantize_mul_mat_vec_q4_K", "struct block_q4_K", "vec_dot_q4_K",
"dequantize_mul_mat_vec_q5_K", "struct block_q5_K", "vec_dot_q5_K",
"dequantize_mul_mat_vec_q6_K", "struct block_q6_K", "vec_dot_q6_K",
}; };
std::string& replace(std::string& s, const std::string& from, const std::string& to) { std::string& replace(std::string& s, const std::string& from, const std::string& to) {
@ -867,6 +957,9 @@ void ggml_cl_init(void) {
CL_CHECK((convert_mul_mat_vec_f16_cl = clCreateKernel(program, "convert_mul_mat_vec_f16", &err), err)); CL_CHECK((convert_mul_mat_vec_f16_cl = clCreateKernel(program, "convert_mul_mat_vec_f16", &err), err));
CL_CHECK((dequantize_mul_mat_vec_q2_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q2_K", &err), err)); CL_CHECK((dequantize_mul_mat_vec_q2_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q2_K", &err), err));
CL_CHECK((dequantize_mul_mat_vec_q3_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q3_K", &err), err)); CL_CHECK((dequantize_mul_mat_vec_q3_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q3_K", &err), err));
CL_CHECK((dequantize_mul_mat_vec_q4_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q4_K", &err), err));
CL_CHECK((dequantize_mul_mat_vec_q5_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q5_K", &err), err));
CL_CHECK((dequantize_mul_mat_vec_q6_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q6_K", &err), err));
// mul kernel // mul kernel
CL_CHECK((mul_f32_cl = clCreateKernel(program, "mul_f32", &err), err)); CL_CHECK((mul_f32_cl = clCreateKernel(program, "mul_f32", &err), err));