cuda malloc:

- added functionality to find the smallest fitting buffer instead of the first found buffer that >= than requested
-- this prevents that two buffer allocations in sequence can take a huge buffer for a small tensor and then require a new buffer for the 2nd tensor
-- in my test it saved 1GB VRAM that are now free for more offloading

cuda free buffers:
- added a helper function that frees all unused buffers from a device to prevent huge F32 buffers from cuBLAS occupying VRAM needlessly after token ingestion

libfalcon:
- corrected vram_overhead calculation to account for the actual non-weight buffers needed during inference
- added vram_overhead for n_batch > 1 as this switches the ingestion into a 32 bit dequantization mode for cu_blas which needs almost 2 GB VRAM buffers
- corrected the automated layer distribution to fill VRAM as much as possible with layers

From here on it's recommended to use --ngl 100 and -b 1  for CUDA processing.
In addition -t is recommended using 1 or 1 less threads than CPU cores (depends on CPU, GPU used)
This commit is contained in:
John 2023-06-19 05:03:28 +02:00
parent b4028edb9a
commit 7c8249ff6b
5 changed files with 104 additions and 49 deletions

View file

@ -397,7 +397,9 @@ int main(int argc, char ** argv) {
embd.erase(embd.begin(), embd.begin() + i); embd.erase(embd.begin(), embd.begin() + i);
} }
} }
// We have buffers from the warmup run that won't all align with a batched run
if (params.n_batch > 1 && embd.size() > 1)
ggml_cuda_pool_free_all(-1);
// evaluate tokens in batches // evaluate tokens in batches
// embd is typically prepared beforehand to fit within a batch, but not always // embd is typically prepared beforehand to fit within a batch, but not always
for (int i = 0; i < (int) embd.size(); i += params.n_batch) { for (int i = 0; i < (int) embd.size(); i += params.n_batch) {
@ -411,7 +413,9 @@ int main(int argc, char ** argv) {
} }
n_past += n_eval; n_past += n_eval;
} }
// frees unused allocations, those during batch processing are of different size than single token eval
if (params.n_batch > 1 && embd.size() > 1)
ggml_cuda_pool_free_all(-1);
if (embd.size() > 0 && !path_session.empty()) { if (embd.size() > 0 && !path_session.empty()) {
session_tokens.insert(session_tokens.end(), embd.begin(), embd.end()); session_tokens.insert(session_tokens.end(), embd.begin(), embd.end());
n_session_consumed = session_tokens.size(); n_session_consumed = session_tokens.size();

View file

@ -1428,17 +1428,29 @@ static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
scoped_spin_lock lock(g_cuda_pool_lock); scoped_spin_lock lock(g_cuda_pool_lock);
int id; int id;
CUDA_CHECK(cudaGetDevice(&id)); CUDA_CHECK(cudaGetDevice(&id));
size_t min_size_diff = SIZE_MAX;
size_t min_size_diff_ok = size * 0.05; // wiggle room
cuda_buffer* best_fit = nullptr; // candidate pointer
for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) { for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
cuda_buffer& b = g_cuda_buffer_pool[id][i]; cuda_buffer& b = g_cuda_buffer_pool[id][i];
if (b.size >= size && b.ptr != nullptr) { if (b.size >= size && b.ptr != nullptr) {
void * ptr = b.ptr; size_t size_diff = b.size - size;
*actual_size = b.size; if (size_diff < min_size_diff) {
b.ptr = nullptr; best_fit = &b;
b.size = 0; min_size_diff = size_diff;
return ptr; if (size_diff < min_size_diff_ok) {
break;
}
}
} }
} }
if (best_fit != nullptr) {
*actual_size = best_fit->size;
void * ptr = best_fit->ptr;
best_fit->ptr = nullptr;
best_fit->size = 0;
return ptr;
}
void * ptr; void * ptr;
CUDA_CHECK(cudaMalloc((void **) &ptr, size)); CUDA_CHECK(cudaMalloc((void **) &ptr, size));
*actual_size = size; *actual_size = size;
@ -1462,6 +1474,30 @@ static void ggml_cuda_pool_free(void * ptr, size_t size) {
CUDA_CHECK(cudaFree(ptr)); CUDA_CHECK(cudaFree(ptr));
} }
// free all buffers that are not currently in use
void ggml_cuda_pool_free_all(int device_id) {
while (atomic_flag_test_and_set(&g_cuda_pool_lock)) {}
int start_id = (device_id < 0) ? 0 : device_id;
int end_id = (device_id < 0) ? GGML_CUDA_MAX_DEVICES : device_id + 1;
for (int id = start_id; id < end_id; ++id) {
for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
cuda_buffer* b = &(g_cuda_buffer_pool[id][i]);
if (b->ptr != NULL) {
cudaError_t err = cudaFree(b->ptr);
if (err != cudaSuccess) {
fprintf(stderr, "ERROR: CUDA buffer free failed: %s\n", cudaGetErrorString(err));
} else {
b->ptr = NULL;
b->size = 0;
}
}
}
}
atomic_flag_clear(&g_cuda_pool_lock);
}
static void * g_scratch_buffer = nullptr; static void * g_scratch_buffer = nullptr;
static size_t g_scratch_size = 1024*1024*1024; // 1 GB by default static size_t g_scratch_size = 1024*1024*1024; // 1 GB by default

View file

@ -20,6 +20,7 @@ bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_
size_t ggml_cuda_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst); size_t ggml_cuda_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
void ggml_cuda_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst, void * wdata, size_t wsize); void ggml_cuda_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst, void * wdata, size_t wsize);
void ggml_cuda_pool_free_all(int device_id);
// TODO: export these with GGML_API // TODO: export these with GGML_API
void * ggml_cuda_host_malloc(size_t size); void * ggml_cuda_host_malloc(size_t size);
void ggml_cuda_host_free(void * ptr); void ggml_cuda_host_free(void * ptr);

View file

@ -168,6 +168,8 @@ struct falcon_model {
std::vector<falcon_layer> layers; std::vector<falcon_layer> layers;
int n_gpu_layers; int n_gpu_layers;
int i_gpu_start;
int i_gpu_last;
// context // context
struct ggml_context * ctx = NULL; struct ggml_context * ctx = NULL;
@ -911,7 +913,9 @@ struct falcon_context_params falcon_context_default_params() {
struct falcon_context_params result = { struct falcon_context_params result = {
/*.n_ctx =*/ 512, /*.n_ctx =*/ 512,
/*.n_batch =*/ 512, /*.n_batch =*/ 512,
/*.gpu_layers =*/ 0, /*.n_gpu_layers =*/ 0,
/*.i_gpu_start =*/ -1,
/*.i_gpu_last =*/ -1,
/*.main_gpu =*/ 0, /*.main_gpu =*/ 0,
/*.tensor_split =*/ {0}, /*.tensor_split =*/ {0},
/*.seed =*/ -1, /*.seed =*/ -1,
@ -1068,7 +1072,7 @@ static void falcon_model_load_internal(
{ {
switch (hparams.n_layer) { switch (hparams.n_layer) {
case 32: model.type = e_model::FALCON_7B; break; case 32: model.type = e_model::FALCON_7B; break;
case 40: model.type = e_model::FALCON_40B; break; case 60: model.type = e_model::FALCON_40B; break;
default: default:
{ {
if (hparams.version == 7) { if (hparams.version == 7) {
@ -1166,23 +1170,30 @@ if (n_gpu_layers > 0)
#define LLAMA_BACKEND_OFFLOAD_SPLIT GGML_BACKEND_CPU #define LLAMA_BACKEND_OFFLOAD_SPLIT GGML_BACKEND_CPU
#endif #endif
size_t vram_total=0; size_t vram_total=0;
size_t vram_free=0; size_t vram_free=0;
size_t vram_reserved=1024*1024*512; //will be adapted by model const size_t vram_reserved=512*MB; // that amount of VRAM is to stay free on GPU (headroom for other processes - may be reduced in pure server environments)
size_t vram_overhead = 1250*MB; // this amount of vram is estimated for non weight storage buffers on VRAM (no big difference between 7B and 40B, needs to increase when more work is offloaded in the future)
// cublas is used in 32 bit mode, temporary cuda storage/conversion buffers are needed for batch ingestion ( could be run in 16 bit mode without performance downgrade and save half the VRAM)
if (model.type == FALCON_40B && n_batch > 1)
vram_overhead += (1024+288+256) * MB;
if (model.type == FALCON_7B && n_batch > 1)
vram_overhead += (315+80+78) * MB;
#if defined(GGML_USE_CUBLAS) #if defined(GGML_USE_CUBLAS)
cudaMemGetInfo(&vram_free, &vram_total); // this should go in ggml-cuda.cu but I don't want to make Johannes life harder by modifying that yet cudaMemGetInfo(&vram_free, &vram_total); // this should go in ggml-cuda.cu but I don't want to make Johannes life harder by modifying that yet
fprintf(stderr, "%s: VRAM free: %7.2f MB of %7.2f MB (already used: %7.2f MB)\n", __func__, vram_free/MB*1.0, vram_total/MB*1.0, (vram_total-vram_free)/MB*1.0); fprintf(stderr, "%s: VRAM free: %7.2f MB of %7.2f MB (in use: %7.2f MB)\n", __func__, vram_free/MB*1.0, vram_total/MB*1.0, (vram_total-vram_free)/MB*1.0);
#endif #endif
// prepare memory for the weights // prepare memory for the weights
size_t vram_weights = 0; size_t vram_weights = 0;
size_t vram_scratch = 0; size_t vram_scratch = 0;
size_t vram_overhead = 0;
(void) vram_scratch; (void) vram_scratch;
(void) n_batch; (void) n_batch;
// calculate scratch buffer size and allocate it // calculate scratch buffer size and allocate it
#ifdef GGML_USE_CUBLAS #ifdef GGML_USE_CUBLAS
vram_scratch = n_batch * MB; // vram_scratch = n_batch * MB;
vram_scratch = 0; // these are not used until we have multi operation support
ggml_cuda_set_scratch_size(vram_scratch); ggml_cuda_set_scratch_size(vram_scratch);
if (n_gpu_layers > 0) { if (n_gpu_layers > 0) {
@ -1203,22 +1214,7 @@ size_t vram_reserved=1024*1024*512; //will be adapted by model
ml->ggml_ctx = ctx; ml->ggml_ctx = ctx;
model.tok_embeddings = ml->get_tensor("transformer.word_embeddings.weight", {n_embd, n_vocab}, GGML_BACKEND_CPU); model.tok_embeddings = ml->get_tensor("transformer.word_embeddings.weight", {n_embd, n_vocab}, GGML_BACKEND_CPU);
// I did not analyze the cause but that's the overhead that is dynamically added to the VRAM at first inference
// same goes with reserved, most likely we can skip both for a proper size calculation.
// If the below values are not correct GPU memory will fill up to 100%, resulting in a extreme slowdown of inference
if (model.type == FALCON_40B)
{
vram_reserved=1900*MB;
vram_overhead+=2700*MB;
}
else
{
vram_reserved=768*MB;
vram_overhead+=1200*MB;
}
ggml_backend backend_norm; ggml_backend backend_norm;
ggml_backend backend_output; ggml_backend backend_output;
// disabled norm/output offloading until further tests, causes silent crash at the moment // disabled norm/output offloading until further tests, causes silent crash at the moment
@ -1240,10 +1236,8 @@ size_t vram_reserved=1024*1024*512; //will be adapted by model
if (backend_norm != GGML_BACKEND_CPU) if (backend_norm != GGML_BACKEND_CPU)
{ {
vram_weights += ggml_nbytes(model.output_norm); vram_weights += ggml_nbytes(model.output_norm) + ggml_nbytes(model.output_norm_b);
vram_weights += ggml_nbytes(model.output_norm_b); vram_free -= ggml_nbytes(model.output_norm) + ggml_nbytes(model.output_norm_b);
vram_free -= ggml_nbytes(model.output_norm);
vram_free -= ggml_nbytes(model.output_norm_b);
} }
if (backend_output != GGML_BACKEND_CPU) if (backend_output != GGML_BACKEND_CPU)
{ {
@ -1252,12 +1246,14 @@ size_t vram_reserved=1024*1024*512; //will be adapted by model
} }
const int i_gpu_start = n_layer - n_gpu_layers; const int i_gpu_start = n_layer - n_gpu_layers;
int i_gpu_end = n_layer; // allows to terminate the offloading earlier. TODO: instead do a proper calculation run and determine the start before the loop int i_gpu_last = n_layer; // allows to terminate the offloading earlier. TODO: instead do a proper calculation run and determine the start before the loop
model.i_gpu_start = i_gpu_start;
model.i_gpu_last = i_gpu_last;
model.layers.resize(n_layer); model.layers.resize(n_layer);
for (uint32_t i = 0; i < n_layer; ++i) { for (uint32_t i = 0; i < n_layer; ++i) {
const ggml_backend backend = (int(i) < i_gpu_start || int(i) >= i_gpu_end) ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT const ggml_backend backend = (int(i) < i_gpu_start || int(i) > i_gpu_last) ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT
const ggml_backend backend_split = (int(i) < i_gpu_start || int(i) >= i_gpu_end) ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT const ggml_backend backend_split = (int(i) < i_gpu_start || int(i) > i_gpu_last) ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT
auto & layer = model.layers[i]; auto & layer = model.layers[i];
@ -1288,14 +1284,15 @@ size_t vram_reserved=1024*1024*512; //will be adapted by model
vram_layer = calculate_layer_vram_bytes(layer); vram_layer = calculate_layer_vram_bytes(layer);
vram_weights += vram_layer; vram_weights += vram_layer;
vram_free = (vram_layer > vram_free) ? 0 : vram_free - vram_layer; // simulate the layer being loaded in VRAM vram_free = (vram_layer > vram_free) ? 0 : vram_free - vram_layer; // simulate the layer being loaded in VRAM
// test if we have enough VRAM to load the next layer
if (vram_free <= (vram_overhead+vram_scratch+vram_reserved)) if (i < n_layer && vram_free <= (vram_overhead+vram_scratch+vram_reserved+vram_layer))
{ {
// this needs some polishing (instead of fiddling with --ngl I'd like the option to auto-fill the vram with as many layers as possible as an alternative) // this needs some polishing (instead of fiddling with --ngl I'd like the option to auto-fill the vram with as many layers as possible as an alternative)
fprintf(stderr, "WARNING: Not enough VRAM to load the model as configured - at layer %d of %d\n", i, n_layer); fprintf(stderr, "INFO: Not enough VRAM to load all requested layers - at layer %d of %d: skipping\n", i, n_layer);
n_gpu_layers = i+1; n_gpu_layers = i+1;
model.n_gpu_layers = n_gpu_layers; model.n_gpu_layers = n_gpu_layers;
i_gpu_end = i; i_gpu_last = i;
model.i_gpu_last = i_gpu_last;
} }
} }
@ -1335,7 +1332,7 @@ size_t vram_reserved=1024*1024*512; //will be adapted by model
if (n_gpu_layers > (int) hparams.n_layer) { if (n_gpu_layers > (int) hparams.n_layer) {
fprintf(stderr, "%s: offloading output layer to GPU\n", __func__); fprintf(stderr, "%s: offloading output layer to GPU\n", __func__);
} }
fprintf(stderr, "%s: total VRAM used: %zu MB\n", fprintf(stderr, "%s: estimated VRAM usage: %zu MB\n",
__func__, (vram_weights + vram_scratch + vram_overhead + MB - 1) / MB); // round up __func__, (vram_weights + vram_scratch + vram_overhead + MB - 1) / MB); // round up
#else #else
(void) n_gpu_layers; (void) n_gpu_layers;
@ -1468,7 +1465,9 @@ static bool falcon_eval_internal(
// ggml_type wtype = ggml_ftype_to_ggml_type((ggml_ftype) (model.hparams.ftype)); // ggml_type wtype = ggml_ftype_to_ggml_type((ggml_ftype) (model.hparams.ftype));
const int sizeof_wtype = ggml_type_sizef(wtype); const int sizeof_wtype = ggml_type_sizef(wtype);
const int i_gpu_start = n_layer - n_gpu_layers; // const int i_gpu_start = n_layer - n_gpu_layers;
const int i_gpu_start = lctx.model.i_gpu_start;
const int i_gpu_last = lctx.model.i_gpu_last > 0 ? lctx.model.i_gpu_last : n_layer;
(void) i_gpu_start; (void) i_gpu_start;
// offload functions set the tensor output backend to GPU // offload functions set the tensor output backend to GPU
@ -1492,7 +1491,7 @@ static bool falcon_eval_internal(
offload_func_t offload_func = llama_nop; offload_func_t offload_func = llama_nop;
#ifdef GGML_USE_CUBLAS #ifdef GGML_USE_CUBLAS
if (il >= i_gpu_start) { if (il >= i_gpu_start && il < i_gpu_last) {
offload_func = ggml_cuda_assign_buffers; // sets the output backend to GPU offload_func = ggml_cuda_assign_buffers; // sets the output backend to GPU
} }
#endif // GGML_USE_CUBLAS #endif // GGML_USE_CUBLAS
@ -1507,7 +1506,7 @@ static bool falcon_eval_internal(
layernorm_output = ggml_norm(ctx0, inpL); layernorm_output = ggml_norm(ctx0, inpL);
ggml_tensor * il_a = ggml_mul(ctx0, layernorm_output, model.layers[il].input_layernorm); ggml_tensor * il_a = ggml_mul(ctx0, layernorm_output, model.layers[il].input_layernorm);
offload_func(il_a); offload_func(il_a); // (todo: uses vram scratch)
layernorm_output = ggml_add(ctx0, layernorm_output = ggml_add(ctx0,
il_a, il_a,
@ -1737,6 +1736,15 @@ static bool falcon_eval_internal(
// run the computation // run the computation
ggml_build_forward_expand(&gf, cur); ggml_build_forward_expand(&gf, cur);
#if 0
// use to confirm vram_overhead is correct
size_t vram_total=0;
size_t vram_free=0;
#if defined(GGML_USE_CUBLAS)
cudaMemGetInfo(&vram_free, &vram_total); // this should go in ggml-cuda.cu but I don't want to make Johannes life harder by modifying that yet
fprintf(stderr, "\n%s: VRAM free: %7.2f MB of %7.2f MB (in use: %7.2f MB)\n", __func__, vram_free/MB*1.0, vram_total/MB*1.0, (vram_total-vram_free)/MB*1.0);
#endif
#endif
#ifdef GGML_USE_METAL #ifdef GGML_USE_METAL
if (lctx.ctx_metal && N == 1) { if (lctx.ctx_metal && N == 1) {
@ -2701,7 +2709,11 @@ struct falcon_context * falcon_init_from_file(
llama_free(ctx); llama_free(ctx);
return nullptr; return nullptr;
} }
params.n_gpu_layers = ctx->model.n_gpu_layers; // model_load_internal() may change this // model_load_internal() may change this if VRAM runs out
params.n_gpu_layers = ctx->model.n_gpu_layers;
params.i_gpu_start = ctx->model.i_gpu_start;
params.i_gpu_last = ctx->model.i_gpu_last;
// reserve memory for context buffers // reserve memory for context buffers
if (!params.vocab_only) { if (!params.vocab_only) {

View file

@ -75,6 +75,8 @@ extern "C" {
int n_ctx; // text context int n_ctx; // text context
int n_batch; // prompt processing batch size int n_batch; // prompt processing batch size
int n_gpu_layers; // number of layers to store in VRAM int n_gpu_layers; // number of layers to store in VRAM
int i_gpu_start; // first gpu layer
int i_gpu_last; // last gpu layer
int main_gpu; // the GPU that is used for scratch and small tensors int main_gpu; // the GPU that is used for scratch and small tensors
float tensor_split[LLAMA_MAX_DEVICES]; // how to split layers across multiple GPUs float tensor_split[LLAMA_MAX_DEVICES]; // how to split layers across multiple GPUs
int seed; // RNG seed, -1 for random int seed; // RNG seed, -1 for random