Fix KV cache size for F32
This commit is contained in:
parent
0b4e849a24
commit
d0f7519338
1 changed files with 11 additions and 10 deletions
21
llama.cpp
21
llama.cpp
|
@ -230,11 +230,16 @@ struct llama_context {
|
||||||
|
|
||||||
static bool kv_cache_init(
|
static bool kv_cache_init(
|
||||||
const struct llama_hparams & hparams,
|
const struct llama_hparams & hparams,
|
||||||
const size_t mem_bytes,
|
|
||||||
struct llama_kv_cache & cache,
|
struct llama_kv_cache & cache,
|
||||||
ggml_type wtype,
|
ggml_type wtype,
|
||||||
int n_ctx) {
|
int n_ctx) {
|
||||||
cache.buf.resize(mem_bytes);
|
const int n_embd = hparams.n_embd;
|
||||||
|
const int n_layer = hparams.n_layer;
|
||||||
|
|
||||||
|
const int n_mem = n_layer*n_ctx;
|
||||||
|
const int n_elements = n_embd*n_mem;
|
||||||
|
|
||||||
|
cache.buf.resize(2*n_elements*ggml_type_size(wtype) + 2u*MB);
|
||||||
|
|
||||||
struct ggml_init_params params;
|
struct ggml_init_params params;
|
||||||
params.mem_size = cache.buf.size();
|
params.mem_size = cache.buf.size();
|
||||||
|
@ -247,12 +252,6 @@ static bool kv_cache_init(
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int n_embd = hparams.n_embd;
|
|
||||||
const int n_layer = hparams.n_layer;
|
|
||||||
|
|
||||||
const int n_mem = n_layer*n_ctx;
|
|
||||||
const int n_elements = n_embd*n_mem;
|
|
||||||
|
|
||||||
cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
|
cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
|
||||||
cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
|
cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
|
||||||
|
|
||||||
|
@ -485,6 +484,8 @@ static bool llama_model_load(
|
||||||
|
|
||||||
// print memory requirements
|
// print memory requirements
|
||||||
{
|
{
|
||||||
|
const size_t scale = memory_type == GGML_TYPE_F32 ? 2 : 1;
|
||||||
|
|
||||||
// this is the total memory required to run the inference
|
// this is the total memory required to run the inference
|
||||||
const size_t mem_required =
|
const size_t mem_required =
|
||||||
ctx_size +
|
ctx_size +
|
||||||
|
@ -494,7 +495,7 @@ static bool llama_model_load(
|
||||||
|
|
||||||
// this is the memory required by one llama_state
|
// this is the memory required by one llama_state
|
||||||
const size_t mem_required_state =
|
const size_t mem_required_state =
|
||||||
MEM_REQ_KV_SELF.at (model.type);
|
scale*MEM_REQ_KV_SELF.at(model.type);
|
||||||
|
|
||||||
fprintf(stderr, "%s: mem required = %7.2f MB (+ %7.2f MB per state)\n", __func__,
|
fprintf(stderr, "%s: mem required = %7.2f MB (+ %7.2f MB per state)\n", __func__,
|
||||||
mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0);
|
mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0);
|
||||||
|
@ -1634,7 +1635,7 @@ struct llama_context * llama_init_from_file(
|
||||||
|
|
||||||
// reserve memory for context buffers
|
// reserve memory for context buffers
|
||||||
{
|
{
|
||||||
if (!kv_cache_init(ctx->model.hparams, MEM_REQ_KV_SELF.at(ctx->model.type), ctx->model.kv_self, memory_type, ctx->model.hparams.n_ctx)) {
|
if (!kv_cache_init(ctx->model.hparams, ctx->model.kv_self, memory_type, ctx->model.hparams.n_ctx)) {
|
||||||
fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);
|
fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);
|
||||||
llama_free(ctx);
|
llama_free(ctx);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue