[WIP] add qwen2vl arch

This commit is contained in:
HimariO 2024-09-26 00:45:08 +08:00
parent 7c6f793492
commit b24bd89e77
5 changed files with 10004 additions and 11 deletions

View file

@ -1978,7 +1978,7 @@ class Qwen2Model(Model):
@Model.register("Qwen2VLForConditionalGeneration") @Model.register("Qwen2VLForConditionalGeneration")
class Qwen2VLModel(Model): class Qwen2VLModel(Model):
model_arch = gguf.MODEL_ARCH.QWEN2 model_arch = gguf.MODEL_ARCH.QWEN2VL
def set_vocab(self): def set_vocab(self):
try: try:
@ -1986,15 +1986,6 @@ class Qwen2VLModel(Model):
except FileNotFoundError: except FileNotFoundError:
self._set_vocab_gpt2() self._set_vocab_gpt2()
# def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", ".bias")) -> str:
# new_name = self.tensor_map.get_name(key=name, try_suffixes=try_suffixes)
# if name.startswith("visual."):
# breakpoint()
# return ""
# if new_name is None:
# raise ValueError(f"Can not map tensor {name!r}")
# return new_name
def get_tensors(self) -> Iterator[tuple[str, Tensor]]: def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
for name, data in super().get_tensors(): for name, data in super().get_tensors():
if name.startswith("visual."): if name.startswith("visual."):

View file

@ -0,0 +1,159 @@
import argparse
import glob
import os
import torch
from safetensors import safe_open
from safetensors.torch import save_file
from typing import Any, ContextManager, cast
# Function to determine if file is a SafeTensor file
def is_safetensor_file(file_path):
return file_path.endswith('.safetensors')
# Unified loading function
def load_model(file_path):
if is_safetensor_file(file_path):
tensors = {}
with cast(ContextManager[Any], safe_open(file_path, framework="pt", device="cpu")) as f:
for key in f.keys():
tensors[key] = f.get_tensor(key).clone()
# output shape
print(f"{key} : {tensors[key].shape}")
return tensors, 'safetensor'
else:
return torch.load(file_path, map_location=torch.device('cpu')), 'pytorch'
# Unified saving function
def save_model(model, file_path, file_type):
if file_type == 'safetensor':
# safe_save(model, file_path)
save_file(model, file_path)
else:
torch.save(model, file_path)
# Adapted function to clean vision tower from checkpoint
def clean_vision_tower_from_checkpoint(checkpoint_path):
checkpoint, file_type = load_model(checkpoint_path)
# file_type = 'pytorch'
model_path = os.path.dirname(checkpoint_path)
print(f"Searching for vision tower tensors in {checkpoint_path}")
clip_tensors = [k for k, v in checkpoint.items() if (k.startswith("model.vision_tower") or k.startswith("vit."))]
if len(clip_tensors) > 0:
print(f"Found {len(clip_tensors)} tensors to extract from {checkpoint_path}")
# Adapted for file type
clip_path = os.path.join(model_path, "llava.clip")
if os.path.exists(clip_path):
print(f"Loading existing llava.clip from {clip_path}")
existing_clip, _ = load_model(clip_path)
else:
print(f"Creating new llava.clip at {clip_path}")
existing_clip = {}
# Update existing_clip with new tensors, avoid duplicates
for name in clip_tensors:
simple_name = name[name.index('vision_model.'):] if 'vision_model.' in name else name
print(f"Adding {simple_name} to llava.clip")
if simple_name not in existing_clip:
existing_clip[simple_name] = checkpoint[name]
# Save the updated clip tensors back to llava.clip
save_model(existing_clip, clip_path, 'pytorch')
# Remove the tensors from the original checkpoint
for name in clip_tensors:
del checkpoint[name]
checkpoint_path = checkpoint_path
return True
return False
def find_relevant_checkpoints(checkpoint_paths, newline_criteria, projector):
newline_checkpoint_path = None
projector_checkpoint_path = None
for path in checkpoint_paths:
checkpoint, _ = load_model(path)
if newline_criteria(checkpoint) and newline_checkpoint_path is None:
newline_checkpoint_path = path
if projector(checkpoint):
projector_checkpoint_path = path
return newline_checkpoint_path, projector_checkpoint_path
def newline_criteria(checkpoint):
return any(k.startswith("model.image_newline") for k in checkpoint.keys())
def proj_criteria(checkpoint):
return any(k.startswith("model.mm_projector") or k.startswith("vision_proj.") for k in checkpoint.keys())
# Command-line interface setup
ap = argparse.ArgumentParser()
ap.add_argument("-m", "--model", required=True, help="Path to LLaVA v1.5+ model")
ap.add_argument("-C", "--clean-vision-tower", action="store_true", help="Remove any vision tower from the model files")
args = ap.parse_args()
if args.clean_vision_tower:
# Generalized to handle both PyTorch and SafeTensors models
model_files = sorted(glob.glob(f"{args.model}/*"), key=os.path.getmtime, reverse=True)
# checkpoint_paths = [path for path in model_files if (path.endswith('.bin') and path.startswith('pytorch')) or (path.endswith('.safetensors') and path.startswith('model'))]
checkpoint_paths = [path for path in model_files if (path.endswith('.bin') and 'pytorch' in path.split('/')[-1].split('\\')[-1]) or (path.endswith('.safetensors') and 'model' in path.split('/')[-1].split('\\')[-1])]
for projector_checkpoint_path in checkpoint_paths:
print(f"Cleaning {projector_checkpoint_path}")
if not clean_vision_tower_from_checkpoint(projector_checkpoint_path):
print(f"No vision tower found in {projector_checkpoint_path}")
# we break once none is found, so far all models append them at the end
# break
print("Done! All vision tower tensors are removed from the model files and stored in llava.clip file.")
# Now we look for the projector in the last checkpoint
model_files = sorted(glob.glob(f"{args.model}/*"), key=os.path.getmtime, reverse=True)
checkpoint_paths = [path for path in model_files if (path.endswith('.bin') and 'pytorch' in path.split('/')[-1].split('\\')[-1]) or (path.endswith('.safetensors') and 'model' in path.split('/')[-1].split('\\')[-1])]
# last_checkpoint_path = checkpoint_paths[0]
# first_checkpoint_path = checkpoint_paths[-1]
newline_checkpoint_path, projector_checkpoint_path = find_relevant_checkpoints(checkpoint_paths, newline_criteria, proj_criteria)
print(f"Taking projector from {projector_checkpoint_path}")
first_mm_tensors = []
first_checkpoint = None
if newline_checkpoint_path is not None:
print(f"Taking newline from {newline_checkpoint_path}")
first_checkpoint, file_type = load_model(newline_checkpoint_path)
first_mm_tensors = [k for k, v in first_checkpoint.items() if k.startswith("model.image_newline")]
# Load the checkpoint
mm_tensors = []
last_checkpoint = None
if projector_checkpoint_path is not None:
last_checkpoint, file_type = load_model(projector_checkpoint_path)
mm_tensors = [k for k, v in last_checkpoint.items() if k.startswith("model.mm_projector") or k.startswith("vision_proj.")]
if len(mm_tensors) == 0:
if last_checkpoint is not None:
for k, v in last_checkpoint.items():
print(k)
print(f"Found {len(mm_tensors)} tensors to extract out of {len(last_checkpoint) if last_checkpoint is not None else 0} tensors.")
print("No tensors found. Is this a LLaVA model?")
exit()
print(f"Found {len(mm_tensors)} tensors to extract.")
print(f"Found additional {len(first_mm_tensors)} tensors to extract.")
# projector = {name: checkpoint.[name].float() for name in mm_tensors}
projector = {}
for name in mm_tensors:
assert last_checkpoint is not None
projector[name] = last_checkpoint[name].float()
for name in first_mm_tensors:
assert first_checkpoint is not None
projector[name] = first_checkpoint[name].float()
if len(projector) > 0:
save_model(projector, f"{args.model}/llava.projector", 'pytorch')
print("Done!")
print(f"Now you can convert {args.model} to a a regular LLaMA GGUF file.")
print(f"Also, use {args.model}/llava.projector to prepare a llava-encoder.gguf file.")

View file

@ -1445,6 +1445,21 @@ extern "C" {
float beta_fast, float beta_fast,
float beta_slow); float beta_slow);
GGML_API struct ggml_tensor * ggml_mrope_ext(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * c,
int n_dims,
int mode,
int n_ctx_orig,
float freq_base,
float freq_scale,
float ext_factor,
float attn_factor,
float beta_fast,
float beta_slow);
// in-place, returns view(a) // in-place, returns view(a)
GGML_API struct ggml_tensor * ggml_rope_ext_inplace( GGML_API struct ggml_tensor * ggml_rope_ext_inplace(
struct ggml_context * ctx, struct ggml_context * ctx,

File diff suppressed because it is too large Load diff

View file

@ -163,6 +163,7 @@ enum llm_arch {
LLM_ARCH_QWEN, LLM_ARCH_QWEN,
LLM_ARCH_QWEN2, LLM_ARCH_QWEN2,
LLM_ARCH_QWEN2MOE, LLM_ARCH_QWEN2MOE,
LLM_ARCH_QWEN2VL,
LLM_ARCH_PHI2, LLM_ARCH_PHI2,
LLM_ARCH_PHI3, LLM_ARCH_PHI3,
LLM_ARCH_PLAMO, LLM_ARCH_PLAMO,
@ -217,6 +218,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_QWEN, "qwen" }, { LLM_ARCH_QWEN, "qwen" },
{ LLM_ARCH_QWEN2, "qwen2" }, { LLM_ARCH_QWEN2, "qwen2" },
{ LLM_ARCH_QWEN2MOE, "qwen2moe" }, { LLM_ARCH_QWEN2MOE, "qwen2moe" },
{ LLM_ARCH_QWEN2VL, "qwen2vl" },
{ LLM_ARCH_PHI2, "phi2" }, { LLM_ARCH_PHI2, "phi2" },
{ LLM_ARCH_PHI3, "phi3" }, { LLM_ARCH_PHI3, "phi3" },
{ LLM_ARCH_PLAMO, "plamo" }, { LLM_ARCH_PLAMO, "plamo" },
@ -898,6 +900,23 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
}, },
}, },
{
LLM_ARCH_QWEN2VL,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{ {
LLM_ARCH_QWEN2MOE, LLM_ARCH_QWEN2MOE,
{ {
@ -3329,6 +3348,8 @@ struct llama_context {
struct ggml_tensor * inp_tokens; // I32 [n_batch] struct ggml_tensor * inp_tokens; // I32 [n_batch]
struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch] struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch]
struct ggml_tensor * inp_pos; // I32 [n_batch] struct ggml_tensor * inp_pos; // I32 [n_batch]
struct ggml_tensor * inp_pos_w; // I32 [n_batch] second-dimension of m-rope position index
struct ggml_tensor * inp_pos_h; // I32 [n_batch] third-dimension of m-rope position index
struct ggml_tensor * inp_out_ids; // I32 [n_outputs] struct ggml_tensor * inp_out_ids; // I32 [n_outputs]
struct ggml_tensor * inp_KQ_mask; // F32 [kv_size, n_batch] struct ggml_tensor * inp_KQ_mask; // F32 [kv_size, n_batch]
struct ggml_tensor * inp_KQ_mask_swa; // F32 [kv_size, n_batch] struct ggml_tensor * inp_KQ_mask_swa; // F32 [kv_size, n_batch]
@ -5686,6 +5707,7 @@ static void llm_load_hparams(
} }
} break; } break;
case LLM_ARCH_QWEN2: case LLM_ARCH_QWEN2:
case LLM_ARCH_QWEN2VL:
{ {
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
switch (hparams.n_layer) { switch (hparams.n_layer) {
@ -8096,6 +8118,7 @@ static bool llm_load_tensors(
} }
} break; } break;
case LLM_ARCH_QWEN2: case LLM_ARCH_QWEN2:
case LLM_ARCH_QWEN2VL:
{ {
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@ -12484,6 +12507,123 @@ struct llm_build_context {
return gf; return gf;
} }
struct ggml_cgraph * build_qwen2vl() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
struct ggml_tensor * cur;
struct ggml_tensor * inpL;
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
// inp_pos - contains the positions
// struct ggml_tensor * inp_pos = build_inp_pos();
lctx.inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens * 3);
cb(lctx.inp_pos, "inp_pos", -1);
ggml_set_input(lctx.inp_pos);
struct ggml_tensor * inp_pos = lctx.inp_pos;
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * inpSA = inpL;
// norm
cur = llm_build_norm(ctx0, inpL, hparams,
model.layers[il].attn_norm, NULL,
LLM_NORM_RMS, cb, il);
cb(cur, "attn_norm", il);
// self-attention
{
// compute Q and K and RoPE them
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
cb(Qcur, "Qcur", il);
struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
cb(Kcur, "Kcur", il);
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
cb(Kcur, "Kcur", il);
struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
cb(Vcur, "Vcur", il);
Qcur = ggml_mrope_ext(
ctx0,
ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
Kcur = ggml_mrope_ext(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Kcur, "Kcur", il);
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
if (il == n_layer - 1) {
// skip computing output for unused tokens
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
// feed-forward network
cur = llm_build_norm(ctx0, ffn_inp, hparams,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, cb, il);
cb(cur, "ffn_norm", il);
cur = llm_build_ffn(ctx0, lctx, cur,
model.layers[il].ffn_up, NULL, NULL,
model.layers[il].ffn_gate, NULL, NULL,
model.layers[il].ffn_down, NULL, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
cb(cur, "ffn_out", il);
cur = ggml_add(ctx0, cur, ffn_inp);
cur = lctx.cvec.apply_to(ctx0, cur, il);
cb(cur, "l_out", il);
// input for next layer
inpL = cur;
}
cur = inpL;
cur = llm_build_norm(ctx0, cur, hparams,
model.output_norm, NULL,
LLM_NORM_RMS, cb, -1);
cb(cur, "result_norm", -1);
// lm_head
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
cb(cur, "result_output", -1);
ggml_build_forward_expand(gf, cur);
return gf;
}
struct ggml_cgraph * build_qwen2moe() { struct ggml_cgraph * build_qwen2moe() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
@ -16732,6 +16872,10 @@ static struct ggml_cgraph * llama_build_graph(
{ {
result = llm.build_qwen2(); result = llm.build_qwen2();
} break; } break;
case LLM_ARCH_QWEN2VL:
{
result = llm.build_qwen2vl();
} break;
case LLM_ARCH_QWEN2MOE: case LLM_ARCH_QWEN2MOE:
{ {
result = llm.build_qwen2moe(); result = llm.build_qwen2moe();
@ -20088,6 +20232,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
case LLM_ARCH_BITNET: case LLM_ARCH_BITNET:
case LLM_ARCH_QWEN: case LLM_ARCH_QWEN:
case LLM_ARCH_QWEN2: case LLM_ARCH_QWEN2:
case LLM_ARCH_QWEN2VL:
case LLM_ARCH_QWEN2MOE: case LLM_ARCH_QWEN2MOE:
case LLM_ARCH_OLMO2: case LLM_ARCH_OLMO2:
case LLM_ARCH_OLMOE: case LLM_ARCH_OLMOE: