diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 0baa989a3..2d2e5a90e 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -1,5 +1,7 @@ #include +#include #include +#include #include "ggml-cuda.h" typedef uint16_t ggml_fp16_t; @@ -35,8 +37,6 @@ typedef struct { } block_q4_3; static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding"); - - static __global__ void dequantize_block_q4_0(const void * vx, float * y) { const block_q4_0 * x = (const block_q4_0 *) vx; @@ -131,24 +131,83 @@ static __global__ void dequantize_block_q4_3(const void * vx, float * y) { } } -extern "C" { - __host__ void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) { - const int nb = k / QK4_0; - dequantize_block_q4_0<<>>(vx, y); +void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) { + const int nb = k / QK4_0; + dequantize_block_q4_0<<>>(vx, y); +} + +void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) { + const int nb = k / QK4_1; + dequantize_block_q4_1<<>>(vx, y); +} + +void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream) { + const int nb = k / QK4_2; + dequantize_block_q4_2<<>>(vx, y); +} + +void dequantize_row_q4_3_cuda(const void * vx, float * y, int k, cudaStream_t stream) { + const int nb = k / QK4_3; + dequantize_block_q4_3<<>>(vx, y); +} + +// lock-free, thread safe buffer pool for cuda +#define MAX_CUDA_BUFFERS 16 +struct cuda_buffer { + std::atomic_uintptr_t ptr; + size_t size; +}; + +static struct cuda_buffer cuda_buffer_pool[MAX_CUDA_BUFFERS] = {0}; + +void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) { + for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) { + struct cuda_buffer * b = &cuda_buffer_pool[i]; + if (b->size >= size) { + uintptr_t ptr = atomic_load(&b->ptr); + if (ptr) { + if (std::atomic_compare_exchange_strong(&b->ptr, &ptr, 0)) { + *actual_size = b->size; + return (void *) ptr; + } + } + } } - __host__ void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) { - const int nb = k / QK4_1; - dequantize_block_q4_1<<>>(vx, y); - } + void * ptr; + CUDA_CHECK(cudaMalloc((void **) &ptr, size)); + *actual_size = size; + return ptr; +} - __host__ void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream) { - const int nb = k / QK4_2; - dequantize_block_q4_2<<>>(vx, y); +void ggml_cuda_pool_free(void * ptr, size_t size) { + for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) { + struct cuda_buffer * b = &cuda_buffer_pool[i]; + uintptr_t p = std::atomic_load(&b->ptr); + if (p == 0) { + if (std::atomic_compare_exchange_strong(&b->ptr, &p, (uintptr_t) ptr)) { + b->size = size; + return; + } + } } + fprintf(stderr, "WARNING: cuda buffer pool full, increase MAX_CUDA_BUFFERS\n"); + CUDA_CHECK(cudaFree(ptr)); +} - __host__ void dequantize_row_q4_3_cuda(const void * vx, float * y, int k, cudaStream_t stream) { - const int nb = k / QK4_3; - dequantize_block_q4_3<<>>(vx, y); +cublasHandle_t cublasH = NULL; +cudaStream_t cudaStream = NULL; + +void ggml_init_cublas(void) { + if (cublasH == NULL) { + // create cublas handle, bind a stream + CUBLAS_CHECK(cublasCreate(&cublasH)); + + CUDA_CHECK(cudaStreamCreateWithFlags(&cudaStream, cudaStreamNonBlocking)); + + CUBLAS_CHECK(cublasSetStream(cublasH, cudaStream)); + + // configure logging to stdout + // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL)); } } diff --git a/ggml-cuda.h b/ggml-cuda.h index be140606a..40877ecd5 100644 --- a/ggml-cuda.h +++ b/ggml-cuda.h @@ -1,7 +1,38 @@ +#include +#include + #ifdef __cplusplus extern "C" { #endif +#define CUDA_CHECK(err) \ + do { \ + cudaError_t err_ = (err); \ + if (err_ != cudaSuccess) { \ + fprintf(stderr, "CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \ + cudaGetErrorString(err_)); \ + exit(1); \ + } \ + } while (0) + +#define CUBLAS_CHECK(err) \ + do { \ + cublasStatus_t err_ = (err); \ + if (err_ != CUBLAS_STATUS_SUCCESS) { \ + fprintf(stderr, "cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \ + exit(1); \ + } \ + } while (0) + + + +extern cublasHandle_t cublasH; +extern cudaStream_t cudaStream; + +void ggml_init_cublas(void); +void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size); +void ggml_cuda_pool_free(void * ptr, size_t size); + void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream); void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream); void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream); diff --git a/ggml.c b/ggml.c index fb5fd1f7e..6d8796e1a 100644 --- a/ggml.c +++ b/ggml.c @@ -148,88 +148,7 @@ inline static void* ggml_aligned_malloc(size_t size) { #elif defined(GGML_USE_OPENBLAS) #include #elif defined(GGML_USE_CUBLAS) -#include -#include #include "ggml-cuda.h" - -#define CUDA_CHECK(err) \ - do { \ - cudaError_t err_ = (err); \ - if (err_ != cudaSuccess) { \ - fprintf(stderr, "CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \ - cudaGetErrorString(err_)); \ - exit(1); \ - } \ - } while (0) - -#define CUBLAS_CHECK(err) \ - do { \ - cublasStatus_t err_ = (err); \ - if (err_ != CUBLAS_STATUS_SUCCESS) { \ - fprintf(stderr, "cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \ - exit(1); \ - } \ - } while (0) - -// lock-free, thread safe buffer pool for cuda -#define MAX_CUDA_BUFFERS 16 -struct cuda_buffer { - atomic_uintptr_t ptr; - size_t size; -}; - -static struct cuda_buffer cuda_buffer_pool[MAX_CUDA_BUFFERS] = {0}; - -static void * cuda_pool_malloc(size_t size, size_t * actual_size) { - for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) { - struct cuda_buffer * b = &cuda_buffer_pool[i]; - if (b->size >= size) { - uintptr_t ptr = atomic_load(&b->ptr); - if (ptr) { - if (atomic_compare_exchange_strong(&b->ptr, &ptr, 0)) { - *actual_size = b->size; - return (void *) ptr; - } - } - } - } - - void * ptr; - CUDA_CHECK(cudaMalloc((void **) &ptr, size)); - *actual_size = size; - return ptr; -} - -static void cuda_pool_free(void * ptr, size_t size) { - for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) { - struct cuda_buffer * b = &cuda_buffer_pool[i]; - uintptr_t p = atomic_load(&b->ptr); - if (p == 0) { - if (atomic_compare_exchange_strong(&b->ptr, &p, (uintptr_t) ptr)) { - b->size = size; - return; - } - } - } - fprintf(stderr, "WARNING: cuda buffer pool full, increase MAX_CUDA_BUFFERS\n"); - CUDA_CHECK(cudaFree(ptr)); -} - -static cublasHandle_t cublasH = NULL; -static cudaStream_t cudaStream = NULL; -static void init_cublas(void) { - if (cublasH == NULL) { - // create cublas handle, bind a stream - CUBLAS_CHECK(cublasCreate(&cublasH)); - - CUDA_CHECK(cudaStreamCreateWithFlags(&cudaStream, cudaStreamNonBlocking)); - - CUBLAS_CHECK(cublasSetStream(cublasH, cudaStream)); - - // configure logging to stdout - // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL)); - } -} #endif #undef MIN @@ -3764,7 +3683,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { // initialize cuBLAS #if defined(GGML_USE_CUBLAS) - init_cublas(); + ggml_init_cublas(); #endif is_first_call = false; @@ -7617,9 +7536,9 @@ static void ggml_compute_forward_mul_mat_f32( const int d_ne = ne11 * ne01; size_t x_size, y_size, d_size; - float *d_X = cuda_pool_malloc(sizeof(float) * x_ne, &x_size); - float *d_Y = cuda_pool_malloc(sizeof(float) * y_ne, &y_size); - float *d_D = cuda_pool_malloc(sizeof(float) * d_ne, &d_size); + float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size); + float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size); + float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size); #endif for (int64_t i03 = 0; i03 < ne03; i03++) { @@ -7656,9 +7575,9 @@ static void ggml_compute_forward_mul_mat_f32( } #if defined(GGML_USE_CUBLAS) CUDA_CHECK(cudaStreamSynchronize(cudaStream)); - cuda_pool_free(d_X, x_size); - cuda_pool_free(d_Y, y_size); - cuda_pool_free(d_D, d_size); + ggml_cuda_pool_free(d_X, x_size); + ggml_cuda_pool_free(d_Y, y_size); + ggml_cuda_pool_free(d_D, d_size); #endif //printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3); @@ -7815,9 +7734,9 @@ static void ggml_compute_forward_mul_mat_f16_f32( const int d_ne = ne11 * ne01; size_t x_size, y_size, d_size; - float *d_X = cuda_pool_malloc(sizeof(float) * x_ne, &x_size); - float *d_Y = cuda_pool_malloc(sizeof(float) * y_ne, &y_size); - float *d_D = cuda_pool_malloc(sizeof(float) * d_ne, &d_size); + float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size); + float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size); + float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size); #else float * const wdata = params->wdata; #endif @@ -7884,9 +7803,9 @@ static void ggml_compute_forward_mul_mat_f16_f32( #if defined(GGML_USE_CUBLAS) CUDA_CHECK(cudaStreamSynchronize(cudaStream)); - cuda_pool_free(d_X, x_size); - cuda_pool_free(d_Y, y_size); - cuda_pool_free(d_D, d_size); + ggml_cuda_pool_free(d_X, x_size); + ggml_cuda_pool_free(d_Y, y_size); + ggml_cuda_pool_free(d_D, d_size); #endif /*printf("CBLAS F16 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/ @@ -8061,10 +7980,10 @@ static void ggml_compute_forward_mul_mat_q_f32( const int d_ne = ne11 * ne01; size_t x_size, y_size, d_size, q_size; - float *d_X = cuda_pool_malloc(sizeof(float) * x_ne, &x_size); - float *d_Y = cuda_pool_malloc(sizeof(float) * y_ne, &y_size); - float *d_D = cuda_pool_malloc(sizeof(float) * d_ne, &d_size); - float *d_Q = cuda_pool_malloc(GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], &q_size); + float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size); + float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size); + float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size); + float *d_Q = ggml_cuda_pool_malloc(GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], &q_size); void (*dequantize_row_q_cuda)(const void * x, float * y, int k, cudaStream_t stream) = NULL; if (type == GGML_TYPE_Q4_0) { @@ -8137,10 +8056,10 @@ static void ggml_compute_forward_mul_mat_q_f32( #if defined(GGML_USE_CUBLAS) CUDA_CHECK(cudaStreamSynchronize(cudaStream)); - cuda_pool_free(d_X, x_size); - cuda_pool_free(d_Y, y_size); - cuda_pool_free(d_D, d_size); - cuda_pool_free(d_Q, q_size); + ggml_cuda_pool_free(d_X, x_size); + ggml_cuda_pool_free(d_Y, y_size); + ggml_cuda_pool_free(d_D, d_size); + ggml_cuda_pool_free(d_Q, q_size); #endif //printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);