Checkpoint

This commit is contained in:
joshcarp 2024-04-29 20:00:44 -04:00
parent 7c3c3eb256
commit 0084a2a8d7

View file

@ -10828,10 +10828,23 @@ struct llm_build_context {
return attn_output, attn_weights, past_key_value
*
*/
// 4 == num groups
int64_t nev[GGML_MAX_DIMS] = {2*Vcur->ne[0], Vcur->ne[1], Vcur->ne[2], Vcur->ne[3]};
struct ggml_tensor * Vcur2 = ggml_new_tensor(ctx0, Vcur->type, GGML_MAX_DIMS, nev);
Vcur2->op = GGML_OP_REPEAT;
Vcur2->grad = ggml_dup_tensor(ctx0, Vcur);
Vcur2 = ggml_reshape_2d(ctx0, Vcur2, modified_hparams.n_embd_k_gqa(), n_tokens);
int64_t nek[GGML_MAX_DIMS] = {2*Kcur->ne[0], Kcur->ne[1], Kcur->ne[2], Kcur->ne[3]};
struct ggml_tensor * Kcur2 = ggml_new_tensor(ctx0, Kcur->type, GGML_MAX_DIMS, nek);
Kcur2->op = GGML_OP_REPEAT;
Kcur2->grad = ggml_dup_tensor(ctx0, Vcur);
Kcur2 = ggml_reshape_2d(ctx0, Vcur2, modified_hparams.n_embd_k_gqa(), n_tokens);
cb(Kcur, "Kcur", il);
cur = llm_build_kv(ctx0, model, modified_hparams, kv_self, gf,
model.layers[il].wo, NULL,
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens/2, kv_head, n_kv, 1.0f, cb, il);
Kcur2, Vcur2, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il);
}
if (il == n_layer - 1) {
// skip computing output for unused tokens