reuse llm_build_kv

This commit is contained in:
Eddie-Wang1120 2024-06-21 16:12:48 +08:00
parent c6ddfa7e37
commit 55a57a5063

View file

@ -6701,9 +6701,6 @@ static bool llm_load_tensors(
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
} }
const uint32_t n_ff = hparams.n_ff;
model.layers.resize(n_layer);
for (int i = 0; i < n_layer; ++i) { for (int i = 0; i < n_layer; ++i) {
ggml_context * ctx_layer = ctx_for_layer(i); ggml_context * ctx_layer = ctx_for_layer(i);
ggml_context * ctx_split = ctx_for_layer_split(i); ggml_context * ctx_split = ctx_for_layer_split(i);
@ -6714,23 +6711,23 @@ static bool llm_load_tensors(
layer.attn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_SUB_NORM, "weight", i), {n_embd}); layer.attn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_SUB_NORM, "weight", i), {n_embd});
layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd});
layer.wq_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "scale", i), {1}); layer.wq_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "scale", i), {1});
layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa});
layer.wk_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "scale", i), {1}); layer.wk_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "scale", i), {1});
layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
layer.wv_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "scale", i), {1}); layer.wv_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "scale", i), {1});
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
layer.wo_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "scale", i), {1}); layer.wo_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "scale", i), {1});
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
layer.ffn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_SUB_NORM, "weight", i), {n_ff}); layer.ffn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_SUB_NORM, "weight", i), {n_ff});
layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
layer.ffn_gate_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "scale", i), {1}); layer.ffn_gate_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE, "scale", i), {1});
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
layer.ffn_down_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "scale", i), {1}); layer.ffn_down_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "scale", i), {1});
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
layer.ffn_up_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "scale", i), {1}); layer.ffn_up_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "scale", i), {1});
} }
} break; } break;
default: default:
@ -7373,7 +7370,10 @@ static struct ggml_tensor * llm_build_kqv(
ggml_build_forward_expand(graph, cur); ggml_build_forward_expand(graph, cur);
if (wo) {
cur = ggml_mul_mat(ctx, wo, cur); cur = ggml_mul_mat(ctx, wo, cur);
}
if (wo_b) { if (wo_b) {
cb(cur, "kqv_wo", il); cb(cur, "kqv_wo", il);
} }
@ -11835,82 +11835,18 @@ struct llm_build_context {
); );
cb(Kcur, "Kcur", il); cb(Kcur, "Kcur", il);
llm_build_kv_store(ctx0, hparams, cparams, kv_self, gf, Kcur, Vcur, n_tokens, kv_head, cb, il); cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
nullptr, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
const int64_t n_ctx = cparams.n_ctx; cur = llm_build_norm(ctx0, cur, hparams,
const int64_t n_head = hparams.n_head;
const int64_t n_head_kv = hparams.n_head_kv;
const int64_t n_embd_head_k = hparams.n_embd_head_k;
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
const int64_t n_embd_head_v = hparams.n_embd_head_v;
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
float kq_scale = 1.0f/sqrtf(float(n_embd_head));
struct ggml_tensor * cur_attn;
struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
cb(q, "q", il);
struct ggml_tensor * k =
ggml_view_3d(ctx0, kv_self.k_l[il],
n_embd_head_k, n_kv, n_head_kv,
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k),
0);
cb(k, "k", il);
if (cparams.flash_attn) {
// split cached v into n_head heads (not transposed)
struct ggml_tensor * v =
ggml_view_3d(ctx0, kv_self.v_l[il],
n_embd_head_v, n_kv, n_head_kv,
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
ggml_row_size(kv_self.v_l[il]->type, n_embd_head_v),
0);
cb(v, "v", il);
cur_attn = ggml_flash_attn_ext(ctx0, q, k, v, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
cur_attn = ggml_reshape_2d(ctx0, cur, n_embd_head_v*n_head, n_tokens);
} else {
struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
cb(kq, "kq", il);
kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
cb(kq, "kq_soft_max_ext", il);
GGML_ASSERT(kv_self.size == n_ctx);
// split cached v into n_head heads
struct ggml_tensor * v =
ggml_view_3d(ctx0, kv_self.v_l[il],
n_kv, n_embd_head_v, n_head_kv,
ggml_element_size(kv_self.v_l[il])*n_ctx,
ggml_element_size(kv_self.v_l[il])*n_ctx*n_embd_head_v,
0);
cb(v, "v", il);
struct ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
cb(kqv, "kqv", il);
struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
cb(kqv_merged, "kqv_merged", il);
cur_attn = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens);
cb(cur_attn, "kqv_merged_cont", il);
}
cur_attn = llm_build_norm(ctx0, cur_attn, hparams,
model.layers[il].attn_sub_norm, NULL, model.layers[il].attn_sub_norm, NULL,
LLM_NORM_RMS, cb, il); LLM_NORM_RMS, cb, il);
cb(cur_attn, "attn_sub_norm", il); cb(cur, "attn_sub_norm", il);
ggml_build_forward_expand(gf, cur_attn); cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur);
cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur_attn);
cur = ggml_mul(ctx0, cur, model.layers[il].wo_scale); cur = ggml_mul(ctx0, cur, model.layers[il].wo_scale);
cb(cur, "attn_o_out", il);
cb(cur, "kqv_out", il);
} }
if (il == n_layer - 1) { if (il == n_layer - 1) {