add mobilevlm

This commit is contained in:
Xuan Son Nguyen 2025-01-19 16:29:20 +01:00
parent 6cabdda0df
commit d0068ef0ed
9 changed files with 210 additions and 61 deletions

View file

@ -17,7 +17,7 @@ from hashlib import sha256
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Literal, Sequence, TypeVar, cast
from itertools import chain
from transformers import AutoConfig
from transformers import AutoConfig, AutoImageProcessor
import math
import numpy as np
import torch
@ -68,9 +68,10 @@ class Model:
dir_model_card: Path
# for vision model
vision_arch: gguf.MODEL_ARCH | None = None
preprocessor_config: dict[str, Any] | None = None
vparams: dict[str, Any] | None = None
v_tensor_map: gguf.TensorNameMap
v_tensor_map: gguf.TensorNameMap | None = None
v_tensor_names: set[str] | None
# subclasses should define this!
@ -102,7 +103,6 @@ class Model:
self.metadata_override = metadata_override
self.model_name = model_name
self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py
self.preprocessor_config = self.load_preprocessor_config(self.dir_model)
# Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type
if self.ftype == gguf.LlamaFileType.GUESSED:
@ -218,7 +218,7 @@ class Model:
def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", ".bias")) -> str:
new_name = self.tensor_map.get_name(key=name, try_suffixes=try_suffixes)
new_name_vision = self.v_tensor_map.get_name(key=name, try_suffixes=try_suffixes)
new_name_vision = self.v_tensor_map.get_name(key=name, try_suffixes=try_suffixes) if self.v_tensor_map is not None else None
if new_name is not None:
return new_name
elif new_name_vision is not None:
@ -488,14 +488,17 @@ class Model:
return hparams
@staticmethod
def load_preprocessor_config(dir_model: Path):
def load_preprocessor_config(dir_or_model_id: Path | str):
# TODO: this varies vastly among models, need to handle more cases in the future
file_path = dir_model / "preprocessor_config.json"
if os.path.exists(file_path):
with open(file_path, "r", encoding="utf-8") as f:
return json.load(f)
if isinstance(dir_or_model_id, Path):
file_path = dir_or_model_id / "preprocessor_config.json"
if os.path.exists(file_path):
with open(file_path, "r", encoding="utf-8") as f:
return json.load(f)
else:
raise Exception(f"Preprocessor config not found at {file_path}")
else:
return None
return AutoImageProcessor.from_pretrained(dir_or_model_id).to_dict()
@classmethod
def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]:
@ -1586,16 +1589,31 @@ class StableLMModel(Model):
raise ValueError(f"Unprocessed norms: {norms}")
@Model.register("LLaMAForCausalLM", "LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM", "LlavaForConditionalGeneration")
@Model.register("LLaMAForCausalLM", "LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM", "LlavaForConditionalGeneration", "MobileLlamaForCausalLM")
class LlamaModel(Model):
model_arch = gguf.MODEL_ARCH.LLAMA
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if "vision_config" in self.hparams:
model_type = self.hparams.get("model_type", None)
self.vision_arch = None
# only tested with https://huggingface.co/llava-hf/llava-1.5-7b-hf
if "vision_config" in self.hparams and model_type == "llava":
self.vparams = self.hparams["vision_config"]
if self.vparams is not None:
self.v_tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.LLAVA_VISION, self.vparams["num_hidden_layers"])
self.preprocessor_config = self.load_preprocessor_config(self.dir_model)
self.vision_arch = gguf.MODEL_ARCH.VISION_LLAVA
# only tested with https://huggingface.co/mtgv/MobileVLM_V2-1.7B
if "mm_vision_tower" in self.hparams and model_type == "mobilevlm":
vision_model_id = self.hparams["mm_vision_tower"]
self.vparams = AutoConfig.from_pretrained(vision_model_id).to_dict()["vision_config"]
self.preprocessor_config = self.load_preprocessor_config(vision_model_id)
self.vision_arch = gguf.MODEL_ARCH.VISION_MOBILEVLM
if self.vparams is not None and self.vision_arch is not None:
self.v_tensor_map = gguf.get_tensor_name_map(self.vision_arch, self.vparams["num_hidden_layers"])
def set_vocab(self):
try:
@ -1631,23 +1649,31 @@ class LlamaModel(Model):
self.gguf_writer.add_add_bos_token(False)
# For vision model
if self.vparams is not None and self.preprocessor_config is not None:
if self.vparams is not None and self.preprocessor_config is not None and self.vision_arch is not None:
self.gguf_writer.add_vision_type("clip-vit")
self.gguf_writer.add_vision_image_size(self.vparams["image_size"])
self.gguf_writer.add_vision_patch_size(self.vparams["patch_size"])
self.gguf_writer.add_vision_clip_architecture("llava")
self.gguf_writer.add_vision_clip_architecture(gguf.MODEL_ARCH_NAMES[self.vision_arch])
self.gguf_writer.add_vision_clip_block_count(self.vparams["num_hidden_layers"])
self.gguf_writer.add_vision_clip_embedding_length(self.vparams["hidden_size"])
self.gguf_writer.add_vision_clip_feed_forward_length(self.vparams["intermediate_size"])
self.gguf_writer.add_vision_clip_head_count(self.vparams["num_attention_heads"])
self.gguf_writer.add_vision_clip_image_mean(self.preprocessor_config["image_mean"])
self.gguf_writer.add_vision_clip_image_std(self.preprocessor_config["image_std"])
self.gguf_writer.add_vision_clip_select_layer(self.hparams["vision_feature_layer"])
self.gguf_writer.add_vision_clip_patch_merge_type(gguf.CLIPPatchMergeType.FLAT)
max_pos_embd = (self.vparams["image_size"] // self.vparams["patch_size"])**2 + 1
self.gguf_writer.add_vision_clip_max_position_embeddings(max_pos_embd)
if "vision_feature_layer" in self.hparams:
self.gguf_writer.add_vision_clip_select_layer(self.hparams["vision_feature_layer"])
elif "mm_vision_select_layer" in self.hparams:
self.gguf_writer.add_vision_clip_select_layer(self.hparams["mm_vision_select_layer"])
else:
raise ValueError("gguf: can not find vision_feature_layer parameter.")
# TODO: should not hardcode these, but they are currently missing from config.json
self.gguf_writer.add_vision_clip_projector_type(gguf.constants.CLIPProjectorType.MLP)
if self.vision_arch == gguf.MODEL_ARCH.VISION_LLAVA:
self.gguf_writer.add_vision_clip_projector_type(gguf.constants.CLIPProjectorType.MLP)
if self.vision_arch == gguf.MODEL_ARCH.VISION_MOBILEVLM:
self.gguf_writer.add_vision_clip_projector_type(gguf.constants.CLIPProjectorType.LDPV2)
self.gguf_writer.add_vision_clip_layer_norm_epsilon(1e-05)
def set_gguf_parameters(self):
@ -1683,6 +1709,8 @@ class LlamaModel(Model):
# For vision model
if name.startswith("language_model"):
name = name.replace("language_model.", "")
else:
name = name.replace("model.vision_tower.", "")
if "post_layernorm" in name:
return [] # skip post_layernorm
@ -2101,7 +2129,7 @@ class DbrxModel(Model):
return n_dims > 1
@Model.register("MiniCPMForCausalLM")
@Model.register("MiniCPMForCausalLM", "MiniCPMV")
class MiniCPMModel(Model):
model_arch = gguf.MODEL_ARCH.MINICPM