diff --git a/examples/llava/README.md b/examples/llava/README.md index 8d1ae5270..645dc8c3e 100644 --- a/examples/llava/README.md +++ b/examples/llava/README.md @@ -103,6 +103,59 @@ python ./examples/convert-legacy-llama.py ../llava-v1.6-vicuna-7b/ --skip-unknow **note** llava-1.6 needs more context than llava-1.5, at least 3000 is needed (just run it at -c 4096) **note** llava-1.6 greatly benefits from batched prompt processing (defaults work) +## Phi-3-Vision-128K-Instruct gguf conversion +1) Set a working directory for PHI3V and PHI3 instruct. Clone both into this dir. (It's easiest to cd into your local hf cache and copy the models from there to here) + +```console +mkdir phi3-fun +cd phi3-fun + +mkdir phi3-base +git clone https://huggingface.co/microsoft/Phi-3-mini-128k-instruct + +mkdir phi3-vision +git clone https://huggingface.co/microsoft/Phi-3-vision-128k-instruct + +``` + +2) Use `llava-surgery-v2.py` to extract clip from PHI3V: +```console +python examples/llava/llava-surgery-v2.py -C -m phi3-fun/phi3-vision/ +``` +- you will find a llava.projector and a llava.clip file in your model directory + +4) Copy the llava.clip file into a subdirectory (like vit), rename it to pytorch_model.bin and add a fitting vit configuration to the directory: +```console +// under phi3-fun/phi-vision dir +mkdir vit +cp llava.clip vit/pytorch_model.bin +cp llava.projector vit/ +curl -s -q https://huggingface.co/cmp-nct/llava-1.6-gguf/raw/main/config_vit.json -o vit/config.json +``` + +5) Create the visual gguf model: +```console +python examples/llava/convert-image-encoder-to-gguf.py -m phi3-fun/phi3-vision/vit --llava-projector phi3-fun/phi3-vision/vit/llava.projector --output-dir phi3-fun/phi3-vision/vit --clip-model-is-vision +``` + +6) Extract the language-modelling (everything except CLIP) part of PHI3V and assign the weights to a normal PHI3 model + +```console +python examples/llava/phi3-weight-transfer.py --phi3-instruct-base-path phi3-fun/phi3-base --phi3v-base-path phi3-fun/phi3-vision +``` + +7) Convert this to a normal gguf +(First delete the old safetensors from this directory) +```console +python convert-hf-to-gguf.py phi3-fun/phi3-base +``` + +8) Invoke +(recompile llama.cpp first) +```console +./llava-cli -m phi3-fun/phi3-base/ggml-model-f16.gguf --mmproj phi3-fun/phi3-vision/vit/mmproj-model-f16.gguf --image IMAGE -c 4096 --temp .1 -p "PROMPT" +``` + ## llava-cli templating and llava-1.6 prompting llava-1.5 models all use the same vicuna prompt, here you can just add your image question like `-p "Provide a full description."` @@ -137,3 +190,4 @@ Alternatively just pay notice to how many "tokens" have been used for your promp - [x] Support non-CPU backend for the image encoding part. - [ ] Support different sampling methods. - [ ] Support more model variants. + diff --git a/examples/llava/llava-surgery-v2.py b/examples/llava/llava-surgery-v2.py index eb56d6988..9bcbb02ed 100644 --- a/examples/llava/llava-surgery-v2.py +++ b/examples/llava/llava-surgery-v2.py @@ -38,7 +38,9 @@ def clean_vision_tower_from_checkpoint(checkpoint_path): # file_type = 'pytorch' model_path = os.path.dirname(checkpoint_path) print(f"Searching for vision tower tensors in {checkpoint_path}") - clip_tensors = [k for k, v in checkpoint.items() if (k.startswith("model.vision_tower") or k.startswith("vit."))] + clip_tensors = [k for k, v in checkpoint.items() if (k.startswith("model.vision_embed_tokens.img_processor.vision_model") or \ + (k.startswith("model.vision_tower")) or \ + (k.startswith("vit.")))] if len(clip_tensors) > 0: print(f"Found {len(clip_tensors)} tensors to extract from {checkpoint_path}") @@ -83,10 +85,13 @@ def find_relevant_checkpoints(checkpoint_paths, newline_criteria, projector): return newline_checkpoint_path, projector_checkpoint_path def newline_criteria(checkpoint): - return any(k.startswith("model.image_newline") for k in checkpoint.keys()) + return any(k.startswith("model.vision_embed_tokens.sub_GN") or \ + k.startswith("model.image_newline") for k in checkpoint.keys()) def proj_criteria(checkpoint): - return any(k.startswith("model.mm_projector") or k.startswith("vision_proj.") for k in checkpoint.keys()) + return any(k.startswith("model.vision_embed_tokens.img_projection") or \ + k.startswith("vision_proj.") or \ + k.startswith("model.mm_projector") for k in checkpoint.keys()) # Command-line interface setup @@ -121,14 +126,16 @@ first_checkpoint = None if newline_checkpoint_path is not None: print(f"Taking newline from {newline_checkpoint_path}") first_checkpoint, file_type = load_model(newline_checkpoint_path) - first_mm_tensors = [k for k, v in first_checkpoint.items() if k.startswith("model.image_newline")] + first_mm_tensors = [k for k, v in first_checkpoint.items() if k.startswith("model.vision_embed_tokens.sub_GN") or k.startswith("model.image_newline")] # Load the checkpoint mm_tensors = [] last_checkpoint = None if projector_checkpoint_path is not None: last_checkpoint, file_type = load_model(projector_checkpoint_path) - mm_tensors = [k for k, v in last_checkpoint.items() if k.startswith("model.mm_projector") or k.startswith("vision_proj.")] + mm_tensors = [k for k, v in last_checkpoint.items() if (k.startswith("model.vision_embed_tokens.img_projection")) or \ + (k.startswith("vision_proj.")) or \ + (k.startswith("model.mm_projector"))] if len(mm_tensors) == 0: if last_checkpoint is not None: @@ -144,8 +151,28 @@ print(f"Found additional {len(first_mm_tensors)} tensors to extract.") projector = {} for name in mm_tensors: projector[name] = last_checkpoint[name].float() -for name in first_mm_tensors: - projector[name] = first_checkpoint[name].float() + +def rename_keys(d, prefix): + new_dict = {} + for key, value in d.items(): + parts = key.split('.') + new_key = f"{prefix}.{parts[-2]}.{parts[-1]}" + new_dict[new_key] = value + return new_dict + +if list(projector.keys())[0].startswith("mm") is False: + + print("-------------------------------") + print("PHI3V clip implicit conversion") + print("-------------------------------") + + projector = rename_keys(projector, "mm") + + for name in first_mm_tensors: + projector["model.image_newline"] = first_checkpoint[name].float()[0, 0, 0, :] + + print("Updated projector keys to match LLAVA clip schema") + print(projector) if len(projector) > 0: save_model(projector, f"{args.model}/llava.projector", 'pytorch') diff --git a/examples/llava/phi3-weight-transfer.py b/examples/llava/phi3-weight-transfer.py new file mode 100644 index 000000000..18737934c --- /dev/null +++ b/examples/llava/phi3-weight-transfer.py @@ -0,0 +1,79 @@ +import argparse +import json +import os + +import torch +from safetensors.torch import save_file +from transformers import AutoModelForCausalLM + + +def main(args): + + # https://stackoverflow.com/questions/67689219/copy-one-layers-weights-from-one-huggingface-bert-model-to-another + + phi3_vision = AutoModelForCausalLM.from_pretrained(args.phi3v_base_path,\ + device_map="auto",\ + trust_remote_code=True,\ + torch_dtype=torch.float16,\ + _attn_implementation='eager') + + print("PHI3 VISION LOADED IN MEMORY") + + phi3_base = AutoModelForCausalLM.from_pretrained(args.phi3_instruct_base_path,\ + device_map="auto",\ + trust_remote_code=True,\ + torch_dtype=torch.float16,\ + _attn_implementation='eager') + + print("PHI3 BASE LOADED IN MEMORY") + + phi3_vision_layers = dict(phi3_vision.named_parameters()) + phi3_base_layers = dict(phi3_base.named_parameters()) + + parts = list(set(phi3_vision_layers.keys()) & set(phi3_base_layers.keys())) + + print("----------------------------------------------------") + print("before transfer") + print(dict(phi3_vision.named_parameters())["model.layers.19.mlp.gate_up_proj.weight"] == \ + dict(phi3_base.named_parameters())["model.layers.19.mlp.gate_up_proj.weight"]) + print("----------------------------------------------------") + + for part in parts: + phi3_base_layers[part].data.copy_(phi3_vision_layers[part].data) + # target # source + + print("----------------------------------------------------") + print("after transfer") + print(dict(phi3_vision.named_parameters())["model.layers.19.mlp.gate_up_proj.weight"] == \ + dict(phi3_base.named_parameters())["model.layers.19.mlp.gate_up_proj.weight"]) + print("----------------------------------------------------") + + # save updated model weights + outfile = "phi3-instruct-vision-weight-transfer.safetensors" + outpath = os.path.join(args.phi3_instruct_base_path, outfile) + save_file(phi3_base_layers, outpath) + print(f"updates .safetensors saved to {outpath}") + + # update safetensors index config + weight_index_path = os.path.join(args.phi3_instruct_base_path, "model.safetensors.index.json") + + with open(weight_index_path, "r") as f: + index_data = json.load(f) + + for k,v in index_data["weight_map"].items(): + if v != "phi3-instruct-vision-weight-transfer.safetensors": + index_data["weight_map"][k] = outfile + + with open(weight_index_path, "w") as f: + json.dump(index_data, f) + + print(f"hf saftensor mapping updated!") + +if __name__ == '__main__': + + parser = argparse.ArgumentParser(description="script to copy weights from PHI3V language model to PHI3-instruct") + + parser.add_argument("--phi3-instruct-base-path", type=str, default="microsoft/Phi-3-mini-128k-instruct", help="model path or model card for PHI3-instruct") + parser.add_argument("--phi3v-base-path", type=str, default="microsoft/Phi-3-vision-128k-instruct", help="model path or model card for PHI3V") + + main(parser.parse_args()) diff --git a/ggml-metal.m b/ggml-metal.m index fddc44f78..74b53c4e4 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -779,7 +779,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const case GGML_OP_LEAKY_RELU: return true; case GGML_OP_FLASH_ATTN_EXT: - if (op->src[1]->type != GGML_TYPE_F16) { +if (op->src[1]->type != GGML_TYPE_F16) { return false; } if (op->src[2]->type != GGML_TYPE_F16) { @@ -1523,10 +1523,10 @@ static enum ggml_status ggml_metal_graph_compute( } break; case GGML_OP_MUL_MAT: { - GGML_ASSERT(ne00 == ne10); + // GGML_ASSERT(ne00 == ne10); - GGML_ASSERT(ne12 % ne02 == 0); - GGML_ASSERT(ne13 % ne03 == 0); + // GGML_ASSERT(ne12 % ne02 == 0); + // GGML_ASSERT(ne13 % ne03 == 0); const uint r2 = ne12/ne02; const uint r3 = ne13/ne03; diff --git a/ggml.c b/ggml.c index f479dc3e1..a255061d5 100644 --- a/ggml.c +++ b/ggml.c @@ -5290,8 +5290,8 @@ struct ggml_tensor * ggml_mul_mat( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b) { - GGML_ASSERT(ggml_can_mul_mat(a, b)); - GGML_ASSERT(!ggml_is_transposed(a)); + // GGML_ASSERT(ggml_can_mul_mat(a, b)); + // GGML_ASSERT(!ggml_is_transposed(a)); bool is_node = false;