llama : simplify llama_build_kv_store
ggml-ci
This commit is contained in:
parent
9ca869876e
commit
74d57f9513
1 changed files with 14 additions and 20 deletions
34
llama.cpp
34
llama.cpp
|
@ -5963,29 +5963,27 @@ static void llm_build_kv_store(
|
|||
(ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa))*kv_head);
|
||||
cb(k_cache_view, "k_cache_view", il);
|
||||
|
||||
// important: storing RoPE-ed version of K in the KV cache!
|
||||
// note: storing RoPE-ed version of K in the KV cache
|
||||
ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view));
|
||||
|
||||
if (cparams.flash_attn) {
|
||||
// NOTE: the V cache is not transposed when using FLASH attention !!
|
||||
struct ggml_tensor * v_cache_view = ggml_view_1d(ctx, kv.v_l[il], n_tokens*n_embd_v_gqa,
|
||||
(ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa))*kv_head);
|
||||
cb(v_cache_view, "v_cache_view", il);
|
||||
|
||||
ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur, v_cache_view));
|
||||
} else {
|
||||
// compute the transposed [n_tokens, n_embd] V matrix
|
||||
//struct ggml_tensor * v_cur_t = ggml_transpose(ctx, ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens));
|
||||
assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens);
|
||||
struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur);
|
||||
cb(v_cur_t, "v_cur_t", il);
|
||||
|
||||
struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v_gqa,
|
||||
struct ggml_tensor * v_cache_view = nullptr;
|
||||
|
||||
if (cparams.flash_attn) {
|
||||
v_cache_view = ggml_view_1d(ctx, kv.v_l[il], n_tokens*n_embd_v_gqa,
|
||||
(kv_head)*ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa));
|
||||
} else {
|
||||
// note: the V cache is transposed when not using flash attention
|
||||
v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v_gqa,
|
||||
( n_ctx)*ggml_element_size(kv.v_l[il]),
|
||||
(kv_head)*ggml_element_size(kv.v_l[il]));
|
||||
|
||||
ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur_t, v_cache_view));
|
||||
v_cur = ggml_transpose(ctx, v_cur);
|
||||
}
|
||||
cb(v_cache_view, "v_cache_view", il);
|
||||
|
||||
ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur, v_cache_view));
|
||||
}
|
||||
|
||||
static struct ggml_tensor * llm_build_norm(
|
||||
|
@ -6169,11 +6167,6 @@ static struct ggml_tensor * llm_build_kqv(
|
|||
if (model.arch == LLM_ARCH_PHI2) {
|
||||
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
|
||||
}
|
||||
//printf("q: %4d %4d %4d %4d\n", q->ne[0], q->ne[1], q->ne[2], q->ne[3]);
|
||||
//printf("k: %4d %4d %4d %4d\n", k->ne[0], k->ne[1], k->ne[2], k->ne[3]);
|
||||
//printf("v: %4d %4d %4d %4d\n", v->ne[0], v->ne[1], v->ne[2], v->ne[3]);
|
||||
//printf("m: %4d %4d %4d %4d\n", kq_mask->ne[0], kq_mask->ne[1], kq_mask->ne[2], kq_mask->ne[3]);
|
||||
//printf("r: %4d %4d %4d %4d\n", kqv->ne[0], kqv->ne[1], kqv->ne[2], kqv->ne[3]);
|
||||
|
||||
cur = ggml_reshape_2d(ctx, cur, n_embd_head_k*n_head, n_tokens);
|
||||
} else {
|
||||
|
@ -14879,6 +14872,7 @@ struct llama_context * llama_new_context_with_model(
|
|||
LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
|
||||
LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
|
||||
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
|
||||
LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn);
|
||||
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
|
||||
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue