Add comment
This commit is contained in:
parent
5ba2143c3c
commit
7c3c3eb256
1 changed files with 49 additions and 10 deletions
59
llama.cpp
59
llama.cpp
|
@ -10734,7 +10734,7 @@ struct llm_build_context {
|
||||||
const int64_t n_head = n_head_kv+ num_query_heads[il];
|
const int64_t n_head = n_head_kv+ num_query_heads[il];
|
||||||
const int64_t n_kv = (num_kv_heads[il]+num_kv_heads[il])*n_embd_head;
|
const int64_t n_kv = (num_kv_heads[il]+num_kv_heads[il])*n_embd_head;
|
||||||
modified_hparams.n_head = n_head;
|
modified_hparams.n_head = n_head;
|
||||||
modified_hparams.n_head_kv = n_head_kv;
|
modified_hparams.n_head_kv = n_head_kv; // TODO, testing out setting this to total nmber of heads
|
||||||
const int64_t n_embd_gqa = n_embd_head * n_head;
|
const int64_t n_embd_gqa = n_embd_head * n_head;
|
||||||
const int64_t n_embd_k_gqa = modified_hparams.n_embd_k_gqa();
|
const int64_t n_embd_k_gqa = modified_hparams.n_embd_k_gqa();
|
||||||
const int64_t n_embd_v_gqa = modified_hparams.n_embd_v_gqa();
|
const int64_t n_embd_v_gqa = modified_hparams.n_embd_v_gqa();
|
||||||
|
@ -10751,12 +10751,15 @@ struct llm_build_context {
|
||||||
struct ggml_tensor * Kcur = nullptr;
|
struct ggml_tensor * Kcur = nullptr;
|
||||||
struct ggml_tensor * Vcur = nullptr;
|
struct ggml_tensor * Vcur = nullptr;
|
||||||
|
|
||||||
cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, attn_norm_output);
|
cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, attn_norm_output); // model.layers[il].wqkv -> might not be all 3 qkv
|
||||||
cb(cur, "wqkv", il);
|
cb(cur, "wqkv", il);
|
||||||
|
// model.layers[il].wqkv has dimensionality of
|
||||||
Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0 * sizeof(float) * (n_embd)));
|
// [model_dim][(n_head_k+n_head_v+n_head_q)*head_dim]
|
||||||
Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_k_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd)));
|
// In most other impls, this is [model_dim][3*above]
|
||||||
Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_v_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_k_gqa)));
|
// This matches up with the dimensions of the huggingface version
|
||||||
|
Qcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd_head, n_tokens, num_query_heads[il], cur->nb[1], cur->nb[2], 0 * sizeof(float) * (n_embd_head)));
|
||||||
|
Kcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd_head,n_tokens, n_head_k, cur->nb[1], cur->nb[2], 1 * sizeof(float) * (n_embd_head)));
|
||||||
|
Vcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd_head,n_tokens, n_head_k, cur->nb[1], cur->nb[2], 1 * sizeof(float) * (n_embd_head + n_embd_head)));
|
||||||
// Q/K Layernorm
|
// Q/K Layernorm
|
||||||
Qcur = llm_build_norm(ctx0, Qcur, modified_hparams,
|
Qcur = llm_build_norm(ctx0, Qcur, modified_hparams,
|
||||||
model.layers[il].attn_q_norm,
|
model.layers[il].attn_q_norm,
|
||||||
|
@ -10771,9 +10774,10 @@ struct llm_build_context {
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
|
||||||
cb(Vcur, "Vcur", il);
|
cb(Vcur, "Vcur", il);
|
||||||
|
// reshape, Qcur -> [64][12(first layer)][n_tokens]
|
||||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
// reshape, Kcur -> [64][3(first layer)][n_tokens]
|
||||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, num_query_heads[il], n_tokens);
|
||||||
|
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_k, n_tokens);
|
||||||
struct ggml_tensor * KQ_mask = build_inp_KQ_mask2(n_kv);
|
struct ggml_tensor * KQ_mask = build_inp_KQ_mask2(n_kv);
|
||||||
|
|
||||||
Qcur = ggml_rope_custom(
|
Qcur = ggml_rope_custom(
|
||||||
|
@ -10789,10 +10793,45 @@ struct llm_build_context {
|
||||||
ctx0, Kcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx,
|
ctx0, Kcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx,
|
||||||
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// So because our original wo matrix wasn't 3x, the below function fails because there aren't enough elems in it.
|
||||||
|
// Got: [head_dim][n_tokens][n_head_v]
|
||||||
|
// Want: [n_embd_v_gqa(384)][n_tokens]
|
||||||
|
// I guess this means that i need to be able to able to repeat them
|
||||||
|
// Assertion failed: (v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens), function llm_build_kv_store, file llama.cpp, line 6309.
|
||||||
|
// In the python version it does this:
|
||||||
|
/*
|
||||||
|
if self.num_groups != 1:
|
||||||
|
# GQA
|
||||||
|
# [B, k_h, S, h] --> [B, q_h, S, h] // so, k=3 -> q=12
|
||||||
|
keys = keys.repeat_interleave(self.num_groups, dim=1)
|
||||||
|
# [B, v_h, S, h] --> [B, q_h, S, h] // so, v=3 -> q=12
|
||||||
|
values = values.repeat_interleave(self.num_groups, dim=1)
|
||||||
|
|
||||||
|
...
|
||||||
|
|
||||||
|
attn_output = F.scaled_dot_product_attention(
|
||||||
|
queries,
|
||||||
|
keys,
|
||||||
|
values,
|
||||||
|
attn_mask=causal_mask,
|
||||||
|
dropout_p=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
attn_output = attn_output.reshape(
|
||||||
|
batch_size, seq_length, self.num_q_heads * self.head_dim
|
||||||
|
)
|
||||||
|
attn_output = self.out_proj(attn_output)
|
||||||
|
if not output_attentions:
|
||||||
|
attn_weights = None
|
||||||
|
return attn_output, attn_weights, past_key_value
|
||||||
|
*
|
||||||
|
*/
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
cur = llm_build_kv(ctx0, model, modified_hparams, kv_self, gf,
|
cur = llm_build_kv(ctx0, model, modified_hparams, kv_self, gf,
|
||||||
model.layers[il].wo, NULL,
|
model.layers[il].wo, NULL,
|
||||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il);
|
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens/2, kv_head, n_kv, 1.0f, cb, il);
|
||||||
}
|
}
|
||||||
if (il == n_layer - 1) {
|
if (il == n_layer - 1) {
|
||||||
// skip computing output for unused tokens
|
// skip computing output for unused tokens
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue