dynamic estimate of required memory usage
This commit is contained in:
parent
ea10d3ded2
commit
424281a4fb
2 changed files with 30 additions and 9 deletions
29
llama.cpp
29
llama.cpp
|
@ -97,7 +97,9 @@ struct llama_context {
|
||||||
llama_model model;
|
llama_model model;
|
||||||
llama_vocab vocab;
|
llama_vocab vocab;
|
||||||
|
|
||||||
size_t mem_per_token = 0;
|
// used to estimate memory requirements experimentally
|
||||||
|
size_t mem_at_token0 = 0; // first time
|
||||||
|
size_t mem_at_token1 = 0; // second time
|
||||||
|
|
||||||
// decode output (2-dimensional array: [n_tokens][n_vocab])
|
// decode output (2-dimensional array: [n_tokens][n_vocab])
|
||||||
std::vector<float> logits;
|
std::vector<float> logits;
|
||||||
|
@ -626,14 +628,24 @@ static bool llama_eval_internal(
|
||||||
const int n_vocab = hparams.n_vocab;
|
const int n_vocab = hparams.n_vocab;
|
||||||
const int n_rot = hparams.n_embd/hparams.n_head;
|
const int n_rot = hparams.n_embd/hparams.n_head;
|
||||||
|
|
||||||
auto & mem_per_token = lctx.mem_per_token;
|
auto & mem_at_token0 = lctx.mem_at_token0;
|
||||||
|
auto & mem_at_token1 = lctx.mem_at_token1;
|
||||||
|
|
||||||
// TODO: fix this hardcoded size
|
// TODO: fix this hardcoded size
|
||||||
static size_t buf_size = 512u*1024*1024;
|
static size_t buf_size = size_t(n_ctx)*1024*1024;
|
||||||
static void * buf = malloc(buf_size);
|
static void * buf = malloc(buf_size);
|
||||||
|
|
||||||
if (mem_per_token > 0 && mem_per_token*N > buf_size) {
|
const size_t C0 = mem_at_token0; // ~base
|
||||||
const size_t buf_size_new = 1.3*(mem_per_token*N); // add 30% to account for ggml object overhead
|
const int64_t C1 = mem_at_token1 - mem_at_token0; // delta 0,1
|
||||||
|
|
||||||
|
// TODO(Green-Sky): determine relation to N (batch size)
|
||||||
|
//const size_t size_estimate = C0 + size_t(C1 * (n_past + N));
|
||||||
|
const size_t size_estimate = C0 + C1 * n_past;
|
||||||
|
|
||||||
|
//fprintf(stderr, "\n%s: size_estimate %zu bytes (%zu | %zu)\n", __func__, size_estimate, mem_per_token0, mem_per_token1);
|
||||||
|
|
||||||
|
if (mem_at_token0 > 0 && mem_at_token1 > 0 && size_estimate > buf_size) {
|
||||||
|
const size_t buf_size_new = 1.1*size_estimate; // just grow by 10%
|
||||||
//fprintf(stderr, "\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
|
//fprintf(stderr, "\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
|
||||||
|
|
||||||
// reallocate
|
// reallocate
|
||||||
|
@ -830,10 +842,13 @@ static bool llama_eval_internal(
|
||||||
memcpy(logits_out.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
|
memcpy(logits_out.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (mem_per_token == 0) {
|
if (mem_at_token0 == 0) {
|
||||||
mem_per_token = ggml_used_mem(ctx0)/N;
|
mem_at_token0 = ggml_used_mem(ctx0);
|
||||||
|
} else if (mem_at_token1 == 0) {
|
||||||
|
mem_at_token1 = ggml_used_mem(ctx0);
|
||||||
}
|
}
|
||||||
//fprintf(stderr, "used_mem = %zu\n", ggml_used_mem(ctx0));
|
//fprintf(stderr, "used_mem = %zu\n", ggml_used_mem(ctx0));
|
||||||
|
//fprintf(stderr, "estimate/used_mem = %f\n", double(size_estimate) / ggml_used_mem(ctx0));
|
||||||
|
|
||||||
ggml_free(ctx0);
|
ggml_free(ctx0);
|
||||||
|
|
||||||
|
|
10
main.cpp
10
main.cpp
|
@ -216,10 +216,16 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// determine the required inference memory per token:
|
// determine the required inference memory per token:
|
||||||
|
// (fill in mem_at_token0 and mem_at_token1)
|
||||||
// TODO: better way to do that
|
// TODO: better way to do that
|
||||||
|
// TODO(Green-Sky): move to internal and detect first time usage
|
||||||
{
|
{
|
||||||
const std::vector<llama_token> tmp = { 0, 1, 2, 3 };
|
// we make 2 evals, of batchsize to take 2 measurements, to determine base and growth
|
||||||
llama_eval(ctx, tmp.data(), tmp.size(), 0, params.n_threads);
|
std::vector<llama_token> tmp(params.n_batch*2, 2);
|
||||||
|
tmp[0] = llama_token_bos();
|
||||||
|
|
||||||
|
llama_eval(ctx, tmp.data(), params.n_batch, 0, params.n_threads);
|
||||||
|
llama_eval(ctx, tmp.data()+params.n_batch, params.n_batch, params.n_batch, params.n_threads);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (params.perplexity) {
|
if (params.perplexity) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue