llama : reuse n_rot from the build context

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-02-24 10:44:59 +02:00
parent 2b9a9bff2b
commit 5f5b1b57ca
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -5129,14 +5129,14 @@ struct llm_build_context {
Qcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
hparams.n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
hparams.n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Kcur, "Kcur", il);
@ -5300,12 +5300,12 @@ struct llm_build_context {
case MODEL_7B:
Qcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
hparams.n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
Kcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
hparams.n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
break;
@ -5428,13 +5428,13 @@ struct llm_build_context {
// using mode = 2 for neox mode
Qcur = ggml_rope_custom(
ctx0, Qcur, inp_pos, hparams.n_rot, rope_type, 0, n_orig_ctx,
ctx0, Qcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_custom(
ctx0, Kcur, inp_pos, hparams.n_rot, rope_type, 0, n_orig_ctx,
ctx0, Kcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Kcur, "Kcur", il);
@ -5661,7 +5661,7 @@ struct llm_build_context {
// RoPE the first n_rot of q/k, pass the other half, and concat.
struct ggml_tensor * qrot = ggml_view_3d(
ctx0, tmpq, hparams.n_rot, n_head, n_tokens,
ctx0, tmpq, n_rot, n_head, n_tokens,
ggml_element_size(tmpq) * n_embd_head,
ggml_element_size(tmpq) * n_embd_head * n_head,
0
@ -5669,7 +5669,7 @@ struct llm_build_context {
cb(qrot, "qrot", il);
struct ggml_tensor * krot = ggml_view_3d(
ctx0, tmpk, hparams.n_rot, n_head, n_tokens,
ctx0, tmpk, n_rot, n_head, n_tokens,
ggml_element_size(tmpk) * n_embd_head,
ggml_element_size(tmpk) * n_embd_head * n_head,
0
@ -5678,29 +5678,29 @@ struct llm_build_context {
// get the second half of tmpq, e.g tmpq[n_rot:, :, :]
struct ggml_tensor * qpass = ggml_view_3d(
ctx0, tmpq, hparams.n_rot, n_head, n_tokens,
ctx0, tmpq, n_rot, n_head, n_tokens,
ggml_element_size(tmpq) * n_embd_head,
ggml_element_size(tmpq) * n_embd_head * n_head,
ggml_element_size(tmpq) * hparams.n_rot
ggml_element_size(tmpq) * n_rot
);
cb(qpass, "qpass", il);
struct ggml_tensor * kpass = ggml_view_3d(
ctx0, tmpk, hparams.n_rot, n_head, n_tokens,
ctx0, tmpk, n_rot, n_head, n_tokens,
ggml_element_size(tmpk) * n_embd_head,
ggml_element_size(tmpk) * n_embd_head * n_head,
ggml_element_size(tmpk) * hparams.n_rot
ggml_element_size(tmpk) * n_rot
);
cb(kpass, "kpass", il);
struct ggml_tensor * qrotated = ggml_rope_custom(
ctx0, qrot, inp_pos, hparams.n_rot, rope_type, 0, n_orig_ctx,
ctx0, qrot, inp_pos, n_rot, rope_type, 0, n_orig_ctx,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
);
cb(qrotated, "qrotated", il);
struct ggml_tensor * krotated = ggml_rope_custom(
ctx0, krot, inp_pos, hparams.n_rot, rope_type, 0, n_orig_ctx,
ctx0, krot, inp_pos, n_rot, rope_type, 0, n_orig_ctx,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
);
cb(krotated, "krotated", il);
@ -5952,14 +5952,14 @@ struct llm_build_context {
Qcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
hparams.n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
hparams.n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Kcur, "Kcur", il);
@ -6284,14 +6284,14 @@ struct llm_build_context {
Qcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
hparams.n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
hparams.n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Kcur, "Kcur", il);
@ -6395,13 +6395,13 @@ struct llm_build_context {
// using mode = 2 for neox mode
Qcur = ggml_rope_custom(
ctx0, Qcur, inp_pos, hparams.n_rot, rope_type, 0, n_orig_ctx,
ctx0, Qcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_custom(
ctx0, Kcur, inp_pos, hparams.n_rot, rope_type, 0, n_orig_ctx,
ctx0, Kcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Kcur, "Kcur", il);
@ -6510,14 +6510,14 @@ struct llm_build_context {
Qcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
hparams.n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
hparams.n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Kcur, "Kcur", il);
@ -6628,7 +6628,7 @@ struct llm_build_context {
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Qcur = ggml_rope_custom(
ctx0, Qcur, inp_pos, hparams.n_rot, rope_type, 0, n_orig_ctx,
ctx0, Qcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
@ -6639,7 +6639,7 @@ struct llm_build_context {
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_custom(
ctx0, Kcur, inp_pos, hparams.n_rot, rope_type, 0, n_orig_ctx,
ctx0, Kcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Kcur, "Kcur", il);
@ -6731,13 +6731,13 @@ struct llm_build_context {
cb(Vcur, "Vcur", il);
Qcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Qcur, hparams.n_rot, n_head, n_tokens), inp_pos,
ctx0, ggml_reshape_3d(ctx0, Qcur, n_rot, n_head, n_tokens), inp_pos,
n_embd_head, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Kcur, hparams.n_rot, n_head_kv, n_tokens), inp_pos,
ctx0, ggml_reshape_3d(ctx0, Kcur, n_rot, n_head_kv, n_tokens), inp_pos,
n_embd_head, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
cb(Kcur, "Kcur", il);
@ -6933,14 +6933,14 @@ struct llm_build_context {
struct ggml_tensor * Qcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens), inp_pos,
hparams.n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
struct ggml_tensor * Kcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), inp_pos,
hparams.n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Kcur, "Kcur", il);
@ -7046,14 +7046,14 @@ struct llm_build_context {
Qcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
hparams.n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
hparams.n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Kcur, "Kcur", il);
@ -7160,14 +7160,14 @@ struct llm_build_context {
Qcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
hparams.n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
hparams.n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Kcur, "Kcur", il);
@ -7287,14 +7287,14 @@ struct llm_build_context {
Qcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
hparams.n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
hparams.n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Kcur, "Kcur", il);