diff --git a/ggml.c b/ggml.c index bb618ea43..ad4526298 100644 --- a/ggml.c +++ b/ggml.c @@ -143,28 +143,213 @@ inline static void* ggml_aligned_malloc(size_t size) { } \ } while (0) -#if GGML_USE_CLBLAST -#ifndef GGML_USE_OPENBLAS -#define GGML_USE_OPENBLAS -#endif - -#define CL_TARGET_OPENCL_VERSION 110 -#include - -cl_platform_id platform; -cl_device_id device; -cl_context context; -cl_command_queue queue; -cl_event event; -bool cl_initialized = false; -#endif - #if defined(GGML_USE_ACCELERATE) #include #elif defined(GGML_USE_OPENBLAS) #include #elif defined(GGML_USE_CUBLAS) #include "ggml-cuda.h" +#elif defined(GGML_USE_CLBLAST) +#define CL_TARGET_OPENCL_VERSION 110 +#include +#include +#include + +cl_platform_id platform; +cl_device_id device; +cl_context context; +cl_command_queue queue; +cl_program program; +cl_kernel kernel_q4_0, kernel_q4_1; +bool cl_initialized = false; +size_t cl_size_a = 0, cl_size_b = 0, cl_size_c = 0; + +cl_program build_program_from_source(cl_context ctx, cl_device_id dev, const char* program_buffer) { + cl_program program; + char *program_log; + size_t program_size, log_size; + int err; + + program_size = strlen(program_buffer); + + program = clCreateProgramWithSource(ctx, 1, + (const char**)&program_buffer, &program_size, &err); + if(err < 0) { + perror("OpenCL error creating program"); + exit(1); + } + + err = clBuildProgram(program, 0, NULL, NULL, NULL, NULL); + if(err < 0) { + + clGetProgramBuildInfo(program, dev, CL_PROGRAM_BUILD_LOG, + 0, NULL, &log_size); + program_log = (char*) malloc(log_size + 1); + program_log[log_size] = '\0'; + clGetProgramBuildInfo(program, dev, CL_PROGRAM_BUILD_LOG, + log_size + 1, program_log, NULL); + printf("%s\n", program_log); + free(program_log); + exit(1); + } + + return program; +} + +static void ggml_cl_init() { + cl_int err = 0; + char * KCPP_CLBLAST_PLATFORM = getenv("KCPP_CLBLAST_PLATFORM"); + char * KCPP_CLBLAST_DEVICES = getenv("KCPP_CLBLAST_DEVICES"); + int plat_num = (KCPP_CLBLAST_PLATFORM == NULL ? 0 : atoi(KCPP_CLBLAST_PLATFORM)); + int dev_num = (KCPP_CLBLAST_DEVICES == NULL ? 0 : atoi(KCPP_CLBLAST_DEVICES)); + printf("\nInitializing CLBlast (First Run)..."); + printf("\nAttempting to use: Platform=%d, Device=%d (If invalid, program will crash)\n",plat_num,dev_num); + cl_uint num_platforms; + clGetPlatformIDs(0, NULL, &num_platforms); + cl_platform_id* platforms = (cl_platform_id*)malloc(num_platforms*sizeof(cl_platform_id)); + clGetPlatformIDs(num_platforms, platforms, NULL); + platform = platforms[plat_num]; + char platform_buffer[1024]; + clGetPlatformInfo(platform, CL_PLATFORM_NAME, sizeof(platform_buffer), &platform_buffer, NULL); + cl_uint num_devices; + clGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, 0, NULL, &num_devices); + cl_device_id* devices = (cl_device_id*)malloc(num_devices*sizeof(cl_device_id)); + clGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, num_devices, devices, NULL); + device = devices[dev_num]; + char device_buffer[1024]; + clGetDeviceInfo(device, CL_DEVICE_NAME, sizeof(device_buffer), &device_buffer, NULL); + printf("Using Platform: %s Device: %s\n", platform_buffer, device_buffer); + context = clCreateContext(NULL, 1, &device, NULL, NULL, &err); + if (err != CL_SUCCESS) { + printf("Error creating OpenCL context: %d\n", err); + fflush(stdout); + } + queue = clCreateCommandQueue(context, device, 0, &err); + + if (err != CL_SUCCESS) { + printf("Error creating OpenCL Command Queue: %d\n", err); + fflush(stdout); + } + + free(platforms); + free(devices); + + program = build_program_from_source(context, device, clblast_dequant); + + // Prepare dequantize kernels + kernel_q4_0 = clCreateKernel(program, "dequantize_row_q4_0", &err); + if(err < 0) { + printf("Error creating OpenCL dequantize q4_0 kernel: %d\n", err); + fflush(stdout); + }; + kernel_q4_1 = clCreateKernel(program, "dequantize_row_q4_1", &err); + if(err < 0) { + printf("Error creating OpenCL dequantize q4_1 kernel: %d\n", err); + fflush(stdout); + }; + cl_initialized = true; +} + +static void ggml_cl_sgemm_wrapper(const CBLAS_ORDER order, const CBLAS_TRANSPOSE trans_a, const CBLAS_TRANSPOSE trans_b, const int m, const int n, const int k, const float alpha, const void *host_a, const int lda, const float *host_b, const int ldb, const float beta, float *host_c, const int ldc, const int btype) { + cl_int err = 0; + + cl_event events[4]; + events[0] = NULL; + events[1] = NULL; + events[2] = NULL; + events[3] = NULL; + + if (!cl_initialized) { + ggml_cl_init(); + } + + bool dequant = (btype == 2 || btype == 3); + cl_kernel kernel = btype == 2 ? kernel_q4_0 : kernel_q4_1; + + size_t global = n * k, local = 16, qb_size; + cl_mem cl_buffer_a, cl_buffer_qb, cl_buffer_b, cl_buffer_c; + + // Prepare buffers + cl_buffer_a = clCreateBuffer(context, CL_MEM_READ_ONLY, m*k*sizeof(float), NULL, &err); + if (err != CL_SUCCESS) { + printf("Error creating OpenCL Buffer A: %d\n", err); + fflush(stdout); + } + if (dequant) { + qb_size = global * (sizeof(float) * (btype == 2 ? 1 : 2) + 16) / 32; + cl_buffer_qb = clCreateBuffer(context, CL_MEM_READ_ONLY, qb_size, NULL, &err); + if (err != CL_SUCCESS) { + printf("Error creating OpenCL Buffer QB: %d\n", err); + fflush(stdout); + } + } + cl_buffer_b = clCreateBuffer(context, CL_MEM_READ_WRITE, n*k*sizeof(float), NULL, &err); + if (err != CL_SUCCESS) { + printf("Error creating OpenCL Buffer B: %d\n", err); + fflush(stdout); + } + cl_buffer_c = clCreateBuffer(context, CL_MEM_READ_WRITE, m*n*sizeof(float), NULL, &err); + if (err != CL_SUCCESS) { + printf("Error creating OpenCL Buffer C: %d\n", err); + fflush(stdout); + } + + if (dequant) { + err = clSetKernelArg(kernel, 0, sizeof(cl_mem), &cl_buffer_qb); + err |= clSetKernelArg(kernel, 1, sizeof(cl_mem), &cl_buffer_b); + if(err < 0) { + printf("Error setting OpenCL kernel args: %d\n", err); + fflush(stdout); + } + clEnqueueWriteBuffer(queue, cl_buffer_qb, CL_FALSE, 0, qb_size, host_b, 0, NULL, events + 1); + } else { + clEnqueueWriteBuffer(queue, cl_buffer_b, CL_FALSE, 0, n*k*sizeof(float), host_b, 0, NULL, events + 1); + } + + clEnqueueWriteBuffer(queue, cl_buffer_a, CL_FALSE, 0, m*k*sizeof(float), host_a, 0, NULL, events); + clEnqueueWriteBuffer(queue, cl_buffer_c, CL_FALSE, 0, m*n*sizeof(float), host_c, 0, NULL, events + 2); + if (dequant) { + err = clEnqueueNDRangeKernel(queue, kernel, 1, NULL, &global, &local, 1, events + 1, events + 3); + if(err < 0) { + printf("Error enqueueing OpenCL dequantize kernel: %d\n", err); + fflush(stdout); + } + } + clWaitForEvents(dequant ? 4 : 3, events); + clReleaseEvent(events[0]); + clReleaseEvent(events[1]); + clReleaseEvent(events[2]); + if (dequant) { + clReleaseEvent(events[3]); + } + + // Call the SGEMM routine. + CLBlastStatusCode status = CLBlastSgemm(order, + trans_a, trans_b, + m, n, k, + alpha, + cl_buffer_a, 0, lda, + cl_buffer_b, 0, ldb, + beta, + cl_buffer_c, 0, ldc, + &queue, events); + + clEnqueueReadBuffer(queue, cl_buffer_c, CL_TRUE, 0, m*n*sizeof(float), host_c, 1, events, events + 1); + + // Wait for completion + if (status == CLBlastSuccess) { + clWaitForEvents(2, events); + clReleaseEvent(events[0]); + clReleaseEvent(events[1]); + } + + clReleaseMemObject(cl_buffer_a); + if (dequant) { + clReleaseMemObject(cl_buffer_qb); + } + clReleaseMemObject(cl_buffer_b); + clReleaseMemObject(cl_buffer_c); +} #endif #undef MIN @@ -3705,6 +3890,8 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { // initialize cuBLAS #if defined(GGML_USE_CUBLAS) ggml_init_cublas(); + #elif defined(GGML_USE_CLBLAST) + ggml_cl_init(); #endif is_first_call = false; @@ -7464,84 +7651,6 @@ static bool ggml_compute_forward_mul_mat_use_blas( return false; } -#ifdef GGML_USE_CLBLAST -static bool ggml_cl_sgemm_wrapper(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE trans_a, const enum CBLAS_TRANSPOSE trans_b, const int m, const int n, const int k, const float alpha, const float *host_a, const int lda, const float *host_b, const int ldb, const float beta, float *host_c, const int ldc) { - cl_int err = 0; - - if (!cl_initialized) { - cl_uint num_platforms; - clGetPlatformIDs(0, NULL, &num_platforms); - cl_platform_id* platforms = (cl_platform_id*)malloc(num_platforms*sizeof(cl_platform_id)); - clGetPlatformIDs(num_platforms, platforms, NULL); - platform = platforms[0]; - cl_uint num_devices; - clGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, 0, NULL, &num_devices); - cl_device_id* devices = (cl_device_id*)malloc(num_devices*sizeof(cl_device_id)); - clGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, num_devices, devices, NULL); - device = devices[0]; - context = clCreateContext(NULL, 1, &device, NULL, NULL, &err); - if (err != CL_SUCCESS) { - printf("Error creating OpenCL context: %d\n", err); - fflush(stdout); - } - queue = clCreateCommandQueue(context, device, 0, &err); - event = NULL; - - if (err != CL_SUCCESS) { - printf("Error creating OpenCL Command Queue: %d\n", err); - fflush(stdout); - } - - free(platforms); - free(devices); - cl_initialized = true; - } - - // Prepare buffers - cl_mem cl_buffer_a = clCreateBuffer(context, CL_MEM_READ_WRITE, m*k*sizeof(float), NULL, &err); - if (err != CL_SUCCESS) { - printf("Error creating OpenCL Buffer A: %d\n", err); - fflush(stdout); - } - cl_mem cl_buffer_b = clCreateBuffer(context, CL_MEM_READ_WRITE, n*k*sizeof(float), NULL, &err); - if (err != CL_SUCCESS) { - printf("Error creating OpenCL Buffer B: %d\n", err); - fflush(stdout); - } - cl_mem cl_buffer_c = clCreateBuffer(context, CL_MEM_READ_WRITE, m*n*sizeof(float), NULL, &err); - if (err != CL_SUCCESS) { - printf("Error creating OpenCL Buffer C: %d\n", err); - fflush(stdout); - } - - clEnqueueWriteBuffer(queue, cl_buffer_a, CL_TRUE, 0, m*k*sizeof(float), host_a, 0, NULL, NULL); - clEnqueueWriteBuffer(queue, cl_buffer_b, CL_TRUE, 0, n*k*sizeof(float), host_b, 0, NULL, NULL); - clEnqueueWriteBuffer(queue, cl_buffer_c, CL_TRUE, 0, m*n*sizeof(float), host_c, 0, NULL, NULL); - - // Call the SGEMM routine. - CLBlastStatusCode status = CLBlastSgemm(order, - trans_a, trans_b, - m, n, k, - alpha, - cl_buffer_a, 0, lda, - cl_buffer_b, 0, ldb, - beta, - cl_buffer_c, 0, ldc, - &queue, &event); - - // Wait for completion - if (status == CLBlastSuccess) { - clWaitForEvents(1, &event); - clReleaseEvent(event); - } - - clEnqueueReadBuffer(queue, cl_buffer_c, CL_TRUE, 0, m*n*sizeof(float), host_c, 0, NULL, NULL); - - clReleaseMemObject(cl_buffer_a); - clReleaseMemObject(cl_buffer_b); - clReleaseMemObject(cl_buffer_c); -} -#endif #endif static void ggml_compute_forward_mul_mat_f32( @@ -7663,14 +7772,14 @@ static void ggml_compute_forward_mul_mat_f32( // copy data to host CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream)); -#else +#elif defined(GGML_USE_CLBLAST) // zT = y * xT -#ifdef GGML_USE_CLBLAST ggml_cl_sgemm_wrapper(CblasRowMajor, CblasNoTrans, CblasTrans, ne11, ne01, ne10, 1.0f, y, ne10, x, ne10, - 0.0f, d, ne01); + 0.0f, d, ne01, + params->type); #else cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, ne11, ne01, ne10, @@ -7892,6 +8001,19 @@ static void ggml_compute_forward_mul_mat_f16_f32( // copy data to host CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream)); +#elif defined(GGML_USE_CLBLAST) + const float * x = wdata; + const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13); + + float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); + + // zT = y * xT + ggml_cl_sgemm_wrapper(CblasRowMajor, CblasNoTrans, CblasTrans, + ne11, ne01, ne10, + 1.0f, y, ne10, + x, ne10, + 0.0f, d, ne01, + params->type); #else const float * x = wdata; const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13); @@ -7899,13 +8021,6 @@ static void ggml_compute_forward_mul_mat_f16_f32( float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); // zT = y * xT -#ifdef GGML_USE_CLBLAST - ggml_cl_sgemm_wrapper(CblasRowMajor, CblasNoTrans, CblasTrans, - ne11, ne01, ne10, - 1.0f, y, ne10, - x, ne10, - 0.0f, d, ne01); -#else cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, ne11, ne01, ne10, 1.0f, y, ne10, @@ -8135,6 +8250,7 @@ static void ggml_compute_forward_mul_mat_q_f32( dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, g_cudaStream); CUDA_CHECK(cudaGetLastError()); #else +#ifndef GGML_USE_CLBLAST { size_t id = 0; for (int64_t i01 = 0; i01 < ne01; ++i01) { @@ -8143,6 +8259,9 @@ static void ggml_compute_forward_mul_mat_q_f32( } } const float * x = wdata; +#else + const void* x = src0->data + i03*nb03 + i02*nb02; +#endif #endif @@ -8160,14 +8279,14 @@ static void ggml_compute_forward_mul_mat_q_f32( // copy data to host CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream)); -#else +#elif defined(GGML_USE_CLBLAST) // zT = y * xT -#ifdef GGML_USE_CLBLAST ggml_cl_sgemm_wrapper(CblasRowMajor, CblasNoTrans, CblasTrans, ne11, ne01, ne10, 1.0f, y, ne10, x, ne10, - 0.0f, d, ne01); + 0.0f, d, ne01, + type); #else cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, ne11, ne01, ne10, diff --git a/ggml_blas_adapter.c b/ggml_blas_adapter.c deleted file mode 100644 index 50faf3807..000000000 --- a/ggml_blas_adapter.c +++ /dev/null @@ -1,153 +0,0 @@ -//this is a drop-in for all CLBlast related code, to keep the main ggml.c unmodified -// we will imitate the function definition from OpenBLAS instead, replaced as necessary. - -//windows binaries for clblast obtained from https://github.com/CNugteren/CLBlast (apache license) -//windows binaries for opencl obtained from https://github.com/KhronosGroup/OpenCL-SDK (apache license) - -#if GGML_USE_OPENBLAS -#include -#include -#include - -#if GGML_USE_CLBLAST - -#define CL_TARGET_OPENCL_VERSION 110 -#include - -cl_platform_id platform; -cl_device_id device; -cl_context context; -cl_command_queue queue; -bool cl_initialized = false; -size_t cl_size_a = 0, cl_size_b = 0, cl_size_c = 0; -cl_mem cl_buffer_a, cl_buffer_b, cl_buffer_c; - -static void ggml_cl_sgemm_wrapper(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE trans_a, const enum CBLAS_TRANSPOSE trans_b, const int m, const int n, const int k, const float alpha, const float *host_a, const int lda, const float *host_b, const int ldb, const float beta, float *host_c, const int ldc) { - cl_int err = 0; - - cl_event events[2]; - events[0] = NULL; - events[1] = NULL; - - if (!cl_initialized) { - char * KCPP_CLBLAST_PLATFORM = getenv("KCPP_CLBLAST_PLATFORM"); - char * KCPP_CLBLAST_DEVICES = getenv("KCPP_CLBLAST_DEVICES"); - int plat_num = (KCPP_CLBLAST_PLATFORM == NULL ? 0 : atoi(KCPP_CLBLAST_PLATFORM)); - int dev_num = (KCPP_CLBLAST_DEVICES == NULL ? 0 : atoi(KCPP_CLBLAST_DEVICES)); - printf("\nInitializing CLBlast (First Run)..."); - printf("\nAttempting to use: Platform=%d, Device=%d (If invalid, program will crash)\n",plat_num,dev_num); - cl_uint num_platforms; - clGetPlatformIDs(0, NULL, &num_platforms); - cl_platform_id* platforms = (cl_platform_id*)malloc(num_platforms*sizeof(cl_platform_id)); - clGetPlatformIDs(num_platforms, platforms, NULL); - platform = platforms[plat_num]; - char platform_buffer[1024]; - clGetPlatformInfo(platform, CL_PLATFORM_NAME, sizeof(platform_buffer), &platform_buffer, NULL); - cl_uint num_devices; - clGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, 0, NULL, &num_devices); - cl_device_id* devices = (cl_device_id*)malloc(num_devices*sizeof(cl_device_id)); - clGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, num_devices, devices, NULL); - device = devices[dev_num]; - char device_buffer[1024]; - clGetDeviceInfo(device, CL_DEVICE_NAME, sizeof(device_buffer), &device_buffer, NULL); - printf("Using Platform: %s Device: %s\n", platform_buffer, device_buffer); - context = clCreateContext(NULL, 1, &device, NULL, NULL, &err); - if (err != CL_SUCCESS) { - printf("Error creating OpenCL context: %d\n", err); - fflush(stdout); - } - queue = clCreateCommandQueue(context, device, 0, &err); - - if (err != CL_SUCCESS) { - printf("Error creating OpenCL Command Queue: %d\n", err); - fflush(stdout); - } - - free(platforms); - free(devices); - - cl_size_a = m * k * sizeof(float); - cl_size_b = n * k * sizeof(float); - cl_size_c = m * n * sizeof(float); - - // Prepare buffers - cl_buffer_a = clCreateBuffer(context, CL_MEM_READ_ONLY, cl_size_a, NULL, &err); - if (err != CL_SUCCESS) { - printf("Error creating OpenCL Buffer A: %d\n", err); - fflush(stdout); - } - cl_buffer_b = clCreateBuffer(context, CL_MEM_READ_ONLY, cl_size_b, NULL, &err); - if (err != CL_SUCCESS) { - printf("Error creating OpenCL Buffer B: %d\n", err); - fflush(stdout); - } - cl_buffer_c = clCreateBuffer(context, CL_MEM_READ_WRITE, cl_size_c, NULL, &err); - if (err != CL_SUCCESS) { - printf("Error creating OpenCL Buffer C: %d\n", err); - fflush(stdout); - } - - cl_initialized = true; - } - - // Resize buffers if too small - if (m * k * sizeof(float) > cl_size_a) { - clReleaseMemObject(cl_buffer_a); - cl_size_a = m * k * sizeof(float); - cl_buffer_a = clCreateBuffer(context, CL_MEM_READ_ONLY, cl_size_a, NULL, NULL); - } - if (n * k * sizeof(float) > cl_size_b) { - clReleaseMemObject(cl_buffer_b); - cl_size_b = n * k * sizeof(float); - cl_buffer_b = clCreateBuffer(context, CL_MEM_READ_ONLY, cl_size_b, NULL, NULL); - } - if (m * n * sizeof(float) > cl_size_c) { - clReleaseMemObject(cl_buffer_c); - cl_size_c = m * n * sizeof(float); - cl_buffer_c = clCreateBuffer(context, CL_MEM_READ_WRITE, cl_size_c, NULL, NULL); - } - - clEnqueueWriteBuffer(queue, cl_buffer_a, CL_TRUE, 0, cl_size_a, host_a, 0, NULL, events); - clEnqueueWriteBuffer(queue, cl_buffer_b, CL_TRUE, 0, cl_size_b, host_b, 0, NULL, events + 1); - // buffer c is not required for this use case - // clEnqueueWriteBuffer(queue, cl_buffer_c, CL_TRUE, 0, cl_size_c, host_c, 0, NULL, NULL); - - clWaitForEvents(2, events); - clReleaseEvent(events[0]); - clReleaseEvent(events[1]); - - // Call the SGEMM routine. - CLBlastStatusCode status = CLBlastSgemm(order, - trans_a, trans_b, - m, n, k, - alpha, - cl_buffer_a, 0, lda, - cl_buffer_b, 0, ldb, - beta, - cl_buffer_c, 0, ldc, - &queue, events); - - clEnqueueReadBuffer(queue, cl_buffer_c, CL_TRUE, 0, m*n*sizeof(float), host_c, 1, events, events + 1); - - // Wait for completion - if (status == CLBlastSuccess) { - clWaitForEvents(2, events); - clReleaseEvent(events[0]); - clReleaseEvent(events[1]); - } -} - -#endif -#endif - -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) -#if GGML_USE_CLBLAST -#define do_blas_sgemm(Order, TransA, TransB,M, N, K,alpha, A, lda, B, ldb, beta, C, ldc) ({\ -ggml_cl_sgemm_wrapper(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);\ -}) -#else -#define do_blas_sgemm(Order, TransA, TransB,M, N, K,alpha, A, lda, B, ldb, beta, C, ldc) ({\ -cblas_sgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);\ -}) -#endif -#endif diff --git a/ggml_clblast_dequant.cl b/ggml_clblast_dequant.cl new file mode 100644 index 000000000..47d39f8a3 --- /dev/null +++ b/ggml_clblast_dequant.cl @@ -0,0 +1,46 @@ +#define MULTILINE_QUOTE(...) #__VA_ARGS__ +const char * clblast_dequant = MULTILINE_QUOTE( + +struct __attribute__ ((packed)) block_q4_0 +{ + float d; + uchar qs[16]; +}; + +__kernel void dequantize_row_q4_0(__global struct block_q4_0* blocks, __global float* result) { + uint i, l; + i = get_global_id(0) / 32; + l = get_local_id(0); + + float d = blocks[i].d; + + uchar vi = blocks[i].qs[l]; + + uint index = i*32 + l*2; + result[index + 0] = ((vi & 0xf) - 8)*d; + result[index + 1] = ((vi >> 4) - 8)*d; +} + +struct __attribute__ ((packed)) block_q4_1 +{ + float d; + float m; + uchar qs[16]; +}; + +__kernel void dequantize_row_q4_1(__global struct block_q4_1* blocks, __global float* result) { + uint i, l; + i = get_global_id(0) / 32; + l = get_local_id(0); + + float d = blocks[i].d; + float m = blocks[i].m; + + uchar vi = blocks[i].qs[l]; + + uint index = i*32 + l*2; + result[index + 0] = (vi & 0xf) * d + m; + result[index + 1] = (vi >> 4) * d + m; +} + +);