falcon : copy-paste self-attention from LLaMA
This commit is contained in:
parent
af4bbcc873
commit
b34ab74094
1 changed files with 52 additions and 65 deletions
117
llama.cpp
117
llama.cpp
|
@ -2201,10 +2201,7 @@ static struct ggml_cgraph * llm_build_llama(
|
|||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
|
||||
}
|
||||
|
||||
struct ggml_tensor * Q =
|
||||
ggml_permute(ctx0,
|
||||
Qcur,
|
||||
0, 2, 1, 3);
|
||||
struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
|
||||
offload_func_kq(Q);
|
||||
ggml_set_name(Q, "Q");
|
||||
|
||||
|
@ -2381,7 +2378,7 @@ static struct ggml_cgraph * llm_build_falcon(
|
|||
const int64_t n_head = hparams.n_head;
|
||||
const int64_t n_head_kv = hparams.n_head_kv;
|
||||
const int64_t n_embd_head = hparams.n_embd_head();
|
||||
//const int64_t n_embd_gqa = hparams.n_embd_gqa();
|
||||
const int64_t n_embd_gqa = hparams.n_embd_gqa();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
|
||||
|
@ -2441,6 +2438,7 @@ static struct ggml_cgraph * llm_build_falcon(
|
|||
struct ggml_tensor * attn_norm;
|
||||
|
||||
// self-attention
|
||||
// TODO: refactor into common function (shared with LLaMA)
|
||||
{
|
||||
attn_norm = ggml_norm(ctx0, inpL, norm_eps);
|
||||
|
||||
|
@ -2473,115 +2471,104 @@ static struct ggml_cgraph * llm_build_falcon(
|
|||
|
||||
const size_t wsize = ggml_type_size(cur->type);
|
||||
|
||||
struct ggml_tensor * Qcur = ggml_view_3d(
|
||||
struct ggml_tensor * tmpq = ggml_view_3d(
|
||||
ctx0, cur, n_embd_head, n_head, N,
|
||||
wsize * n_embd_head,
|
||||
wsize * n_embd_head * (n_head + 2 * n_head_kv),
|
||||
0);
|
||||
|
||||
struct ggml_tensor * Kcur = ggml_view_3d(
|
||||
struct ggml_tensor * tmpk = ggml_view_3d(
|
||||
ctx0, cur, n_embd_head, n_head_kv, N,
|
||||
wsize * n_embd_head,
|
||||
wsize * n_embd_head * (n_head + 2 * n_head_kv),
|
||||
wsize * n_embd_head * n_head);
|
||||
wsize * n_embd_head * n_head);
|
||||
|
||||
struct ggml_tensor * Vcur = ggml_view_3d(
|
||||
struct ggml_tensor * tmpv = ggml_view_3d(
|
||||
ctx0, cur, n_embd_head, n_head_kv, N,
|
||||
wsize * n_embd_head,
|
||||
wsize * n_embd_head * (n_head + 2 * n_head_kv),
|
||||
wsize * n_embd_head * (n_head + n_head_kv));
|
||||
|
||||
// using mode = 2 for neox mode
|
||||
Qcur = ggml_rope_custom_inplace(ctx0, Qcur, n_past, n_embd_head, 2, 0, freq_base, freq_scale);
|
||||
Kcur = ggml_rope_custom_inplace(ctx0, Kcur, n_past, n_embd_head, 2, 0, freq_base, freq_scale);
|
||||
struct ggml_tensor * Qcur = ggml_rope_custom_inplace(ctx0, tmpq, n_past, n_embd_head, 2, 0, freq_base, freq_scale);
|
||||
struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, tmpk, n_past, n_embd_head, 2, 0, freq_base, freq_scale);
|
||||
|
||||
// store key and value to memory
|
||||
{
|
||||
struct ggml_tensor* k = ggml_view_1d(
|
||||
ctx0, kv_self.k, N * n_head_kv * n_embd_head,
|
||||
(ggml_element_size(kv_self.k) * n_head_kv * n_embd_head) *
|
||||
(il * n_ctx + n_past));
|
||||
struct ggml_tensor* v = ggml_view_1d(
|
||||
ctx0, kv_self.v, N * n_head_kv * n_embd_head,
|
||||
(ggml_element_size(kv_self.v) * n_head_kv * n_embd_head) *
|
||||
(il * n_ctx + n_past));
|
||||
struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_cont(ctx0, tmpv), n_embd_gqa, N));
|
||||
ggml_set_name(Vcur, "Vcur");
|
||||
|
||||
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + n_past));
|
||||
ggml_set_name(k, "k");
|
||||
|
||||
struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd_gqa,
|
||||
( n_ctx)*ggml_element_size(kv_self.v),
|
||||
(il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + n_past*ggml_element_size(kv_self.v));
|
||||
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
|
||||
}
|
||||
|
||||
struct ggml_tensor * K = ggml_permute(
|
||||
ctx0,
|
||||
ggml_reshape_3d(
|
||||
ctx0,
|
||||
ggml_view_1d(ctx0, kv_self.k, (n_past + N) * n_head_kv * n_embd_head,
|
||||
il * n_ctx *
|
||||
ggml_element_size(kv_self.k) *
|
||||
n_head_kv *
|
||||
n_embd_head),
|
||||
n_embd_head, n_head_kv, n_past + N),
|
||||
0, 2, 1, 3);
|
||||
|
||||
// K * Q
|
||||
|
||||
struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
|
||||
ggml_set_name(Q, "Q");
|
||||
|
||||
struct ggml_tensor * K =
|
||||
ggml_view_3d(ctx0, kv_self.k,
|
||||
n_embd_head, n_past + N, n_head_kv,
|
||||
ggml_element_size(kv_self.k)*n_embd_gqa,
|
||||
ggml_element_size(kv_self.k)*n_embd_head,
|
||||
ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il);
|
||||
ggml_set_name(K, "K");
|
||||
|
||||
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
||||
ggml_set_name(KQ, "KQ");
|
||||
|
||||
// KQ_scaled = KQ / sqrt(n_embd/n_head)
|
||||
struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale);
|
||||
ggml_set_name(KQ_scaled, "KQ_scaled");
|
||||
|
||||
// KQ_masked = mask_past(KQ_scaled)
|
||||
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
|
||||
ggml_set_name(KQ_masked, "KQ_masked");
|
||||
|
||||
// KQ = soft_max(KQ_masked)
|
||||
struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
|
||||
ggml_set_name(KQ_soft_max, "KQ_soft_max");
|
||||
|
||||
// V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
|
||||
struct ggml_tensor* V = ggml_permute(
|
||||
ctx0,
|
||||
ggml_reshape_3d(
|
||||
ctx0,
|
||||
ggml_view_1d(ctx0, kv_self.v, (n_past + N) * n_head_kv * n_embd_head,
|
||||
il * n_ctx *
|
||||
ggml_element_size(kv_self.v) *
|
||||
n_head_kv *
|
||||
n_embd_head),
|
||||
n_embd_head, n_head_kv, n_past + N),
|
||||
0, 2, 1, 3);
|
||||
struct ggml_tensor * V =
|
||||
ggml_view_3d(ctx0, kv_self.v,
|
||||
n_past + N, n_embd_head, n_head_kv,
|
||||
ggml_element_size(kv_self.v)*n_ctx,
|
||||
ggml_element_size(kv_self.v)*n_ctx*n_embd_head,
|
||||
ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il);
|
||||
ggml_set_name(V, "V");
|
||||
|
||||
V = ggml_cont(ctx0, ggml_transpose(ctx0, V));
|
||||
|
||||
// KQV = transpose(V) * KQ_soft_max
|
||||
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
|
||||
ggml_set_name(KQV, "KQV");
|
||||
|
||||
// KQV_merged = KQV.permute(0, 2, 1, 3)
|
||||
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
||||
ggml_set_name(KQV_merged, "KQV_merged");
|
||||
|
||||
cur = ggml_cpy(ctx0,
|
||||
KQV_merged,
|
||||
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
|
||||
ggml_set_name(cur, "KQV_merged_contiguous");
|
||||
|
||||
// cur = KQV_merged.contiguous().view(n_embd, N)
|
||||
cur = ggml_cpy(ctx0,
|
||||
KQV_merged,
|
||||
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
|
||||
|
||||
// projection
|
||||
{
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
model.layers[il].wo,
|
||||
cur);
|
||||
}
|
||||
cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur);
|
||||
ggml_set_name(cur, "result_wo");
|
||||
}
|
||||
|
||||
struct ggml_tensor * inpFF = attn_norm;
|
||||
struct ggml_tensor * attn_out = ggml_cpy(
|
||||
ctx0, cur, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
|
||||
|
||||
{
|
||||
cur = ggml_mul_mat(ctx0, model.layers[il].w3, inpFF);
|
||||
cur = ggml_gelu(ctx0, cur);
|
||||
cur = ggml_mul_mat(ctx0, model.layers[il].w2, cur);
|
||||
}
|
||||
cur = ggml_mul_mat(ctx0, model.layers[il].w3, inpFF);
|
||||
cur = ggml_gelu(ctx0, cur);
|
||||
cur = ggml_mul_mat(ctx0, model.layers[il].w2, cur);
|
||||
|
||||
cur = ggml_add(ctx0, cur, attn_out);
|
||||
cur = ggml_add(ctx0, cur, inpL);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue