add phi3v projection handling in clip.cpp

This commit is contained in:
farris 2024-06-05 11:55:41 -07:00
parent efeaeaf79f
commit 5175117a09
6 changed files with 85 additions and 65 deletions

View file

@ -127,11 +127,12 @@ python examples/llava/llava-surgery-v2.py -C -m phi3-fun/phi3-vision/
4) Copy the llava.clip file into a subdirectory (like vit), rename it to pytorch_model.bin and add a fitting vit configuration to the directory: 4) Copy the llava.clip file into a subdirectory (like vit), rename it to pytorch_model.bin and add a fitting vit configuration to the directory:
```console ```console
// under phi3-fun/phi-vision dir // under phi3-fun/phi-vision dir
mkdir vit mkdir vit
cp llava.clip vit/pytorch_model.bin cp llava.clip vit/pytorch_model.bin
cp llava.projector vit/ cp llava.projector vit/
curl -s -q https://huggingface.co/cmp-nct/llava-1.6-gguf/raw/main/config_vit.json -o vit/config.json curl -s -q https://huggingface.co/cmp-nct/llava-1.6-gguf/raw/main/config_vit.json -o vit/config.json
``` ```
set `mm_projector_type` -> `mlp_phi` in `config.json`
5) Create the visual gguf model: 5) Create the visual gguf model:
```console ```console
@ -151,7 +152,6 @@ python convert-hf-to-gguf.py phi3-fun/phi3-base
``` ```
8) Invoke 8) Invoke
(recompile llama.cpp first)
```console ```console
./llava-cli -m phi3-fun/phi3-base/ggml-model-f16.gguf --mmproj phi3-fun/phi3-vision/vit/mmproj-model-f16.gguf --image IMAGE -c 4096 --temp .1 -p "PROMPT" ./llava-cli -m phi3-fun/phi3-base/ggml-model-f16.gguf --mmproj phi3-fun/phi3-vision/vit/mmproj-model-f16.gguf --image IMAGE -c 4096 --temp .1 -p "PROMPT"
``` ```

View file

@ -130,12 +130,14 @@ enum projector_type {
PROJECTOR_TYPE_LDP, PROJECTOR_TYPE_LDP,
PROJECTOR_TYPE_LDPV2, PROJECTOR_TYPE_LDPV2,
PROJECTOR_TYPE_UNKNOWN, PROJECTOR_TYPE_UNKNOWN,
PROJECTOR_TYPE_MLP_PHI
}; };
static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = { static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
{ PROJECTOR_TYPE_MLP, "mlp" }, { PROJECTOR_TYPE_MLP, "mlp" },
{ PROJECTOR_TYPE_LDP, "ldp" }, { PROJECTOR_TYPE_LDP, "ldp" },
{ PROJECTOR_TYPE_LDPV2, "ldpv2"}, { PROJECTOR_TYPE_LDPV2, "ldpv2"},
{ PROJECTOR_TYPE_MLP_PHI, "mlp_phi" }
}; };
@ -698,8 +700,6 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
// ne is whcn, ne = [1024, 576, 1, 1] // ne is whcn, ne = [1024, 576, 1, 1]
embeddings = ggml_get_rows(ctx0, embeddings, patches); embeddings = ggml_get_rows(ctx0, embeddings, patches);
// print_tensor_info(embeddings, "embeddings");
// llava projector // llava projector
if (ctx->proj_type == PROJECTOR_TYPE_MLP) { if (ctx->proj_type == PROJECTOR_TYPE_MLP) {
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
@ -709,7 +709,24 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings); embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings);
embeddings = ggml_add(ctx0, embeddings, model.mm_2_b); embeddings = ggml_add(ctx0, embeddings, model.mm_2_b);
} else if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) { } else if (ctx->proj_type == PROJECTOR_TYPE_MLP_PHI) {
// needs to be reworked, see https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_embedding_phi3_v.py
// line 204 onwards
struct ggml_tensor * embeddings_ = embeddings;
// [1024, 576, 1, 1] -> [4096, 576, 1, 1]
embeddings = ggml_concat(ctx0, embeddings, embeddings_, 0);
embeddings = ggml_concat(ctx0, embeddings, embeddings_, 0);
embeddings = ggml_concat(ctx0, embeddings, embeddings_, 0);
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
embeddings = ggml_gelu(ctx0, embeddings);
embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings);
embeddings = ggml_add(ctx0, embeddings, model.mm_2_b);
}
else if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) {
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
embeddings = ggml_add(ctx0, embeddings, model.mm_0_b); embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
// ggml_tensor_printf(embeddings, "mm_0_w",0,true,false); // ggml_tensor_printf(embeddings, "mm_0_w",0,true,false);
@ -1208,7 +1225,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
} }
// LLaVA projection // LLaVA projection
if (new_clip->proj_type == PROJECTOR_TYPE_MLP || new_clip->proj_type == PROJECTOR_TYPE_MLP_NORM) { if (new_clip->proj_type == PROJECTOR_TYPE_MLP || new_clip->proj_type == PROJECTOR_TYPE_MLP_NORM || new_clip->proj_type == PROJECTOR_TYPE_MLP_PHI) {
vision_model.mm_0_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 0, "weight")); vision_model.mm_0_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 0, "weight"));
vision_model.mm_0_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 0, "bias")); vision_model.mm_0_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 0, "bias"));
try { try {
@ -2069,6 +2086,9 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
if (ctx->proj_type == PROJECTOR_TYPE_MLP) { if (ctx->proj_type == PROJECTOR_TYPE_MLP) {
return ctx->vision_model.mm_2_b->ne[0]; return ctx->vision_model.mm_2_b->ne[0];
} }
if (ctx->proj_type == PROJECTOR_TYPE_MLP_PHI) {
return ctx->vision_model.mm_2_b->ne[0];
}
if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) { if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) {
return ctx->vision_model.mm_3_b->ne[0]; return ctx->vision_model.mm_3_b->ne[0];
} }

View file

@ -86,7 +86,7 @@ ap.add_argument("--clip-model-is-vision", action="store_true", required=False,
ap.add_argument("--clip-model-is-openclip", action="store_true", required=False, ap.add_argument("--clip-model-is-openclip", action="store_true", required=False,
help="The clip model is from openclip (for ViT-SO400M type))") help="The clip model is from openclip (for ViT-SO400M type))")
ap.add_argument("--llava-projector", help="Path to llava.projector file. If specified, save an image encoder for LLaVA models.") ap.add_argument("--llava-projector", help="Path to llava.projector file. If specified, save an image encoder for LLaVA models.")
ap.add_argument("--projector-type", help="Type of projector. Possible values: mlp, ldp, ldpv2", choices=["mlp", "ldp", "ldpv2"], default="mlp") ap.add_argument("--projector-type", help="Type of projector. Possible values: mlp, ldp, ldpv2", choices=["mlp", "ldp", "ldpv2", "mlp_phi"], default="mlp_phi")
ap.add_argument("-o", "--output-dir", help="Directory to save GGUF files. Default is the original model directory", default=None) ap.add_argument("-o", "--output-dir", help="Directory to save GGUF files. Default is the original model directory", default=None)
# Example --image_mean 0.48145466 0.4578275 0.40821073 --image_std 0.26862954 0.26130258 0.27577711 # Example --image_mean 0.48145466 0.4578275 0.40821073 --image_std 0.26862954 0.26130258 0.27577711
# Example --image_mean 0.5 0.5 0.5 --image_std 0.5 0.5 0.5 # Example --image_mean 0.5 0.5 0.5 --image_std 0.5 0.5 0.5
@ -206,39 +206,39 @@ if has_vision_encoder:
fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), v_hparams["layer_norm_eps"]) fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), v_hparams["layer_norm_eps"])
block_count = v_hparams["num_hidden_layers"] - 1 if has_llava_projector else v_hparams["num_hidden_layers"] block_count = v_hparams["num_hidden_layers"] - 1 if has_llava_projector else v_hparams["num_hidden_layers"]
fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), block_count) fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), block_count)
# /** # /**
# "image_grid_pinpoints": [ # "image_grid_pinpoints": [
# [ # [
# 336, # 336,
# 672 # 672
# ], # ],
# [ # [
# 672, # 672,
# 336 # 336
# ], # ],
# [ # [
# 672, # 672,
# 672 # 672
# ], # ],
# [ # [
# 1008, # 1008,
# 336 # 336
# ], # ],
# [ # [
# 336, # 336,
# 1008 # 1008
# ] # ]
# ], # ],
# Flattened: # Flattened:
# [ # [
# 336, 672, # 336, 672,
# 672, 336, # 672, 336,
# 672, 672, # 672, 672,
# 1008, 336, # 1008, 336,
# 336, 1008 # 336, 1008
# ] # ]
# * # *
# */ # */
if "image_grid_pinpoints" in v_hparams: if "image_grid_pinpoints" in v_hparams:
# flatten it # flatten it
image_grid_pinpoints = [] image_grid_pinpoints = []
@ -257,7 +257,6 @@ if has_vision_encoder:
if "mm_projector_type" in v_hparams: if "mm_projector_type" in v_hparams:
fout.add_string("clip.vision.mm_projector_type", v_hparams["mm_projector_type"]) fout.add_string("clip.vision.mm_projector_type", v_hparams["mm_projector_type"])
if processor is not None: if processor is not None:
image_mean = processor.image_processor.image_mean if args.image_mean is None or args.image_mean == default_image_mean else args.image_mean image_mean = processor.image_processor.image_mean if args.image_mean is None or args.image_mean == default_image_mean else args.image_mean
image_std = processor.image_processor.image_std if args.image_std is None or args.image_std == default_image_std else args.image_std image_std = processor.image_processor.image_std if args.image_std is None or args.image_std == default_image_std else args.image_std

View file

@ -11,19 +11,19 @@ def main(args):
# https://stackoverflow.com/questions/67689219/copy-one-layers-weights-from-one-huggingface-bert-model-to-another # https://stackoverflow.com/questions/67689219/copy-one-layers-weights-from-one-huggingface-bert-model-to-another
phi3_vision = AutoModelForCausalLM.from_pretrained(args.phi3v_base_path,\ phi3_vision = AutoModelForCausalLM.from_pretrained(args.phi3v_base_path,
device_map="auto",\ device_map="auto",
trust_remote_code=True,\ trust_remote_code=True,
torch_dtype=torch.float16,\ torch_dtype=torch.float16,
_attn_implementation='eager') _attn_implementation='eager')
print("PHI3 VISION LOADED IN MEMORY") print("PHI3 VISION LOADED IN MEMORY")
phi3_base = AutoModelForCausalLM.from_pretrained(args.phi3_instruct_base_path,\ phi3_base = AutoModelForCausalLM.from_pretrained(args.phi3_instruct_base_path,
device_map="auto",\ device_map="auto",
trust_remote_code=True,\ trust_remote_code=True,
torch_dtype=torch.float16,\ torch_dtype=torch.float16,
_attn_implementation='eager') _attn_implementation='eager')
print("PHI3 BASE LOADED IN MEMORY") print("PHI3 BASE LOADED IN MEMORY")
@ -34,21 +34,21 @@ def main(args):
print("----------------------------------------------------") print("----------------------------------------------------")
print("before transfer") print("before transfer")
print(dict(phi3_vision.named_parameters())["model.layers.19.mlp.gate_up_proj.weight"] == \ print(dict(phi3_vision.named_parameters())["model.layers.19.mlp.gate_up_proj.weight"]
dict(phi3_base.named_parameters())["model.layers.19.mlp.gate_up_proj.weight"]) == dict(phi3_base.named_parameters())["model.layers.19.mlp.gate_up_proj.weight"])
print("----------------------------------------------------") print("----------------------------------------------------")
for part in parts: for part in parts:
phi3_base_layers[part].data.copy_(phi3_vision_layers[part].data) phi3_base_layers[part].data.copy_(phi3_vision_layers[part].data)
# target # source # target # source
print("----------------------------------------------------") print("----------------------------------------------------")
print("after transfer") print("after transfer")
print(dict(phi3_vision.named_parameters())["model.layers.19.mlp.gate_up_proj.weight"] == \ print(dict(phi3_vision.named_parameters())["model.layers.19.mlp.gate_up_proj.weight"]
dict(phi3_base.named_parameters())["model.layers.19.mlp.gate_up_proj.weight"]) == dict(phi3_base.named_parameters())["model.layers.19.mlp.gate_up_proj.weight"])
print("----------------------------------------------------") print("----------------------------------------------------")
# save updated model weights # save updated model weights
outfile = "phi3-instruct-vision-weight-transfer.safetensors" outfile = "phi3-instruct-vision-weight-transfer.safetensors"
outpath = os.path.join(args.phi3_instruct_base_path, outfile) outpath = os.path.join(args.phi3_instruct_base_path, outfile)
save_file(phi3_base_layers, outpath) save_file(phi3_base_layers, outpath)
@ -59,7 +59,7 @@ def main(args):
with open(weight_index_path, "r") as f: with open(weight_index_path, "r") as f:
index_data = json.load(f) index_data = json.load(f)
for k,v in index_data["weight_map"].items(): for k,v in index_data["weight_map"].items():
if v != "phi3-instruct-vision-weight-transfer.safetensors": if v != "phi3-instruct-vision-weight-transfer.safetensors":
index_data["weight_map"][k] = outfile index_data["weight_map"][k] = outfile
@ -69,8 +69,9 @@ def main(args):
print(f"hf saftensor mapping updated!") print(f"hf saftensor mapping updated!")
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description="script to copy weights from PHI3V language model to PHI3-instruct") parser = argparse.ArgumentParser(description="script to copy weights from PHI3V language model to PHI3-instruct")
parser.add_argument("--phi3-instruct-base-path", type=str, default="microsoft/Phi-3-mini-128k-instruct", help="model path or model card for PHI3-instruct") parser.add_argument("--phi3-instruct-base-path", type=str, default="microsoft/Phi-3-mini-128k-instruct", help="model path or model card for PHI3-instruct")

View file

@ -779,7 +779,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
case GGML_OP_LEAKY_RELU: case GGML_OP_LEAKY_RELU:
return true; return true;
case GGML_OP_FLASH_ATTN_EXT: case GGML_OP_FLASH_ATTN_EXT:
if (op->src[1]->type != GGML_TYPE_F16) { if (op->src[1]->type != GGML_TYPE_F16) {
return false; return false;
} }
if (op->src[2]->type != GGML_TYPE_F16) { if (op->src[2]->type != GGML_TYPE_F16) {
@ -1523,10 +1523,10 @@ static enum ggml_status ggml_metal_graph_compute(
} break; } break;
case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT:
{ {
// GGML_ASSERT(ne00 == ne10); GGML_ASSERT(ne00 == ne10);
// GGML_ASSERT(ne12 % ne02 == 0); GGML_ASSERT(ne12 % ne02 == 0);
// GGML_ASSERT(ne13 % ne03 == 0); GGML_ASSERT(ne13 % ne03 == 0);
const uint r2 = ne12/ne02; const uint r2 = ne12/ne02;
const uint r3 = ne13/ne03; const uint r3 = ne13/ne03;

4
ggml.c
View file

@ -5290,8 +5290,8 @@ struct ggml_tensor * ggml_mul_mat(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
struct ggml_tensor * b) { struct ggml_tensor * b) {
// GGML_ASSERT(ggml_can_mul_mat(a, b)); GGML_ASSERT(ggml_can_mul_mat(a, b));
// GGML_ASSERT(!ggml_is_transposed(a)); GGML_ASSERT(!ggml_is_transposed(a));
bool is_node = false; bool is_node = false;