diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index d1b8cef11..6b9d54951 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -56,7 +56,7 @@ class Model(ABC): self.part_names = self._get_part_names() self.hparams = Model.load_hparams(self.dir_model) self.gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file) - self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer"]) + self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_transformer_layers"]) @property @abstractmethod @@ -2903,6 +2903,105 @@ class OlmoModel(Model): self.gguf_writer.add_tensor(new_name, data) +@Model.register("OpenELMForCausalLM") +class OpenELM(Model): + model_arch = gguf.MODEL_ARCH.OPENELM + def set_gguf_parameters(self): + self.gguf_writer.add_name("OpenElm") + self.block_count = self.find_hparam(["num_transformer_layers"]) + self.gguf_writer.add_layer_norm_eps(1e-5) + self.gguf_writer.add_layer_norm_rms_eps(1e-6) # https://github.com/apple/corenet/blob/0333b1fbb29c31809663c4e6de2654b9ff2d27de/mlx_examples/open_elm/open_elm.py#L20 + n_embd = self.find_hparam(["model_dim"]) + self.gguf_writer.add_embedding_length(n_embd) + head_dim = self.find_hparam(["head_dim"]) + n_head = n_embd // head_dim + rot_pct = 1.0 + self.gguf_writer.add_context_length(self.find_hparam(["max_context_length"])) + self.gguf_writer.add_embedding_length(n_embd) + self.gguf_writer.add_block_count(self.block_count) + self.gguf_writer.add_head_count(n_head) + self.gguf_writer.add_head_count_kv(n_head) + self.gguf_writer.add_rope_dimension_count(int(rot_pct * n_embd) // n_head) + self.gguf_writer.add_file_type(self.ftype) + + def set_vocab(self): + from sentencepiece import SentencePieceProcessor + tokenizer_path = self.dir_model / 'tokenizer.model' + if not tokenizer_path.is_file(): + print(f'Error: Missing {tokenizer_path}', file=sys.stderr) + sys.exit(1) + tokenizer = SentencePieceProcessor(str(tokenizer_path)) + vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size()) + tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)] + scores: list[float] = [-10000.0] * vocab_size + toktypes: list[int] = [SentencePieceTokenTypes.UNKNOWN] * vocab_size + for token_id in range(tokenizer.vocab_size()): + piece = tokenizer.id_to_piece(token_id) + text = piece.encode("utf-8") + score = tokenizer.get_score(token_id) + toktype = SentencePieceTokenTypes.NORMAL + if tokenizer.is_unknown(token_id): + toktype = SentencePieceTokenTypes.UNKNOWN + elif tokenizer.is_control(token_id): + toktype = SentencePieceTokenTypes.CONTROL + elif tokenizer.is_unused(token_id): + toktype = SentencePieceTokenTypes.UNUSED + elif tokenizer.is_byte(token_id): + toktype = SentencePieceTokenTypes.BYTE + tokens[token_id] = text + scores[token_id] = score + toktypes[token_id] = toktype + added_tokens_file = self.dir_model / 'added_tokens.json' + if added_tokens_file.is_file(): + with open(added_tokens_file, "r", encoding="utf-8") as f: + added_tokens_json = json.load(f) + for key in added_tokens_json: + token_id = added_tokens_json[key] + if (token_id >= vocab_size): + print(f'ignore token {token_id}: id is out of range, max={vocab_size - 1}') + continue + tokens[token_id] = key.encode("utf-8") + scores[token_id] = -1000.0 + toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED + self.gguf_writer.add_tokenizer_model("llama") + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_scores(scores) + self.gguf_writer.add_token_types(toktypes) + special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens)) + special_vocab.add_to_gguf(self.gguf_writer) + + # Same as super class, but permuting q_proj, k_proj + # Copied from: LlamaModel + def write_tensors(self): + block_count = self.hparams.get("num_transformer_layers", self.hparams.get("num_hidden_layers", self.hparams.get("num_transformer_layers"))) + tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count) + n_head = self.hparams.get("model_dim") // self.hparams.get("head_dim") # TODO: propagate this + for name, data_torch in self.get_tensors(): + old_dtype = data_torch.dtype + # convert any unsupported data types to float32 + if data_torch.dtype not in (torch.float16, torch.float32): + data_torch = data_torch.to(torch.float32) + data = data_torch.numpy() + data = data.squeeze() + new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias")) + if new_name is None: + print(f"Can not map tensor {name!r}") + sys.exit() + new_name += ".weight" + n_dims = len(data.shape) + data_dtype = data.dtype + # if f32 desired, convert any float16 to float32 + if self.ftype == 0 and data_dtype == np.float16: + data = data.astype(np.float32) + # 1d tensors need to be converted to float32 + if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1: + data = data.astype(np.float32) + # if f16 desired, convert any float32 2-dim weight tensors to float16 + if self.ftype == 1 and data_dtype == np.float32 and n_dims == 2: + data = data.astype(np.float16) + print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}") + self.gguf_writer.add_tensor(new_name, data) + ###### CONVERSION LOGIC ###### diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 6d597bfd9..b6442bdb1 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -138,6 +138,7 @@ class MODEL_ARCH(IntEnum): COMMAND_R = auto() DBRX = auto() OLMO = auto() + OPENELM = auto() class MODEL_TENSOR(IntEnum): @@ -215,6 +216,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.COMMAND_R: "command-r", MODEL_ARCH.DBRX: "dbrx", MODEL_ARCH.OLMO: "olmo", + MODEL_ARCH.OPENELM: "openelm", } TENSOR_NAMES: dict[MODEL_TENSOR, str] = { @@ -725,6 +727,18 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, ], + MODEL_ARCH.OPENELM: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_NORM, + ], # TODO } diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index e5750d419..080b9c7ef 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -24,6 +24,7 @@ class TensorNameMap: "backbone.embedding", # mamba "backbone.embeddings", # mamba-hf "transformer.in_out_embed", # Grok + "transformer.token_embeddings.weight", # openelm ), # Token type embeddings @@ -36,6 +37,7 @@ class TensorNameMap: "word_embeddings_layernorm", # bloom "embeddings.LayerNorm", # bert "emb_ln", # nomic-bert + "transformer.norm.weight", # openelm ), # Position embeddings @@ -68,6 +70,7 @@ class TensorNameMap: "model.norm_f", # mamba-qbert "backbone.norm_f", # mamba "transformer.rms_norm", # Grok + "transformer.norm.weight" # openelm ), # Rope frequencies @@ -97,6 +100,7 @@ class TensorNameMap: "backbone.layers.{bid}.norm", # mamba "transformer.decoder_layer.{bid}.rms_norm", # Grok "transformer.blocks.{bid}.norm_attn_norm.norm_1", # dbrx + "transformer.layers.{bid}.attn_norm.weight" # openelm ), # Attention norm 2 @@ -117,7 +121,8 @@ class TensorNameMap: "h.{bid}.attn.c_attn", # gpt2 "transformer.h.{bid}.mixer.Wqkv", # phi2 "encoder.layers.{bid}.attn.Wqkv", # nomic-bert - "model.layers.{bid}.self_attn.qkv_proj" # phi3 + "model.layers.{bid}.self_attn.qkv_proj", # phi3 + "transformer.layers.{bid}.attn.qkv_proj.weight" # openelm ), # Attention query @@ -128,7 +133,7 @@ class TensorNameMap: "transformer.h.{bid}.attn.q_proj", # gpt-j "model.layers.layers.{bid}.self_attn.q_proj", # plamo "model.layers.{bid}.attention.wq", # internlm2 - "transformer.decoder_layer.{bid}.multi_head_attention.query" # Grok + "transformer.decoder_layer.{bid}.multi_head_attention.query",# Grok ), # Attention key @@ -139,7 +144,8 @@ class TensorNameMap: "transformer.h.{bid}.attn.k_proj", # gpt-j "model.layers.layers.{bid}.self_attn.k_proj", # plamo "model.layers.{bid}.attention.wk", # internlm2 - "transformer.decoder_layer.{bid}.multi_head_attention.key" # Grok + "transformer.decoder_layer.{bid}.multi_head_attention.key", # Grok + "transformer.layers.{bid}.attn.k_norm.weight" # openelm ), # Attention value @@ -150,7 +156,8 @@ class TensorNameMap: "transformer.h.{bid}.attn.v_proj", # gpt-j "model.layers.layers.{bid}.self_attn.v_proj", # plamo "model.layers.{bid}.attention.wv", # internlm2 - "transformer.decoder_layer.{bid}.multi_head_attention.value" # Grok + "transformer.decoder_layer.{bid}.multi_head_attention.value",# Grok + ), # Attention output @@ -173,6 +180,7 @@ class TensorNameMap: "encoder.layers.{bid}.attn.out_proj", # nomic-bert "transformer.decoder_layer.{bid}.multi_head_attention.linear", # Grok "transformer.blocks.{bid}.norm_attn_norm.attn.out_proj", # dbrx + "transformer.layers.{bid}.attn.out_proj.weight" # openelm ), # Attention output norm @@ -204,6 +212,7 @@ class TensorNameMap: "h.{bid}.ln_2", # gpt2 "model.layers.{bid}.ffn_norm", # internlm2 "transformer.decoder_layer.{bid}.rms_norm_2", # Grok + "transformer.layers.{bid}.ffn_norm.weight", # openelm ), MODEL_TENSOR.FFN_GATE_INP: ( @@ -240,6 +249,7 @@ class TensorNameMap: "model.layers.{bid}.feed_forward.w3", # internlm2 "encoder.layers.{bid}.mlp.fc11", # nomic-bert "model.layers.{bid}.mlp.c_fc", # starcoder2 + "transformer.layers.{bid}.ffn.proj_1.weight" # openelm ), MODEL_TENSOR.FFN_UP_EXP: ( @@ -299,6 +309,7 @@ class TensorNameMap: "model.layers.{bid}.feed_forward.w2", # internlm2 "encoder.layers.{bid}.mlp.fc2", # nomic-bert "model.layers.{bid}.mlp.c_proj", # starcoder2 + "transformer.layers.{bid}.ffn.proj_2.weight" # openelm ), MODEL_TENSOR.FFN_DOWN_EXP: ( @@ -317,6 +328,7 @@ class TensorNameMap: "model.layers.{bid}.self_attn.q_layernorm", # persimmon "model.layers.{bid}.self_attn.q_norm", # cohere "transformer.blocks.{bid}.attn.q_ln", # sea-lion + "transformer.layers.{bid}.attn.q_norm.weight" # openelm ), MODEL_TENSOR.ATTN_K_NORM: ( @@ -324,6 +336,7 @@ class TensorNameMap: "model.layers.{bid}.self_attn.k_layernorm", # persimmon "model.layers.{bid}.self_attn.k_norm", # cohere "transformer.blocks.{bid}.attn.k_ln", # sea-lion + "transformer.layers.{bid}.attn.k_norm.weight" ), MODEL_TENSOR.ROPE_FREQS: ( diff --git a/llama.cpp b/llama.cpp index 72c10ffc2..0d7b921b5 100644 --- a/llama.cpp +++ b/llama.cpp @@ -225,6 +225,7 @@ enum llm_arch { LLM_ARCH_COMMAND_R, LLM_ARCH_DBRX, LLM_ARCH_OLMO, + LLM_ARCH_OPENELM, LLM_ARCH_UNKNOWN, }; @@ -261,6 +262,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_COMMAND_R, "command-r" }, { LLM_ARCH_DBRX, "dbrx" }, { LLM_ARCH_OLMO, "olmo" }, + { LLM_ARCH_OPENELM, "openelm" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -1028,6 +1030,21 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_OPENELM, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + }, + }, { LLM_ARCH_UNKNOWN, { @@ -1775,8 +1792,10 @@ enum e_model { MODEL_22M, MODEL_33M, MODEL_109M, + MODEL_270M, MODEL_137M, MODEL_335M, + MODEL_450M, MODEL_0_5B, MODEL_1B, MODEL_2B, @@ -4188,6 +4207,17 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_OPENELM: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 16: model.type = e_model::MODEL_270M; break; + case 20: model.type = e_model::MODEL_450M; break; + case 28: model.type = e_model::MODEL_1B; break; + case 36: model.type = e_model::MODEL_3B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; default: (void)0; } @@ -4675,6 +4705,22 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { if (vocab.special_eot_id != -1) { LLAMA_LOG_INFO( "%s: EOT token = %d '%s'\n", __func__, vocab.special_eot_id, vocab.id_to_token[vocab.special_eot_id].text.c_str() ); } } +float make_divisible( + double v, + int divisor = 8, + float min_value = 0.0 +) { + if (min_value == 0.0) { + min_value = divisor; + } + float rounded_v = int((v + divisor / 2) / divisor) * divisor; + float new_v = (min_value > rounded_v)? min_value : rounded_v; + if (new_v < 0.9 * v) { + new_v += divisor; + } + return new_v; +} + // Returns false if cancelled by progress_callback static bool llm_load_tensors( llama_model_loader & ml, @@ -5924,6 +5970,48 @@ static bool llm_load_tensors( layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); } } break; + case LLM_ARCH_OPENELM: + { + std::vector num_kv_heads = {3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5}; + std::vector num_query_heads = {12, 12, 12, 12, 12, 16, 16, 16, 16, 16, 16, 16, 20, 20, 20, 20}; + std::vector ffn_multipliers = {0.5, 0.73, 0.97, 1.2, 1.43, 1.67, 1.9, 2.13, 2.37, 2.6, 2.83, 3.07, 3.3, 3.53, 3.77, 4.0}; + llama_hparams modified_hparams(hparams); + const int64_t n_embd_head = hparams.n_embd_head_v; + + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }); + { + model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }); + model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + ml.n_created--; // artificial tensor + ml.size_data += ggml_nbytes(model.output); + } + for (int i = 0; i < n_layer; ++i) { + const int64_t n_head_k = num_kv_heads[i]; + const int64_t n_head_v = num_kv_heads[i]; + const int64_t n_head_kv = n_head_k+n_head_v; + const int64_t n_head = n_head_kv+ num_query_heads[i]; + const int64_t n_kv = (num_kv_heads[i]+num_kv_heads[i])*n_embd_head; + modified_hparams.n_head = n_head; + modified_hparams.n_head_kv = n_head_kv; + const int64_t n_embd_gqa = n_embd_head * n_head; + const int64_t n_embd_k_gqa = modified_hparams.n_embd_k_gqa(); + const int64_t n_embd_v_gqa = modified_hparams.n_embd_v_gqa(); + const int64_t ffn_inter = make_divisible(n_embd*ffn_multipliers[i], 256); + + + ggml_context* ctx_layer = ctx_for_layer(i); + ggml_context* ctx_split = ctx_for_layer_split(i); + auto& layer = model.layers[i]; + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }); + layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head }); + layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head }); + layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, n_embd_head*n_head }); + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head_kv*n_embd_head*2, n_embd }); + layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }); + layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, 2 * ffn_inter }); + layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { ffn_inter, n_embd }); + } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -6849,6 +6937,17 @@ struct llm_build_context { return lctx.inp_KQ_mask; } + struct ggml_tensor * build_inp_KQ_mask2(int64_t n_kv, bool causal = true) { + if (causal) { + lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, n_tokens); + } else { + lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens); + } + cb(lctx.inp_KQ_mask, "KQ_mask", -1); + ggml_set_input(lctx.inp_KQ_mask); + return lctx.inp_KQ_mask; + } + struct ggml_tensor * build_inp_KQ_pos() { lctx.inp_KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_kv); cb(lctx.inp_KQ_pos, "KQ_pos", -1); @@ -10610,6 +10709,139 @@ struct llm_build_context { return gf; } + + struct ggml_cgraph * build_openelm() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + const int64_t n_embd_head = hparams.n_embd_head_v; + // TODO: get this from config + std::vector num_kv_heads = {3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5}; + std::vector num_query_heads = {12, 12, 12, 12, 12, 16, 16, 16, 16, 16, 16, 16, 20, 20, 20, 20}; + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + struct ggml_tensor * inp_pos = build_inp_pos(); + llama_hparams modified_hparams(hparams); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + + for (int il = 0; il < n_layer; ++il) { + auto residual = inpL; + // TODO: Want the offsets to be calculated with the num heads at layer level + // This doesn't work at the moment, comment out to test + const int64_t n_head_k = num_kv_heads[il]; + const int64_t n_head_v = num_kv_heads[il]; + const int64_t n_head_kv = n_head_k+n_head_v; + const int64_t n_head = n_head_kv+ num_query_heads[il]; + const int64_t n_kv = (num_kv_heads[il]+num_kv_heads[il])*n_embd_head; + modified_hparams.n_head = n_head; + modified_hparams.n_head_kv = n_head_kv; + const int64_t n_embd_gqa = n_embd_head * n_head; + const int64_t n_embd_k_gqa = modified_hparams.n_embd_k_gqa(); + const int64_t n_embd_v_gqa = modified_hparams.n_embd_v_gqa(); + + // self-attention + { + struct ggml_tensor* attn_norm_output = llm_build_norm(ctx0, inpL, modified_hparams, + model.layers[il].attn_norm, + NULL, + LLM_NORM_RMS, cb, il); + cb(attn_norm_output, "attn_norm", il); + + struct ggml_tensor * Qcur = nullptr; + struct ggml_tensor * Kcur = nullptr; + struct ggml_tensor * Vcur = nullptr; + + cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, attn_norm_output); + cb(cur, "wqkv", il); + + Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0 * sizeof(float) * (n_embd))); + Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_k_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd))); + Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_v_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_k_gqa))); + // Q/K Layernorm + Qcur = llm_build_norm(ctx0, Qcur, modified_hparams, + model.layers[il].attn_q_norm, + NULL, + LLM_NORM_RMS, cb, il); + cb(Qcur, "Qcur", il); + + Kcur = llm_build_norm(ctx0, Kcur, modified_hparams, + model.layers[il].attn_k_norm, + NULL, + LLM_NORM_RMS, cb, il); + cb(Kcur, "Kcur", il); + + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + struct ggml_tensor * KQ_mask = build_inp_KQ_mask2(n_kv); + + Qcur = ggml_rope_custom( + ctx0, Qcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx, + freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur", il); + + Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head))); + cb(Qcur, "Qcur", il); + + Kcur = ggml_rope_custom( + ctx0, Kcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx, + freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur", il); + cur = llm_build_kv(ctx0, model, modified_hparams, kv_self, gf, + model.layers[il].wo, NULL, + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il); + } + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor* inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + residual = ggml_get_rows(ctx0, residual, inp_out_ids); + } + cur = ggml_add(ctx0, cur, residual); + residual = cur; + cur = llm_build_norm(ctx0, cur, modified_hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + // FF + { + + cur = llm_build_norm(ctx0, cur, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, NULL, + NULL, NULL, + model.layers[il].ffn_down, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, residual, cur); + cb(cur, "l_out", il); + + inpL = cur; + } + + cur = llm_build_norm(ctx0, inpL, modified_hparams, + model.output_norm, + NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + cur = ggml_mul_mat(ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } }; static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector & ids) { @@ -10823,6 +11055,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_olmo(); } break; + case LLM_ARCH_OPENELM: + { + result = llm.build_openelm(); + } break; default: GGML_ASSERT(false); } @@ -15679,6 +15915,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_XVERSE: case LLM_ARCH_COMMAND_R: case LLM_ARCH_OLMO: + case LLM_ARCH_OPENELM: return LLAMA_ROPE_TYPE_NORM; // the pairs of head values are offset by n_rot/2 diff --git a/sgemm.cpp b/sgemm.cpp index 4e0159804..39ca855a3 100644 --- a/sgemm.cpp +++ b/sgemm.cpp @@ -818,7 +818,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda assert(m >= 0); assert(n >= 0); assert(k >= 0); - assert(lda >= k); + assert(lda >= k); // Currently failing here assert(ldb >= k); assert(ldc >= m); assert(nth > 0);