Improve cuBLAS performance by using a memory pool (#1094)
* Improve cuBLAS performance by using a memory pool * Move cuda specific definitions to ggml-cuda.h/cu * Add CXX flags to nvcc * Change memory pool synchronization mechanism to a spin lock General code cleanup
This commit is contained in:
parent
25d7abbd1f
commit
50cb666b8a
4 changed files with 170 additions and 109 deletions
116
ggml-cuda.cu
116
ggml-cuda.cu
|
@ -1,5 +1,7 @@
|
|||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <atomic>
|
||||
#include "ggml-cuda.h"
|
||||
|
||||
typedef uint16_t ggml_fp16_t;
|
||||
|
@ -29,14 +31,12 @@ static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2
|
|||
|
||||
#define QK4_3 16
|
||||
typedef struct {
|
||||
__half d; // delta
|
||||
__half m; // min
|
||||
uint8_t qs[QK4_3 / 2]; // nibbles / quants
|
||||
__half d; // delta
|
||||
__half m; // min
|
||||
uint8_t qs[QK4_3 / 2]; // nibbles / quants
|
||||
} block_q4_3;
|
||||
static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding");
|
||||
|
||||
|
||||
|
||||
static __global__ void dequantize_block_q4_0(const void * vx, float * y) {
|
||||
const block_q4_0 * x = (const block_q4_0 *) vx;
|
||||
|
||||
|
@ -131,24 +131,98 @@ static __global__ void dequantize_block_q4_3(const void * vx, float * y) {
|
|||
}
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
__host__ void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
|
||||
const int nb = k / QK4_0;
|
||||
dequantize_block_q4_0<<<nb, 1, 0, stream>>>(vx, y);
|
||||
}
|
||||
void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
|
||||
const int nb = k / QK4_0;
|
||||
dequantize_block_q4_0<<<nb, 1, 0, stream>>>(vx, y);
|
||||
}
|
||||
|
||||
__host__ void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
|
||||
const int nb = k / QK4_1;
|
||||
dequantize_block_q4_1<<<nb, 1, 0, stream>>>(vx, y);
|
||||
}
|
||||
void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
|
||||
const int nb = k / QK4_1;
|
||||
dequantize_block_q4_1<<<nb, 1, 0, stream>>>(vx, y);
|
||||
}
|
||||
|
||||
__host__ void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
|
||||
const int nb = k / QK4_2;
|
||||
dequantize_block_q4_2<<<nb, 1, 0, stream>>>(vx, y);
|
||||
}
|
||||
void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
|
||||
const int nb = k / QK4_2;
|
||||
dequantize_block_q4_2<<<nb, 1, 0, stream>>>(vx, y);
|
||||
}
|
||||
|
||||
__host__ void dequantize_row_q4_3_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
|
||||
const int nb = k / QK4_3;
|
||||
dequantize_block_q4_3<<<nb, 1, 0, stream>>>(vx, y);
|
||||
void dequantize_row_q4_3_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
|
||||
const int nb = k / QK4_3;
|
||||
dequantize_block_q4_3<<<nb, 1, 0, stream>>>(vx, y);
|
||||
}
|
||||
|
||||
// buffer pool for cuda
|
||||
#define MAX_CUDA_BUFFERS 16
|
||||
|
||||
struct scoped_spin_lock {
|
||||
std::atomic_flag& lock;
|
||||
scoped_spin_lock(std::atomic_flag& lock) : lock(lock) {
|
||||
while (lock.test_and_set(std::memory_order_acquire)) {
|
||||
; // spin
|
||||
}
|
||||
}
|
||||
~scoped_spin_lock() {
|
||||
lock.clear(std::memory_order_release);
|
||||
}
|
||||
scoped_spin_lock(const scoped_spin_lock&) = delete;
|
||||
scoped_spin_lock& operator=(const scoped_spin_lock&) = delete;
|
||||
};
|
||||
|
||||
struct cuda_buffer {
|
||||
void * ptr = nullptr;
|
||||
size_t size = 0;
|
||||
};
|
||||
|
||||
static cuda_buffer g_cuda_buffer_pool[MAX_CUDA_BUFFERS];
|
||||
static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT;
|
||||
|
||||
void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
|
||||
scoped_spin_lock lock(g_cuda_pool_lock);
|
||||
|
||||
for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
|
||||
cuda_buffer& b = g_cuda_buffer_pool[i];
|
||||
if (b.size >= size && b.ptr != nullptr) {
|
||||
void * ptr = b.ptr;
|
||||
*actual_size = b.size;
|
||||
b.ptr = nullptr;
|
||||
b.size = 0;
|
||||
return ptr;
|
||||
}
|
||||
}
|
||||
void * ptr;
|
||||
CUDA_CHECK(cudaMalloc((void **) &ptr, size));
|
||||
*actual_size = size;
|
||||
return ptr;
|
||||
}
|
||||
|
||||
void ggml_cuda_pool_free(void * ptr, size_t size) {
|
||||
scoped_spin_lock lock(g_cuda_pool_lock);
|
||||
|
||||
for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
|
||||
cuda_buffer& b = g_cuda_buffer_pool[i];
|
||||
if (b.ptr == nullptr) {
|
||||
b.ptr = ptr;
|
||||
b.size = size;
|
||||
return;
|
||||
}
|
||||
}
|
||||
fprintf(stderr, "WARNING: cuda buffer pool full, increase MAX_CUDA_BUFFERS\n");
|
||||
CUDA_CHECK(cudaFree(ptr));
|
||||
}
|
||||
|
||||
cublasHandle_t g_cublasH = NULL;
|
||||
cudaStream_t g_cudaStream = NULL;
|
||||
|
||||
void ggml_init_cublas(void) {
|
||||
if (g_cublasH == NULL) {
|
||||
// create cublas handle, bind a stream
|
||||
CUBLAS_CHECK(cublasCreate(&g_cublasH));
|
||||
|
||||
CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStream, cudaStreamNonBlocking));
|
||||
|
||||
CUBLAS_CHECK(cublasSetStream(g_cublasH, g_cudaStream));
|
||||
|
||||
// configure logging to stdout
|
||||
// CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL));
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue