minicpmv works but missing uhd slices
This commit is contained in:
parent
ba489b4743
commit
c0d93dd509
11 changed files with 423 additions and 281 deletions
|
@ -2141,60 +2141,19 @@ class DbrxModel(Model):
|
|||
return n_dims > 1
|
||||
|
||||
|
||||
@Model.register("MiniCPMForCausalLM", "MiniCPMV")
|
||||
@Model.register("MiniCPMForCausalLM")
|
||||
class MiniCPMModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.MINICPM
|
||||
proj_type: gguf.constants.CLIPProjectorType | None
|
||||
resampler_n_embd = 0
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
model_type = self.hparams.get("model_type", None)
|
||||
|
||||
# only tested with https://huggingface.co/openbmb/MiniCPM-V-2_6
|
||||
if "vision_config" in self.hparams and model_type == "minicpmv":
|
||||
self.vparams = self.hparams["vision_config"]
|
||||
self.preprocessor_config = self.load_preprocessor_config(self.dir_model)
|
||||
self.vision_arch = gguf.MODEL_ARCH.VISION_MINICPMV
|
||||
version = str(self.hparams.get("version", "unknown"))
|
||||
if version == "2.5":
|
||||
self.proj_type = gguf.constants.CLIPProjectorType.MINICPMV_2_5
|
||||
elif version == "2.6":
|
||||
self.proj_type = gguf.constants.CLIPProjectorType.MINICPMV_2_6
|
||||
else:
|
||||
raise ValueError(f"Unsupported MiniCPM-V version: {version}")
|
||||
# TODO: how to do this without reading the whole safetensor file?
|
||||
for tname, tensor in self.get_tensors():
|
||||
if tname == "resampler.ln_post.bias":
|
||||
self.resampler_n_embd = tensor.shape[0]
|
||||
if self.resampler_n_embd < 2:
|
||||
raise ValueError("Failed to detect resampler embedding size")
|
||||
|
||||
if self.vparams is not None and self.vision_arch is not None and self.preprocessor_config is not None:
|
||||
self.preprocessor_config["image_mean"] = [0.5, 0.5, 0.5]
|
||||
self.preprocessor_config["image_std"] = [0.5, 0.5, 0.5]
|
||||
self.hparams["vision_feature_layer"] = 0
|
||||
self.v_tensor_map = gguf.get_tensor_name_map(self.vision_arch, self.vparams["num_hidden_layers"])
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
# scale_emb
|
||||
embedding_scale = float(self.hparams.get("scale_emb", 1.0))
|
||||
embedding_scale = float(self.hparams["scale_emb"])
|
||||
self.gguf_writer.add_embedding_scale(embedding_scale)
|
||||
logger.info(f"gguf: (minicpm) embedding_scale = {embedding_scale}")
|
||||
# scale_depth
|
||||
if "scale_depth" in self.hparams:
|
||||
residual_scale = self.hparams["scale_depth"] / self.hparams["num_hidden_layers"] ** 0.5
|
||||
else:
|
||||
residual_scale = 1.0
|
||||
residual_scale = self.hparams["scale_depth"] / self.hparams["num_hidden_layers"] ** 0.5
|
||||
self.gguf_writer.add_residual_scale(residual_scale)
|
||||
logger.info(f"gguf: (minicpm) residual_scale = {residual_scale}")
|
||||
# logit_scale
|
||||
if "dim_model_base" in self.hparams:
|
||||
logit_scale = self.hparams["hidden_size"] / self.hparams["dim_model_base"]
|
||||
else:
|
||||
logit_scale = 1.0
|
||||
logit_scale = self.hparams["hidden_size"] / self.hparams["dim_model_base"]
|
||||
self.gguf_writer.add_logit_scale(logit_scale)
|
||||
logger.info(f"gguf: (minicpm) logit_scale = {logit_scale}")
|
||||
if self.hparams.get("rope_scaling") is not None:
|
||||
|
@ -2202,15 +2161,6 @@ class MiniCPMModel(Model):
|
|||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LONGROPE)
|
||||
logger.info(f"gguf: (minicpm) rope_scaling_type = {gguf.RopeScalingType.LONGROPE}")
|
||||
|
||||
# For vision model
|
||||
if self.vparams is not None and self.proj_type is not None:
|
||||
self.gguf_writer.add_vision_vit_patch_merge_type(gguf.CLIPPatchMergeType.FLAT)
|
||||
self.gguf_writer.add_vision_vit_projector_type(self.proj_type)
|
||||
self.gguf_writer.add_vision_vit_layer_norm_epsilon(1e-06)
|
||||
max_pos_embd = (self.vparams["image_size"] // self.vparams["patch_size"])**2
|
||||
self.gguf_writer.add_vision_vit_max_position_embeddings(max_pos_embd)
|
||||
|
||||
|
||||
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
|
||||
rope_dims = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
|
||||
|
||||
|
@ -2228,118 +2178,22 @@ class MiniCPMModel(Model):
|
|||
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_LONG), torch.tensor(long_factors, dtype=torch.float32))
|
||||
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT), torch.tensor(short_factors, dtype=torch.float32))
|
||||
|
||||
if self.vision_arch == gguf.MODEL_ARCH.VISION_MINICPMV:
|
||||
yield (
|
||||
self.format_tensor_name(gguf.MODEL_TENSOR.V_RESMPL_POS_EMBD_K, is_vision=True),
|
||||
torch.from_numpy(self._get_2d_sincos_pos_embed(self.resampler_n_embd, (70, 70)))
|
||||
)
|
||||
|
||||
def set_vocab(self):
|
||||
if self.vision_arch == gguf.MODEL_ARCH.VISION_MINICPMV:
|
||||
# undocumented anywhere, I only found this thanks to https://huggingface.co/openbmb/MiniCPM-V-2_6-gguf
|
||||
self._set_vocab_gpt2()
|
||||
else:
|
||||
self._set_vocab_sentencepiece()
|
||||
self._set_vocab_sentencepiece()
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
del bid # unused
|
||||
|
||||
# For vision model
|
||||
if name.startswith("llm."):
|
||||
name = name.replace("llm.", "")
|
||||
|
||||
# split the resampler.attn.in_proj_(weight|bias) tensors into q, k, v
|
||||
if name.endswith("in_proj_weight") or name.endswith("in_proj_bias"):
|
||||
assert data_torch.shape[0] == 3 * self.resampler_n_embd
|
||||
split_tensor = data_torch.chunk(3, dim=0)
|
||||
name_q = name.replace("in_proj_", "in_proj_q.") # in_proj_q.(weight|bias)
|
||||
name_k = name.replace("in_proj_", "in_proj_k.") # in_proj_k.(weight|bias)
|
||||
name_v = name.replace("in_proj_", "in_proj_v.") # in_proj_v.(weight|bias)
|
||||
return [
|
||||
(self.map_tensor_name(name_q), split_tensor[0]),
|
||||
(self.map_tensor_name(name_k), split_tensor[1]),
|
||||
(self.map_tensor_name(name_v), split_tensor[2]),
|
||||
]
|
||||
|
||||
if name == "resampler.proj" or name == "resampler.query":
|
||||
name += ".weight"
|
||||
|
||||
if "post_layernorm" in name:
|
||||
return [] # skip post_layernorm
|
||||
|
||||
n_head = self.hparams["num_attention_heads"]
|
||||
n_kv_head = self.hparams.get("num_key_value_heads")
|
||||
|
||||
# HF models permute some of the tensors, so we need to undo that
|
||||
if not name.startswith("vpm") and name.endswith(("q_proj.weight")):
|
||||
if name.endswith(("q_proj.weight")):
|
||||
data_torch = LlamaModel.permute(data_torch, n_head, n_head)
|
||||
if not name.startswith("vpm") and name.endswith(("k_proj.weight")):
|
||||
if name.endswith(("k_proj.weight")):
|
||||
data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head)
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool:
|
||||
del name, bid # unused
|
||||
if "v.resmpl.query" in new_name or "v.resmpl.pos_embd_k" in new_name:
|
||||
return gguf.GGMLQuantizationType.F32
|
||||
if "v.resmpl." in new_name:
|
||||
return gguf.GGMLQuantizationType.F32 if n_dims == 1 else gguf.GGMLQuantizationType.F16
|
||||
return False
|
||||
|
||||
# utils to work with MiniCPM-V resampler
|
||||
|
||||
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
|
||||
def _get_2d_sincos_pos_embed(self, embed_dim: int, grid_size: tuple[int, int] | int, cls_token=False) -> np.ndarray:
|
||||
"""
|
||||
grid_size: int of the grid height and width
|
||||
return:
|
||||
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
||||
"""
|
||||
if isinstance(grid_size, int):
|
||||
grid_h_size, grid_w_size = grid_size, grid_size
|
||||
else:
|
||||
grid_h_size, grid_w_size = grid_size[0], grid_size[1]
|
||||
|
||||
grid_h = np.arange(grid_h_size, dtype=np.float32)
|
||||
grid_w = np.arange(grid_w_size, dtype=np.float32)
|
||||
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
||||
grid = np.stack(grid, axis=0)
|
||||
|
||||
grid = grid.reshape([2, 1, grid_h_size, grid_w_size])
|
||||
pos_embed = self._get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
||||
if cls_token:
|
||||
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
|
||||
return pos_embed
|
||||
|
||||
def _get_2d_sincos_pos_embed_from_grid(self, embed_dim: int, grid: np.ndarray) -> np.ndarray:
|
||||
assert embed_dim % 2 == 0
|
||||
|
||||
# use half of dimensions to encode grid_h
|
||||
emb_h = self._get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
||||
emb_w = self._get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
||||
|
||||
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
||||
return emb
|
||||
|
||||
def _get_1d_sincos_pos_embed_from_grid(self, embed_dim: int, pos: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
embed_dim: output dimension for each position
|
||||
pos: a list of positions to be encoded: size (M,)
|
||||
out: (M, D)
|
||||
"""
|
||||
assert embed_dim % 2 == 0
|
||||
omega = np.arange(embed_dim // 2, dtype=np.float32)
|
||||
omega /= embed_dim / 2.
|
||||
omega = 1. / 10000 ** omega # (D/2,)
|
||||
|
||||
pos = pos.reshape(-1) # (M,)
|
||||
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
||||
|
||||
emb_sin = np.sin(out) # (M, D/2)
|
||||
emb_cos = np.cos(out) # (M, D/2)
|
||||
|
||||
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
||||
return emb
|
||||
|
||||
|
||||
@Model.register("MiniCPM3ForCausalLM")
|
||||
|
@ -2479,6 +2333,155 @@ class Qwen2VLModel(Model):
|
|||
yield name, data
|
||||
|
||||
|
||||
@Model.register("MiniCPMV")
|
||||
class MiniCPMVModel(Qwen2Model):
|
||||
# based on minicpmv-surgery.py, not sure why it is Qwen2Model instead of MiniCPMModel
|
||||
model_arch = gguf.MODEL_ARCH.QWEN2
|
||||
proj_type: gguf.constants.CLIPProjectorType | None
|
||||
resampler_n_embd = 0
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
model_type = self.hparams.get("model_type", None)
|
||||
|
||||
# only tested with https://huggingface.co/openbmb/MiniCPM-V-2_6
|
||||
if "vision_config" in self.hparams and model_type == "minicpmv":
|
||||
self.vparams = self.hparams["vision_config"]
|
||||
self.preprocessor_config = self.load_preprocessor_config(self.dir_model)
|
||||
self.vision_arch = gguf.MODEL_ARCH.VISION_MINICPMV
|
||||
version = str(self.hparams.get("version", "unknown"))
|
||||
if version == "2.5":
|
||||
self.proj_type = gguf.constants.CLIPProjectorType.MINICPMV_2_5
|
||||
elif version == "2.6":
|
||||
self.proj_type = gguf.constants.CLIPProjectorType.MINICPMV_2_6
|
||||
else:
|
||||
raise ValueError(f"Unsupported MiniCPM-V version: {version}")
|
||||
# TODO: how to do this without reading the whole safetensor file?
|
||||
for tname, tensor in self.get_tensors():
|
||||
if tname == "resampler.ln_post.bias":
|
||||
self.resampler_n_embd = tensor.shape[0]
|
||||
if self.resampler_n_embd < 2:
|
||||
raise ValueError("Failed to detect resampler embedding size")
|
||||
else:
|
||||
raise ValueError("Expected vision_config, but not found")
|
||||
|
||||
if self.vparams is not None and self.vision_arch is not None and self.preprocessor_config is not None:
|
||||
self.preprocessor_config["image_mean"] = [0.5, 0.5, 0.5]
|
||||
self.preprocessor_config["image_std"] = [0.5, 0.5, 0.5]
|
||||
self.hparams["vision_feature_layer"] = 0
|
||||
self.v_tensor_map = gguf.get_tensor_name_map(self.vision_arch, self.vparams["num_hidden_layers"])
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
# For vision model
|
||||
if self.vparams is not None and self.proj_type is not None:
|
||||
self.gguf_writer.add_vision_vit_patch_merge_type(gguf.CLIPPatchMergeType.FLAT)
|
||||
self.gguf_writer.add_vision_vit_projector_type(self.proj_type)
|
||||
self.gguf_writer.add_vision_vit_layer_norm_epsilon(1e-06)
|
||||
max_pos_embd = (self.vparams["image_size"] // self.vparams["patch_size"])**2
|
||||
self.gguf_writer.add_vision_vit_max_position_embeddings(max_pos_embd)
|
||||
|
||||
|
||||
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
|
||||
yield (
|
||||
self.format_tensor_name(gguf.MODEL_TENSOR.V_RESMPL_POS_EMBD_K, is_vision=True),
|
||||
torch.from_numpy(self._get_2d_sincos_pos_embed(self.resampler_n_embd, (70, 70)))
|
||||
)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
del bid # unused
|
||||
|
||||
# for language part
|
||||
if name.startswith("llm."):
|
||||
return [(self.map_tensor_name(name.replace("llm.", "")), data_torch)]
|
||||
|
||||
# split the resampler.attn.in_proj_(weight|bias) tensors into q, k, v
|
||||
if name.endswith("in_proj_weight") or name.endswith("in_proj_bias"):
|
||||
assert data_torch.shape[0] == 3 * self.resampler_n_embd
|
||||
split_tensor = data_torch.chunk(3, dim=0)
|
||||
name_q = name.replace("in_proj_", "in_proj_q.") # in_proj_q.(weight|bias)
|
||||
name_k = name.replace("in_proj_", "in_proj_k.") # in_proj_k.(weight|bias)
|
||||
name_v = name.replace("in_proj_", "in_proj_v.") # in_proj_v.(weight|bias)
|
||||
return [
|
||||
(self.map_tensor_name(name_q), split_tensor[0]),
|
||||
(self.map_tensor_name(name_k), split_tensor[1]),
|
||||
(self.map_tensor_name(name_v), split_tensor[2]),
|
||||
]
|
||||
|
||||
# append .weight to these tensors
|
||||
if name == "resampler.proj" or name == "resampler.query":
|
||||
name += ".weight"
|
||||
|
||||
if "post_layernorm" in name:
|
||||
return [] # skip post_layernorm
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool:
|
||||
del name, bid # unused
|
||||
if "v.resmpl.query" in new_name or "v.resmpl.pos_embd_k" in new_name:
|
||||
return gguf.GGMLQuantizationType.F32
|
||||
if "v.resmpl." in new_name:
|
||||
return gguf.GGMLQuantizationType.F32 if n_dims == 1 else gguf.GGMLQuantizationType.F16
|
||||
return False
|
||||
|
||||
# utils to work with MiniCPM-V resampler
|
||||
|
||||
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
|
||||
def _get_2d_sincos_pos_embed(self, embed_dim: int, grid_size: tuple[int, int] | int, cls_token=False) -> np.ndarray:
|
||||
"""
|
||||
grid_size: int of the grid height and width
|
||||
return:
|
||||
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
||||
"""
|
||||
if isinstance(grid_size, int):
|
||||
grid_h_size, grid_w_size = grid_size, grid_size
|
||||
else:
|
||||
grid_h_size, grid_w_size = grid_size[0], grid_size[1]
|
||||
|
||||
grid_h = np.arange(grid_h_size, dtype=np.float32)
|
||||
grid_w = np.arange(grid_w_size, dtype=np.float32)
|
||||
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
||||
grid = np.stack(grid, axis=0)
|
||||
|
||||
grid = grid.reshape([2, 1, grid_h_size, grid_w_size])
|
||||
pos_embed = self._get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
||||
if cls_token:
|
||||
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
|
||||
return pos_embed
|
||||
|
||||
def _get_2d_sincos_pos_embed_from_grid(self, embed_dim: int, grid: np.ndarray) -> np.ndarray:
|
||||
assert embed_dim % 2 == 0
|
||||
|
||||
# use half of dimensions to encode grid_h
|
||||
emb_h = self._get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
||||
emb_w = self._get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
||||
|
||||
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
||||
return emb
|
||||
|
||||
def _get_1d_sincos_pos_embed_from_grid(self, embed_dim: int, pos: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
embed_dim: output dimension for each position
|
||||
pos: a list of positions to be encoded: size (M,)
|
||||
out: (M, D)
|
||||
"""
|
||||
assert embed_dim % 2 == 0
|
||||
omega = np.arange(embed_dim // 2, dtype=np.float32)
|
||||
omega /= embed_dim / 2.
|
||||
omega = 1. / 10000 ** omega # (D/2,)
|
||||
|
||||
pos = pos.reshape(-1) # (M,)
|
||||
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
||||
|
||||
emb_sin = np.sin(out) # (M, D/2)
|
||||
emb_cos = np.cos(out) # (M, D/2)
|
||||
|
||||
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
||||
return emb
|
||||
|
||||
|
||||
@Model.register("WavTokenizerDec")
|
||||
class WavTokenizerDecModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.WAVTOKENIZER_DEC
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue