fix scratch buffer size, re-enable vmm pool for all devices

This commit is contained in:
slaren 2023-12-26 01:28:39 +01:00
parent 23c6dd677b
commit da9fc775a3
2 changed files with 4 additions and 3 deletions

View file

@ -6741,7 +6741,7 @@ static void ggml_cuda_pool_free_vmm(int device, void * ptr, size_t size) {
} }
static void * ggml_cuda_pool_malloc(int device, size_t size, size_t * actual_size) { static void * ggml_cuda_pool_malloc(int device, size_t size, size_t * actual_size) {
if (device == g_main_device && g_device_caps[device].vmm) { if (g_device_caps[device].vmm) {
return ggml_cuda_pool_malloc_vmm(device, size, actual_size); return ggml_cuda_pool_malloc_vmm(device, size, actual_size);
} else { } else {
return ggml_cuda_pool_malloc_leg(device, size, actual_size); return ggml_cuda_pool_malloc_leg(device, size, actual_size);
@ -6749,7 +6749,7 @@ static void * ggml_cuda_pool_malloc(int device, size_t size, size_t * actual_siz
} }
static void ggml_cuda_pool_free(int device, void * ptr, size_t size) { static void ggml_cuda_pool_free(int device, void * ptr, size_t size) {
if (device == g_main_device && g_device_caps[device].vmm) { if (g_device_caps[device].vmm) {
ggml_cuda_pool_free_vmm(device, ptr, size); ggml_cuda_pool_free_vmm(device, ptr, size);
} else { } else {
ggml_cuda_pool_free_leg(device, ptr, size); ggml_cuda_pool_free_leg(device, ptr, size);

View file

@ -9519,7 +9519,8 @@ struct llama_context * llama_new_context_with_model(
ctx->alloc = ggml_allocr_new_from_buffer(ctx->buf_alloc); ctx->alloc = ggml_allocr_new_from_buffer(ctx->buf_alloc);
#if defined(GGML_USE_CUBLAS) && !defined(LLAMA_GGML_BACKEND_CUDA_TEST) #if defined(GGML_USE_CUBLAS) && !defined(LLAMA_GGML_BACKEND_CUDA_TEST)
if (model->n_gpu_layers > 0) { if (model->n_gpu_layers > 0) {
ggml_cuda_set_scratch_size(alloc_size); // the CPU buffer adds this padding in case the malloc buffer is not aligned, so we need to do the same for the GPU buffer, since we use the same offsets
ggml_cuda_set_scratch_size(alloc_size + 64);
LLAMA_LOG_INFO("%s: VRAM scratch buffer: %.2f MiB\n", __func__, alloc_size / 1024.0 / 1024.0); LLAMA_LOG_INFO("%s: VRAM scratch buffer: %.2f MiB\n", __func__, alloc_size / 1024.0 / 1024.0);
// calculate total VRAM usage // calculate total VRAM usage