py: a bit cleaner

This commit is contained in:
Xuan Son Nguyen 2025-01-23 23:07:08 +01:00
parent c3a654c0fb
commit b986af80de

View file

@ -1734,9 +1734,7 @@ class LlamaModel(Model):
n_kv_head = self.hparams.get("num_key_value_heads")
is_vision_tensor = "vision_tower" in name or "vision_model" in name
# For vision model
if name.startswith("language_model"):
name = name.replace("language_model.", "")
if is_vision_tensor:
if name.startswith("model.text_model"):
name = name.replace("text_model.", "") # for SmolVLM
else:
@ -1745,6 +1743,9 @@ class LlamaModel(Model):
return [] # skip post_layernorm
if not is_vision_tensor:
if name.startswith("language_model"):
# language model tensors, remove the prefix
name = name.replace("language_model.", "")
if name.endswith(("q_proj.weight", "q_proj.bias")):
data_torch = LlamaModel.permute(data_torch, n_head, n_head)
if name.endswith(("k_proj.weight", "k_proj.bias")):