llama : minor styling

This commit is contained in:
Georgi Gerganov 2024-07-01 18:26:24 +03:00
parent ed5496fb32
commit ce711f6eae
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -2087,6 +2087,7 @@ struct llama_hparams {
uint32_t n_head_kv; uint32_t n_head_kv;
uint32_t n_layer; uint32_t n_layer;
uint32_t n_rot; uint32_t n_rot;
uint32_t n_swa = 0; // sliding window attention (SWA)
uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
uint32_t n_ff; uint32_t n_ff;
@ -2101,7 +2102,6 @@ struct llama_hparams {
uint32_t n_ff_shexp = 0; uint32_t n_ff_shexp = 0;
uint32_t n_expert_shared = 0; uint32_t n_expert_shared = 0;
float expert_weights_scale = 0.0; float expert_weights_scale = 0.0;
uint32_t n_swa = 0; // sliding window attention (SWA)
float f_norm_eps; float f_norm_eps;
float f_norm_rms_eps; float f_norm_rms_eps;
@ -2142,6 +2142,7 @@ struct llama_hparams {
if (this->n_head_kv != other.n_head_kv) return true; if (this->n_head_kv != other.n_head_kv) return true;
if (this->n_layer != other.n_layer) return true; if (this->n_layer != other.n_layer) return true;
if (this->n_rot != other.n_rot) return true; if (this->n_rot != other.n_rot) return true;
if (this->n_swa != other.n_swa) return true;
if (this->n_embd_head_k != other.n_embd_head_k) return true; if (this->n_embd_head_k != other.n_embd_head_k) return true;
if (this->n_embd_head_v != other.n_embd_head_v) return true; if (this->n_embd_head_v != other.n_embd_head_v) return true;
if (this->n_ff != other.n_ff) return true; if (this->n_ff != other.n_ff) return true;
@ -2657,6 +2658,7 @@ struct llama_context {
struct ggml_tensor * inp_pos; // I32 [n_batch] struct ggml_tensor * inp_pos; // I32 [n_batch]
struct ggml_tensor * inp_out_ids; // I32 [n_outputs] struct ggml_tensor * inp_out_ids; // I32 [n_outputs]
struct ggml_tensor * inp_KQ_mask; // F32 [kv_size, n_batch] struct ggml_tensor * inp_KQ_mask; // F32 [kv_size, n_batch]
struct ggml_tensor * inp_KQ_mask_swa; // F32 [kv_size, n_batch]
struct ggml_tensor * inp_K_shift; // I32 [kv_size] struct ggml_tensor * inp_K_shift; // I32 [kv_size]
struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch] struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
struct ggml_tensor * inp_cls; // I32 [n_batch] struct ggml_tensor * inp_cls; // I32 [n_batch]
@ -2664,9 +2666,6 @@ struct llama_context {
struct ggml_tensor * inp_s_mask; // F32 [1, n_kv] struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]
struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch] struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch]
// KQ mask per layer, used by sliding window attention (gemma 2)
struct ggml_tensor * inp_KQ_mask_swa;
// control vectors // control vectors
struct llama_control_vector cvec; struct llama_control_vector cvec;
}; };
@ -5427,6 +5426,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
LLAMA_LOG_INFO("%s: n_head_kv = %u\n", __func__, hparams.n_head_kv); LLAMA_LOG_INFO("%s: n_head_kv = %u\n", __func__, hparams.n_head_kv);
LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer); LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer);
LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot); LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot);
LLAMA_LOG_INFO("%s: n_swa = %u\n", __func__, hparams.n_swa);
LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k); LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k);
LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v); LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v);
LLAMA_LOG_INFO("%s: n_gqa = %u\n", __func__, hparams.n_gqa()); LLAMA_LOG_INFO("%s: n_gqa = %u\n", __func__, hparams.n_gqa());
@ -7788,13 +7788,13 @@ struct llm_build_context {
lctx.inp_pos = nullptr; lctx.inp_pos = nullptr;
lctx.inp_out_ids = nullptr; lctx.inp_out_ids = nullptr;
lctx.inp_KQ_mask = nullptr; lctx.inp_KQ_mask = nullptr;
lctx.inp_KQ_mask_swa = nullptr;
lctx.inp_K_shift = nullptr; lctx.inp_K_shift = nullptr;
lctx.inp_mean = nullptr; lctx.inp_mean = nullptr;
lctx.inp_cls = nullptr; lctx.inp_cls = nullptr;
lctx.inp_s_copy = nullptr; lctx.inp_s_copy = nullptr;
lctx.inp_s_mask = nullptr; lctx.inp_s_mask = nullptr;
lctx.inp_s_seq = nullptr; lctx.inp_s_seq = nullptr;
lctx.inp_KQ_mask_swa = nullptr;
} }
void free() { void free() {
@ -7813,7 +7813,6 @@ struct llm_build_context {
cb(lctx.inp_K_shift, "K_shift", -1); cb(lctx.inp_K_shift, "K_shift", -1);
ggml_set_input(lctx.inp_K_shift); ggml_set_input(lctx.inp_K_shift);
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * rope_factors = build_rope_factors(il); struct ggml_tensor * rope_factors = build_rope_factors(il);
struct ggml_tensor * tmp = struct ggml_tensor * tmp =
@ -7947,18 +7946,26 @@ struct llm_build_context {
return lctx.inp_out_ids; return lctx.inp_out_ids;
} }
struct ggml_tensor * build_inp_KQ_mask(bool causal = true, bool sliding_window = false) { struct ggml_tensor * build_inp_KQ_mask(bool causal = true) {
struct ggml_tensor * KQ_mask = causal lctx.inp_KQ_mask = causal
? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)) ? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
: ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); : ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
cb(KQ_mask, "KQ_mask", -1); cb(lctx.inp_KQ_mask, "KQ_mask", -1);
ggml_set_input(KQ_mask); ggml_set_input(lctx.inp_KQ_mask);
if (sliding_window) {
lctx.inp_KQ_mask_swa = KQ_mask; return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16) : lctx.inp_KQ_mask;
} else {
lctx.inp_KQ_mask = KQ_mask;
} }
return flash_attn ? ggml_cast(ctx0, KQ_mask, GGML_TYPE_F16) : KQ_mask;
struct ggml_tensor * build_inp_KQ_mask_swa(bool causal = true) {
GGML_ASSERT(hparams.n_swa > 0);
lctx.inp_KQ_mask_swa = causal
? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
: ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
cb(lctx.inp_KQ_mask_swa, "KQ_mask_swa", -1);
ggml_set_input(lctx.inp_KQ_mask_swa);
return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask_swa, GGML_TYPE_F16) : lctx.inp_KQ_mask_swa;
} }
struct ggml_tensor * build_inp_mean() { struct ggml_tensor * build_inp_mean() {
@ -11042,12 +11049,12 @@ struct llm_build_context {
// KQ_mask (mask for 1 head, it will be broadcasted to all heads) // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
// gemma 2 requires different mask for layers using sliding window (SWA) // gemma 2 requires different mask for layers using sliding window (SWA)
struct ggml_tensor * KQ_mask_full = build_inp_KQ_mask(true, false); struct ggml_tensor * KQ_mask = build_inp_KQ_mask(true);
struct ggml_tensor * KQ_mask_SWA = build_inp_KQ_mask(true, true); struct ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa(true);
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
// (il % 2) layers use SWA // (il % 2) layers use SWA
struct ggml_tensor * KQ_mask = (il % 2 == 0) ? KQ_mask_SWA : KQ_mask_full; struct ggml_tensor * KQ_mask_l = (il % 2 == 0) ? KQ_mask_swa : KQ_mask;
// norm // norm
cur = llm_build_norm(ctx0, inpL, hparams, cur = llm_build_norm(ctx0, inpL, hparams,
@ -11084,7 +11091,7 @@ struct llm_build_context {
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
model.layers[il].wo, NULL, model.layers[il].wo, NULL,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il); Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f, cb, il);
} }
cur = llm_build_norm(ctx0, cur, hparams, cur = llm_build_norm(ctx0, cur, hparams,