fix kv_self gradients for training
use ggml_set instead of ggml_cpy to set kv_self cache with properly propagating gradients
This commit is contained in:
parent
47561de7d8
commit
956511b248
1 changed files with 23 additions and 13 deletions
|
@ -330,6 +330,9 @@ struct ggml_tensor * forward(
|
|||
struct ggml_tensor * tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
||||
memcpy(tokens->data, tokens_input->data, N*ggml_element_size(tokens));
|
||||
|
||||
struct ggml_tensor * kc = kv_self.k;
|
||||
struct ggml_tensor * vc = kv_self.v;
|
||||
|
||||
// inpL shape [n_embd,N,1,1]
|
||||
struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens);
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
|
@ -365,20 +368,27 @@ struct ggml_tensor * forward(
|
|||
// compute the transposed [N, n_embd] V matrix
|
||||
// wv shape [n_embd, n_embd, 1, 1]
|
||||
// Vcur shape [n_embd, N, 1, 1]
|
||||
struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wv, cur), n_embd, N));
|
||||
struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wv, cur), n_embd, N)));
|
||||
|
||||
// kv_self.k shape [n_embd * n_ctx * n_layer, 1]
|
||||
// kv_self.v shape [n_embd * n_ctx * n_layer, 1]
|
||||
// k shape [n_embd * N, 1] == kv_self.k[:,n_past:n_past+N,il,0]
|
||||
// v shape [N, n_embd, 1, 1] == kv_self.v[:,n_past:n_past+N,il,0]
|
||||
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past));
|
||||
struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd,
|
||||
( n_ctx)*ggml_element_size(kv_self.v),
|
||||
(il*n_ctx)*ggml_element_size(kv_self.v)*n_embd + n_past*ggml_element_size(kv_self.v));
|
||||
|
||||
/* {
|
||||
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past));
|
||||
struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd,
|
||||
( n_ctx)*ggml_element_size(kv_self.v),
|
||||
(il*n_ctx)*ggml_element_size(kv_self.v)*n_embd + n_past*ggml_element_size(kv_self.v));
|
||||
|
||||
// important: storing RoPE-ed version of K in the KV cache!
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
|
||||
// important: storing RoPE-ed version of K in the KV cache!
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
|
||||
} //*/
|
||||
|
||||
kc = ggml_set_1d(ctx0, kc, ggml_reshape_1d(ctx0, Kcur, n_embd*N), (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past));
|
||||
vc = ggml_set_2d(ctx0, vc, Vcur, ( n_ctx)*ggml_element_size(kv_self.v),
|
||||
(il*n_ctx)*ggml_element_size(kv_self.v)*n_embd + n_past*ggml_element_size(kv_self.v));
|
||||
}
|
||||
|
||||
// Qcur shape [n_embd/n_head, n_head, N, 1]
|
||||
|
@ -393,7 +403,7 @@ struct ggml_tensor * forward(
|
|||
struct ggml_tensor * K =
|
||||
ggml_permute(ctx0,
|
||||
ggml_reshape_3d(ctx0,
|
||||
ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kv_self.k)*n_embd),
|
||||
ggml_view_1d(ctx0, kc, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kc)*n_embd),
|
||||
n_embd/n_head, n_head, n_past + N),
|
||||
0, 2, 1, 3);
|
||||
|
||||
|
@ -420,11 +430,11 @@ struct ggml_tensor * forward(
|
|||
//// V shape [n_past + N, n_embd/n_head, n_head, 1]
|
||||
// V shape [n_past + N, n_embd/n_head, n_head, 1] == kv_self.v[:,:(n_past+N),il,1]
|
||||
struct ggml_tensor * V =
|
||||
ggml_view_3d(ctx0, kv_self.v,
|
||||
ggml_view_3d(ctx0, vc,
|
||||
n_past + N, n_embd/n_head, n_head,
|
||||
n_ctx*ggml_element_size(kv_self.v),
|
||||
n_ctx*ggml_element_size(kv_self.v)*n_embd/n_head,
|
||||
il*n_ctx*ggml_element_size(kv_self.v)*n_embd);
|
||||
n_ctx*ggml_element_size(vc),
|
||||
n_ctx*ggml_element_size(vc)*n_embd/n_head,
|
||||
il*n_ctx*ggml_element_size(vc)*n_embd);
|
||||
|
||||
// KQV shape [n_embd/n_head, N, n_head, 1]
|
||||
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue