From 309af7fce92c071cd6ef872cb8ec1e802fee641f Mon Sep 17 00:00:00 2001 From: 0cc4m Date: Mon, 24 Apr 2023 07:16:43 +0200 Subject: [PATCH] Add q4_2 and q4_3 CLBlast support, improve code --- ggml-opencl.cpp | 73 ++++++++++++++++++++++------------------- ggml_clblast_dequant.cl | 68 +++++++++++++++++++++++++++++--------- 2 files changed, 93 insertions(+), 48 deletions(-) diff --git a/ggml-opencl.cpp b/ggml-opencl.cpp index fa462ff8e..904e32ac9 100644 --- a/ggml-opencl.cpp +++ b/ggml-opencl.cpp @@ -30,8 +30,7 @@ cl_device_id device; cl_context context; cl_command_queue queue; cl_program program; -cl_kernel kernel_q4_0, kernel_q4_1; -bool cl_initialized = false; +cl_kernel kernel_q4_0, kernel_q4_1, kernel_q4_2, kernel_q4_3; size_t cl_size_a = 0, cl_size_b = 0, cl_size_c = 0; static cl_buffer g_cl_buffer_pool[MAX_CL_BUFFERS]; @@ -127,16 +126,9 @@ void ggml_cl_init() { clGetDeviceInfo(device, CL_DEVICE_NAME, sizeof(device_buffer), &device_buffer, NULL); printf("Using Platform: %s Device: %s\n", platform_buffer, device_buffer); context = clCreateContext(NULL, 1, &device, NULL, NULL, &err); - if (err != CL_SUCCESS) { - printf("Error creating OpenCL context: %d\n", err); - fflush(stdout); - } + CL_CHECK(err, "clCreateContext"); queue = clCreateCommandQueue(context, device, 0, &err); - - if (err != CL_SUCCESS) { - printf("Error creating OpenCL Command Queue: %d\n", err); - fflush(stdout); - } + CL_CHECK(err, "clCreateCommandQueue"); free(platforms); free(devices); @@ -145,16 +137,13 @@ void ggml_cl_init() { // Prepare dequantize kernels kernel_q4_0 = clCreateKernel(program, "dequantize_row_q4_0", &err); - if(err < 0) { - printf("Error creating OpenCL dequantize q4_0 kernel: %d\n", err); - fflush(stdout); - }; + CL_CHECK(err, "clCreateKernel"); kernel_q4_1 = clCreateKernel(program, "dequantize_row_q4_1", &err); - if(err < 0) { - printf("Error creating OpenCL dequantize q4_1 kernel: %d\n", err); - fflush(stdout); - }; - cl_initialized = true; + CL_CHECK(err, "clCreateKernel"); + kernel_q4_2 = clCreateKernel(program, "dequantize_row_q4_2", &err); + CL_CHECK(err, "clCreateKernel"); + kernel_q4_3 = clCreateKernel(program, "dequantize_row_q4_3", &err); + CL_CHECK(err, "clCreateKernel"); } void ggml_cl_sgemm_wrapper(const CLBlastLayout order, const CLBlastTranspose trans_a, const CLBlastTranspose trans_b, const int m, const int n, const int k, const float alpha, const void *host_a, const int lda, const float *host_b, const int ldb, const float beta, float *host_c, const int ldc, const int btype) { @@ -166,10 +155,36 @@ void ggml_cl_sgemm_wrapper(const CLBlastLayout order, const CLBlastTranspose tra events[2] = NULL; events[3] = NULL; - bool dequant = (btype == 2 || btype == 3); - cl_kernel kernel = btype == 2 ? kernel_q4_0 : kernel_q4_1; + cl_kernel kernel; + size_t global, local, qb_size; + bool dequant = btype >= 2 && btype < 6; + if (dequant) { + global = n * k; + + switch (btype) { + case 2: + kernel = kernel_q4_0; + local = 16; + qb_size = global * (sizeof(float) + local) / 32; + break; + case 3: + kernel = kernel_q4_1; + local = 16; + qb_size = global * (sizeof(float) * 2 + local) / 32; + break; + case 4: + kernel = kernel_q4_2; + local = 8; + qb_size = global * (sizeof(short) + local) / 16; + break; + case 5: + kernel = kernel_q4_3; + local = 8; + qb_size = global * (sizeof(short) * 2 + local) / 16; + break; + } + } - size_t global = n * k, local = 16, qb_size; cl_mem cl_buffer_a, cl_buffer_qb, cl_buffer_b, cl_buffer_c; size_t buf_size_a, buf_size_qb, buf_size_b, buf_size_c; @@ -177,7 +192,6 @@ void ggml_cl_sgemm_wrapper(const CLBlastLayout order, const CLBlastTranspose tra // Prepare buffers cl_buffer_a = ggml_cl_pool_malloc(m * k * sizeof(float), &buf_size_a); if (dequant) { - qb_size = global * (sizeof(float) * (btype == 2 ? 1 : 2) + 16) / 32; cl_buffer_qb = ggml_cl_pool_malloc(qb_size, &buf_size_qb); } cl_buffer_b = ggml_cl_pool_malloc(n*k*sizeof(float), &buf_size_b); @@ -186,23 +200,16 @@ void ggml_cl_sgemm_wrapper(const CLBlastLayout order, const CLBlastTranspose tra if (dequant) { err = clSetKernelArg(kernel, 0, sizeof(cl_mem), &cl_buffer_qb); err |= clSetKernelArg(kernel, 1, sizeof(cl_mem), &cl_buffer_b); - if(err < 0) { - printf("Error setting OpenCL kernel args: %d\n", err); - fflush(stdout); - } + CL_CHECK(err, "clSetKernelArg"); clEnqueueWriteBuffer(queue, cl_buffer_qb, CL_FALSE, 0, qb_size, host_b, 0, NULL, events + 1); } else { clEnqueueWriteBuffer(queue, cl_buffer_b, CL_FALSE, 0, n*k*sizeof(float), host_b, 0, NULL, events + 1); } clEnqueueWriteBuffer(queue, cl_buffer_a, CL_FALSE, 0, m*k*sizeof(float), host_a, 0, NULL, events); - clEnqueueWriteBuffer(queue, cl_buffer_c, CL_FALSE, 0, m*n*sizeof(float), host_c, 0, NULL, events + 2); if (dequant) { err = clEnqueueNDRangeKernel(queue, kernel, 1, NULL, &global, &local, 1, events + 1, events + 3); - if(err < 0) { - printf("Error enqueueing OpenCL dequantize kernel: %d\n", err); - fflush(stdout); - } + CL_CHECK(err, "clEnqueueNDRangeKernel"); } clWaitForEvents(dequant ? 4 : 3, events); clReleaseEvent(events[0]); diff --git a/ggml_clblast_dequant.cl b/ggml_clblast_dequant.cl index 47d39f8a3..99474fdb3 100644 --- a/ggml_clblast_dequant.cl +++ b/ggml_clblast_dequant.cl @@ -1,27 +1,26 @@ #define MULTILINE_QUOTE(...) #__VA_ARGS__ const char * clblast_dequant = MULTILINE_QUOTE( -struct __attribute__ ((packed)) block_q4_0 +struct block_q4_0 { float d; uchar qs[16]; }; __kernel void dequantize_row_q4_0(__global struct block_q4_0* blocks, __global float* result) { - uint i, l; - i = get_global_id(0) / 32; - l = get_local_id(0); + const uint i = get_global_id(0) / 32; + const uint l = get_local_id(0); - float d = blocks[i].d; + const float d = blocks[i].d; - uchar vi = blocks[i].qs[l]; + const uchar vi = blocks[i].qs[l]; - uint index = i*32 + l*2; + const uint index = i*32 + l*2; result[index + 0] = ((vi & 0xf) - 8)*d; result[index + 1] = ((vi >> 4) - 8)*d; } -struct __attribute__ ((packed)) block_q4_1 +struct block_q4_1 { float d; float m; @@ -29,16 +28,55 @@ struct __attribute__ ((packed)) block_q4_1 }; __kernel void dequantize_row_q4_1(__global struct block_q4_1* blocks, __global float* result) { - uint i, l; - i = get_global_id(0) / 32; - l = get_local_id(0); + const uint i = get_global_id(0) / 32; + const uint l = get_local_id(0); - float d = blocks[i].d; - float m = blocks[i].m; + const float d = blocks[i].d; + const float m = blocks[i].m; - uchar vi = blocks[i].qs[l]; + const uchar vi = blocks[i].qs[l]; - uint index = i*32 + l*2; + const uint index = i*32 + l*2; + result[index + 0] = (vi & 0xf) * d + m; + result[index + 1] = (vi >> 4) * d + m; +} + +struct block_q4_2 +{ + ushort d; + uchar qs[8]; +}; + +__kernel void dequantize_row_q4_2(__global struct block_q4_2* blocks, __global float* result) { + const uint i = get_global_id(0) / 16; + const uint l = get_local_id(0); + + const float d = vload_half(0, (const half*) &blocks[i].d);; + + const uchar vi = blocks[i].qs[l]; + + const uint index = i*16 + l*2; + result[index + 0] = ((vi & 0xf) - 8)*d; + result[index + 1] = ((vi >> 4) - 8)*d; +} + +struct block_q4_3 +{ + ushort d; + ushort m; + uchar qs[8]; +}; + +__kernel void dequantize_row_q4_3(__global struct block_q4_3* blocks, __global float* result) { + const uint i = get_global_id(0) / 16; + const uint l = get_local_id(0); + + const float d = vload_half(0, (const half*) &(blocks[i].d)); + const float m = vload_half(0, (const half*) &(blocks[i].m)); + + const uchar vi = blocks[i].qs[l]; + + const uint index = i*16 + l*2; result[index + 0] = (vi & 0xf) * d + m; result[index + 1] = (vi >> 4) * d + m; }