diff --git a/examples/llava/README.md b/examples/llava/README.md index 645dc8c3e..f4820b9b9 100644 --- a/examples/llava/README.md +++ b/examples/llava/README.md @@ -127,11 +127,12 @@ python examples/llava/llava-surgery-v2.py -C -m phi3-fun/phi3-vision/ 4) Copy the llava.clip file into a subdirectory (like vit), rename it to pytorch_model.bin and add a fitting vit configuration to the directory: ```console // under phi3-fun/phi-vision dir -mkdir vit +mkdir vit cp llava.clip vit/pytorch_model.bin cp llava.projector vit/ curl -s -q https://huggingface.co/cmp-nct/llava-1.6-gguf/raw/main/config_vit.json -o vit/config.json ``` +set `mm_projector_type` -> `mlp_phi` in `config.json` 5) Create the visual gguf model: ```console @@ -151,7 +152,6 @@ python convert-hf-to-gguf.py phi3-fun/phi3-base ``` 8) Invoke -(recompile llama.cpp first) ```console ./llava-cli -m phi3-fun/phi3-base/ggml-model-f16.gguf --mmproj phi3-fun/phi3-vision/vit/mmproj-model-f16.gguf --image IMAGE -c 4096 --temp .1 -p "PROMPT" ``` diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index 95fbe3d02..3fdfe45b9 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -130,12 +130,14 @@ enum projector_type { PROJECTOR_TYPE_LDP, PROJECTOR_TYPE_LDPV2, PROJECTOR_TYPE_UNKNOWN, + PROJECTOR_TYPE_MLP_PHI }; static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_MLP, "mlp" }, { PROJECTOR_TYPE_LDP, "ldp" }, { PROJECTOR_TYPE_LDPV2, "ldpv2"}, + { PROJECTOR_TYPE_MLP_PHI, "mlp_phi" } }; @@ -698,8 +700,6 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 // ne is whcn, ne = [1024, 576, 1, 1] embeddings = ggml_get_rows(ctx0, embeddings, patches); - // print_tensor_info(embeddings, "embeddings"); - // llava projector if (ctx->proj_type == PROJECTOR_TYPE_MLP) { embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); @@ -709,7 +709,24 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings); embeddings = ggml_add(ctx0, embeddings, model.mm_2_b); - } else if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) { + } else if (ctx->proj_type == PROJECTOR_TYPE_MLP_PHI) { + // needs to be reworked, see https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_embedding_phi3_v.py + // line 204 onwards + struct ggml_tensor * embeddings_ = embeddings; + // [1024, 576, 1, 1] -> [4096, 576, 1, 1] + embeddings = ggml_concat(ctx0, embeddings, embeddings_, 0); + embeddings = ggml_concat(ctx0, embeddings, embeddings_, 0); + embeddings = ggml_concat(ctx0, embeddings, embeddings_, 0); + + embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); + embeddings = ggml_add(ctx0, embeddings, model.mm_0_b); + + embeddings = ggml_gelu(ctx0, embeddings); + embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings); + embeddings = ggml_add(ctx0, embeddings, model.mm_2_b); + + } + else if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) { embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); embeddings = ggml_add(ctx0, embeddings, model.mm_0_b); // ggml_tensor_printf(embeddings, "mm_0_w",0,true,false); @@ -1208,7 +1225,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { } // LLaVA projection - if (new_clip->proj_type == PROJECTOR_TYPE_MLP || new_clip->proj_type == PROJECTOR_TYPE_MLP_NORM) { + if (new_clip->proj_type == PROJECTOR_TYPE_MLP || new_clip->proj_type == PROJECTOR_TYPE_MLP_NORM || new_clip->proj_type == PROJECTOR_TYPE_MLP_PHI) { vision_model.mm_0_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 0, "weight")); vision_model.mm_0_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 0, "bias")); try { @@ -2069,6 +2086,9 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { if (ctx->proj_type == PROJECTOR_TYPE_MLP) { return ctx->vision_model.mm_2_b->ne[0]; } + if (ctx->proj_type == PROJECTOR_TYPE_MLP_PHI) { + return ctx->vision_model.mm_2_b->ne[0]; + } if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) { return ctx->vision_model.mm_3_b->ne[0]; } diff --git a/examples/llava/convert-image-encoder-to-gguf.py b/examples/llava/convert-image-encoder-to-gguf.py index b00bf7c6d..766b65e78 100644 --- a/examples/llava/convert-image-encoder-to-gguf.py +++ b/examples/llava/convert-image-encoder-to-gguf.py @@ -86,7 +86,7 @@ ap.add_argument("--clip-model-is-vision", action="store_true", required=False, ap.add_argument("--clip-model-is-openclip", action="store_true", required=False, help="The clip model is from openclip (for ViT-SO400M type))") ap.add_argument("--llava-projector", help="Path to llava.projector file. If specified, save an image encoder for LLaVA models.") -ap.add_argument("--projector-type", help="Type of projector. Possible values: mlp, ldp, ldpv2", choices=["mlp", "ldp", "ldpv2"], default="mlp") +ap.add_argument("--projector-type", help="Type of projector. Possible values: mlp, ldp, ldpv2", choices=["mlp", "ldp", "ldpv2", "mlp_phi"], default="mlp_phi") ap.add_argument("-o", "--output-dir", help="Directory to save GGUF files. Default is the original model directory", default=None) # Example --image_mean 0.48145466 0.4578275 0.40821073 --image_std 0.26862954 0.26130258 0.27577711 # Example --image_mean 0.5 0.5 0.5 --image_std 0.5 0.5 0.5 @@ -206,39 +206,39 @@ if has_vision_encoder: fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), v_hparams["layer_norm_eps"]) block_count = v_hparams["num_hidden_layers"] - 1 if has_llava_projector else v_hparams["num_hidden_layers"] fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), block_count) - # /** - # "image_grid_pinpoints": [ - # [ - # 336, - # 672 - # ], - # [ - # 672, - # 336 - # ], - # [ - # 672, - # 672 - # ], - # [ - # 1008, - # 336 - # ], - # [ - # 336, - # 1008 - # ] - # ], - # Flattened: - # [ - # 336, 672, - # 672, 336, - # 672, 672, - # 1008, 336, - # 336, 1008 - # ] - # * - # */ + # /** + # "image_grid_pinpoints": [ + # [ + # 336, + # 672 + # ], + # [ + # 672, + # 336 + # ], + # [ + # 672, + # 672 + # ], + # [ + # 1008, + # 336 + # ], + # [ + # 336, + # 1008 + # ] + # ], + # Flattened: + # [ + # 336, 672, + # 672, 336, + # 672, 672, + # 1008, 336, + # 336, 1008 + # ] + # * + # */ if "image_grid_pinpoints" in v_hparams: # flatten it image_grid_pinpoints = [] @@ -257,7 +257,6 @@ if has_vision_encoder: if "mm_projector_type" in v_hparams: fout.add_string("clip.vision.mm_projector_type", v_hparams["mm_projector_type"]) - if processor is not None: image_mean = processor.image_processor.image_mean if args.image_mean is None or args.image_mean == default_image_mean else args.image_mean image_std = processor.image_processor.image_std if args.image_std is None or args.image_std == default_image_std else args.image_std diff --git a/examples/llava/phi3-weight-transfer.py b/examples/llava/phi3-weight-transfer.py index 18737934c..136639d83 100644 --- a/examples/llava/phi3-weight-transfer.py +++ b/examples/llava/phi3-weight-transfer.py @@ -11,19 +11,19 @@ def main(args): # https://stackoverflow.com/questions/67689219/copy-one-layers-weights-from-one-huggingface-bert-model-to-another - phi3_vision = AutoModelForCausalLM.from_pretrained(args.phi3v_base_path,\ - device_map="auto",\ - trust_remote_code=True,\ - torch_dtype=torch.float16,\ - _attn_implementation='eager') + phi3_vision = AutoModelForCausalLM.from_pretrained(args.phi3v_base_path, + device_map="auto", + trust_remote_code=True, + torch_dtype=torch.float16, + _attn_implementation='eager') print("PHI3 VISION LOADED IN MEMORY") - phi3_base = AutoModelForCausalLM.from_pretrained(args.phi3_instruct_base_path,\ - device_map="auto",\ - trust_remote_code=True,\ - torch_dtype=torch.float16,\ - _attn_implementation='eager') + phi3_base = AutoModelForCausalLM.from_pretrained(args.phi3_instruct_base_path, + device_map="auto", + trust_remote_code=True, + torch_dtype=torch.float16, + _attn_implementation='eager') print("PHI3 BASE LOADED IN MEMORY") @@ -34,21 +34,21 @@ def main(args): print("----------------------------------------------------") print("before transfer") - print(dict(phi3_vision.named_parameters())["model.layers.19.mlp.gate_up_proj.weight"] == \ - dict(phi3_base.named_parameters())["model.layers.19.mlp.gate_up_proj.weight"]) + print(dict(phi3_vision.named_parameters())["model.layers.19.mlp.gate_up_proj.weight"] + == dict(phi3_base.named_parameters())["model.layers.19.mlp.gate_up_proj.weight"]) print("----------------------------------------------------") for part in parts: - phi3_base_layers[part].data.copy_(phi3_vision_layers[part].data) + phi3_base_layers[part].data.copy_(phi3_vision_layers[part].data) # target # source print("----------------------------------------------------") print("after transfer") - print(dict(phi3_vision.named_parameters())["model.layers.19.mlp.gate_up_proj.weight"] == \ - dict(phi3_base.named_parameters())["model.layers.19.mlp.gate_up_proj.weight"]) + print(dict(phi3_vision.named_parameters())["model.layers.19.mlp.gate_up_proj.weight"] + == dict(phi3_base.named_parameters())["model.layers.19.mlp.gate_up_proj.weight"]) print("----------------------------------------------------") - # save updated model weights + # save updated model weights outfile = "phi3-instruct-vision-weight-transfer.safetensors" outpath = os.path.join(args.phi3_instruct_base_path, outfile) save_file(phi3_base_layers, outpath) @@ -59,7 +59,7 @@ def main(args): with open(weight_index_path, "r") as f: index_data = json.load(f) - + for k,v in index_data["weight_map"].items(): if v != "phi3-instruct-vision-weight-transfer.safetensors": index_data["weight_map"][k] = outfile @@ -69,8 +69,9 @@ def main(args): print(f"hf saftensor mapping updated!") + if __name__ == '__main__': - + parser = argparse.ArgumentParser(description="script to copy weights from PHI3V language model to PHI3-instruct") parser.add_argument("--phi3-instruct-base-path", type=str, default="microsoft/Phi-3-mini-128k-instruct", help="model path or model card for PHI3-instruct") diff --git a/ggml-metal.m b/ggml-metal.m index 74b53c4e4..fddc44f78 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -779,7 +779,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const case GGML_OP_LEAKY_RELU: return true; case GGML_OP_FLASH_ATTN_EXT: -if (op->src[1]->type != GGML_TYPE_F16) { + if (op->src[1]->type != GGML_TYPE_F16) { return false; } if (op->src[2]->type != GGML_TYPE_F16) { @@ -1523,10 +1523,10 @@ static enum ggml_status ggml_metal_graph_compute( } break; case GGML_OP_MUL_MAT: { - // GGML_ASSERT(ne00 == ne10); + GGML_ASSERT(ne00 == ne10); - // GGML_ASSERT(ne12 % ne02 == 0); - // GGML_ASSERT(ne13 % ne03 == 0); + GGML_ASSERT(ne12 % ne02 == 0); + GGML_ASSERT(ne13 % ne03 == 0); const uint r2 = ne12/ne02; const uint r3 = ne13/ne03; diff --git a/ggml.c b/ggml.c index a255061d5..f479dc3e1 100644 --- a/ggml.c +++ b/ggml.c @@ -5290,8 +5290,8 @@ struct ggml_tensor * ggml_mul_mat( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b) { - // GGML_ASSERT(ggml_can_mul_mat(a, b)); - // GGML_ASSERT(!ggml_is_transposed(a)); + GGML_ASSERT(ggml_can_mul_mat(a, b)); + GGML_ASSERT(!ggml_is_transposed(a)); bool is_node = false;