From 0084a2a8d7661da50daabad1bb7172a126ea887e Mon Sep 17 00:00:00 2001 From: joshcarp Date: Mon, 29 Apr 2024 20:00:44 -0400 Subject: [PATCH] Checkpoint --- llama.cpp | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/llama.cpp b/llama.cpp index f27d21808..93618bf85 100644 --- a/llama.cpp +++ b/llama.cpp @@ -10828,10 +10828,23 @@ struct llm_build_context { return attn_output, attn_weights, past_key_value * */ + // 4 == num groups + int64_t nev[GGML_MAX_DIMS] = {2*Vcur->ne[0], Vcur->ne[1], Vcur->ne[2], Vcur->ne[3]}; + struct ggml_tensor * Vcur2 = ggml_new_tensor(ctx0, Vcur->type, GGML_MAX_DIMS, nev); + Vcur2->op = GGML_OP_REPEAT; + Vcur2->grad = ggml_dup_tensor(ctx0, Vcur); + Vcur2 = ggml_reshape_2d(ctx0, Vcur2, modified_hparams.n_embd_k_gqa(), n_tokens); + + int64_t nek[GGML_MAX_DIMS] = {2*Kcur->ne[0], Kcur->ne[1], Kcur->ne[2], Kcur->ne[3]}; + struct ggml_tensor * Kcur2 = ggml_new_tensor(ctx0, Kcur->type, GGML_MAX_DIMS, nek); + Kcur2->op = GGML_OP_REPEAT; + Kcur2->grad = ggml_dup_tensor(ctx0, Vcur); + Kcur2 = ggml_reshape_2d(ctx0, Vcur2, modified_hparams.n_embd_k_gqa(), n_tokens); + cb(Kcur, "Kcur", il); cur = llm_build_kv(ctx0, model, modified_hparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens/2, kv_head, n_kv, 1.0f, cb, il); + Kcur2, Vcur2, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il); } if (il == n_layer - 1) { // skip computing output for unused tokens