add back convert hf to gguf

This commit is contained in:
Xuan Son Nguyen 2025-01-18 22:56:04 +01:00
parent 0a81051ae2
commit 6cabdda0df
7 changed files with 266 additions and 6 deletions

View file

@ -17,6 +17,7 @@ from hashlib import sha256
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Literal, Sequence, TypeVar, cast
from itertools import chain
from transformers import AutoConfig
import math
import numpy as np
import torch
@ -66,6 +67,12 @@ class Model:
metadata_override: Path | None
dir_model_card: Path
# for vision model
preprocessor_config: dict[str, Any] | None = None
vparams: dict[str, Any] | None = None
v_tensor_map: gguf.TensorNameMap
v_tensor_names: set[str] | None
# subclasses should define this!
model_arch: gguf.MODEL_ARCH
@ -95,6 +102,7 @@ class Model:
self.metadata_override = metadata_override
self.model_name = model_name
self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py
self.preprocessor_config = self.load_preprocessor_config(self.dir_model)
# Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type
if self.ftype == gguf.LlamaFileType.GUESSED:
@ -210,9 +218,13 @@ class Model:
def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", ".bias")) -> str:
new_name = self.tensor_map.get_name(key=name, try_suffixes=try_suffixes)
if new_name is None:
new_name_vision = self.v_tensor_map.get_name(key=name, try_suffixes=try_suffixes)
if new_name is not None:
return new_name
elif new_name_vision is not None:
return new_name_vision
else:
raise ValueError(f"Can not map tensor {name!r}")
return new_name
def set_gguf_parameters(self):
self.gguf_writer.add_block_count(self.block_count)
@ -466,7 +478,24 @@ class Model:
@staticmethod
def load_hparams(dir_model: Path):
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
return json.load(f)
hparams = json.load(f)
if "text_config" in hparams:
text_config = hparams["text_config"]
# for example, llava-1.5-7b-hf misses the language model config, need to retrieve it via model ID
if "_name_or_path" in text_config:
text_config = AutoConfig.from_pretrained(text_config["_name_or_path"]).to_dict()
hparams = {**text_config, **hparams}
return hparams
@staticmethod
def load_preprocessor_config(dir_model: Path):
# TODO: this varies vastly among models, need to handle more cases in the future
file_path = dir_model / "preprocessor_config.json"
if os.path.exists(file_path):
with open(file_path, "r", encoding="utf-8") as f:
return json.load(f)
else:
return None
@classmethod
def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]:
@ -1557,10 +1586,17 @@ class StableLMModel(Model):
raise ValueError(f"Unprocessed norms: {norms}")
@Model.register("LLaMAForCausalLM", "LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM")
@Model.register("LLaMAForCausalLM", "LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM", "LlavaForConditionalGeneration")
class LlamaModel(Model):
model_arch = gguf.MODEL_ARCH.LLAMA
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if "vision_config" in self.hparams:
self.vparams = self.hparams["vision_config"]
if self.vparams is not None:
self.v_tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.LLAVA_VISION, self.vparams["num_hidden_layers"])
def set_vocab(self):
try:
self._set_vocab_sentencepiece()
@ -1594,6 +1630,26 @@ class LlamaModel(Model):
if self.hparams.get("vocab_size", 32000) == 49152:
self.gguf_writer.add_add_bos_token(False)
# For vision model
if self.vparams is not None and self.preprocessor_config is not None:
self.gguf_writer.add_vision_type("clip-vit")
self.gguf_writer.add_vision_image_size(self.vparams["image_size"])
self.gguf_writer.add_vision_patch_size(self.vparams["patch_size"])
self.gguf_writer.add_vision_clip_architecture("llava")
self.gguf_writer.add_vision_clip_block_count(self.vparams["num_hidden_layers"])
self.gguf_writer.add_vision_clip_embedding_length(self.vparams["hidden_size"])
self.gguf_writer.add_vision_clip_feed_forward_length(self.vparams["intermediate_size"])
self.gguf_writer.add_vision_clip_head_count(self.vparams["num_attention_heads"])
self.gguf_writer.add_vision_clip_image_mean(self.preprocessor_config["image_mean"])
self.gguf_writer.add_vision_clip_image_std(self.preprocessor_config["image_std"])
self.gguf_writer.add_vision_clip_select_layer(self.hparams["vision_feature_layer"])
self.gguf_writer.add_vision_clip_patch_merge_type(gguf.CLIPPatchMergeType.FLAT)
max_pos_embd = (self.vparams["image_size"] // self.vparams["patch_size"])**2 + 1
self.gguf_writer.add_vision_clip_max_position_embeddings(max_pos_embd)
# TODO: should not hardcode these, but they are currently missing from config.json
self.gguf_writer.add_vision_clip_projector_type(gguf.constants.CLIPProjectorType.MLP)
self.gguf_writer.add_vision_clip_layer_norm_epsilon(1e-05)
def set_gguf_parameters(self):
super().set_gguf_parameters()
hparams = self.hparams
@ -1624,6 +1680,12 @@ class LlamaModel(Model):
n_head = self.hparams["num_attention_heads"]
n_kv_head = self.hparams.get("num_key_value_heads")
# For vision model
if name.startswith("language_model"):
name = name.replace("language_model.", "")
if "post_layernorm" in name:
return [] # skip post_layernorm
if name.endswith(("q_proj.weight", "q_proj.bias")):
data_torch = LlamaModel.permute(data_torch, n_head, n_head)
if name.endswith(("k_proj.weight", "k_proj.bias")):