From 553b09ba8fb9ebf8db0b28700693d43f8fa15f71 Mon Sep 17 00:00:00 2001 From: S Date: Fri, 5 Apr 2024 10:35:19 +0100 Subject: [PATCH] Fix embedding layer based on Noeda's example --- convert-hf-to-gguf.py | 50 +------------------------------------------ llama.cpp | 43 ++++++++++++++++++++++++++++--------- 2 files changed, 34 insertions(+), 59 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index e52ebe4b8..91730e0d8 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -2351,55 +2351,7 @@ class CommandR2Model(Model): super().set_gguf_parameters() self.gguf_writer.add_logit_scale(self.hparams["logit_scale"]) self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) - - def write_tensors(self): - block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer"))) - tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count) - for name, data_torch in self.get_tensors(): - # we don't need these - if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")): - continue - - #Convert Q norm and K norm to 1d so they are exported in float32 and not quantized - if name.endswith((".q_norm.weight")) or name.endswith((".k_norm.weight")): - data_torch = data_torch.flatten() - - old_dtype = data_torch.dtype - - # convert any unsupported data types to float32 - if data_torch.dtype not in (torch.float16, torch.float32): - data_torch = data_torch.to(torch.float32) - - data = data_torch.squeeze().numpy() - - # map tensor names - new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias")) - if new_name is None: - print(f"Can not map tensor {name!r}") - sys.exit() - - n_dims = len(data.shape) - data_dtype = data.dtype - - # if f32 desired, convert any float16 to float32 - if self.ftype == 0 and data_dtype == np.float16: - data = data.astype(np.float32) - - # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32 - if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1: - data = data.astype(np.float32) - - # if f16 desired, convert any float32 2-dim weight tensors to float16 - if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: - data = data.astype(np.float16) - - print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}") - - self.gguf_writer.add_tensor(new_name, data) - - - - + ###### CONVERSION LOGIC ###### diff --git a/llama.cpp b/llama.cpp index c8ef61aee..53cf86dc0 100644 --- a/llama.cpp +++ b/llama.cpp @@ -5405,13 +5405,13 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - - if(n_layer >= 64) + + if (n_layer >= 64) { - layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {hparams.n_embd_head_k * hparams.n_head}); - layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {hparams.n_embd_head_k * hparams.n_head_kv}); + layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {hparams.n_embd_head_k, hparams.n_head}); + layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {hparams.n_embd_head_k, hparams.n_head_kv}); } - + layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); @@ -9460,19 +9460,31 @@ struct llm_build_context { cb(Vcur, "Vcur", il); } - if(model.layers[il].attn_q_norm) + if (model.layers[il].attn_q_norm) { - Qcur = llm_build_norm(ctx0, Qcur, hparams, + + Qcur = ggml_view_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur) * n_embd_head, + ggml_element_size(Qcur) * n_embd_head * n_head, + 0); + cb(Qcur, "Qcur", il); + Kcur = ggml_view_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens, + ggml_element_size(Kcur) * n_embd_head, + ggml_element_size(Kcur) * n_embd_head * n_head_kv, + 0); + cb(Kcur, "Kcur", il); + + Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, NULL, LLM_NORM, cb, il); cb(Qcur, "Qcur", il); - + Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, NULL, LLM_NORM, cb, il); - cb(Kcur, "Kcur", il); + cb(Kcur, "Kcur", il); } Qcur = ggml_rope_custom( @@ -13085,9 +13097,15 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n return std::make_pair(i_layer, n_layer); }; + // Command-R+ has such a large embedding weight tensor it overflows + // 32-bit signed integers. This is band-aid until quants can deal with + // that. + if (name == "token_embd.weight" && arch == LLM_ARCH_COMMAND_R && qs.model.hparams.n_layer >= 64) { + new_type = GGML_TYPE_F16; + } // for arches that share the same tensor between the token embeddings and the output, we quantize the token embeddings // with the quantization of the output tensor - if (name == tn(LLM_TENSOR_OUTPUT, "weight") || (!qs.has_output && name == tn(LLM_TENSOR_TOKEN_EMBD, "weight"))) { + else if (name == tn(LLM_TENSOR_OUTPUT, "weight") || (!qs.has_output && name == tn(LLM_TENSOR_TOKEN_EMBD, "weight"))) { if (qs.params->output_tensor_type < GGML_TYPE_COUNT) { new_type = qs.params->output_tensor_type; } else { @@ -13119,6 +13137,11 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n new_type = GGML_TYPE_IQ3_S; } } + + } else if ((arch == LLM_ARCH_COMMAND_R) && + (name.find("q_norm") != std::string::npos || + name.find("k_norm") != std::string::npos)) { + new_type = GGML_TYPE_F32; } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) { if (name.find("attn_v.weight") != std::string::npos) {