cuda : improve cuda pool efficiency using virtual memory
This commit is contained in:
parent
7082d24cec
commit
0d77fbd774
2 changed files with 150 additions and 26 deletions
3
Makefile
3
Makefile
|
@ -367,9 +367,10 @@ endif # LLAMA_BLIS
|
|||
|
||||
ifdef LLAMA_CUBLAS
|
||||
MK_CPPFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include -I/usr/local/cuda/targets/aarch64-linux/include
|
||||
MK_LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib -L/usr/local/cuda/targets/aarch64-linux/lib
|
||||
MK_LDFLAGS += -lcuda -L/usr/lib/wsl/lib -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib -L/usr/local/cuda/targets/aarch64-linux/lib
|
||||
OBJS += ggml-cuda.o
|
||||
MK_NVCCFLAGS = -use_fast_math
|
||||
|
||||
ifndef JETSON_EOL_MODULE_DETECT
|
||||
MK_NVCCFLAGS += --forward-unknown-to-host-compiler
|
||||
endif # JETSON_EOL_MODULE_DETECT
|
||||
|
|
173
ggml-cuda.cu
173
ggml-cuda.cu
|
@ -88,6 +88,7 @@
|
|||
#define __trap abort
|
||||
#else
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda.h>
|
||||
#include <cublas_v2.h>
|
||||
#include <cuda_fp16.h>
|
||||
// CUDA 10.2 does not have these macro definitions.
|
||||
|
@ -213,6 +214,24 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
|
|||
} \
|
||||
} while (0)
|
||||
|
||||
// driver API
|
||||
#define CU_CHECK(err) \
|
||||
do { \
|
||||
CUresult err_ = (err); \
|
||||
if (err_ != CUDA_SUCCESS) { \
|
||||
int id; \
|
||||
cuDeviceGet(&id, 0); \
|
||||
const char * err_str; \
|
||||
cuGetErrorString(err_, &err_str); \
|
||||
fprintf(stderr, "\nCUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
|
||||
err_str); \
|
||||
fprintf(stderr, "%s\n", #err); \
|
||||
fprintf(stderr, "current device: %d\n", id); \
|
||||
GGML_ASSERT(!"CUDA error"); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
|
||||
#if CUDART_VERSION >= 12000
|
||||
#define CUBLAS_CHECK(err) \
|
||||
do { \
|
||||
|
@ -6543,13 +6562,18 @@ struct scoped_spin_lock {
|
|||
scoped_spin_lock& operator=(const scoped_spin_lock&) = delete;
|
||||
};
|
||||
|
||||
static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT;
|
||||
|
||||
#if 0
|
||||
#define DEBUG_CUDA_MALLOC
|
||||
struct cuda_buffer {
|
||||
void * ptr = nullptr;
|
||||
size_t size = 0;
|
||||
};
|
||||
|
||||
static cuda_buffer g_cuda_buffer_pool[GGML_CUDA_MAX_DEVICES][MAX_CUDA_BUFFERS];
|
||||
static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT;
|
||||
|
||||
static size_t g_cuda_pool_size[GGML_CUDA_MAX_DEVICES] = {0};
|
||||
|
||||
static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
|
||||
scoped_spin_lock lock(g_cuda_pool_lock);
|
||||
|
@ -6557,7 +6581,7 @@ static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
|
|||
CUDA_CHECK(cudaGetDevice(&id));
|
||||
#ifdef DEBUG_CUDA_MALLOC
|
||||
int nnz = 0;
|
||||
size_t max_size = 0, tot_size = 0;
|
||||
size_t max_size = 0;
|
||||
#endif
|
||||
size_t best_diff = 1ull << 36;
|
||||
int ibest = -1;
|
||||
|
@ -6566,7 +6590,6 @@ static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
|
|||
if (b.ptr != nullptr) {
|
||||
#ifdef DEBUG_CUDA_MALLOC
|
||||
++nnz;
|
||||
tot_size += b.size;
|
||||
if (b.size > max_size) max_size = b.size;
|
||||
#endif
|
||||
if (b.size >= size) {
|
||||
|
@ -6593,15 +6616,16 @@ static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
|
|||
b.size = 0;
|
||||
return ptr;
|
||||
}
|
||||
#ifdef DEBUG_CUDA_MALLOC
|
||||
fprintf(stderr, "%s: %d buffers, max_size = %u MB, tot_size = %u MB, requested %u MB\n", __func__, nnz,
|
||||
(uint32_t)(max_size/1024/1024), (uint32_t)(tot_size/1024/1024), (uint32_t)(size/1024/1024));
|
||||
#endif
|
||||
void * ptr;
|
||||
size_t look_ahead_size = (size_t) (1.05 * size);
|
||||
look_ahead_size = 256 * ((look_ahead_size + 255)/256);
|
||||
CUDA_CHECK(cudaMalloc((void **) &ptr, look_ahead_size));
|
||||
*actual_size = look_ahead_size;
|
||||
g_cuda_pool_size[id] += look_ahead_size;
|
||||
#ifdef DEBUG_CUDA_MALLOC
|
||||
fprintf(stderr, "%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, requested %u MB\n", __func__, id, nnz,
|
||||
(uint32_t)(max_size/1024/1024), (uint32_t)(g_cuda_pool_size[id]/1024/1024), (uint32_t)(size/1024/1024));
|
||||
#endif
|
||||
return ptr;
|
||||
}
|
||||
|
||||
|
@ -6620,7 +6644,106 @@ static void ggml_cuda_pool_free(void * ptr, size_t size) {
|
|||
}
|
||||
fprintf(stderr, "WARNING: cuda buffer pool full, increase MAX_CUDA_BUFFERS\n");
|
||||
CUDA_CHECK(cudaFree(ptr));
|
||||
g_cuda_pool_size[id] -= size;
|
||||
}
|
||||
#else
|
||||
|
||||
static std::vector<CUmemGenericAllocationHandle> g_cuda_pool_handles[GGML_CUDA_MAX_DEVICES];
|
||||
static CUdeviceptr g_cuda_pool_addr[GGML_CUDA_MAX_DEVICES] = {0};
|
||||
static size_t g_cuda_pool_size[GGML_CUDA_MAX_DEVICES] = {0};
|
||||
static size_t g_cuda_pool_used[GGML_CUDA_MAX_DEVICES] = {0};
|
||||
|
||||
static const size_t CUDA_POOL_MAX_SIZE = 1ull << 36; // 64 GB
|
||||
|
||||
//#define DEBUG_CUDA_MALLOC
|
||||
|
||||
#define ggml_cuda_pool_malloc(size, actual_size) ggml_cuda_pool_malloc_(size, actual_size, #size " " #actual_size)
|
||||
static void * ggml_cuda_pool_malloc_(size_t size, size_t * actual_size, const char * call) {
|
||||
scoped_spin_lock lock(g_cuda_pool_lock);
|
||||
int id;
|
||||
CUDA_CHECK(cudaGetDevice(&id));
|
||||
|
||||
size_t avail = g_cuda_pool_size[id] - g_cuda_pool_used[id];
|
||||
|
||||
if (size > avail) {
|
||||
size_t reserve_size = size - avail;
|
||||
|
||||
// allocate more physical memory
|
||||
CUmemAllocationProp prop = {};
|
||||
prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
|
||||
prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
|
||||
prop.location.id = id;
|
||||
|
||||
// get the minimum allocation granularity for this device
|
||||
size_t granularity = 0;
|
||||
CU_CHECK(cuMemGetAllocationGranularity(&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM));
|
||||
|
||||
// round up to the nearest granularity
|
||||
reserve_size = granularity * ((reserve_size + granularity - 1) / granularity);
|
||||
|
||||
GGML_ASSERT(g_cuda_pool_size[id] + reserve_size <= CUDA_POOL_MAX_SIZE);
|
||||
|
||||
CUmemGenericAllocationHandle handle;
|
||||
CU_CHECK(cuMemCreate(&handle, reserve_size, &prop, 0));
|
||||
|
||||
// reserve virtual address space (if not already reserved)
|
||||
if (g_cuda_pool_addr[id] == 0) {
|
||||
CU_CHECK(cuMemAddressReserve(&g_cuda_pool_addr[id], CUDA_POOL_MAX_SIZE, 0, 0, 0));
|
||||
}
|
||||
|
||||
// map at the end of the pool
|
||||
CU_CHECK(cuMemMap(g_cuda_pool_addr[id] + g_cuda_pool_size[id], reserve_size, 0, handle, 0));
|
||||
|
||||
// set access
|
||||
CUmemAccessDesc access = {};
|
||||
access.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
|
||||
access.location.id = id;
|
||||
access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
|
||||
CU_CHECK(cuMemSetAccess(g_cuda_pool_addr[id] + g_cuda_pool_size[id], reserve_size, &access, 1));
|
||||
|
||||
// add to the pool
|
||||
g_cuda_pool_handles[id].push_back(handle);
|
||||
g_cuda_pool_size[id] += reserve_size;
|
||||
|
||||
printf("cuda pool[%d]: size increased to %llu MB (reserved %llu MB) [%s]\n",
|
||||
id, (unsigned long long) (g_cuda_pool_size[id]/1024/1024),
|
||||
(unsigned long long) (reserve_size/1024/1024), call);
|
||||
}
|
||||
|
||||
GGML_ASSERT(g_cuda_pool_addr[id] != 0);
|
||||
|
||||
void * ptr = (void *) (g_cuda_pool_addr[id] + g_cuda_pool_used[id]);
|
||||
*actual_size = size;
|
||||
g_cuda_pool_used[id] += size;
|
||||
|
||||
#ifdef DEBUG_CUDA_MALLOC
|
||||
printf("cuda pool[%d]: allocated %llu bytes at %llx [%s]\n", id, (unsigned long long) size, ptr, call);
|
||||
#endif
|
||||
|
||||
return ptr;
|
||||
|
||||
GGML_UNUSED(call);
|
||||
}
|
||||
|
||||
#define ggml_cuda_pool_free(ptr, size) ggml_cuda_pool_free_(ptr, size, #ptr " " #size)
|
||||
static void ggml_cuda_pool_free_(void * ptr, size_t size, const char * call) {
|
||||
scoped_spin_lock lock(g_cuda_pool_lock);
|
||||
int id;
|
||||
CUDA_CHECK(cudaGetDevice(&id));
|
||||
|
||||
#ifdef DEBUG_CUDA_MALLOC
|
||||
printf("cuda pool[%d]: free %llu bytes at %llx [%s]\n", id, (unsigned long long) size, ptr, call);
|
||||
#endif
|
||||
|
||||
g_cuda_pool_used[id] -= size;
|
||||
|
||||
// all deallocations must be in reverse order of the allocations
|
||||
GGML_ASSERT(ptr == (void *) (g_cuda_pool_addr[id] + g_cuda_pool_used[id]));
|
||||
|
||||
GGML_UNUSED(call);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
static bool g_cublas_loaded = false;
|
||||
|
||||
|
@ -7437,13 +7560,13 @@ inline void ggml_cuda_op_mul_mat_cublas(
|
|||
|
||||
ggml_cuda_pool_free(dst_f16, dst_as);
|
||||
|
||||
if (src0_as != 0) {
|
||||
ggml_cuda_pool_free(src0_as_f16, src0_as);
|
||||
}
|
||||
|
||||
if (src1_as != 0) {
|
||||
ggml_cuda_pool_free(src1_as_f16, src1_as);
|
||||
}
|
||||
|
||||
if (src0_as != 0) {
|
||||
ggml_cuda_pool_free(src0_as_f16, src0_as);
|
||||
}
|
||||
}
|
||||
else {
|
||||
float * src0_ddq_as_f32 = nullptr;
|
||||
|
@ -7800,14 +7923,14 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s
|
|||
CUDA_CHECK(cudaMemcpyAsync(dst->data, dst_ddf, ggml_nbytes(dst), cudaMemcpyDeviceToHost, main_stream));
|
||||
}
|
||||
|
||||
if (src0_asf > 0) {
|
||||
ggml_cuda_pool_free(src0_ddf, src0_asf);
|
||||
if (dst_asf > 0) {
|
||||
ggml_cuda_pool_free(dst_ddf, dst_asf);
|
||||
}
|
||||
if (src1_asf > 0) {
|
||||
ggml_cuda_pool_free(src1_ddf, src1_asf);
|
||||
}
|
||||
if (dst_asf > 0) {
|
||||
ggml_cuda_pool_free(dst_ddf, dst_asf);
|
||||
if (src0_asf > 0) {
|
||||
ggml_cuda_pool_free(src0_ddf, src0_asf);
|
||||
}
|
||||
|
||||
if (dst->backend == GGML_BACKEND_CPU) {
|
||||
|
@ -8119,17 +8242,17 @@ static void ggml_cuda_op_mul_mat(
|
|||
CUDA_CHECK(ggml_cuda_set_device(id));
|
||||
|
||||
// free buffers again when done
|
||||
if (src0_as[id] > 0) {
|
||||
ggml_cuda_pool_free(src0_dd[id], src0_as[id]);
|
||||
}
|
||||
if (src1_asf[id] > 0) {
|
||||
ggml_cuda_pool_free(src1_ddf[id], src1_asf[id]);
|
||||
if (dst_as[id] > 0) {
|
||||
ggml_cuda_pool_free(dst_dd[id], dst_as[id]);
|
||||
}
|
||||
if (src1_asq[id] > 0) {
|
||||
ggml_cuda_pool_free(src1_ddq[id], src1_asq[id]);
|
||||
}
|
||||
if (dst_as[id] > 0) {
|
||||
ggml_cuda_pool_free(dst_dd[id], dst_as[id]);
|
||||
if (src1_asf[id] > 0) {
|
||||
ggml_cuda_pool_free(src1_ddf[id], src1_asf[id]);
|
||||
}
|
||||
if (src0_as[id] > 0) {
|
||||
ggml_cuda_pool_free(src0_dd[id], src0_as[id]);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -8497,12 +8620,12 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
|||
cu_compute_type,
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
|
||||
if (ptrs_src_s != 0) {
|
||||
ggml_cuda_pool_free(ptrs_src, ptrs_src_s);
|
||||
}
|
||||
if (ptrs_dst_s != 0) {
|
||||
ggml_cuda_pool_free(ptrs_dst, ptrs_dst_s);
|
||||
}
|
||||
if (ptrs_src_s != 0) {
|
||||
ggml_cuda_pool_free(ptrs_src, ptrs_src_s);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue