Removed unnecessary tensor operations.

This commit is contained in:
Stanisław Szymczyk 2024-05-18 11:38:07 +02:00
parent b24c9ed551
commit 039896407a

View file

@ -10939,13 +10939,9 @@ struct llm_build_context {
// split into {n_head * qk_nope_head_dim, n_tokens}
struct ggml_tensor * q_nope = ggml_view_3d(ctx0, q, qk_nope_head_dim, n_head, n_tokens, ggml_element_size(q) * hparams.n_embd_head_k, ggml_element_size(q) * hparams.n_embd_head_k * n_head, 0);
cb(q_nope, "q_nope", il);
// and {n_head * qk_rope_head_dim, n_tokens}
struct ggml_tensor * q_pe = ggml_view_3d(ctx0, q, qk_rope_head_dim, n_head, n_tokens, ggml_element_size(q) * hparams.n_embd_head_k, ggml_element_size(q) * hparams.n_embd_head_k * n_head, ggml_element_size(q) * qk_nope_head_dim);
q_nope = ggml_cont(ctx0, q_nope);
cb(q_nope, "q_nope", il);
q_pe = ggml_cont(ctx0, q_pe);
cb(q_pe, "q_pe", il);
// {n_embd, kv_lora_rank + qk_rope_head_dim} * {n_embd, n_tokens} -> {kv_lora_rank + qk_rope_head_dim, n_tokens}
@ -10954,10 +10950,9 @@ struct llm_build_context {
// split into {kv_lora_rank, n_tokens}
struct ggml_tensor * compressed_kv = ggml_view_2d(ctx0, compressed_kv_pe, kv_lora_rank, n_tokens, compressed_kv_pe->nb[1], 0);
cb(compressed_kv, "compressed_kv", il);
// and {qk_rope_head_dim, n_tokens}
struct ggml_tensor * k_pe = ggml_view_2d(ctx0, compressed_kv_pe, qk_rope_head_dim, n_tokens, compressed_kv_pe->nb[1], ggml_element_size(compressed_kv_pe)*kv_lora_rank);
k_pe = ggml_cont(ctx0, k_pe);
cb(k_pe, "k_pe", il);
compressed_kv = llm_build_norm(ctx0, compressed_kv, hparams,
@ -10971,16 +10966,20 @@ struct llm_build_context {
// split into {n_head * qk_nope_head_dim, n_tokens}
struct ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, qk_nope_head_dim, n_head, n_tokens, ggml_element_size(kv) * (qk_nope_head_dim + hparams.n_embd_head_v), ggml_element_size(kv) * n_head * (qk_nope_head_dim + hparams.n_embd_head_v), 0);
cb(k_nope, "k_nope", il);
// and {n_head * n_embd_head_v, n_tokens}
struct ggml_tensor * value_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens, ggml_element_size(kv) * (qk_nope_head_dim + hparams.n_embd_head_v), ggml_element_size(kv) * n_head * (qk_nope_head_dim + hparams.n_embd_head_v), ggml_element_size(kv) * qk_nope_head_dim);
value_states = ggml_dup(ctx0, value_states);
cb(value_states, "value_states", il);
value_states = ggml_reshape_2d(ctx0, value_states, hparams.n_embd_head_v * n_head, n_tokens);
value_states = ggml_cont(ctx0, value_states);
cb(value_states, "value_states", il);
value_states = ggml_view_2d(ctx0, value_states, hparams.n_embd_head_v * n_head, n_tokens, ggml_element_size(kv) * hparams.n_embd_head_v * n_head, 0);
cb(value_states, "value_states", il);
q_pe = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, q_pe, qk_rope_head_dim, n_head, n_tokens), inp_pos,
ctx0, q_pe, inp_pos,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
@ -10988,7 +10987,7 @@ struct llm_build_context {
// shared RoPE key
k_pe = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, k_pe, qk_rope_head_dim, 1, n_tokens), inp_pos,
ctx0, ggml_view_3d(ctx0, k_pe, qk_rope_head_dim, 1, n_tokens, k_pe->nb[0], k_pe->nb[1], 0), inp_pos,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
@ -11007,20 +11006,6 @@ struct llm_build_context {
key_states = ggml_set_inplace(ctx0, key_states, k_nope, key_states->nb[1], key_states->nb[2], key_states->nb[3], 0);
key_states = ggml_set_inplace(ctx0, key_states, k_pe, key_states->nb[1], key_states->nb[2], key_states->nb[3], ggml_element_size(key_states) * qk_nope_head_dim);
// TODO see if we can avoid these operations by permuting
// rows/columns of some model tensors during model conversion
query_states = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, query_states, 2, hparams.n_embd_head_k / 2, n_head, n_tokens)));
cb(query_states, "query_states", il);
query_states = ggml_reshape_3d(ctx0, query_states, hparams.n_embd_head_k, n_head, n_tokens);
cb(query_states, "query_states", il);
key_states = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, key_states, 2, hparams.n_embd_head_k / 2, n_head, n_tokens)));
cb(key_states, "key_states", il);
key_states = ggml_reshape_3d(ctx0, key_states, hparams.n_embd_head_k, n_head, n_tokens);
cb(key_states, "key_states", il);
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
model.layers[il].wo, model.layers[il].bo,
key_states, value_states, query_states, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(hparams.n_embd_head_k)), cb, il);