Replace buffer pool with static buffers a, b, qb, c

Fix compile warnings
This commit is contained in:
0cc4m 2023-04-24 22:08:51 +02:00
parent ae73887fb9
commit daa5df51f7
3 changed files with 50 additions and 98 deletions

View file

@ -4,71 +4,18 @@
#include <cstdio>
#include <cstring>
#include "ggml.h"
#include <ggml_clblast_dequant.cl>
struct scoped_spin_lock {
std::atomic_flag& lock;
scoped_spin_lock(std::atomic_flag& lock) : lock(lock) {
while (lock.test_and_set(std::memory_order_acquire)) {
; // spin
}
}
~scoped_spin_lock() {
lock.clear(std::memory_order_release);
}
scoped_spin_lock(const scoped_spin_lock&) = delete;
scoped_spin_lock& operator=(const scoped_spin_lock&) = delete;
};
struct cl_buffer {
cl_mem mem;
size_t size = 0;
};
cl_platform_id platform;
cl_device_id device;
cl_context context;
cl_command_queue queue;
cl_program program;
cl_kernel kernel_q4_0, kernel_q4_1, kernel_q4_2, kernel_q4_3;
size_t cl_size_a = 0, cl_size_b = 0, cl_size_c = 0;
static cl_buffer g_cl_buffer_pool[MAX_CL_BUFFERS];
static std::atomic_flag g_cl_pool_lock = ATOMIC_FLAG_INIT;
cl_mem ggml_cl_pool_malloc(size_t size, size_t * actual_size) {
scoped_spin_lock lock(g_cl_pool_lock);
for (int i = 0; i < MAX_CL_BUFFERS; ++i) {
cl_buffer& b = g_cl_buffer_pool[i];
if (b.size >= size && b.size != 0) {
cl_mem mem = b.mem;
*actual_size = b.size;
b.size = 0;
return mem;
}
}
cl_int err;
cl_mem mem = clCreateBuffer(context, 0, size, NULL, &err);
*actual_size = size;
CL_CHECK(err, "clCreateBuffer");
return mem;
}
void ggml_cl_pool_free(cl_mem mem, size_t size) {
scoped_spin_lock lock(g_cl_pool_lock);
for (int i = 0; i < MAX_CL_BUFFERS; ++i) {
cl_buffer& b = g_cl_buffer_pool[i];
if (b.size == 0) {
b.mem = mem;
b.size = size;
return;
}
}
fprintf(stderr, "WARNING: cl buffer pool full, increase MAX_CL_BUFFERS\n");
clReleaseMemObject(mem);
}
cl_mem cl_buffer_a, cl_buffer_qb, cl_buffer_b, cl_buffer_c;
size_t cl_size_a = 0, cl_size_qb = 0, cl_size_b = 0, cl_size_c = 0;
cl_program build_program_from_source(cl_context ctx, cl_device_id dev, const char* program_buffer) {
cl_program program;
@ -102,7 +49,7 @@ cl_program build_program_from_source(cl_context ctx, cl_device_id dev, const cha
return program;
}
void ggml_cl_init() {
void ggml_cl_init(void) {
cl_int err = 0;
char * GGML_CLBLAST_PLATFORM = getenv("GGML_CLBLAST_PLATFORM");
char * GGML_CLBLAST_DEVICE = getenv("GGML_CLBLAST_DEVICE");
@ -146,6 +93,21 @@ void ggml_cl_init() {
CL_CHECK(err, "clCreateKernel");
}
void ggml_cl_malloc(size_t req_size, size_t* cur_size, cl_mem_flags flags, cl_mem* buf) {
if (req_size <= *cur_size) {
return;
}
// Reallocate buffer with enough space
if (*cur_size > 0) {
clReleaseMemObject(*buf);
}
cl_int err;
*buf = clCreateBuffer(context, flags, req_size, NULL, &err);
*cur_size = req_size;
CL_CHECK(err, "clCreateBuffer");
}
void ggml_cl_sgemm_wrapper(const CLBlastLayout order, const CLBlastTranspose trans_a, const CLBlastTranspose trans_b, const int m, const int n, const int k, const float alpha, const void *host_a, const int lda, const float *host_b, const int ldb, const float beta, float *host_c, const int ldc, const int btype) {
cl_int err = 0;
@ -156,8 +118,8 @@ void ggml_cl_sgemm_wrapper(const CLBlastLayout order, const CLBlastTranspose tra
events[3] = NULL;
cl_kernel kernel;
size_t global, local, qb_size;
bool dequant = btype >= 2 && btype < 6;
size_t global, local, size_qb;
const bool dequant = btype >= 2 && btype < 6;
if (dequant) {
global = n * k;
@ -165,48 +127,48 @@ void ggml_cl_sgemm_wrapper(const CLBlastLayout order, const CLBlastTranspose tra
case 2:
kernel = kernel_q4_0;
local = 16;
qb_size = global * (sizeof(float) + local) / 32;
size_qb = global * (sizeof(float) + local) / 32;
break;
case 3:
kernel = kernel_q4_1;
local = 16;
qb_size = global * (sizeof(float) * 2 + local) / 32;
size_qb = global * (sizeof(float) * 2 + local) / 32;
break;
case 4:
kernel = kernel_q4_2;
local = 8;
qb_size = global * (sizeof(short) + local) / 16;
size_qb = global * (sizeof(short) + local) / 16;
break;
case 5:
kernel = kernel_q4_3;
local = 8;
qb_size = global * (sizeof(short) * 2 + local) / 16;
size_qb = global * (sizeof(short) * 2 + local) / 16;
break;
}
}
cl_mem cl_buffer_a, cl_buffer_qb, cl_buffer_b, cl_buffer_c;
size_t buf_size_a, buf_size_qb, buf_size_b, buf_size_c;
const size_t size_a = m * k * sizeof(float);
const size_t size_b = n * k * sizeof(float);
const size_t size_c = m * n * sizeof(float);
// Prepare buffers
cl_buffer_a = ggml_cl_pool_malloc(m * k * sizeof(float), &buf_size_a);
ggml_cl_malloc(size_a, &cl_size_a, CL_MEM_READ_ONLY, &cl_buffer_a);
if (dequant) {
cl_buffer_qb = ggml_cl_pool_malloc(qb_size, &buf_size_qb);
ggml_cl_malloc(size_qb, &cl_size_qb, CL_MEM_READ_ONLY, &cl_buffer_qb);
}
cl_buffer_b = ggml_cl_pool_malloc(n*k*sizeof(float), &buf_size_b);
cl_buffer_c = ggml_cl_pool_malloc(m*n*sizeof(float), &buf_size_c);
ggml_cl_malloc(size_b, &cl_size_b, CL_MEM_READ_WRITE, &cl_buffer_b);
ggml_cl_malloc(size_c, &cl_size_c, CL_MEM_WRITE_ONLY, &cl_buffer_c);
if (dequant) {
err = clSetKernelArg(kernel, 0, sizeof(cl_mem), &cl_buffer_qb);
err |= clSetKernelArg(kernel, 1, sizeof(cl_mem), &cl_buffer_b);
CL_CHECK(err, "clSetKernelArg");
clEnqueueWriteBuffer(queue, cl_buffer_qb, CL_FALSE, 0, qb_size, host_b, 0, NULL, events + 1);
clEnqueueWriteBuffer(queue, cl_buffer_qb, CL_FALSE, 0, size_qb, host_b, 0, NULL, events + 1);
} else {
clEnqueueWriteBuffer(queue, cl_buffer_b, CL_FALSE, 0, n*k*sizeof(float), host_b, 0, NULL, events + 1);
clEnqueueWriteBuffer(queue, cl_buffer_b, CL_FALSE, 0, size_b, host_b, 0, NULL, events + 1);
}
clEnqueueWriteBuffer(queue, cl_buffer_a, CL_FALSE, 0, m*k*sizeof(float), host_a, 0, NULL, events);
clEnqueueWriteBuffer(queue, cl_buffer_a, CL_FALSE, 0, size_a, host_a, 0, NULL, events);
if (dequant) {
err = clEnqueueNDRangeKernel(queue, kernel, 1, NULL, &global, &local, 1, events + 1, events + 3);
CL_CHECK(err, "clEnqueueNDRangeKernel");
@ -219,28 +181,20 @@ void ggml_cl_sgemm_wrapper(const CLBlastLayout order, const CLBlastTranspose tra
clReleaseEvent(events[3]);
}
// Call the SGEMM routine.
CLBlastStatusCode status = CLBlastSgemm(order,
trans_a, trans_b,
m, n, k,
alpha,
cl_buffer_a, 0, lda,
cl_buffer_b, 0, ldb,
beta,
cl_buffer_c, 0, ldc,
&queue, events);
CLBlastSgemm(order,
trans_a, trans_b,
m, n, k,
alpha,
cl_buffer_a, 0, lda,
cl_buffer_b, 0, ldb,
beta,
cl_buffer_c, 0, ldc,
&queue, events);
clEnqueueReadBuffer(queue, cl_buffer_c, CL_TRUE, 0, m*n*sizeof(float), host_c, 1, events, events + 1);
clEnqueueReadBuffer(queue, cl_buffer_c, CL_TRUE, 0, size_c, host_c, 1, events, events + 1);
// Wait for completion
clWaitForEvents(2, events);
clReleaseEvent(events[0]);
clReleaseEvent(events[1]);
ggml_cl_pool_free(cl_buffer_a, buf_size_a);
if (dequant) {
ggml_cl_pool_free(cl_buffer_qb, buf_size_qb);
}
ggml_cl_pool_free(cl_buffer_b, buf_size_b);
ggml_cl_pool_free(cl_buffer_c, buf_size_c);
}

View file

@ -22,7 +22,7 @@ cl_mem ggml_cl_pool_malloc(size_t size, size_t * actual_size);
void ggml_cl_pool_free(cl_mem mem, size_t size);
cl_program build_program_from_source(cl_context ctx, cl_device_id dev, const char* program_buffer);
void ggml_cl_init();
void ggml_cl_init(void);
void ggml_cl_sgemm_wrapper(const CLBlastLayout order, const CLBlastTranspose trans_a, const CLBlastTranspose trans_b, const int m, const int n, const int k, const float alpha, const void *host_a, const int lda, const float *host_b, const int ldb, const float beta, float *host_c, const int ldc, const int btype);

8
ggml.c
View file

@ -8031,7 +8031,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
else {
GGML_ASSERT(false);
}
#else
#elif !defined(GGML_USE_CLBLAST)
float * const wdata = params->wdata;
dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
#endif
@ -8050,8 +8050,9 @@ static void ggml_compute_forward_mul_mat_q_f32(
dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, g_cudaStream);
CUDA_CHECK(cudaGetLastError());
#elif defined(GGML_USE_CLBLAST)
const void* x = (char *) src0->data + i03*nb03 + i02*nb02;
#else
#ifndef GGML_USE_CLBLAST
{
size_t id = 0;
for (int64_t i01 = 0; i01 < ne01; ++i01) {
@ -8060,9 +8061,6 @@ static void ggml_compute_forward_mul_mat_q_f32(
}
}
const float * x = wdata;
#else
const void* x = src0->data + i03*nb03 + i02*nb02;
#endif
#endif