Fix KV cache size for F32

This commit is contained in:
Georgi Gerganov 2023-03-24 22:58:00 +02:00
parent 0b4e849a24
commit d0f7519338
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -230,11 +230,16 @@ struct llama_context {
static bool kv_cache_init( static bool kv_cache_init(
const struct llama_hparams & hparams, const struct llama_hparams & hparams,
const size_t mem_bytes,
struct llama_kv_cache & cache, struct llama_kv_cache & cache,
ggml_type wtype, ggml_type wtype,
int n_ctx) { int n_ctx) {
cache.buf.resize(mem_bytes); const int n_embd = hparams.n_embd;
const int n_layer = hparams.n_layer;
const int n_mem = n_layer*n_ctx;
const int n_elements = n_embd*n_mem;
cache.buf.resize(2*n_elements*ggml_type_size(wtype) + 2u*MB);
struct ggml_init_params params; struct ggml_init_params params;
params.mem_size = cache.buf.size(); params.mem_size = cache.buf.size();
@ -247,12 +252,6 @@ static bool kv_cache_init(
return false; return false;
} }
const int n_embd = hparams.n_embd;
const int n_layer = hparams.n_layer;
const int n_mem = n_layer*n_ctx;
const int n_elements = n_embd*n_mem;
cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
@ -485,6 +484,8 @@ static bool llama_model_load(
// print memory requirements // print memory requirements
{ {
const size_t scale = memory_type == GGML_TYPE_F32 ? 2 : 1;
// this is the total memory required to run the inference // this is the total memory required to run the inference
const size_t mem_required = const size_t mem_required =
ctx_size + ctx_size +
@ -494,7 +495,7 @@ static bool llama_model_load(
// this is the memory required by one llama_state // this is the memory required by one llama_state
const size_t mem_required_state = const size_t mem_required_state =
MEM_REQ_KV_SELF.at (model.type); scale*MEM_REQ_KV_SELF.at(model.type);
fprintf(stderr, "%s: mem required = %7.2f MB (+ %7.2f MB per state)\n", __func__, fprintf(stderr, "%s: mem required = %7.2f MB (+ %7.2f MB per state)\n", __func__,
mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0); mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0);
@ -1634,7 +1635,7 @@ struct llama_context * llama_init_from_file(
// reserve memory for context buffers // reserve memory for context buffers
{ {
if (!kv_cache_init(ctx->model.hparams, MEM_REQ_KV_SELF.at(ctx->model.type), ctx->model.kv_self, memory_type, ctx->model.hparams.n_ctx)) { if (!kv_cache_init(ctx->model.hparams, ctx->model.kv_self, memory_type, ctx->model.hparams.n_ctx)) {
fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__); fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);
llama_free(ctx); llama_free(ctx);
return nullptr; return nullptr;