Deduplicate dequant kernels
This commit is contained in:
parent
67dbd356b6
commit
0df55da4ca
1 changed files with 46 additions and 94 deletions
140
ggml-opencl.cpp
140
ggml-opencl.cpp
|
@ -74,88 +74,6 @@ __kernel void convert_fp16_to_fp32(__global half* x, __global float* y) {
|
|||
y[i] = vload_half(0, &x[i]);
|
||||
}
|
||||
|
||||
|
||||
__kernel void dequantize_row_q4_0(__global struct block_q4_0* x, __global float* y) {
|
||||
constant uint qk = QK4_0;
|
||||
|
||||
const uint i = get_global_id(0) / qk;
|
||||
const uint j = get_local_id(0);
|
||||
|
||||
const float d = x[i].d;
|
||||
|
||||
const int x0 = (x[i].qs[j] & 0xf) - 8;
|
||||
const int x1 = (x[i].qs[j] >> 4) - 8;
|
||||
|
||||
y[i*qk + j + 0 ] = x0*d;
|
||||
y[i*qk + j + qk/2] = x1*d;
|
||||
}
|
||||
|
||||
__kernel void dequantize_row_q4_1(__global struct block_q4_1* x, __global float* y) {
|
||||
constant uint qk = QK4_1;
|
||||
|
||||
const uint i = get_global_id(0) / qk;
|
||||
const uint j = get_local_id(0);
|
||||
|
||||
const float d = x[i].d;
|
||||
const float m = x[i].m;
|
||||
|
||||
const int x0 = (x[i].qs[j] & 0xf);
|
||||
const int x1 = (x[i].qs[j] >> 4);
|
||||
|
||||
y[i*qk + j + 0 ] = x0*d + m;
|
||||
y[i*qk + j + qk/2] = x1*d + m;
|
||||
}
|
||||
|
||||
__kernel void dequantize_row_q5_0(__global struct block_q5_0* x, __global float* y) {
|
||||
constant uint qk = QK5_0;
|
||||
|
||||
const uint i = get_global_id(0) / qk;
|
||||
const uint j = get_local_id(0);
|
||||
|
||||
const float d = vload_half(0, (__global half*) &x[i].d);
|
||||
|
||||
uint32_t qh = x[i].qh;
|
||||
|
||||
const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
|
||||
const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
|
||||
|
||||
const int32_t x0 = ((x[i].qs[j] & 0xf) | xh_0) - 16;
|
||||
const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16;
|
||||
|
||||
y[i*qk + j + 0 ] = x0*d;
|
||||
y[i*qk + j + qk/2] = x1*d;
|
||||
}
|
||||
|
||||
__kernel void dequantize_row_q5_1(__global struct block_q5_1* x, __global float* y) {
|
||||
constant uint qk = QK5_1;
|
||||
|
||||
const uint i = get_global_id(0) / qk;
|
||||
const uint j = get_local_id(0);
|
||||
|
||||
const float d = vload_half(0, (__global half*) &x[i].d);
|
||||
const float m = vload_half(0, (__global half*) &x[i].m);
|
||||
|
||||
uint32_t qh = x[i].qh;
|
||||
|
||||
const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
|
||||
const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
|
||||
|
||||
const int x0 = (x[i].qs[j] & 0xf) | xh_0;
|
||||
const int x1 = (x[i].qs[j] >> 4) | xh_1;
|
||||
|
||||
y[i*qk + j + 0 ] = x0*d + m;
|
||||
y[i*qk + j + qk/2] = x1*d + m;
|
||||
}
|
||||
|
||||
__kernel void dequantize_row_q8_0(__global struct block_q8_0* x, __global float* y) {
|
||||
constant uint qk = QK8_0;
|
||||
const uint i = get_global_id(0) / qk;
|
||||
const uint j = get_local_id(0);
|
||||
|
||||
const float d = x[i].d;
|
||||
y[i*qk + j] = x[i].qs[j]*d;
|
||||
}
|
||||
|
||||
void dequantize_q4_0(__global const struct block_q4_0* x, const int ib, const int iqs, float* v0, float* v1) {
|
||||
const float d = x[ib].d;
|
||||
|
||||
|
@ -223,6 +141,30 @@ void convert_f16(__global half* x, const int ib, const int iqs, float* v0, float
|
|||
}
|
||||
);
|
||||
|
||||
std::string dequant_template = MULTILINE_QUOTE(
|
||||
__kernel void KERNEL_NAME(__global X_TYPE* x, __global float* y) {
|
||||
const int i = get_group_id(0)*get_local_size(0) + get_local_id(0)*2;
|
||||
|
||||
if (i >= get_global_size(0)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint qk = QUANT_K;
|
||||
const uint qr = QUANT_R;
|
||||
|
||||
const int ib = i/qk; // block index
|
||||
const int iqs = (i%qk)/qr; // quant index
|
||||
const int iybs = i - i%qk; // y block start index
|
||||
const int y_offset = qr == 1 ? 1 : qk/2;
|
||||
|
||||
// dequantize
|
||||
float v0, v1;
|
||||
DEQUANT_FUNC(x, ib, iqs, &v0, &v1);
|
||||
y[iybs + iqs + 0] = v0;
|
||||
y[iybs + iqs + y_offset] = v1;
|
||||
}
|
||||
);
|
||||
|
||||
std::string dequant_mul_mat_vec_template = MULTILINE_QUOTE(
|
||||
__kernel void KERNEL_NAME(__global X_TYPE* x, __local float* tmp, __global float* y, __global float* dst, const int ncols) {
|
||||
const int block_size = get_local_size(0);
|
||||
|
@ -265,10 +207,19 @@ __kernel void KERNEL_NAME(__global X_TYPE* x, __local float* tmp, __global float
|
|||
}
|
||||
);
|
||||
|
||||
std::array<std::string, 5> dequant_mul_mat_vec_str_keys = {
|
||||
std::array<std::string, 5> dequant_str_keys = {
|
||||
"KERNEL_NAME", "X_TYPE", "QUANT_K", "QUANT_R", "DEQUANT_FUNC"
|
||||
};
|
||||
|
||||
std::array<std::string, 30> dequant_str_values = {
|
||||
"dequantize_row_q4_0", "struct block_q4_0", "QK4_0", "QR4_0", "dequantize_q4_0",
|
||||
"dequantize_row_q4_1", "struct block_q4_1", "QK4_1", "QR4_1", "dequantize_q4_1",
|
||||
"dequantize_row_q5_0", "struct block_q5_0", "QK5_0", "QR5_0", "dequantize_q5_0",
|
||||
"dequantize_row_q5_1", "struct block_q5_1", "QK5_1", "QR5_1", "dequantize_q5_1",
|
||||
"dequantize_row_q8_0", "struct block_q8_0", "QK8_0", "QR8_0", "dequantize_q8_0",
|
||||
"convert_row_f16", "half", "32", "1", "convert_f16"
|
||||
};
|
||||
|
||||
std::array<std::string, 30> dequant_mul_mat_vec_str_values = {
|
||||
"dequantize_mul_mat_vec_q4_0", "struct block_q4_0", "QK4_0", "QR4_0", "dequantize_q4_0",
|
||||
"dequantize_mul_mat_vec_q4_1", "struct block_q4_1", "QK4_1", "QR4_1", "dequantize_q4_1",
|
||||
|
@ -290,12 +241,15 @@ std::string& replace(std::string& s, const std::string& from, const std::string&
|
|||
std::string generate_kernels() {
|
||||
std::stringstream src;
|
||||
src << program_source << '\n';
|
||||
for (size_t i = 0; i < dequant_mul_mat_vec_str_values.size(); i += dequant_mul_mat_vec_str_keys.size()) {
|
||||
std::string kernel = dequant_mul_mat_vec_template;
|
||||
for (size_t j = 0; j < dequant_mul_mat_vec_str_keys.size(); j++) {
|
||||
replace(kernel, dequant_mul_mat_vec_str_keys[j], dequant_mul_mat_vec_str_values[i + j]);
|
||||
for (size_t i = 0; i < dequant_str_values.size(); i += dequant_str_keys.size()) {
|
||||
std::string dequant_kernel = dequant_template;
|
||||
std::string dmmv_kernel = dequant_mul_mat_vec_template;
|
||||
for (size_t j = 0; j < dequant_str_keys.size(); j++) {
|
||||
replace(dequant_kernel, dequant_str_keys[j], dequant_str_values[i + j]);
|
||||
replace(dmmv_kernel, dequant_str_keys[j], dequant_mul_mat_vec_str_values[i + j]);
|
||||
}
|
||||
src << kernel << '\n';
|
||||
src << dequant_kernel << '\n';
|
||||
src << dmmv_kernel << '\n';
|
||||
}
|
||||
return src.str();
|
||||
}
|
||||
|
@ -314,7 +268,7 @@ static cl_device_id device;
|
|||
static cl_context context;
|
||||
static cl_command_queue queue;
|
||||
static cl_program program;
|
||||
static cl_kernel convert_fp16_to_fp32_cl;
|
||||
static cl_kernel convert_row_f16_cl;
|
||||
static cl_kernel dequantize_row_q4_0_cl, dequantize_row_q4_1_cl, dequantize_row_q5_0_cl, dequantize_row_q5_1_cl, dequantize_row_q8_0_cl;
|
||||
static cl_kernel dequantize_mul_mat_vec_q4_0_cl, dequantize_mul_mat_vec_q4_1_cl, dequantize_mul_mat_vec_q5_0_cl, dequantize_mul_mat_vec_q5_1_cl, dequantize_mul_mat_vec_q8_0_cl, convert_mul_mat_vec_f16_cl;
|
||||
static bool fp16_support;
|
||||
|
@ -396,7 +350,7 @@ void ggml_cl_init(void) {
|
|||
program = build_program_from_source(context, device, kernel_src.c_str());
|
||||
|
||||
// FP16 to FP32 kernel
|
||||
convert_fp16_to_fp32_cl = clCreateKernel(program, "convert_fp16_to_fp32", &err);
|
||||
convert_row_f16_cl = clCreateKernel(program, "convert_row_f16", &err);
|
||||
CL_CHECK(err, "clCreateKernel");
|
||||
|
||||
// Dequantize kernels
|
||||
|
@ -439,7 +393,7 @@ static cl_kernel* ggml_get_to_fp32_cl(ggml_type type) {
|
|||
case GGML_TYPE_Q8_0:
|
||||
return &dequantize_row_q8_0_cl;
|
||||
case GGML_TYPE_F16:
|
||||
return &convert_fp16_to_fp32_cl;
|
||||
return &convert_row_f16_cl;
|
||||
default:
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -776,7 +730,6 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
|
|||
CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Y, 0, src1, i03, i02, NULL), "ggml_cl_h2d_tensor_2d");
|
||||
|
||||
// compute
|
||||
// dequantize_mul_mat_vec(__global void * vx, __local float* tmp, __global float * y, __global float * dst, __global int ncols, __global int vx_type) {
|
||||
const size_t global = ne01 * CL_DMMV_BLOCK_SIZE;
|
||||
const size_t local = CL_DMMV_BLOCK_SIZE;
|
||||
const cl_int ncols = ne00;
|
||||
|
@ -790,11 +743,10 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
|
|||
} else { // general dequantization kernel + CLBlast matrix matrix multiplication
|
||||
// convert src0 to fp32 on device
|
||||
const size_t global = x_ne;
|
||||
const size_t local = ggml_blck_size(type) / 2;
|
||||
CL_CHECK(clSetKernelArg(*to_fp32_cl, 0, sizeof(cl_mem), &d_Q), "clSetKernelArg");
|
||||
CL_CHECK(clSetKernelArg(*to_fp32_cl, 1, sizeof(cl_mem), &d_X), "clSetKernelArg");
|
||||
CL_CHECK(clFinish(queue), "clFinish");
|
||||
CL_CHECK(clEnqueueNDRangeKernel(queue, *to_fp32_cl, 1, NULL, &global, type == GGML_TYPE_F16 ? NULL : &local, 0, NULL, NULL), "clEnqueueNDRangeKernel");
|
||||
CL_CHECK(clEnqueueNDRangeKernel(queue, *to_fp32_cl, 1, NULL, &global, NULL, 0, NULL, NULL), "clEnqueueNDRangeKernel");
|
||||
|
||||
// copy src1 to device
|
||||
CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Y, 0, src1, i03, i02, NULL), "ggml_cl_h2d_tensor_2d");
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue