diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 27bf2c1f2..b3c2ce2b1 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1734,17 +1734,18 @@ class LlamaModel(Model): n_kv_head = self.hparams.get("num_key_value_heads") is_vision_tensor = "vision_tower" in name or "vision_model" in name - # For vision model - if name.startswith("language_model"): - name = name.replace("language_model.", "") - if name.startswith("model.text_model"): - name = name.replace("text_model.", "") # for SmolVLM - else: - name = name.replace("model.vision_tower.", "") - if "post_layernorm" in name and self.vision_arch != gguf.MODEL_ARCH.VISION_IDEFICS3: - return [] # skip post_layernorm + if is_vision_tensor: + if name.startswith("model.text_model"): + name = name.replace("text_model.", "") # for SmolVLM + else: + name = name.replace("model.vision_tower.", "") + if "post_layernorm" in name and self.vision_arch != gguf.MODEL_ARCH.VISION_IDEFICS3: + return [] # skip post_layernorm if not is_vision_tensor: + if name.startswith("language_model"): + # language model tensors, remove the prefix + name = name.replace("language_model.", "") if name.endswith(("q_proj.weight", "q_proj.bias")): data_torch = LlamaModel.permute(data_torch, n_head, n_head) if name.endswith(("k_proj.weight", "k_proj.bias")):