Change memory pool synchronization mechanism to a spin lock
General code cleanup
This commit is contained in:
parent
c832e7c793
commit
d774e05428
3 changed files with 67 additions and 54 deletions
83
ggml-cuda.cu
83
ggml-cuda.cu
|
@ -31,9 +31,9 @@ static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2
|
|||
|
||||
#define QK4_3 16
|
||||
typedef struct {
|
||||
__half d; // delta
|
||||
__half m; // min
|
||||
uint8_t qs[QK4_3 / 2]; // nibbles / quants
|
||||
__half d; // delta
|
||||
__half m; // min
|
||||
uint8_t qs[QK4_3 / 2]; // nibbles / quants
|
||||
} block_q4_3;
|
||||
static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding");
|
||||
|
||||
|
@ -151,29 +151,44 @@ void dequantize_row_q4_3_cuda(const void * vx, float * y, int k, cudaStream_t st
|
|||
dequantize_block_q4_3<<<nb, 1, 0, stream>>>(vx, y);
|
||||
}
|
||||
|
||||
// lock-free, thread safe buffer pool for cuda
|
||||
// buffer pool for cuda
|
||||
#define MAX_CUDA_BUFFERS 16
|
||||
struct cuda_buffer {
|
||||
std::atomic_uintptr_t ptr { 0 };
|
||||
size_t size { 0 };
|
||||
};
|
||||
|
||||
static cuda_buffer cuda_buffer_pool[MAX_CUDA_BUFFERS];
|
||||
|
||||
void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
|
||||
for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
|
||||
struct cuda_buffer * b = &cuda_buffer_pool[i];
|
||||
if (b->size >= size) {
|
||||
uintptr_t ptr = atomic_load(&b->ptr);
|
||||
if (ptr) {
|
||||
if (std::atomic_compare_exchange_strong(&b->ptr, &ptr, 0)) {
|
||||
*actual_size = b->size;
|
||||
return (void *) ptr;
|
||||
}
|
||||
}
|
||||
struct scoped_spin_lock {
|
||||
std::atomic_flag& lock;
|
||||
scoped_spin_lock(std::atomic_flag& lock) : lock(lock) {
|
||||
while (lock.test_and_set(std::memory_order_acquire)) {
|
||||
; // spin
|
||||
}
|
||||
}
|
||||
~scoped_spin_lock() {
|
||||
lock.clear(std::memory_order_release);
|
||||
}
|
||||
scoped_spin_lock(const scoped_spin_lock&) = delete;
|
||||
scoped_spin_lock& operator=(const scoped_spin_lock&) = delete;
|
||||
};
|
||||
|
||||
struct cuda_buffer {
|
||||
void * ptr = nullptr;
|
||||
size_t size = 0;
|
||||
};
|
||||
|
||||
static cuda_buffer g_cuda_buffer_pool[MAX_CUDA_BUFFERS];
|
||||
static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT;
|
||||
|
||||
void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
|
||||
scoped_spin_lock lock(g_cuda_pool_lock);
|
||||
|
||||
for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
|
||||
cuda_buffer& b = g_cuda_buffer_pool[i];
|
||||
if (b.size >= size && b.ptr != nullptr) {
|
||||
void * ptr = b.ptr;
|
||||
*actual_size = b.size;
|
||||
b.ptr = nullptr;
|
||||
b.size = 0;
|
||||
return ptr;
|
||||
}
|
||||
}
|
||||
void * ptr;
|
||||
CUDA_CHECK(cudaMalloc((void **) &ptr, size));
|
||||
*actual_size = size;
|
||||
|
@ -181,31 +196,31 @@ void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
|
|||
}
|
||||
|
||||
void ggml_cuda_pool_free(void * ptr, size_t size) {
|
||||
scoped_spin_lock lock(g_cuda_pool_lock);
|
||||
|
||||
for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
|
||||
struct cuda_buffer * b = &cuda_buffer_pool[i];
|
||||
uintptr_t p = std::atomic_load(&b->ptr);
|
||||
if (p == 0) {
|
||||
if (std::atomic_compare_exchange_strong(&b->ptr, &p, (uintptr_t) ptr)) {
|
||||
b->size = size;
|
||||
return;
|
||||
}
|
||||
cuda_buffer& b = g_cuda_buffer_pool[i];
|
||||
if (b.ptr == nullptr) {
|
||||
b.ptr = ptr;
|
||||
b.size = size;
|
||||
return;
|
||||
}
|
||||
}
|
||||
fprintf(stderr, "WARNING: cuda buffer pool full, increase MAX_CUDA_BUFFERS\n");
|
||||
CUDA_CHECK(cudaFree(ptr));
|
||||
}
|
||||
|
||||
cublasHandle_t cublasH = NULL;
|
||||
cudaStream_t cudaStream = NULL;
|
||||
cublasHandle_t g_cublasH = NULL;
|
||||
cudaStream_t g_cudaStream = NULL;
|
||||
|
||||
void ggml_init_cublas(void) {
|
||||
if (cublasH == NULL) {
|
||||
if (g_cublasH == NULL) {
|
||||
// create cublas handle, bind a stream
|
||||
CUBLAS_CHECK(cublasCreate(&cublasH));
|
||||
CUBLAS_CHECK(cublasCreate(&g_cublasH));
|
||||
|
||||
CUDA_CHECK(cudaStreamCreateWithFlags(&cudaStream, cudaStreamNonBlocking));
|
||||
CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStream, cudaStreamNonBlocking));
|
||||
|
||||
CUBLAS_CHECK(cublasSetStream(cublasH, cudaStream));
|
||||
CUBLAS_CHECK(cublasSetStream(g_cublasH, g_cudaStream));
|
||||
|
||||
// configure logging to stdout
|
||||
// CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL));
|
||||
|
|
|
@ -24,10 +24,8 @@ extern "C" {
|
|||
} \
|
||||
} while (0)
|
||||
|
||||
|
||||
|
||||
extern cublasHandle_t cublasH;
|
||||
extern cudaStream_t cudaStream;
|
||||
extern cublasHandle_t g_cublasH;
|
||||
extern cudaStream_t g_cudaStream;
|
||||
|
||||
void ggml_init_cublas(void);
|
||||
void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size);
|
||||
|
|
32
ggml.c
32
ggml.c
|
@ -7550,19 +7550,19 @@ static void ggml_compute_forward_mul_mat_f32(
|
|||
|
||||
#if defined(GGML_USE_CUBLAS)
|
||||
// copy data to device
|
||||
CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, cudaStream));
|
||||
CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream));
|
||||
CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, g_cudaStream));
|
||||
CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, g_cudaStream));
|
||||
|
||||
// compute
|
||||
CUBLAS_CHECK(
|
||||
cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
ne01, ne11, ne10,
|
||||
&alpha, d_X, ne00,
|
||||
d_Y, ne10,
|
||||
&beta, d_D, ne01));
|
||||
|
||||
// copy data to host
|
||||
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
|
||||
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
|
||||
#else
|
||||
// zT = y * xT
|
||||
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
|
||||
|
@ -7574,7 +7574,7 @@ static void ggml_compute_forward_mul_mat_f32(
|
|||
}
|
||||
}
|
||||
#if defined(GGML_USE_CUBLAS)
|
||||
CUDA_CHECK(cudaStreamSynchronize(cudaStream));
|
||||
CUDA_CHECK(cudaStreamSynchronize(g_cudaStream));
|
||||
ggml_cuda_pool_free(d_X, x_size);
|
||||
ggml_cuda_pool_free(d_Y, y_size);
|
||||
ggml_cuda_pool_free(d_D, d_size);
|
||||
|
@ -7770,12 +7770,12 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
|||
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
||||
|
||||
// copy data to device
|
||||
CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(ggml_fp16_t) * x_ne, cudaMemcpyHostToDevice, cudaStream));
|
||||
CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(ggml_fp16_t) * y_ne, cudaMemcpyHostToDevice, cudaStream));
|
||||
CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(ggml_fp16_t) * x_ne, cudaMemcpyHostToDevice, g_cudaStream));
|
||||
CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(ggml_fp16_t) * y_ne, cudaMemcpyHostToDevice, g_cudaStream));
|
||||
|
||||
// compute
|
||||
CUBLAS_CHECK(
|
||||
cublasGemmEx(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
cublasGemmEx(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
ne01, ne11, ne10,
|
||||
&alpha, d_X, CUDA_R_16F, ne00,
|
||||
d_Y, CUDA_R_16F, ne10,
|
||||
|
@ -7784,7 +7784,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
|||
CUBLAS_GEMM_DEFAULT));
|
||||
|
||||
// copy data to host
|
||||
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
|
||||
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
|
||||
#else
|
||||
const float * x = wdata;
|
||||
const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
|
||||
|
@ -7802,7 +7802,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
|||
}
|
||||
|
||||
#if defined(GGML_USE_CUBLAS)
|
||||
CUDA_CHECK(cudaStreamSynchronize(cudaStream));
|
||||
CUDA_CHECK(cudaStreamSynchronize(g_cudaStream));
|
||||
ggml_cuda_pool_free(d_X, x_size);
|
||||
ggml_cuda_pool_free(d_Y, y_size);
|
||||
ggml_cuda_pool_free(d_D, d_size);
|
||||
|
@ -8013,9 +8013,9 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|||
// copy and dequantize on device
|
||||
CUDA_CHECK(
|
||||
cudaMemcpyAsync(d_Q, (char *) src0->data + i03*nb03 + i02*nb02,
|
||||
GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], cudaMemcpyHostToDevice, cudaStream));
|
||||
GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], cudaMemcpyHostToDevice, g_cudaStream));
|
||||
|
||||
dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, cudaStream);
|
||||
dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, g_cudaStream);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
#else
|
||||
{
|
||||
|
@ -8031,18 +8031,18 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|||
|
||||
#if defined(GGML_USE_CUBLAS)
|
||||
// copy data to device
|
||||
CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream));
|
||||
CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, g_cudaStream));
|
||||
|
||||
// compute
|
||||
CUBLAS_CHECK(
|
||||
cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
ne01, ne11, ne10,
|
||||
&alpha, d_X, ne00,
|
||||
d_Y, ne10,
|
||||
&beta, d_D, ne01));
|
||||
|
||||
// copy data to host
|
||||
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
|
||||
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
|
||||
#else
|
||||
// zT = y * xT
|
||||
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
|
||||
|
@ -8055,7 +8055,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|||
}
|
||||
|
||||
#if defined(GGML_USE_CUBLAS)
|
||||
CUDA_CHECK(cudaStreamSynchronize(cudaStream));
|
||||
CUDA_CHECK(cudaStreamSynchronize(g_cudaStream));
|
||||
ggml_cuda_pool_free(d_X, x_size);
|
||||
ggml_cuda_pool_free(d_Y, y_size);
|
||||
ggml_cuda_pool_free(d_D, d_size);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue