diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index e164e9a07..8b8d3988b 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -204,9 +204,10 @@ class Model: f"Missing tensors: {missing}\n" f"Extra tensors: {extra}") - def format_tensor_name(self, key: gguf.MODEL_TENSOR, bid: int | None = None, suffix: str = ".weight") -> str: - if key not in gguf.MODEL_TENSORS[self.model_arch]: - raise ValueError(f"Missing {key!r} for MODEL_TENSORS of {self.model_arch!r}") + def format_tensor_name(self, key: gguf.MODEL_TENSOR, bid: int | None = None, suffix: str = ".weight", is_vision = False) -> str: + arch = self.vision_arch if is_vision and self.vision_arch is not None else self.model_arch + if key not in gguf.MODEL_TENSORS[arch]: + raise ValueError(f"Missing {key!r} for MODEL_TENSORS of {arch!r}") name: str = gguf.TENSOR_NAMES[key] if "{bid}" in name: assert bid is not None @@ -2144,6 +2145,7 @@ class DbrxModel(Model): class MiniCPMModel(Model): model_arch = gguf.MODEL_ARCH.MINICPM proj_type: gguf.constants.CLIPProjectorType | None + resampler_n_embd = 0 def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -2162,6 +2164,12 @@ class MiniCPMModel(Model): self.proj_type = gguf.constants.CLIPProjectorType.MINICPMV_2_6 else: raise ValueError(f"Unsupported MiniCPM-V version: {version}") + # TODO: how to do this without reading the whole safetensor file? + for tname, tensor in self.get_tensors(): + if tname == "resampler.ln_post.bias": + self.resampler_n_embd = tensor.shape[0] + if self.resampler_n_embd < 2: + raise ValueError("Failed to detect resampler embedding size") if self.vparams is not None and self.vision_arch is not None and self.preprocessor_config is not None: self.preprocessor_config["image_mean"] = [0.5, 0.5, 0.5] @@ -2220,6 +2228,12 @@ class MiniCPMModel(Model): yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_LONG), torch.tensor(long_factors, dtype=torch.float32)) yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT), torch.tensor(short_factors, dtype=torch.float32)) + if self.vision_arch == gguf.MODEL_ARCH.VISION_MINICPMV: + yield ( + self.format_tensor_name(gguf.MODEL_TENSOR.V_RESMPL_POS_EMBD_K, is_vision=True), + torch.from_numpy(self._get_2d_sincos_pos_embed(self.resampler_n_embd, (70, 70))) + ) + def set_vocab(self): if self.vision_arch == gguf.MODEL_ARCH.VISION_MINICPMV: # undocumented anywhere, I only found this thanks to https://huggingface.co/openbmb/MiniCPM-V-2_6-gguf @@ -2233,11 +2247,23 @@ class MiniCPMModel(Model): # For vision model if name.startswith("llm."): name = name.replace("llm.", "") - # attention, someone mess up and use underscore instead of dot - if name.endswith("in_proj_weight"): - name = name.replace("_weight", ".weight") - if name.endswith("in_proj_bias"): - name = name.replace("_bias", ".bias") + + # split the resampler.attn.in_proj_(weight|bias) tensors into q, k, v + if name.endswith("in_proj_weight") or name.endswith("in_proj_bias"): + assert data_torch.shape[0] == 3 * self.resampler_n_embd + split_tensor = data_torch.chunk(3, dim=0) + name_q = name.replace("in_proj_", "in_proj_q.") # in_proj_q.(weight|bias) + name_k = name.replace("in_proj_", "in_proj_k.") # in_proj_k.(weight|bias) + name_v = name.replace("in_proj_", "in_proj_v.") # in_proj_v.(weight|bias) + return [ + (self.map_tensor_name(name_q), split_tensor[0]), + (self.map_tensor_name(name_k), split_tensor[1]), + (self.map_tensor_name(name_v), split_tensor[2]), + ] + + if name == "resampler.proj" or name == "resampler.query": + name += ".weight" + if "post_layernorm" in name: return [] # skip post_layernorm @@ -2251,6 +2277,69 @@ class MiniCPMModel(Model): data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head) return [(self.map_tensor_name(name), data_torch)] + + def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool: + del name, bid # unused + if "v.resmpl.query" in new_name or "v.resmpl.pos_embd_k" in new_name: + return gguf.GGMLQuantizationType.F32 + if "v.resmpl." in new_name: + return gguf.GGMLQuantizationType.F32 if n_dims == 1 else gguf.GGMLQuantizationType.F16 + return False + + # utils to work with MiniCPM-V resampler + + # https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20 + def _get_2d_sincos_pos_embed(self, embed_dim: int, grid_size: tuple[int, int] | int, cls_token=False) -> np.ndarray: + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + if isinstance(grid_size, int): + grid_h_size, grid_w_size = grid_size, grid_size + else: + grid_h_size, grid_w_size = grid_size[0], grid_size[1] + + grid_h = np.arange(grid_h_size, dtype=np.float32) + grid_w = np.arange(grid_w_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_h_size, grid_w_size]) + pos_embed = self._get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + def _get_2d_sincos_pos_embed_from_grid(self, embed_dim: int, grid: np.ndarray) -> np.ndarray: + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = self._get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = self._get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + def _get_1d_sincos_pos_embed_from_grid(self, embed_dim: int, pos: np.ndarray) -> np.ndarray: + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float32) + omega /= embed_dim / 2. + omega = 1. / 10000 ** omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb @Model.register("MiniCPM3ForCausalLM")