Improve cuBLAS performance by using a memory pool
This commit is contained in:
parent
2510c1831f
commit
e8797a9aed
1 changed files with 82 additions and 45 deletions
127
ggml.c
127
ggml.c
|
@ -152,25 +152,69 @@ inline static void* ggml_aligned_malloc(size_t size) {
|
|||
#include <cuda_runtime.h>
|
||||
#include "ggml-cuda.h"
|
||||
|
||||
#define CUDA_CHECK(err) \
|
||||
do { \
|
||||
cudaError_t err_ = (err); \
|
||||
if (err_ != cudaSuccess) { \
|
||||
printf("CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
|
||||
cudaGetErrorString(err_)); \
|
||||
exit(1); \
|
||||
} \
|
||||
#define CUDA_CHECK(err) \
|
||||
do { \
|
||||
cudaError_t err_ = (err); \
|
||||
if (err_ != cudaSuccess) { \
|
||||
fprintf(stderr, "CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
|
||||
cudaGetErrorString(err_)); \
|
||||
exit(1); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define CUBLAS_CHECK(err) \
|
||||
do { \
|
||||
cublasStatus_t err_ = (err); \
|
||||
if (err_ != CUBLAS_STATUS_SUCCESS) { \
|
||||
printf("cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \
|
||||
exit(1); \
|
||||
} \
|
||||
#define CUBLAS_CHECK(err) \
|
||||
do { \
|
||||
cublasStatus_t err_ = (err); \
|
||||
if (err_ != CUBLAS_STATUS_SUCCESS) { \
|
||||
fprintf(stderr, "cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \
|
||||
exit(1); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
// lock-free, thread safe buffer pool for cuda
|
||||
#define MAX_CUDA_BUFFERS 16
|
||||
struct cuda_buffer {
|
||||
atomic_uintptr_t ptr;
|
||||
size_t size;
|
||||
};
|
||||
|
||||
static struct cuda_buffer cuda_buffer_pool[MAX_CUDA_BUFFERS] = {0};
|
||||
|
||||
static void * cuda_pool_malloc(size_t size, size_t * actual_size) {
|
||||
for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
|
||||
struct cuda_buffer * b = &cuda_buffer_pool[i];
|
||||
if (b->size >= size) {
|
||||
uintptr_t ptr = atomic_load(&b->ptr);
|
||||
if (ptr) {
|
||||
if (atomic_compare_exchange_strong(&b->ptr, &ptr, 0)) {
|
||||
*actual_size = b->size;
|
||||
return (void *) ptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void * ptr;
|
||||
CUDA_CHECK(cudaMalloc((void **) &ptr, size));
|
||||
*actual_size = size;
|
||||
return ptr;
|
||||
}
|
||||
|
||||
static void cuda_pool_free(void * ptr, size_t size) {
|
||||
for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
|
||||
struct cuda_buffer * b = &cuda_buffer_pool[i];
|
||||
uintptr_t p = atomic_load(&b->ptr);
|
||||
if (p == 0) {
|
||||
if (atomic_compare_exchange_strong(&b->ptr, &p, (uintptr_t) ptr)) {
|
||||
b->size = size;
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
fprintf(stderr, "WARNING: cuda buffer pool full, increase MAX_CUDA_BUFFERS\n");
|
||||
CUDA_CHECK(cudaFree(ptr));
|
||||
}
|
||||
|
||||
static cublasHandle_t cublasH = NULL;
|
||||
static cudaStream_t cudaStream = NULL;
|
||||
static void init_cublas(void) {
|
||||
|
@ -7566,18 +7610,16 @@ static void ggml_compute_forward_mul_mat_f32(
|
|||
}
|
||||
|
||||
#if defined(GGML_USE_CUBLAS)
|
||||
float *d_X = NULL;
|
||||
float *d_Y = NULL;
|
||||
float *d_D = NULL;
|
||||
const float alpha = 1.0f;
|
||||
const float beta = 0.0f;
|
||||
const int x_ne = ne01 * ne10;
|
||||
const int y_ne = ne11 * ne10;
|
||||
const int d_ne = ne11 * ne01;
|
||||
|
||||
CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne));
|
||||
CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
|
||||
CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
|
||||
size_t x_size, y_size, d_size;
|
||||
float *d_X = cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
|
||||
float *d_Y = cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
|
||||
float *d_D = cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
|
||||
#endif
|
||||
|
||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||
|
@ -7614,9 +7656,9 @@ static void ggml_compute_forward_mul_mat_f32(
|
|||
}
|
||||
#if defined(GGML_USE_CUBLAS)
|
||||
CUDA_CHECK(cudaStreamSynchronize(cudaStream));
|
||||
CUDA_CHECK(cudaFree(d_X));
|
||||
CUDA_CHECK(cudaFree(d_Y));
|
||||
CUDA_CHECK(cudaFree(d_D));
|
||||
cuda_pool_free(d_X, x_size);
|
||||
cuda_pool_free(d_Y, y_size);
|
||||
cuda_pool_free(d_D, d_size);
|
||||
#endif
|
||||
//printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
|
||||
|
||||
|
@ -7766,18 +7808,16 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
|||
#if defined(GGML_USE_CUBLAS)
|
||||
ggml_fp16_t * const wdata = params->wdata;
|
||||
|
||||
float *d_X = NULL;
|
||||
float *d_Y = NULL;
|
||||
float *d_D = NULL;
|
||||
const float alpha = 1.0f;
|
||||
const float beta = 0.0f;
|
||||
const int x_ne = ne01 * ne10;
|
||||
const int y_ne = ne11 * ne10;
|
||||
const int d_ne = ne11 * ne01;
|
||||
|
||||
CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(ggml_fp16_t) * x_ne));
|
||||
CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
|
||||
CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
|
||||
size_t x_size, y_size, d_size;
|
||||
float *d_X = cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
|
||||
float *d_Y = cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
|
||||
float *d_D = cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
|
||||
#else
|
||||
float * const wdata = params->wdata;
|
||||
#endif
|
||||
|
@ -7844,9 +7884,9 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
|||
|
||||
#if defined(GGML_USE_CUBLAS)
|
||||
CUDA_CHECK(cudaStreamSynchronize(cudaStream));
|
||||
CUDA_CHECK(cudaFree(d_X));
|
||||
CUDA_CHECK(cudaFree(d_Y));
|
||||
CUDA_CHECK(cudaFree(d_D));
|
||||
cuda_pool_free(d_X, x_size);
|
||||
cuda_pool_free(d_Y, y_size);
|
||||
cuda_pool_free(d_D, d_size);
|
||||
#endif
|
||||
/*printf("CBLAS F16 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/
|
||||
|
||||
|
@ -8014,20 +8054,17 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|||
}
|
||||
|
||||
#if defined(GGML_USE_CUBLAS)
|
||||
float *d_X = NULL;
|
||||
float *d_Y = NULL;
|
||||
float *d_D = NULL;
|
||||
float *d_Q = NULL;
|
||||
const float alpha = 1.0f;
|
||||
const float beta = 0.0f;
|
||||
const int x_ne = ne01 * ne10;
|
||||
const int y_ne = ne11 * ne10;
|
||||
const int d_ne = ne11 * ne01;
|
||||
|
||||
CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne));
|
||||
CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
|
||||
CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
|
||||
CUDA_CHECK(cudaMalloc((void **)(&d_Q), GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type]));
|
||||
size_t x_size, y_size, d_size, q_size;
|
||||
float *d_X = cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
|
||||
float *d_Y = cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
|
||||
float *d_D = cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
|
||||
float *d_Q = cuda_pool_malloc(GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], &q_size);
|
||||
|
||||
void (*dequantize_row_q_cuda)(const void * x, float * y, int k, cudaStream_t stream) = NULL;
|
||||
if (type == GGML_TYPE_Q4_0) {
|
||||
|
@ -8100,10 +8137,10 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|||
|
||||
#if defined(GGML_USE_CUBLAS)
|
||||
CUDA_CHECK(cudaStreamSynchronize(cudaStream));
|
||||
CUDA_CHECK(cudaFree(d_X));
|
||||
CUDA_CHECK(cudaFree(d_Y));
|
||||
CUDA_CHECK(cudaFree(d_D));
|
||||
CUDA_CHECK(cudaFree(d_Q));
|
||||
cuda_pool_free(d_X, x_size);
|
||||
cuda_pool_free(d_Y, y_size);
|
||||
cuda_pool_free(d_D, d_size);
|
||||
cuda_pool_free(d_Q, q_size);
|
||||
#endif
|
||||
//printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue