From 0ec53cfff77f01be078e77e682716e04ba4d610f Mon Sep 17 00:00:00 2001 From: Ashish <1856117+ashishdatta@users.noreply.github.com> Date: Sat, 13 Apr 2024 21:12:30 -0700 Subject: [PATCH] Converge StableLM and StableLM2 code to simplify graph construction --- convert-hf-to-gguf.py | 14 --- gguf-py/gguf/constants.py | 16 --- llama.cpp | 251 +++++++++----------------------------- 3 files changed, 56 insertions(+), 225 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 07a8a8d3b..42cffab77 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -1186,20 +1186,6 @@ class PersimmonModel(Model): @Model.register("StableLmForCausalLM", "StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM") class StableLMModel(Model): model_arch = gguf.MODEL_ARCH.STABLELM - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.model_arch = ( - gguf.MODEL_ARCH.STABLELM - if self.hparams["num_hidden_layers"] < 40 - else gguf.MODEL_ARCH.STABLELM2 - ) - self.gguf_writer = gguf.GGUFWriter( - self.fname_out, - gguf.MODEL_ARCH_NAMES[self.model_arch], - endianess=self.endianess, - use_temp_file=False, - ) - def set_vocab(self): if (self.dir_model / "tokenizer.json").is_file(): self._set_vocab_gpt2() diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index bff325c69..9da13dfc9 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -113,7 +113,6 @@ class MODEL_ARCH(IntEnum): NOMIC_BERT = auto() BLOOM = auto() STABLELM = auto() - STABLELM2 = auto() QWEN = auto() QWEN2 = auto() PHI2 = auto() @@ -184,7 +183,6 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.NOMIC_BERT: "nomic-bert", MODEL_ARCH.BLOOM: "bloom", MODEL_ARCH.STABLELM: "stablelm", - MODEL_ARCH.STABLELM2: "stablelm2", MODEL_ARCH.QWEN: "qwen", MODEL_ARCH.QWEN2: "qwen2", MODEL_ARCH.PHI2: "phi2", @@ -441,20 +439,6 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_GATE, MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, - ], - MODEL_ARCH.STABLELM2: [ - MODEL_TENSOR.TOKEN_EMBD, - MODEL_TENSOR.OUTPUT_NORM, - MODEL_TENSOR.OUTPUT, - MODEL_TENSOR.ROPE_FREQS, - MODEL_TENSOR.ATTN_NORM, - MODEL_TENSOR.ATTN_Q, - MODEL_TENSOR.ATTN_K, - MODEL_TENSOR.ATTN_V, - MODEL_TENSOR.ATTN_OUT, - MODEL_TENSOR.FFN_GATE, - MODEL_TENSOR.FFN_DOWN, - MODEL_TENSOR.FFN_UP, MODEL_TENSOR.ATTN_Q_NORM, MODEL_TENSOR.ATTN_K_NORM, ], diff --git a/llama.cpp b/llama.cpp index e5a85ad98..25c2b6bd6 100644 --- a/llama.cpp +++ b/llama.cpp @@ -207,7 +207,6 @@ enum llm_arch { LLM_ARCH_NOMIC_BERT, LLM_ARCH_BLOOM, LLM_ARCH_STABLELM, - LLM_ARCH_STABLELM2, LLM_ARCH_QWEN, LLM_ARCH_QWEN2, LLM_ARCH_PHI2, @@ -241,7 +240,6 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_NOMIC_BERT, "nomic-bert" }, { LLM_ARCH_BLOOM, "bloom" }, { LLM_ARCH_STABLELM, "stablelm" }, - { LLM_ARCH_STABLELM2, "stablelm2" }, { LLM_ARCH_QWEN, "qwen" }, { LLM_ARCH_QWEN2, "qwen2" }, { LLM_ARCH_PHI2, "phi2" }, @@ -704,23 +702,6 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - }, - }, - { - LLM_ARCH_STABLELM2, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, }, @@ -3876,14 +3857,6 @@ static void llm_load_hparams( switch (hparams.n_layer) { case 24: model.type = e_model::MODEL_1B; break; case 32: model.type = e_model::MODEL_3B; break; - default: model.type = e_model::MODEL_UNKNOWN; - } - } break; - case LLM_ARCH_STABLELM2: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - - switch (hparams.n_layer) { case 40: model.type = e_model::MODEL_12B; break; default: model.type = e_model::MODEL_UNKNOWN; } @@ -5079,42 +5052,16 @@ static bool llm_load_tensors( layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, false); layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, false); - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); - layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}); + if (n_layer >= 40) { + layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM,"weight", i), {hparams.n_embd_head_k, hparams.n_head}); + layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM,"weight", i), {hparams.n_embd_head_k, hparams.n_head_kv}); - layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); - } - } break; - case LLM_ARCH_STABLELM2: - { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); - - // output - { - model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); - } - - for (int i = 0; i < n_layer; ++i) { - ggml_context * ctx_layer = ctx_for_layer(i); - ggml_context * ctx_split = ctx_for_layer_split(i); - - auto & layer = model.layers[i]; - - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); - - layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); - layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); - layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); - - layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM,"weight", i), {hparams.n_embd_head_k, hparams.n_head}); - layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM,"weight", i), {hparams.n_embd_head_k, hparams.n_head_kv}); + } + if (n_layer < 40) { + layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}); + } layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); @@ -8121,6 +8068,8 @@ struct llm_build_context { struct ggml_tensor * cur; struct ggml_tensor * inpL; + struct ggml_tensor * ffn_inp; + struct ggml_tensor * attn_out = cur; inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); @@ -8139,6 +8088,9 @@ struct llm_build_context { model.layers[il].attn_norm_b, LLM_NORM, cb, il); cb(cur, "attn_norm", il); + if(n_layer >= 40) { + ffn_inp = cur; + } // self-attention { @@ -8163,127 +8115,6 @@ struct llm_build_context { Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); cb(Vcur, "Vcur", il); } - - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); - cb(Qcur, "Qcur", il); - - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); - cb(Kcur, "Kcur", il); - - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, - model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); - } - - if (il == n_layer - 1) { - // skip computing output for unused tokens - struct ggml_tensor * inp_out_ids = build_inp_out_ids(); - cur = ggml_get_rows(ctx0, cur, inp_out_ids); - inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); - } - - struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); - cb(ffn_inp, "ffn_inp", il); - - // feed-forward network - { - cur = llm_build_norm(ctx0, ffn_inp, hparams, - model.layers[il].ffn_norm, - model.layers[il].ffn_norm_b, - LLM_NORM, cb, il); - cb(cur, "ffn_norm", il); - - cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, NULL, - model.layers[il].ffn_gate, NULL, - model.layers[il].ffn_down, NULL, - NULL, - LLM_FFN_SILU, LLM_FFN_PAR, cb, il); - cb(cur, "ffn_out", il); - } - - cur = ggml_add(ctx0, cur, ffn_inp); - cb(cur, "l_out", il); - - // input for next layer - inpL = cur; - } - - cur = inpL; - - cur = llm_build_norm(ctx0, cur, hparams, - model.output_norm, - model.output_norm_b, - LLM_NORM, cb, -1); - cb(cur, "result_norm", -1); - - // lm_head - cur = ggml_mul_mat(ctx0, model.output, cur); - cb(cur, "result_output", -1); - - ggml_build_forward_expand(gf, cur); - - return gf; - } - - struct ggml_cgraph * build_stablelm2() { - struct ggml_cgraph * gf = ggml_new_graph(ctx0); - - const int64_t n_embd_head = hparams.n_embd_head_v; - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - - struct ggml_tensor * cur; - struct ggml_tensor * inpL; - - inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); - - // inp_pos - contains the positions - struct ggml_tensor * inp_pos = build_inp_pos(); - - // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); - - for (int il = 0; il < n_layer; ++il) { - // norm - cur = llm_build_norm(ctx0, inpL, hparams, - model.layers[il].attn_norm, - model.layers[il].attn_norm_b, - LLM_NORM, cb, il); - cb(cur, "attn_norm", il); - struct ggml_tensor * ffn_inp = cur; - - // self-attention - { - // compute Q and K and RoPE them - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - - struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - - struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - if (model.layers[il].attn_q_norm) { Qcur = ggml_view_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens, ggml_element_size(Qcur) * n_embd_head, @@ -8317,6 +8148,7 @@ struct llm_build_context { ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); + Kcur = ggml_rope_custom( ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, @@ -8329,31 +8161,64 @@ struct llm_build_context { Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } - if (il == n_layer - 1) { + if (il == n_layer - 1 && n_layer < 40) { // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + else if (il == n_layer - 1 && n_layer >= 40){ struct ggml_tensor * inp_out_ids = build_inp_out_ids(); cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids); } - struct ggml_tensor * attn_out = cur; + if (n_layer < 40) { + ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + } + else { + attn_out = cur; + } // feed-forward network { - cur = llm_build_ffn(ctx0, ffn_inp, + if (n_layer < 40) { + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, + model.layers[il].ffn_norm_b, + LLM_NORM, cb, il); + cb(cur, "ffn_norm", il); + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, NULL, + model.layers[il].ffn_gate, NULL, + model.layers[il].ffn_down, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + } + else { + cur = llm_build_ffn(ctx0, ffn_inp, model.layers[il].ffn_up, NULL, model.layers[il].ffn_gate, NULL, model.layers[il].ffn_down, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, cb, il); - cb(cur, "ffn_out", il); + cb(cur, "ffn_out", il); + } } - // add together residual + FFN + self-attention - cur = ggml_add(ctx0, cur, inpL); - cur = ggml_add(ctx0, cur, attn_out); - cb(cur, "l_out", il); + if (n_layer < 40) { + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "l_out", il); + } + else { + // add together residual + FFN + self-attention + cur = ggml_add(ctx0, cur, inpL); + cur = ggml_add(ctx0, cur, attn_out); + cb(cur, "l_out", il); + } // input for next layer inpL = cur; @@ -8375,6 +8240,7 @@ struct llm_build_context { return gf; } + struct ggml_cgraph * build_qwen() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); @@ -10078,10 +9944,6 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_stablelm(); } break; - case LLM_ARCH_STABLELM2: - { - result = llm.build_stablelm2(); - } break; case LLM_ARCH_QWEN: { result = llm.build_qwen(); @@ -15000,7 +14862,6 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_BERT: case LLM_ARCH_NOMIC_BERT: case LLM_ARCH_STABLELM: - case LLM_ARCH_STABLELM2: case LLM_ARCH_QWEN: case LLM_ARCH_QWEN2: case LLM_ARCH_PHI2: