Simpler scratch buffer usage

This commit is contained in:
Georgi Gerganov 2023-03-24 21:41:47 +02:00
parent 9330ff0f35
commit ea60d2193a
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

104
llama.cpp
View file

@ -48,45 +48,17 @@ static const size_t MB = 1024*1024;
// needs modifications in ggml
static const std::map<e_model, size_t> MEM_REQ_SCRATCH0 = {
{ MODEL_7B, 128ull*MB },
{ MODEL_13B, 128ull*MB },
{ MODEL_30B, 128ull*MB },
{ MODEL_65B, 128ull*MB },
{ MODEL_7B, 512ull*MB },
{ MODEL_13B, 512ull*MB },
{ MODEL_30B, 512ull*MB },
{ MODEL_65B, 512ull*MB },
};
static const std::map<e_model, size_t> MEM_REQ_SCRATCH1 = {
{ MODEL_7B, 128ull*MB },
{ MODEL_13B, 128ull*MB },
{ MODEL_30B, 128ull*MB },
{ MODEL_65B, 128ull*MB },
};
static const std::map<e_model, size_t> MEM_REQ_SCRATCH2 = {
{ MODEL_7B, 32ull*MB },
{ MODEL_13B, 32ull*MB },
{ MODEL_30B, 32ull*MB },
{ MODEL_65B, 32ull*MB },
};
static const std::map<e_model, size_t> MEM_REQ_SCRATCH3 = {
{ MODEL_7B, 32ull*MB },
{ MODEL_13B, 32ull*MB },
{ MODEL_30B, 32ull*MB },
{ MODEL_65B, 32ull*MB },
};
static const std::map<e_model, size_t> MEM_REQ_SCRATCH4 = {
{ MODEL_7B, 128ull*MB },
{ MODEL_13B, 128ull*MB },
{ MODEL_30B, 128ull*MB },
{ MODEL_65B, 128ull*MB },
};
static const std::map<e_model, size_t> MEM_REQ_SCRATCH5 = {
{ MODEL_7B, 4ull*MB },
{ MODEL_13B, 4ull*MB },
{ MODEL_30B, 4ull*MB },
{ MODEL_65B, 4ull*MB },
{ MODEL_7B, 512ull*MB },
{ MODEL_13B, 512ull*MB },
{ MODEL_30B, 512ull*MB },
{ MODEL_65B, 512ull*MB },
};
// 2*n_embd*n_ctx*n_layer*sizeof(float16)
@ -98,10 +70,10 @@ static const std::map<e_model, size_t> MEM_REQ_KV_SELF = {
};
static const std::map<e_model, size_t> MEM_REQ_EVAL = {
{ MODEL_7B, 256ull*MB },
{ MODEL_13B, 256ull*MB },
{ MODEL_30B, 256ull*MB },
{ MODEL_65B, 256ull*MB },
{ MODEL_7B, 128ull*MB },
{ MODEL_13B, 128ull*MB },
{ MODEL_30B, 128ull*MB },
{ MODEL_65B, 128ull*MB },
};
// default hparams (LLaMA 7B)
@ -516,10 +488,6 @@ static bool llama_model_load(
ctx_size +
MEM_REQ_SCRATCH0.at(model.type) +
MEM_REQ_SCRATCH1.at(model.type) +
MEM_REQ_SCRATCH2.at(model.type) +
MEM_REQ_SCRATCH3.at(model.type) +
MEM_REQ_SCRATCH4.at(model.type) +
MEM_REQ_SCRATCH5.at(model.type) +
MEM_REQ_EVAL.at (model.type);
// this is the memory required by one llama_state
@ -864,8 +832,6 @@ static bool llama_eval_internal(
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
memcpy(embd->data, tokens, N*ggml_element_size(embd));
lctx.use_buf(ctx0, 3);
struct ggml_tensor * inpL = ggml_get_rows(ctx0, model.tok_embeddings, embd);
for (int il = 0; il < n_layer; ++il) {
@ -873,10 +839,10 @@ static bool llama_eval_internal(
struct ggml_tensor * cur;
lctx.use_buf(ctx0, 0);
// norm
{
lctx.use_buf(ctx0, 0);
cur = ggml_rms_norm(ctx0, inpL);
// cur = attention_norm*cur
@ -887,9 +853,6 @@ static bool llama_eval_internal(
// self-attention
{
// needed due to ggml_rope creating a "parameters" tensor
lctx.use_buf(ctx0, 4);
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
@ -926,8 +889,6 @@ static bool llama_eval_internal(
// K * Q
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
lctx.use_buf(ctx0, 1);
// KQ_scaled = KQ / sqrt(n_embd/n_head)
struct ggml_tensor * KQ_scaled =
ggml_scale(ctx0,
@ -935,16 +896,12 @@ static bool llama_eval_internal(
ggml_new_f32(ctx0, 1.0f/sqrt(float(n_embd)/n_head))
);
lctx.use_buf(ctx0, 0);
// KQ_masked = mask_past(KQ_scaled)
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
// KQ = soft_max(KQ_masked)
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
lctx.use_buf(ctx0, 1);
// V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
struct ggml_tensor * V_trans =
ggml_cpy(ctx0,
@ -958,8 +915,6 @@ static bool llama_eval_internal(
// KQV = transpose(V) * KQ_soft_max
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
lctx.use_buf(ctx0, 0);
// KQV_merged = KQV.permute(0, 2, 1, 3)
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
@ -968,15 +923,13 @@ static bool llama_eval_internal(
KQV_merged,
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
lctx.use_buf(ctx0, 1);
// projection (no bias)
cur = ggml_mul_mat(ctx0,
model.layers[il].wo,
cur);
}
lctx.use_buf(ctx0, 2);
lctx.use_buf(ctx0, 1);
struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA);
@ -984,20 +937,14 @@ static bool llama_eval_internal(
{
// norm
{
lctx.use_buf(ctx0, 0);
cur = ggml_rms_norm(ctx0, inpFF);
lctx.use_buf(ctx0, 1);
// cur = ffn_norm*cur
cur = ggml_mul(ctx0,
ggml_repeat(ctx0, model.layers[il].ffn_norm, cur),
cur);
}
lctx.use_buf(ctx0, 0);
struct ggml_tensor * tmp = ggml_mul_mat(ctx0,
model.layers[il].w3,
cur);
@ -1009,8 +956,6 @@ static bool llama_eval_internal(
// SILU activation
cur = ggml_silu(ctx0, cur);
lctx.use_buf(ctx0, 1);
cur = ggml_mul(ctx0, cur, tmp);
cur = ggml_mul_mat(ctx0,
@ -1018,25 +963,22 @@ static bool llama_eval_internal(
cur);
}
lctx.use_buf(ctx0, 3);
cur = ggml_add(ctx0, cur, inpFF);
// input for next layer
inpL = cur;
}
lctx.use_buf(ctx0, 0);
// used at the end to optionally extract the embeddings
struct ggml_tensor * embeddings = NULL;
// norm
{
lctx.use_buf(ctx0, 0);
inpL = ggml_rms_norm(ctx0, inpL);
lctx.use_buf(ctx0, 1);
// inpL = norm*inpL
inpL = ggml_mul(ctx0,
ggml_repeat(ctx0, model.norm, inpL),
@ -1045,8 +987,6 @@ static bool llama_eval_internal(
embeddings = inpL;
}
lctx.use_buf(ctx0, 0);
// lm_head
inpL = ggml_mul_mat(ctx0, model.output, inpL);
@ -1097,11 +1037,7 @@ static bool llama_eval_internal(
printf("\n%s: used_mem = %.3f MB, scratch -- %.3f MB, %.3f MB %.3f MB %.3f %.3f %.3f MB\n", __func__,
ggml_used_mem(ctx0)/1024.0/1024.0,
lctx.get_buf_max_mem(0)/1024.0/1024.0,
lctx.get_buf_max_mem(1)/1024.0/1024.0,
lctx.get_buf_max_mem(2)/1024.0/1024.0,
lctx.get_buf_max_mem(3)/1024.0/1024.0,
lctx.get_buf_max_mem(4)/1024.0/1024.0,
lctx.get_buf_max_mem(5)/1024.0/1024.0);
lctx.get_buf_max_mem(1)/1024.0/1024.0);
#endif
ggml_free(ctx0);
@ -1722,10 +1658,6 @@ struct llama_context * llama_init_from_file(
ctx->buf_scratch[0].resize(MEM_REQ_SCRATCH0.at(ctx->model.type));
ctx->buf_scratch[1].resize(MEM_REQ_SCRATCH1.at(ctx->model.type));
ctx->buf_scratch[2].resize(MEM_REQ_SCRATCH2.at(ctx->model.type));
ctx->buf_scratch[3].resize(MEM_REQ_SCRATCH3.at(ctx->model.type));
ctx->buf_scratch[4].resize(MEM_REQ_SCRATCH4.at(ctx->model.type));
ctx->buf_scratch[5].resize(MEM_REQ_SCRATCH5.at(ctx->model.type));
}
return ctx;