minicpmv works but missing uhd slices
This commit is contained in:
parent
ba489b4743
commit
c0d93dd509
11 changed files with 423 additions and 281 deletions
|
@ -2141,60 +2141,19 @@ class DbrxModel(Model):
|
||||||
return n_dims > 1
|
return n_dims > 1
|
||||||
|
|
||||||
|
|
||||||
@Model.register("MiniCPMForCausalLM", "MiniCPMV")
|
@Model.register("MiniCPMForCausalLM")
|
||||||
class MiniCPMModel(Model):
|
class MiniCPMModel(Model):
|
||||||
model_arch = gguf.MODEL_ARCH.MINICPM
|
model_arch = gguf.MODEL_ARCH.MINICPM
|
||||||
proj_type: gguf.constants.CLIPProjectorType | None
|
|
||||||
resampler_n_embd = 0
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
model_type = self.hparams.get("model_type", None)
|
|
||||||
|
|
||||||
# only tested with https://huggingface.co/openbmb/MiniCPM-V-2_6
|
|
||||||
if "vision_config" in self.hparams and model_type == "minicpmv":
|
|
||||||
self.vparams = self.hparams["vision_config"]
|
|
||||||
self.preprocessor_config = self.load_preprocessor_config(self.dir_model)
|
|
||||||
self.vision_arch = gguf.MODEL_ARCH.VISION_MINICPMV
|
|
||||||
version = str(self.hparams.get("version", "unknown"))
|
|
||||||
if version == "2.5":
|
|
||||||
self.proj_type = gguf.constants.CLIPProjectorType.MINICPMV_2_5
|
|
||||||
elif version == "2.6":
|
|
||||||
self.proj_type = gguf.constants.CLIPProjectorType.MINICPMV_2_6
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported MiniCPM-V version: {version}")
|
|
||||||
# TODO: how to do this without reading the whole safetensor file?
|
|
||||||
for tname, tensor in self.get_tensors():
|
|
||||||
if tname == "resampler.ln_post.bias":
|
|
||||||
self.resampler_n_embd = tensor.shape[0]
|
|
||||||
if self.resampler_n_embd < 2:
|
|
||||||
raise ValueError("Failed to detect resampler embedding size")
|
|
||||||
|
|
||||||
if self.vparams is not None and self.vision_arch is not None and self.preprocessor_config is not None:
|
|
||||||
self.preprocessor_config["image_mean"] = [0.5, 0.5, 0.5]
|
|
||||||
self.preprocessor_config["image_std"] = [0.5, 0.5, 0.5]
|
|
||||||
self.hparams["vision_feature_layer"] = 0
|
|
||||||
self.v_tensor_map = gguf.get_tensor_name_map(self.vision_arch, self.vparams["num_hidden_layers"])
|
|
||||||
|
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
super().set_gguf_parameters()
|
super().set_gguf_parameters()
|
||||||
# scale_emb
|
embedding_scale = float(self.hparams["scale_emb"])
|
||||||
embedding_scale = float(self.hparams.get("scale_emb", 1.0))
|
|
||||||
self.gguf_writer.add_embedding_scale(embedding_scale)
|
self.gguf_writer.add_embedding_scale(embedding_scale)
|
||||||
logger.info(f"gguf: (minicpm) embedding_scale = {embedding_scale}")
|
logger.info(f"gguf: (minicpm) embedding_scale = {embedding_scale}")
|
||||||
# scale_depth
|
|
||||||
if "scale_depth" in self.hparams:
|
|
||||||
residual_scale = self.hparams["scale_depth"] / self.hparams["num_hidden_layers"] ** 0.5
|
residual_scale = self.hparams["scale_depth"] / self.hparams["num_hidden_layers"] ** 0.5
|
||||||
else:
|
|
||||||
residual_scale = 1.0
|
|
||||||
self.gguf_writer.add_residual_scale(residual_scale)
|
self.gguf_writer.add_residual_scale(residual_scale)
|
||||||
logger.info(f"gguf: (minicpm) residual_scale = {residual_scale}")
|
logger.info(f"gguf: (minicpm) residual_scale = {residual_scale}")
|
||||||
# logit_scale
|
|
||||||
if "dim_model_base" in self.hparams:
|
|
||||||
logit_scale = self.hparams["hidden_size"] / self.hparams["dim_model_base"]
|
logit_scale = self.hparams["hidden_size"] / self.hparams["dim_model_base"]
|
||||||
else:
|
|
||||||
logit_scale = 1.0
|
|
||||||
self.gguf_writer.add_logit_scale(logit_scale)
|
self.gguf_writer.add_logit_scale(logit_scale)
|
||||||
logger.info(f"gguf: (minicpm) logit_scale = {logit_scale}")
|
logger.info(f"gguf: (minicpm) logit_scale = {logit_scale}")
|
||||||
if self.hparams.get("rope_scaling") is not None:
|
if self.hparams.get("rope_scaling") is not None:
|
||||||
|
@ -2202,15 +2161,6 @@ class MiniCPMModel(Model):
|
||||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LONGROPE)
|
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LONGROPE)
|
||||||
logger.info(f"gguf: (minicpm) rope_scaling_type = {gguf.RopeScalingType.LONGROPE}")
|
logger.info(f"gguf: (minicpm) rope_scaling_type = {gguf.RopeScalingType.LONGROPE}")
|
||||||
|
|
||||||
# For vision model
|
|
||||||
if self.vparams is not None and self.proj_type is not None:
|
|
||||||
self.gguf_writer.add_vision_vit_patch_merge_type(gguf.CLIPPatchMergeType.FLAT)
|
|
||||||
self.gguf_writer.add_vision_vit_projector_type(self.proj_type)
|
|
||||||
self.gguf_writer.add_vision_vit_layer_norm_epsilon(1e-06)
|
|
||||||
max_pos_embd = (self.vparams["image_size"] // self.vparams["patch_size"])**2
|
|
||||||
self.gguf_writer.add_vision_vit_max_position_embeddings(max_pos_embd)
|
|
||||||
|
|
||||||
|
|
||||||
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
|
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
|
||||||
rope_dims = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
|
rope_dims = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
|
||||||
|
|
||||||
|
@ -2228,119 +2178,23 @@ class MiniCPMModel(Model):
|
||||||
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_LONG), torch.tensor(long_factors, dtype=torch.float32))
|
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_LONG), torch.tensor(long_factors, dtype=torch.float32))
|
||||||
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT), torch.tensor(short_factors, dtype=torch.float32))
|
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT), torch.tensor(short_factors, dtype=torch.float32))
|
||||||
|
|
||||||
if self.vision_arch == gguf.MODEL_ARCH.VISION_MINICPMV:
|
|
||||||
yield (
|
|
||||||
self.format_tensor_name(gguf.MODEL_TENSOR.V_RESMPL_POS_EMBD_K, is_vision=True),
|
|
||||||
torch.from_numpy(self._get_2d_sincos_pos_embed(self.resampler_n_embd, (70, 70)))
|
|
||||||
)
|
|
||||||
|
|
||||||
def set_vocab(self):
|
def set_vocab(self):
|
||||||
if self.vision_arch == gguf.MODEL_ARCH.VISION_MINICPMV:
|
|
||||||
# undocumented anywhere, I only found this thanks to https://huggingface.co/openbmb/MiniCPM-V-2_6-gguf
|
|
||||||
self._set_vocab_gpt2()
|
|
||||||
else:
|
|
||||||
self._set_vocab_sentencepiece()
|
self._set_vocab_sentencepiece()
|
||||||
|
|
||||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||||
del bid # unused
|
del bid # unused
|
||||||
|
|
||||||
# For vision model
|
|
||||||
if name.startswith("llm."):
|
|
||||||
name = name.replace("llm.", "")
|
|
||||||
|
|
||||||
# split the resampler.attn.in_proj_(weight|bias) tensors into q, k, v
|
|
||||||
if name.endswith("in_proj_weight") or name.endswith("in_proj_bias"):
|
|
||||||
assert data_torch.shape[0] == 3 * self.resampler_n_embd
|
|
||||||
split_tensor = data_torch.chunk(3, dim=0)
|
|
||||||
name_q = name.replace("in_proj_", "in_proj_q.") # in_proj_q.(weight|bias)
|
|
||||||
name_k = name.replace("in_proj_", "in_proj_k.") # in_proj_k.(weight|bias)
|
|
||||||
name_v = name.replace("in_proj_", "in_proj_v.") # in_proj_v.(weight|bias)
|
|
||||||
return [
|
|
||||||
(self.map_tensor_name(name_q), split_tensor[0]),
|
|
||||||
(self.map_tensor_name(name_k), split_tensor[1]),
|
|
||||||
(self.map_tensor_name(name_v), split_tensor[2]),
|
|
||||||
]
|
|
||||||
|
|
||||||
if name == "resampler.proj" or name == "resampler.query":
|
|
||||||
name += ".weight"
|
|
||||||
|
|
||||||
if "post_layernorm" in name:
|
|
||||||
return [] # skip post_layernorm
|
|
||||||
|
|
||||||
n_head = self.hparams["num_attention_heads"]
|
n_head = self.hparams["num_attention_heads"]
|
||||||
n_kv_head = self.hparams.get("num_key_value_heads")
|
n_kv_head = self.hparams.get("num_key_value_heads")
|
||||||
|
|
||||||
# HF models permute some of the tensors, so we need to undo that
|
# HF models permute some of the tensors, so we need to undo that
|
||||||
if not name.startswith("vpm") and name.endswith(("q_proj.weight")):
|
if name.endswith(("q_proj.weight")):
|
||||||
data_torch = LlamaModel.permute(data_torch, n_head, n_head)
|
data_torch = LlamaModel.permute(data_torch, n_head, n_head)
|
||||||
if not name.startswith("vpm") and name.endswith(("k_proj.weight")):
|
if name.endswith(("k_proj.weight")):
|
||||||
data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head)
|
data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head)
|
||||||
|
|
||||||
return [(self.map_tensor_name(name), data_torch)]
|
return [(self.map_tensor_name(name), data_torch)]
|
||||||
|
|
||||||
def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool:
|
|
||||||
del name, bid # unused
|
|
||||||
if "v.resmpl.query" in new_name or "v.resmpl.pos_embd_k" in new_name:
|
|
||||||
return gguf.GGMLQuantizationType.F32
|
|
||||||
if "v.resmpl." in new_name:
|
|
||||||
return gguf.GGMLQuantizationType.F32 if n_dims == 1 else gguf.GGMLQuantizationType.F16
|
|
||||||
return False
|
|
||||||
|
|
||||||
# utils to work with MiniCPM-V resampler
|
|
||||||
|
|
||||||
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
|
|
||||||
def _get_2d_sincos_pos_embed(self, embed_dim: int, grid_size: tuple[int, int] | int, cls_token=False) -> np.ndarray:
|
|
||||||
"""
|
|
||||||
grid_size: int of the grid height and width
|
|
||||||
return:
|
|
||||||
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
|
||||||
"""
|
|
||||||
if isinstance(grid_size, int):
|
|
||||||
grid_h_size, grid_w_size = grid_size, grid_size
|
|
||||||
else:
|
|
||||||
grid_h_size, grid_w_size = grid_size[0], grid_size[1]
|
|
||||||
|
|
||||||
grid_h = np.arange(grid_h_size, dtype=np.float32)
|
|
||||||
grid_w = np.arange(grid_w_size, dtype=np.float32)
|
|
||||||
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
|
||||||
grid = np.stack(grid, axis=0)
|
|
||||||
|
|
||||||
grid = grid.reshape([2, 1, grid_h_size, grid_w_size])
|
|
||||||
pos_embed = self._get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
|
||||||
if cls_token:
|
|
||||||
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
|
|
||||||
return pos_embed
|
|
||||||
|
|
||||||
def _get_2d_sincos_pos_embed_from_grid(self, embed_dim: int, grid: np.ndarray) -> np.ndarray:
|
|
||||||
assert embed_dim % 2 == 0
|
|
||||||
|
|
||||||
# use half of dimensions to encode grid_h
|
|
||||||
emb_h = self._get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
|
||||||
emb_w = self._get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
|
||||||
|
|
||||||
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
|
||||||
return emb
|
|
||||||
|
|
||||||
def _get_1d_sincos_pos_embed_from_grid(self, embed_dim: int, pos: np.ndarray) -> np.ndarray:
|
|
||||||
"""
|
|
||||||
embed_dim: output dimension for each position
|
|
||||||
pos: a list of positions to be encoded: size (M,)
|
|
||||||
out: (M, D)
|
|
||||||
"""
|
|
||||||
assert embed_dim % 2 == 0
|
|
||||||
omega = np.arange(embed_dim // 2, dtype=np.float32)
|
|
||||||
omega /= embed_dim / 2.
|
|
||||||
omega = 1. / 10000 ** omega # (D/2,)
|
|
||||||
|
|
||||||
pos = pos.reshape(-1) # (M,)
|
|
||||||
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
|
||||||
|
|
||||||
emb_sin = np.sin(out) # (M, D/2)
|
|
||||||
emb_cos = np.cos(out) # (M, D/2)
|
|
||||||
|
|
||||||
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
|
||||||
return emb
|
|
||||||
|
|
||||||
|
|
||||||
@Model.register("MiniCPM3ForCausalLM")
|
@Model.register("MiniCPM3ForCausalLM")
|
||||||
class MiniCPM3Model(Model):
|
class MiniCPM3Model(Model):
|
||||||
|
@ -2479,6 +2333,155 @@ class Qwen2VLModel(Model):
|
||||||
yield name, data
|
yield name, data
|
||||||
|
|
||||||
|
|
||||||
|
@Model.register("MiniCPMV")
|
||||||
|
class MiniCPMVModel(Qwen2Model):
|
||||||
|
# based on minicpmv-surgery.py, not sure why it is Qwen2Model instead of MiniCPMModel
|
||||||
|
model_arch = gguf.MODEL_ARCH.QWEN2
|
||||||
|
proj_type: gguf.constants.CLIPProjectorType | None
|
||||||
|
resampler_n_embd = 0
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
model_type = self.hparams.get("model_type", None)
|
||||||
|
|
||||||
|
# only tested with https://huggingface.co/openbmb/MiniCPM-V-2_6
|
||||||
|
if "vision_config" in self.hparams and model_type == "minicpmv":
|
||||||
|
self.vparams = self.hparams["vision_config"]
|
||||||
|
self.preprocessor_config = self.load_preprocessor_config(self.dir_model)
|
||||||
|
self.vision_arch = gguf.MODEL_ARCH.VISION_MINICPMV
|
||||||
|
version = str(self.hparams.get("version", "unknown"))
|
||||||
|
if version == "2.5":
|
||||||
|
self.proj_type = gguf.constants.CLIPProjectorType.MINICPMV_2_5
|
||||||
|
elif version == "2.6":
|
||||||
|
self.proj_type = gguf.constants.CLIPProjectorType.MINICPMV_2_6
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported MiniCPM-V version: {version}")
|
||||||
|
# TODO: how to do this without reading the whole safetensor file?
|
||||||
|
for tname, tensor in self.get_tensors():
|
||||||
|
if tname == "resampler.ln_post.bias":
|
||||||
|
self.resampler_n_embd = tensor.shape[0]
|
||||||
|
if self.resampler_n_embd < 2:
|
||||||
|
raise ValueError("Failed to detect resampler embedding size")
|
||||||
|
else:
|
||||||
|
raise ValueError("Expected vision_config, but not found")
|
||||||
|
|
||||||
|
if self.vparams is not None and self.vision_arch is not None and self.preprocessor_config is not None:
|
||||||
|
self.preprocessor_config["image_mean"] = [0.5, 0.5, 0.5]
|
||||||
|
self.preprocessor_config["image_std"] = [0.5, 0.5, 0.5]
|
||||||
|
self.hparams["vision_feature_layer"] = 0
|
||||||
|
self.v_tensor_map = gguf.get_tensor_name_map(self.vision_arch, self.vparams["num_hidden_layers"])
|
||||||
|
|
||||||
|
def set_gguf_parameters(self):
|
||||||
|
super().set_gguf_parameters()
|
||||||
|
# For vision model
|
||||||
|
if self.vparams is not None and self.proj_type is not None:
|
||||||
|
self.gguf_writer.add_vision_vit_patch_merge_type(gguf.CLIPPatchMergeType.FLAT)
|
||||||
|
self.gguf_writer.add_vision_vit_projector_type(self.proj_type)
|
||||||
|
self.gguf_writer.add_vision_vit_layer_norm_epsilon(1e-06)
|
||||||
|
max_pos_embd = (self.vparams["image_size"] // self.vparams["patch_size"])**2
|
||||||
|
self.gguf_writer.add_vision_vit_max_position_embeddings(max_pos_embd)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
|
||||||
|
yield (
|
||||||
|
self.format_tensor_name(gguf.MODEL_TENSOR.V_RESMPL_POS_EMBD_K, is_vision=True),
|
||||||
|
torch.from_numpy(self._get_2d_sincos_pos_embed(self.resampler_n_embd, (70, 70)))
|
||||||
|
)
|
||||||
|
|
||||||
|
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||||
|
del bid # unused
|
||||||
|
|
||||||
|
# for language part
|
||||||
|
if name.startswith("llm."):
|
||||||
|
return [(self.map_tensor_name(name.replace("llm.", "")), data_torch)]
|
||||||
|
|
||||||
|
# split the resampler.attn.in_proj_(weight|bias) tensors into q, k, v
|
||||||
|
if name.endswith("in_proj_weight") or name.endswith("in_proj_bias"):
|
||||||
|
assert data_torch.shape[0] == 3 * self.resampler_n_embd
|
||||||
|
split_tensor = data_torch.chunk(3, dim=0)
|
||||||
|
name_q = name.replace("in_proj_", "in_proj_q.") # in_proj_q.(weight|bias)
|
||||||
|
name_k = name.replace("in_proj_", "in_proj_k.") # in_proj_k.(weight|bias)
|
||||||
|
name_v = name.replace("in_proj_", "in_proj_v.") # in_proj_v.(weight|bias)
|
||||||
|
return [
|
||||||
|
(self.map_tensor_name(name_q), split_tensor[0]),
|
||||||
|
(self.map_tensor_name(name_k), split_tensor[1]),
|
||||||
|
(self.map_tensor_name(name_v), split_tensor[2]),
|
||||||
|
]
|
||||||
|
|
||||||
|
# append .weight to these tensors
|
||||||
|
if name == "resampler.proj" or name == "resampler.query":
|
||||||
|
name += ".weight"
|
||||||
|
|
||||||
|
if "post_layernorm" in name:
|
||||||
|
return [] # skip post_layernorm
|
||||||
|
|
||||||
|
return [(self.map_tensor_name(name), data_torch)]
|
||||||
|
|
||||||
|
def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool:
|
||||||
|
del name, bid # unused
|
||||||
|
if "v.resmpl.query" in new_name or "v.resmpl.pos_embd_k" in new_name:
|
||||||
|
return gguf.GGMLQuantizationType.F32
|
||||||
|
if "v.resmpl." in new_name:
|
||||||
|
return gguf.GGMLQuantizationType.F32 if n_dims == 1 else gguf.GGMLQuantizationType.F16
|
||||||
|
return False
|
||||||
|
|
||||||
|
# utils to work with MiniCPM-V resampler
|
||||||
|
|
||||||
|
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
|
||||||
|
def _get_2d_sincos_pos_embed(self, embed_dim: int, grid_size: tuple[int, int] | int, cls_token=False) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
grid_size: int of the grid height and width
|
||||||
|
return:
|
||||||
|
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
||||||
|
"""
|
||||||
|
if isinstance(grid_size, int):
|
||||||
|
grid_h_size, grid_w_size = grid_size, grid_size
|
||||||
|
else:
|
||||||
|
grid_h_size, grid_w_size = grid_size[0], grid_size[1]
|
||||||
|
|
||||||
|
grid_h = np.arange(grid_h_size, dtype=np.float32)
|
||||||
|
grid_w = np.arange(grid_w_size, dtype=np.float32)
|
||||||
|
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
||||||
|
grid = np.stack(grid, axis=0)
|
||||||
|
|
||||||
|
grid = grid.reshape([2, 1, grid_h_size, grid_w_size])
|
||||||
|
pos_embed = self._get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
||||||
|
if cls_token:
|
||||||
|
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
|
||||||
|
return pos_embed
|
||||||
|
|
||||||
|
def _get_2d_sincos_pos_embed_from_grid(self, embed_dim: int, grid: np.ndarray) -> np.ndarray:
|
||||||
|
assert embed_dim % 2 == 0
|
||||||
|
|
||||||
|
# use half of dimensions to encode grid_h
|
||||||
|
emb_h = self._get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
||||||
|
emb_w = self._get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
||||||
|
|
||||||
|
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
def _get_1d_sincos_pos_embed_from_grid(self, embed_dim: int, pos: np.ndarray) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
embed_dim: output dimension for each position
|
||||||
|
pos: a list of positions to be encoded: size (M,)
|
||||||
|
out: (M, D)
|
||||||
|
"""
|
||||||
|
assert embed_dim % 2 == 0
|
||||||
|
omega = np.arange(embed_dim // 2, dtype=np.float32)
|
||||||
|
omega /= embed_dim / 2.
|
||||||
|
omega = 1. / 10000 ** omega # (D/2,)
|
||||||
|
|
||||||
|
pos = pos.reshape(-1) # (M,)
|
||||||
|
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
||||||
|
|
||||||
|
emb_sin = np.sin(out) # (M, D/2)
|
||||||
|
emb_cos = np.cos(out) # (M, D/2)
|
||||||
|
|
||||||
|
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
@Model.register("WavTokenizerDec")
|
@Model.register("WavTokenizerDec")
|
||||||
class WavTokenizerDecModel(Model):
|
class WavTokenizerDecModel(Model):
|
||||||
model_arch = gguf.MODEL_ARCH.WAVTOKENIZER_DEC
|
model_arch = gguf.MODEL_ARCH.WAVTOKENIZER_DEC
|
||||||
|
|
|
@ -98,8 +98,9 @@ int main(int argc, char ** argv) {
|
||||||
common_params params;
|
common_params params;
|
||||||
|
|
||||||
// default prompt for llava 1.5
|
// default prompt for llava 1.5
|
||||||
params.prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n"
|
//params.prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER:<img_placement>\nwhat did you see?\nASSISTANT:";
|
||||||
"USER:<img_placement>\nwhat did you see?\nASSISTANT:";
|
// default prompt for minicpmv 2.6
|
||||||
|
params.prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nwhat did you see?\n<image><img_placement></image><|im_end|>\n<|im_start|>assistant\n";
|
||||||
params.n_predict = 64;
|
params.n_predict = 64;
|
||||||
params.n_batch = 2048;
|
params.n_batch = 2048;
|
||||||
params.n_ubatch = 1024;
|
params.n_ubatch = 1024;
|
||||||
|
|
|
@ -457,12 +457,14 @@ class MODEL_TENSOR(IntEnum):
|
||||||
V_PRE_NORM = auto()
|
V_PRE_NORM = auto()
|
||||||
V_POST_NORM = auto()
|
V_POST_NORM = auto()
|
||||||
V_RESMPL_POS_EMBD_K = auto() # minicpmv
|
V_RESMPL_POS_EMBD_K = auto() # minicpmv
|
||||||
V_RESMPL_ATTN_IN = auto() # minicpmv
|
V_RESMPL_ATTN_Q = auto() # minicpmv
|
||||||
|
V_RESMPL_ATTN_K = auto() # minicpmv
|
||||||
|
V_RESMPL_ATTN_V = auto() # minicpmv
|
||||||
V_RESMPL_ATTN_OUT = auto() # minicpmv
|
V_RESMPL_ATTN_OUT = auto() # minicpmv
|
||||||
V_RESMPL_KV_PROJ = auto() # minicpmv
|
V_RESMPL_KV = auto() # minicpmv
|
||||||
V_RESMPL_NORM_POST = auto() # minicpmv
|
V_RESMPL_KV_NORM = auto() # minicpmv
|
||||||
V_RESMPL_NORM_KV = auto() # minicpmv
|
V_RESMPL_POST_NORM = auto() # minicpmv
|
||||||
V_RESMPL_NORM_Q = auto() # minicpmv
|
V_RESMPL_Q_NORM = auto() # minicpmv
|
||||||
V_RESMPL_PROJ = auto() # minicpmv
|
V_RESMPL_PROJ = auto() # minicpmv
|
||||||
V_RESMPL_QUERY = auto() # minicpmv
|
V_RESMPL_QUERY = auto() # minicpmv
|
||||||
|
|
||||||
|
@ -674,12 +676,14 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||||
MODEL_TENSOR.V_PRE_NORM: "v.pre_norm",
|
MODEL_TENSOR.V_PRE_NORM: "v.pre_norm",
|
||||||
MODEL_TENSOR.V_POST_NORM: "v.post_norm",
|
MODEL_TENSOR.V_POST_NORM: "v.post_norm",
|
||||||
MODEL_TENSOR.V_RESMPL_POS_EMBD_K: "v.resmpl.pos_embd_k",
|
MODEL_TENSOR.V_RESMPL_POS_EMBD_K: "v.resmpl.pos_embd_k",
|
||||||
MODEL_TENSOR.V_RESMPL_ATTN_IN: "v.resmpl.attn_in",
|
MODEL_TENSOR.V_RESMPL_ATTN_Q: "v.resmpl.attn_q",
|
||||||
|
MODEL_TENSOR.V_RESMPL_ATTN_K: "v.resmpl.attn_k",
|
||||||
|
MODEL_TENSOR.V_RESMPL_ATTN_V: "v.resmpl.attn_v",
|
||||||
MODEL_TENSOR.V_RESMPL_ATTN_OUT: "v.resmpl.attn_out",
|
MODEL_TENSOR.V_RESMPL_ATTN_OUT: "v.resmpl.attn_out",
|
||||||
MODEL_TENSOR.V_RESMPL_KV_PROJ: "v.resmpl.kv_proj",
|
MODEL_TENSOR.V_RESMPL_KV: "v.resmpl.kv",
|
||||||
MODEL_TENSOR.V_RESMPL_NORM_POST: "v.resmpl.norm_post",
|
MODEL_TENSOR.V_RESMPL_KV_NORM: "v.resmpl.kv_norm",
|
||||||
MODEL_TENSOR.V_RESMPL_NORM_KV: "v.resmpl.norm_kv",
|
MODEL_TENSOR.V_RESMPL_POST_NORM: "v.resmpl.post_norm",
|
||||||
MODEL_TENSOR.V_RESMPL_NORM_Q: "v.resmpl.norm_q",
|
MODEL_TENSOR.V_RESMPL_Q_NORM: "v.resmpl.q_norm",
|
||||||
MODEL_TENSOR.V_RESMPL_PROJ: "v.resmpl.proj",
|
MODEL_TENSOR.V_RESMPL_PROJ: "v.resmpl.proj",
|
||||||
MODEL_TENSOR.V_RESMPL_QUERY: "v.resmpl.query",
|
MODEL_TENSOR.V_RESMPL_QUERY: "v.resmpl.query",
|
||||||
}
|
}
|
||||||
|
@ -1667,12 +1671,15 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||||
MODEL_TENSOR.V_ENC_OUTPUT_NORM,
|
MODEL_TENSOR.V_ENC_OUTPUT_NORM,
|
||||||
MODEL_TENSOR.V_ENC_FFN_UP,
|
MODEL_TENSOR.V_ENC_FFN_UP,
|
||||||
MODEL_TENSOR.V_ENC_FFN_DOWN,
|
MODEL_TENSOR.V_ENC_FFN_DOWN,
|
||||||
MODEL_TENSOR.V_RESMPL_ATTN_IN,
|
MODEL_TENSOR.V_RESMPL_POS_EMBD_K,
|
||||||
|
MODEL_TENSOR.V_RESMPL_ATTN_Q,
|
||||||
|
MODEL_TENSOR.V_RESMPL_ATTN_K,
|
||||||
|
MODEL_TENSOR.V_RESMPL_ATTN_V,
|
||||||
MODEL_TENSOR.V_RESMPL_ATTN_OUT,
|
MODEL_TENSOR.V_RESMPL_ATTN_OUT,
|
||||||
MODEL_TENSOR.V_RESMPL_KV_PROJ,
|
MODEL_TENSOR.V_RESMPL_KV,
|
||||||
MODEL_TENSOR.V_RESMPL_NORM_POST,
|
MODEL_TENSOR.V_RESMPL_KV_NORM,
|
||||||
MODEL_TENSOR.V_RESMPL_NORM_KV,
|
MODEL_TENSOR.V_RESMPL_POST_NORM,
|
||||||
MODEL_TENSOR.V_RESMPL_NORM_Q,
|
MODEL_TENSOR.V_RESMPL_Q_NORM,
|
||||||
MODEL_TENSOR.V_RESMPL_PROJ,
|
MODEL_TENSOR.V_RESMPL_PROJ,
|
||||||
MODEL_TENSOR.V_RESMPL_QUERY,
|
MODEL_TENSOR.V_RESMPL_QUERY,
|
||||||
],
|
],
|
||||||
|
|
|
@ -868,27 +868,35 @@ class TensorNameMap:
|
||||||
"resampler.pos_embed_k",
|
"resampler.pos_embed_k",
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.V_RESMPL_ATTN_IN: (
|
MODEL_TENSOR.V_RESMPL_ATTN_Q: (
|
||||||
"resampler.attn.in_proj",
|
"resampler.attn.in_proj_q", # tensor generated from resampler.attn.in_proj
|
||||||
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.V_RESMPL_ATTN_K: (
|
||||||
|
"resampler.attn.in_proj_k", # tensor generated from resampler.attn.in_proj
|
||||||
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.V_RESMPL_ATTN_V: (
|
||||||
|
"resampler.attn.in_proj_v", # tensor generated from resampler.attn.in_proj
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.V_RESMPL_ATTN_OUT: (
|
MODEL_TENSOR.V_RESMPL_ATTN_OUT: (
|
||||||
"resampler.attn.out_proj",
|
"resampler.attn.out_proj",
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.V_RESMPL_KV_PROJ: (
|
MODEL_TENSOR.V_RESMPL_KV: (
|
||||||
"resampler.kv_proj",
|
"resampler.kv_proj",
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.V_RESMPL_NORM_POST: (
|
MODEL_TENSOR.V_RESMPL_POST_NORM: (
|
||||||
"resampler.ln_post",
|
"resampler.ln_post",
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.V_RESMPL_NORM_KV: (
|
MODEL_TENSOR.V_RESMPL_KV_NORM: (
|
||||||
"resampler.ln_kv",
|
"resampler.ln_kv",
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.V_RESMPL_NORM_Q: (
|
MODEL_TENSOR.V_RESMPL_Q_NORM: (
|
||||||
"resampler.ln_q",
|
"resampler.ln_q",
|
||||||
),
|
),
|
||||||
|
|
||||||
|
|
|
@ -1372,12 +1372,14 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||||
{ LLM_TENSOR_V_ENC_FFN_UP, "v.enc.blk.%d.ffn_up" },
|
{ LLM_TENSOR_V_ENC_FFN_UP, "v.enc.blk.%d.ffn_up" },
|
||||||
{ LLM_TENSOR_V_ENC_FFN_DOWN, "v.enc.blk.%d.ffn_down" },
|
{ LLM_TENSOR_V_ENC_FFN_DOWN, "v.enc.blk.%d.ffn_down" },
|
||||||
{ LLM_TENSOR_V_RESMPL_POS_EMBD_K, "v.resmpl.pos_embd_k" },
|
{ LLM_TENSOR_V_RESMPL_POS_EMBD_K, "v.resmpl.pos_embd_k" },
|
||||||
{ LLM_TENSOR_V_RESMPL_ATTN_IN, "v.resmpl.attn_in" },
|
{ LLM_TENSOR_V_RESMPL_ATTN_Q, "v.resmpl.attn_q" },
|
||||||
|
{ LLM_TENSOR_V_RESMPL_ATTN_K, "v.resmpl.attn_k" },
|
||||||
|
{ LLM_TENSOR_V_RESMPL_ATTN_V, "v.resmpl.attn_v" },
|
||||||
{ LLM_TENSOR_V_RESMPL_ATTN_OUT, "v.resmpl.attn_out" },
|
{ LLM_TENSOR_V_RESMPL_ATTN_OUT, "v.resmpl.attn_out" },
|
||||||
{ LLM_TENSOR_V_RESMPL_KV_PROJ, "v.resmpl.kv_proj" },
|
{ LLM_TENSOR_V_RESMPL_KV, "v.resmpl.kv" },
|
||||||
{ LLM_TENSOR_V_RESMPL_NORM_POST, "v.resmpl.norm_post" },
|
{ LLM_TENSOR_V_RESMPL_KV_NORM, "v.resmpl.kv_norm" },
|
||||||
{ LLM_TENSOR_V_RESMPL_NORM_KV, "v.resmpl.norm_kv" },
|
{ LLM_TENSOR_V_RESMPL_POST_NORM, "v.resmpl.post_norm" },
|
||||||
{ LLM_TENSOR_V_RESMPL_NORM_Q, "v.resmpl.norm_q" },
|
{ LLM_TENSOR_V_RESMPL_Q_NORM, "v.resmpl.q_norm" },
|
||||||
{ LLM_TENSOR_V_RESMPL_PROJ, "v.resmpl.proj" },
|
{ LLM_TENSOR_V_RESMPL_PROJ, "v.resmpl.proj" },
|
||||||
{ LLM_TENSOR_V_RESMPL_QUERY, "v.resmpl.query" },
|
{ LLM_TENSOR_V_RESMPL_QUERY, "v.resmpl.query" },
|
||||||
}
|
}
|
||||||
|
@ -1531,6 +1533,24 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||||
{LLM_TENSOR_CONVNEXT_PW1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
{LLM_TENSOR_CONVNEXT_PW1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||||
{LLM_TENSOR_CONVNEXT_PW2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
{LLM_TENSOR_CONVNEXT_PW2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||||
{LLM_TENSOR_CONVNEXT_GAMMA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
{LLM_TENSOR_CONVNEXT_GAMMA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||||
|
// vision
|
||||||
|
{LLM_TENSOR_V_MMPROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||||
|
{LLM_TENSOR_V_MMPROJ_MLP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||||
|
{LLM_TENSOR_V_MMPROJ_PEG, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||||
|
{LLM_TENSOR_V_ENC_EMBD_CLS, {LLM_TENSOR_LAYER_INPUT, GGML_OP_ADD}},
|
||||||
|
{LLM_TENSOR_V_ENC_EMBD_PATCH, {LLM_TENSOR_LAYER_INPUT, GGML_OP_ADD}},
|
||||||
|
{LLM_TENSOR_V_ENC_EMBD_POS, {LLM_TENSOR_LAYER_INPUT, GGML_OP_ADD}},
|
||||||
|
{LLM_TENSOR_V_ENC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||||
|
{LLM_TENSOR_V_ENC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||||
|
{LLM_TENSOR_V_ENC_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||||
|
{LLM_TENSOR_V_ENC_INPUT_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||||
|
{LLM_TENSOR_V_ENC_OUTPUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||||
|
{LLM_TENSOR_V_ENC_OUTPUT_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||||
|
{LLM_TENSOR_V_ENC_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||||
|
{LLM_TENSOR_V_ENC_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||||
|
{LLM_TENSOR_V_PRE_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||||
|
{LLM_TENSOR_V_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||||
|
// TODO: add minicpmv resampler tensors
|
||||||
};
|
};
|
||||||
|
|
||||||
LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {}
|
LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {}
|
||||||
|
|
|
@ -371,12 +371,14 @@ enum llm_tensor {
|
||||||
LLM_TENSOR_V_POST_NORM,
|
LLM_TENSOR_V_POST_NORM,
|
||||||
// vision - minicpmv
|
// vision - minicpmv
|
||||||
LLM_TENSOR_V_RESMPL_POS_EMBD_K,
|
LLM_TENSOR_V_RESMPL_POS_EMBD_K,
|
||||||
LLM_TENSOR_V_RESMPL_ATTN_IN,
|
LLM_TENSOR_V_RESMPL_ATTN_Q,
|
||||||
|
LLM_TENSOR_V_RESMPL_ATTN_K,
|
||||||
|
LLM_TENSOR_V_RESMPL_ATTN_V,
|
||||||
LLM_TENSOR_V_RESMPL_ATTN_OUT,
|
LLM_TENSOR_V_RESMPL_ATTN_OUT,
|
||||||
LLM_TENSOR_V_RESMPL_KV_PROJ,
|
LLM_TENSOR_V_RESMPL_KV,
|
||||||
LLM_TENSOR_V_RESMPL_NORM_POST,
|
LLM_TENSOR_V_RESMPL_KV_NORM,
|
||||||
LLM_TENSOR_V_RESMPL_NORM_KV,
|
LLM_TENSOR_V_RESMPL_POST_NORM,
|
||||||
LLM_TENSOR_V_RESMPL_NORM_Q,
|
LLM_TENSOR_V_RESMPL_Q_NORM,
|
||||||
LLM_TENSOR_V_RESMPL_PROJ,
|
LLM_TENSOR_V_RESMPL_PROJ,
|
||||||
LLM_TENSOR_V_RESMPL_QUERY,
|
LLM_TENSOR_V_RESMPL_QUERY,
|
||||||
};
|
};
|
||||||
|
|
|
@ -1248,7 +1248,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||||
hparams.rope_type = llama_model_rope_type(this);
|
hparams.rope_type = llama_model_rope_type(this);
|
||||||
|
|
||||||
// vision model
|
// vision model
|
||||||
auto & vparams = clip.hparams;
|
auto & vparams = vit.hparams;
|
||||||
std::string vision_type;
|
std::string vision_type;
|
||||||
ml.get_key(LLM_KV_VISION_TYPE, vision_type, false);
|
ml.get_key(LLM_KV_VISION_TYPE, vision_type, false);
|
||||||
if (vision_type == "vit") {
|
if (vision_type == "vit") {
|
||||||
|
@ -3451,10 +3451,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||||
__func__, first_moved_tensor->name, ggml_type_name(first_moved_tensor->type), n_moved_tensors - 1,
|
__func__, first_moved_tensor->name, ggml_type_name(first_moved_tensor->type), n_moved_tensors - 1,
|
||||||
ggml_backend_buft_name(first_moved_from_buft), ggml_backend_buft_name(first_moved_to_buft));
|
ggml_backend_buft_name(first_moved_from_buft), ggml_backend_buft_name(first_moved_to_buft));
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// load tensors for vision model
|
// load tensors for vision model
|
||||||
auto & vparams = clip.hparams;
|
auto & vparams = vit.hparams;
|
||||||
if (has_vision) {
|
if (has_vision) {
|
||||||
// language params
|
// language params
|
||||||
const int64_t n_embd = hparams.n_embd;
|
const int64_t n_embd = hparams.n_embd;
|
||||||
|
@ -3467,101 +3466,122 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||||
const int64_t patch_size = vparams.patch_size;
|
const int64_t patch_size = vparams.patch_size;
|
||||||
const auto tn = LLM_TN(vparams.arch);
|
const auto tn = LLM_TN(vparams.arch);
|
||||||
|
|
||||||
// clip is CPU-only for now
|
// TODO: vit is cpu only for now
|
||||||
clip.buft = ggml_backend_cpu_buffer_type();
|
vit.buft = ggml_backend_cpu_buffer_type();
|
||||||
ggml_context * ctx_vision = ctx_map.at(clip.buft);
|
ggml_context * ctx_vision = ctx_map.at(vit.buft);
|
||||||
clip.layers.resize(n_vlayer);
|
vit.layers.resize(n_vlayer);
|
||||||
|
|
||||||
switch (vparams.arch) {
|
switch (vparams.arch) {
|
||||||
case LLM_ARCH_VISION_LLAVA:
|
case LLM_ARCH_VISION_LLAVA:
|
||||||
case LLM_ARCH_VISION_MOBILEVLM:
|
case LLM_ARCH_VISION_MOBILEVLM:
|
||||||
{
|
{
|
||||||
if (vparams.arch == LLM_ARCH_VISION_LLAVA) {
|
if (vparams.arch == LLM_ARCH_VISION_LLAVA) {
|
||||||
clip.mm_1_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_MMPROJ, "weight", 1), {n_vembd, n_vff});
|
vit.mm_1_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_MMPROJ, "weight", 1), {n_vembd, n_vff});
|
||||||
clip.mm_1_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_MMPROJ, "bias" , 1), {n_vff});
|
vit.mm_1_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_MMPROJ, "bias" , 1), {n_vff});
|
||||||
clip.mm_2_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_MMPROJ, "weight", 2), {n_vff, n_vff});
|
vit.mm_2_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_MMPROJ, "weight", 2), {n_vff, n_vff});
|
||||||
clip.mm_2_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_MMPROJ, "bias" , 2), {n_vff});
|
vit.mm_2_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_MMPROJ, "bias" , 2), {n_vff});
|
||||||
} else if (vparams.arch == LLM_ARCH_VISION_MOBILEVLM) {
|
} else if (vparams.arch == LLM_ARCH_VISION_MOBILEVLM) {
|
||||||
clip.mm_model_mlp_0_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_MMPROJ_MLP, "weight", 0), {n_vembd, n_embd});
|
vit.mm_model_mlp_0_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_MMPROJ_MLP, "weight", 0), {n_vembd, n_embd});
|
||||||
clip.mm_model_mlp_0_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_MMPROJ_MLP, "bias", 0), {n_embd});
|
vit.mm_model_mlp_0_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_MMPROJ_MLP, "bias", 0), {n_embd});
|
||||||
clip.mm_model_mlp_2_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_MMPROJ_MLP, "weight", 2), {n_embd, n_embd});
|
vit.mm_model_mlp_2_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_MMPROJ_MLP, "weight", 2), {n_embd, n_embd});
|
||||||
clip.mm_model_mlp_2_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_MMPROJ_MLP, "bias", 2), {n_embd});
|
vit.mm_model_mlp_2_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_MMPROJ_MLP, "bias", 2), {n_embd});
|
||||||
clip.mm_model_peg_0_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_MMPROJ_PEG, "weight", 0), {n_channel, n_channel, 1, n_embd});
|
vit.mm_model_peg_0_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_MMPROJ_PEG, "weight", 0), {n_channel, n_channel, 1, n_embd});
|
||||||
clip.mm_model_peg_0_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_MMPROJ_PEG, "bias", 0), {n_embd});
|
vit.mm_model_peg_0_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_MMPROJ_PEG, "bias", 0), {n_embd});
|
||||||
}
|
}
|
||||||
|
|
||||||
clip.class_embedding = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_EMBD_CLS ), {n_vembd});
|
vit.class_embedding = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_EMBD_CLS ), {n_vembd});
|
||||||
clip.patch_embeddings = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_EMBD_PATCH, "weight"), {patch_size, patch_size, n_channel, n_vembd});
|
vit.patch_embeddings = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_EMBD_PATCH, "weight"), {patch_size, patch_size, n_channel, n_vembd});
|
||||||
clip.position_embeddings = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_EMBD_POS, "weight"), {n_vembd, max_pos_embd});
|
vit.position_embeddings = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_EMBD_POS, "weight"), {n_vembd, max_pos_embd});
|
||||||
|
|
||||||
clip.pre_norm_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_PRE_NORM, "weight"), {n_vembd});
|
vit.pre_norm_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_PRE_NORM, "weight"), {n_vembd});
|
||||||
clip.pre_norm_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_PRE_NORM, "bias" ), {n_vembd});
|
vit.pre_norm_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_PRE_NORM, "bias" ), {n_vembd});
|
||||||
clip.post_norm_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_POST_NORM, "weight"), {n_vembd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
vit.post_norm_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_POST_NORM, "weight"), {n_vembd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||||
clip.post_norm_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_POST_NORM, "bias" ), {n_vembd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
vit.post_norm_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_POST_NORM, "bias" ), {n_vembd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||||
|
|
||||||
for (int i = 0; i < n_vlayer; ++i) {
|
for (int i = 0; i < n_vlayer; ++i) {
|
||||||
auto & layer = clip.layers[i];
|
auto & layer = vit.layers[i];
|
||||||
|
|
||||||
layer.k_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_ATTN_K, "weight", i), {n_vembd, n_vembd});
|
layer.k_w = create_tensor(tn(LLM_TENSOR_V_ENC_ATTN_K, "weight", i), {n_vembd, n_vembd}, 0);
|
||||||
layer.k_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_ATTN_K, "bias" , i), {n_vembd});
|
layer.k_b = create_tensor(tn(LLM_TENSOR_V_ENC_ATTN_K, "bias" , i), {n_vembd}, 0);
|
||||||
layer.v_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_ATTN_V, "weight", i), {n_vembd, n_vembd});
|
layer.v_w = create_tensor(tn(LLM_TENSOR_V_ENC_ATTN_V, "weight", i), {n_vembd, n_vembd}, 0);
|
||||||
layer.v_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_ATTN_V, "bias" , i), {n_vembd});
|
layer.v_b = create_tensor(tn(LLM_TENSOR_V_ENC_ATTN_V, "bias" , i), {n_vembd}, 0);
|
||||||
layer.q_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_ATTN_Q, "weight", i), {n_vembd, n_vembd});
|
layer.q_w = create_tensor(tn(LLM_TENSOR_V_ENC_ATTN_Q, "weight", i), {n_vembd, n_vembd}, 0);
|
||||||
layer.q_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_ATTN_Q, "bias" , i), {n_vembd});
|
layer.q_b = create_tensor(tn(LLM_TENSOR_V_ENC_ATTN_Q, "bias" , i), {n_vembd}, 0);
|
||||||
|
|
||||||
layer.ffn_up_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_FFN_UP, "weight", i), {n_vembd, n_vff});
|
layer.ffn_up_w = create_tensor(tn(LLM_TENSOR_V_ENC_FFN_UP, "weight", i), {n_vembd, n_vff}, 0);
|
||||||
layer.ffn_up_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_FFN_UP, "bias" , i), {n_vff});
|
layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_V_ENC_FFN_UP, "bias" , i), {n_vff}, 0);
|
||||||
layer.ffn_down_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_FFN_DOWN, "weight", i), {n_vff, n_vembd});
|
layer.ffn_down_w = create_tensor(tn(LLM_TENSOR_V_ENC_FFN_DOWN, "weight", i), {n_vff, n_vembd}, 0);
|
||||||
layer.ffn_down_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_FFN_DOWN, "bias" , i), {n_vembd});
|
layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_V_ENC_FFN_DOWN, "bias" , i), {n_vembd}, 0);
|
||||||
|
|
||||||
layer.norm_in_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_INPUT_NORM, "weight", i), {n_vembd});
|
layer.norm_in_w = create_tensor(tn(LLM_TENSOR_V_ENC_INPUT_NORM, "weight", i), {n_vembd}, 0);
|
||||||
layer.norm_in_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_INPUT_NORM, "bias" , i), {n_vembd});
|
layer.norm_in_b = create_tensor(tn(LLM_TENSOR_V_ENC_INPUT_NORM, "bias" , i), {n_vembd}, 0);
|
||||||
layer.norm_out_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_OUTPUT_NORM, "weight", i), {n_vembd});
|
layer.norm_out_w = create_tensor(tn(LLM_TENSOR_V_ENC_OUTPUT_NORM, "weight", i), {n_vembd}, 0);
|
||||||
layer.norm_out_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_OUTPUT_NORM, "bias" , i), {n_vembd});
|
layer.norm_out_b = create_tensor(tn(LLM_TENSOR_V_ENC_OUTPUT_NORM, "bias" , i), {n_vembd}, 0);
|
||||||
|
|
||||||
layer.output_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_OUTPUT, "weight", i), {n_vembd, n_vembd});
|
layer.output_w = create_tensor(tn(LLM_TENSOR_V_ENC_OUTPUT, "weight", i), {n_vembd, n_vembd}, 0);
|
||||||
layer.output_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_OUTPUT, "bias" , i), {n_vembd});
|
layer.output_b = create_tensor(tn(LLM_TENSOR_V_ENC_OUTPUT, "bias" , i), {n_vembd}, 0);
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
case LLM_ARCH_VISION_MINICPMV:
|
case LLM_ARCH_VISION_MINICPMV:
|
||||||
{
|
{
|
||||||
clip.patch_embeddings = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_EMBD_PATCH, "weight"), {patch_size, patch_size, n_channel, n_vembd});
|
vit.patch_embeddings = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_EMBD_PATCH, "weight"), {patch_size, patch_size, n_channel, n_vembd});
|
||||||
clip.position_embeddings = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_EMBD_POS, "weight"), {n_vembd, max_pos_embd});
|
vit.patch_bias = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_EMBD_PATCH, "bias" ), {n_vembd});
|
||||||
|
vit.position_embeddings = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_EMBD_POS, "weight"), {n_vembd, max_pos_embd});
|
||||||
|
|
||||||
// TODO: load all resampler tensors
|
// resampler
|
||||||
|
int rs_n_embd = llama_vision_n_mmproj_embd(vit);
|
||||||
|
vit.mm_model_pos_embed_k = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_RESMPL_POS_EMBD_K, "weight"), {rs_n_embd, max_pos_embd});
|
||||||
|
vit.mm_model_query = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_RESMPL_QUERY, "weight"), {rs_n_embd, 64}); // why 64?
|
||||||
|
vit.mm_model_proj = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_RESMPL_PROJ, "weight"), {rs_n_embd, rs_n_embd});
|
||||||
|
vit.mm_model_kv_proj = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_RESMPL_KV, "weight"), {n_vembd, rs_n_embd});
|
||||||
|
vit.mm_model_attn_q_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_RESMPL_ATTN_Q, "weight"), {rs_n_embd, rs_n_embd});
|
||||||
|
vit.mm_model_attn_q_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_RESMPL_ATTN_Q, "bias" ), {rs_n_embd});
|
||||||
|
vit.mm_model_attn_k_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_RESMPL_ATTN_K, "weight"), {rs_n_embd, rs_n_embd});
|
||||||
|
vit.mm_model_attn_k_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_RESMPL_ATTN_K, "bias" ), {rs_n_embd});
|
||||||
|
vit.mm_model_attn_v_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_RESMPL_ATTN_V, "weight"), {rs_n_embd, rs_n_embd});
|
||||||
|
vit.mm_model_attn_v_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_RESMPL_ATTN_V, "bias" ), {rs_n_embd});
|
||||||
|
vit.mm_model_attn_o_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_RESMPL_ATTN_OUT, "weight"), {rs_n_embd, rs_n_embd});
|
||||||
|
vit.mm_model_attn_o_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_RESMPL_ATTN_OUT, "bias" ), {rs_n_embd});
|
||||||
|
vit.mm_model_ln_q_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_RESMPL_Q_NORM, "weight"), {rs_n_embd});
|
||||||
|
vit.mm_model_ln_q_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_RESMPL_Q_NORM, "bias" ), {rs_n_embd});
|
||||||
|
vit.mm_model_ln_kv_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_RESMPL_KV_NORM, "weight"), {rs_n_embd});
|
||||||
|
vit.mm_model_ln_kv_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_RESMPL_KV_NORM, "bias" ), {rs_n_embd});
|
||||||
|
vit.mm_model_ln_post_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_RESMPL_POST_NORM, "weight"), {rs_n_embd});
|
||||||
|
vit.mm_model_ln_post_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_RESMPL_POST_NORM, "bias" ), {rs_n_embd});
|
||||||
|
|
||||||
for (int i = 0; i < n_vlayer; ++i) {
|
for (int i = 0; i < n_vlayer; ++i) {
|
||||||
auto & layer = clip.layers[i];
|
auto & layer = vit.layers[i];
|
||||||
|
|
||||||
layer.k_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_ATTN_K, "weight", i), {n_vembd, n_vembd});
|
layer.k_w = create_tensor(tn(LLM_TENSOR_V_ENC_ATTN_K, "weight", i), {n_vembd, n_vembd}, 0);
|
||||||
layer.k_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_ATTN_K, "bias" , i), {n_vembd});
|
layer.k_b = create_tensor(tn(LLM_TENSOR_V_ENC_ATTN_K, "bias" , i), {n_vembd}, 0);
|
||||||
layer.v_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_ATTN_V, "weight", i), {n_vembd, n_vembd});
|
layer.v_w = create_tensor(tn(LLM_TENSOR_V_ENC_ATTN_V, "weight", i), {n_vembd, n_vembd}, 0);
|
||||||
layer.v_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_ATTN_V, "bias" , i), {n_vembd});
|
layer.v_b = create_tensor(tn(LLM_TENSOR_V_ENC_ATTN_V, "bias" , i), {n_vembd}, 0);
|
||||||
layer.q_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_ATTN_Q, "weight", i), {n_vembd, n_vembd});
|
layer.q_w = create_tensor(tn(LLM_TENSOR_V_ENC_ATTN_Q, "weight", i), {n_vembd, n_vembd}, 0);
|
||||||
layer.q_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_ATTN_Q, "bias" , i), {n_vembd});
|
layer.q_b = create_tensor(tn(LLM_TENSOR_V_ENC_ATTN_Q, "bias" , i), {n_vembd}, 0);
|
||||||
|
|
||||||
layer.ffn_up_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_FFN_UP, "weight", i), {n_vembd, n_vff});
|
layer.ffn_up_w = create_tensor(tn(LLM_TENSOR_V_ENC_FFN_UP, "weight", i), {n_vembd, n_vff}, 0);
|
||||||
layer.ffn_up_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_FFN_UP, "bias" , i), {n_vff});
|
layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_V_ENC_FFN_UP, "bias" , i), {n_vff}, 0);
|
||||||
layer.ffn_down_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_FFN_DOWN, "weight", i), {n_vff, n_vembd});
|
layer.ffn_down_w = create_tensor(tn(LLM_TENSOR_V_ENC_FFN_DOWN, "weight", i), {n_vff, n_vembd}, 0);
|
||||||
layer.ffn_down_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_FFN_DOWN, "bias" , i), {n_vembd});
|
layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_V_ENC_FFN_DOWN, "bias" , i), {n_vembd}, 0);
|
||||||
|
|
||||||
layer.norm_in_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_INPUT_NORM, "weight", i), {n_vembd});
|
layer.norm_in_w = create_tensor(tn(LLM_TENSOR_V_ENC_INPUT_NORM, "weight", i), {n_vembd}, 0);
|
||||||
layer.norm_in_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_INPUT_NORM, "bias" , i), {n_vembd});
|
layer.norm_in_b = create_tensor(tn(LLM_TENSOR_V_ENC_INPUT_NORM, "bias" , i), {n_vembd}, 0);
|
||||||
layer.norm_out_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_OUTPUT_NORM, "weight", i), {n_vembd});
|
layer.norm_out_w = create_tensor(tn(LLM_TENSOR_V_ENC_OUTPUT_NORM, "weight", i), {n_vembd}, 0);
|
||||||
layer.norm_out_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_OUTPUT_NORM, "bias" , i), {n_vembd});
|
layer.norm_out_b = create_tensor(tn(LLM_TENSOR_V_ENC_OUTPUT_NORM, "bias" , i), {n_vembd}, 0);
|
||||||
|
|
||||||
layer.output_w = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_OUTPUT, "weight", i), {n_vembd, n_vembd});
|
layer.output_w = create_tensor(tn(LLM_TENSOR_V_ENC_OUTPUT, "weight", i), {n_vembd, n_vembd}, 0);
|
||||||
layer.output_b = ml.create_tensor(ctx_vision, tn(LLM_TENSOR_V_ENC_OUTPUT, "bias" , i), {n_vembd});
|
layer.output_b = create_tensor(tn(LLM_TENSOR_V_ENC_OUTPUT, "bias" , i), {n_vembd}, 0);
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error("unknown vision architecture");
|
throw std::runtime_error("unknown vision architecture");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (llama_vision_n_mmproj_embd(clip) != hparams.n_embd) {
|
if (llama_vision_n_mmproj_embd(vit) != hparams.n_embd) {
|
||||||
std::runtime_error("model has vision, but n_mmproj_embd != n_embd");
|
std::runtime_error("model has vision, but n_mmproj_embd != n_embd");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
ml.done_getting_tensors();
|
ml.done_getting_tensors();
|
||||||
|
|
||||||
|
|
|
@ -365,7 +365,7 @@ struct llama_model {
|
||||||
|
|
||||||
// vision
|
// vision
|
||||||
bool has_vision = false;
|
bool has_vision = false;
|
||||||
llama_vision_model clip;
|
llama_vision_model vit;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
struct impl;
|
struct impl;
|
||||||
|
|
|
@ -19,8 +19,6 @@ struct img_size;
|
||||||
static int bmp_export(const struct llama_image_u8 &img, const std::string &location);
|
static int bmp_export(const struct llama_image_u8 &img, const std::string &location);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#define VISION_GRAPH_MAX_NODE 1024
|
|
||||||
|
|
||||||
struct img_size {
|
struct img_size {
|
||||||
int width;
|
int width;
|
||||||
int height;
|
int height;
|
||||||
|
@ -48,9 +46,9 @@ uint32_t llama_vision_n_mmproj_embd(const llama_vision_model & vmodel) {
|
||||||
} else if (proj_type == VISION_PROJECTOR_TYPE_LDPV2) {
|
} else if (proj_type == VISION_PROJECTOR_TYPE_LDPV2) {
|
||||||
return vmodel.mm_model_peg_0_b->ne[0];
|
return vmodel.mm_model_peg_0_b->ne[0];
|
||||||
} else if (proj_type == VISION_PROJECTOR_TYPE_MINICPMV_2_5) {
|
} else if (proj_type == VISION_PROJECTOR_TYPE_MINICPMV_2_5) {
|
||||||
return 4096;
|
return 4096; // resampler
|
||||||
} else if (proj_type == VISION_PROJECTOR_TYPE_MINICPMV_2_6) {
|
} else if (proj_type == VISION_PROJECTOR_TYPE_MINICPMV_2_6) {
|
||||||
return 3584;
|
return 3584; // resampler
|
||||||
} else {
|
} else {
|
||||||
GGML_ASSERT(false && "invalid proj type");
|
GGML_ASSERT(false && "invalid proj type");
|
||||||
}
|
}
|
||||||
|
@ -761,16 +759,21 @@ struct llama_vision_graph_builder {
|
||||||
return cur;
|
return cur;
|
||||||
}
|
}
|
||||||
|
|
||||||
// graph for each vision arch
|
struct ggml_tensor * build_vit() {
|
||||||
|
|
||||||
struct ggml_cgraph * build_llava() {
|
|
||||||
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, VISION_GRAPH_MAX_NODE, false);
|
|
||||||
struct ggml_tensor * cur = build_inp();
|
struct ggml_tensor * cur = build_inp();
|
||||||
cur = build_pre_norm(cur);
|
cur = build_pre_norm(cur);
|
||||||
for (int il = 0; il < n_layers; il++) {
|
for (int il = 0; il < n_layers; il++) {
|
||||||
cur = build_layer(cur, il);
|
cur = build_layer(cur, il);
|
||||||
}
|
}
|
||||||
cur = build_post_norm(cur);
|
cur = build_post_norm(cur);
|
||||||
|
return cur;
|
||||||
|
}
|
||||||
|
|
||||||
|
// graph for each vision arch
|
||||||
|
|
||||||
|
struct ggml_cgraph * build_llava() {
|
||||||
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, VISION_GRAPH_MAX_NODE, false);
|
||||||
|
struct ggml_tensor * cur = build_vit();
|
||||||
|
|
||||||
// llava projector
|
// llava projector
|
||||||
{
|
{
|
||||||
|
@ -825,6 +828,78 @@ struct llama_vision_graph_builder {
|
||||||
|
|
||||||
return gf;
|
return gf;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct ggml_cgraph * build_minicpmv() {
|
||||||
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, VISION_GRAPH_MAX_NODE, false);
|
||||||
|
struct ggml_tensor * cur = build_vit();
|
||||||
|
|
||||||
|
// minicpmv resampler projector
|
||||||
|
{
|
||||||
|
int hidden_size = llama_vision_n_mmproj_embd(*ctx.model);
|
||||||
|
struct ggml_tensor * q = model.mm_model_query;
|
||||||
|
// layernorm
|
||||||
|
{
|
||||||
|
q = ggml_norm(ctx0, q, eps);
|
||||||
|
q = ggml_add(ctx0, ggml_mul(ctx0, q, model.mm_model_ln_q_w), model.mm_model_ln_q_b);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * v = ggml_mul_mat(ctx0, model.mm_model_kv_proj, cur);
|
||||||
|
// layernorm
|
||||||
|
{
|
||||||
|
v = ggml_norm(ctx0, v, eps);
|
||||||
|
v = ggml_add(ctx0, ggml_mul(ctx0, v, model.mm_model_ln_kv_w), model.mm_model_ln_kv_b);
|
||||||
|
}
|
||||||
|
|
||||||
|
// position
|
||||||
|
struct ggml_tensor * k = ggml_add(ctx0, v, model.mm_model_pos_embed_k);
|
||||||
|
|
||||||
|
// attention
|
||||||
|
{
|
||||||
|
const int d_head = 128;
|
||||||
|
int n_head = hidden_size/d_head;
|
||||||
|
int num_query = -1;
|
||||||
|
if (model.hparams.proj_type == VISION_PROJECTOR_TYPE_MINICPMV_2_5) {
|
||||||
|
num_query = 96;
|
||||||
|
} else if (model.hparams.proj_type == VISION_PROJECTOR_TYPE_MINICPMV_2_6) {
|
||||||
|
num_query = 64;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * Q = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q), model.mm_model_attn_q_b);
|
||||||
|
Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head));
|
||||||
|
struct ggml_tensor * K = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_k_w, k), model.mm_model_attn_k_b);
|
||||||
|
struct ggml_tensor * V = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_v_w, v), model.mm_model_attn_v_b);
|
||||||
|
// permute
|
||||||
|
Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_query, batch_size);
|
||||||
|
Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3)); // TODO: do this when converting the model
|
||||||
|
Q = ggml_reshape_3d(ctx0, Q, d_head, num_query, n_head * batch_size);
|
||||||
|
K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size);
|
||||||
|
K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3)); // TODO: do this when converting the model
|
||||||
|
K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size);
|
||||||
|
V = ggml_reshape_4d(ctx0, V, d_head, n_head, num_positions, batch_size);
|
||||||
|
V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3)); // TODO: do this when converting the model
|
||||||
|
V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size);
|
||||||
|
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
||||||
|
KQ = ggml_soft_max_inplace(ctx0, KQ);
|
||||||
|
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
|
||||||
|
KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_query, n_head, batch_size);
|
||||||
|
KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3); // TODO: do this when converting the model
|
||||||
|
KQV = ggml_cont_3d(ctx0, KQV, hidden_size, num_query, batch_size);
|
||||||
|
|
||||||
|
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_o_w, KQV), model.mm_model_attn_o_b);
|
||||||
|
}
|
||||||
|
// layernorm
|
||||||
|
{
|
||||||
|
cur = ggml_norm(ctx0, cur, eps);
|
||||||
|
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.mm_model_ln_post_w), model.mm_model_ln_post_b);
|
||||||
|
}
|
||||||
|
cur = ggml_mul_mat(ctx0, model.mm_model_proj, cur);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_set_name(cur, "output");
|
||||||
|
ggml_build_forward_expand(gf, cur);
|
||||||
|
|
||||||
|
return gf;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
static int32_t llama_vision_encode_impl(llama_vision_context & ctx, const llama_vision_tokens & inp) {
|
static int32_t llama_vision_encode_impl(llama_vision_context & ctx, const llama_vision_tokens & inp) {
|
||||||
|
@ -852,8 +927,11 @@ static int32_t llama_vision_encode_impl(llama_vision_context & ctx, const llama_
|
||||||
case LLM_ARCH_VISION_MOBILEVLM:
|
case LLM_ARCH_VISION_MOBILEVLM:
|
||||||
gf = builder.build_llava();
|
gf = builder.build_llava();
|
||||||
break;
|
break;
|
||||||
|
case LLM_ARCH_VISION_MINICPMV:
|
||||||
|
gf = builder.build_minicpmv();
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
GGML_ASSERT(false && "unsupported arch");
|
GGML_ASSERT(false && "unsupported vision arch");
|
||||||
}
|
}
|
||||||
|
|
||||||
// alloc memory for graph
|
// alloc memory for graph
|
||||||
|
@ -903,8 +981,8 @@ static int32_t llama_vision_encode_impl(llama_vision_context & ctx, const llama_
|
||||||
free(positions_data);
|
free(positions_data);
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
|
||||||
struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "inp_patches");
|
struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "inp_patches");
|
||||||
|
if (patches) {
|
||||||
int* patches_data = (int*)malloc(ggml_nbytes(patches));
|
int* patches_data = (int*)malloc(ggml_nbytes(patches));
|
||||||
for (int i = 0; i < num_patches; i++) {
|
for (int i = 0; i < num_patches; i++) {
|
||||||
patches_data[i] = i + 1;
|
patches_data[i] = i + 1;
|
||||||
|
@ -962,7 +1040,8 @@ struct llama_vision_tokens * llama_vision_tokenize(
|
||||||
case LLM_ARCH_VISION_MOBILEVLM:
|
case LLM_ARCH_VISION_MOBILEVLM:
|
||||||
return new llama_vision_tokens(llama_vision_processor_llava(vctx).tokenize(*bmp));
|
return new llama_vision_tokens(llama_vision_processor_llava(vctx).tokenize(*bmp));
|
||||||
case LLM_ARCH_VISION_MINICPMV:
|
case LLM_ARCH_VISION_MINICPMV:
|
||||||
return new llama_vision_tokens(llama_vision_processor_uhd(vctx).tokenize(*bmp));
|
//return new llama_vision_tokens(llama_vision_processor_uhd(vctx).tokenize(*bmp));
|
||||||
|
return new llama_vision_tokens(llama_vision_processor_llava(vctx).tokenize(*bmp));
|
||||||
default:
|
default:
|
||||||
GGML_ASSERT(false && "unsupported arch");
|
GGML_ASSERT(false && "unsupported arch");
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,6 +7,8 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <array>
|
#include <array>
|
||||||
|
|
||||||
|
#define VISION_GRAPH_MAX_NODE 2048
|
||||||
|
|
||||||
enum vision_projector_type {
|
enum vision_projector_type {
|
||||||
VISION_PROJECTOR_TYPE_UNKNOWN,
|
VISION_PROJECTOR_TYPE_UNKNOWN,
|
||||||
VISION_PROJECTOR_TYPE_MLP,
|
VISION_PROJECTOR_TYPE_MLP,
|
||||||
|
@ -108,24 +110,24 @@ struct llama_vision_model {
|
||||||
struct ggml_tensor * mm_model_peg_0_b = nullptr;
|
struct ggml_tensor * mm_model_peg_0_b = nullptr;
|
||||||
|
|
||||||
// MINICPMV projection
|
// MINICPMV projection
|
||||||
struct ggml_tensor * mm_model_pos_embed_k;
|
struct ggml_tensor * mm_model_pos_embed_k = nullptr;
|
||||||
struct ggml_tensor * mm_model_query;
|
struct ggml_tensor * mm_model_query = nullptr;
|
||||||
struct ggml_tensor * mm_model_proj;
|
struct ggml_tensor * mm_model_proj = nullptr;
|
||||||
struct ggml_tensor * mm_model_kv_proj;
|
struct ggml_tensor * mm_model_kv_proj = nullptr;
|
||||||
struct ggml_tensor * mm_model_attn_q_w;
|
struct ggml_tensor * mm_model_attn_q_w = nullptr;
|
||||||
struct ggml_tensor * mm_model_attn_q_b;
|
struct ggml_tensor * mm_model_attn_q_b = nullptr;
|
||||||
struct ggml_tensor * mm_model_attn_k_w;
|
struct ggml_tensor * mm_model_attn_k_w = nullptr;
|
||||||
struct ggml_tensor * mm_model_attn_k_b;
|
struct ggml_tensor * mm_model_attn_k_b = nullptr;
|
||||||
struct ggml_tensor * mm_model_attn_v_w;
|
struct ggml_tensor * mm_model_attn_v_w = nullptr;
|
||||||
struct ggml_tensor * mm_model_attn_v_b;
|
struct ggml_tensor * mm_model_attn_v_b = nullptr;
|
||||||
struct ggml_tensor * mm_model_attn_o_w;
|
struct ggml_tensor * mm_model_attn_o_w = nullptr;
|
||||||
struct ggml_tensor * mm_model_attn_o_b;
|
struct ggml_tensor * mm_model_attn_o_b = nullptr;
|
||||||
struct ggml_tensor * mm_model_ln_q_w;
|
struct ggml_tensor * mm_model_ln_q_w = nullptr;
|
||||||
struct ggml_tensor * mm_model_ln_q_b;
|
struct ggml_tensor * mm_model_ln_q_b = nullptr;
|
||||||
struct ggml_tensor * mm_model_ln_kv_w;
|
struct ggml_tensor * mm_model_ln_kv_w = nullptr;
|
||||||
struct ggml_tensor * mm_model_ln_kv_b;
|
struct ggml_tensor * mm_model_ln_kv_b = nullptr;
|
||||||
struct ggml_tensor * mm_model_ln_post_w;
|
struct ggml_tensor * mm_model_ln_post_w = nullptr;
|
||||||
struct ggml_tensor * mm_model_ln_post_b;
|
struct ggml_tensor * mm_model_ln_post_b = nullptr;
|
||||||
|
|
||||||
struct ggml_tensor * image_newline = nullptr;
|
struct ggml_tensor * image_newline = nullptr;
|
||||||
};
|
};
|
||||||
|
|
|
@ -9838,9 +9838,9 @@ struct llama_context * llama_init_from_model(
|
||||||
}
|
}
|
||||||
|
|
||||||
if (model->has_vision) {
|
if (model->has_vision) {
|
||||||
ctx->vctx.model = &model->clip;
|
ctx->vctx.model = &model->vit;
|
||||||
ctx->vctx.sched = ctx->sched.get();
|
ctx->vctx.sched = ctx->sched.get();
|
||||||
const size_t max_nodes = 1024;
|
const size_t max_nodes = VISION_GRAPH_MAX_NODE; // TODO: make it dynamic
|
||||||
ctx->vctx.buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
|
ctx->vctx.buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue